package repository import ( "carrot_bbs/internal/model" "time" "gorm.io/gorm" "gorm.io/gorm/clause" ) // MessageRepository 消息仓储 type MessageRepository struct { db *gorm.DB } // NewMessageRepository 创建消息仓储 func NewMessageRepository(db *gorm.DB) *MessageRepository { return &MessageRepository{db: db} } // CreateMessage 创建消息 func (r *MessageRepository) CreateMessage(msg *model.Message) error { return r.db.Create(msg).Error } // GetConversation 获取会话 func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) { var conv model.Conversation err := r.db.Preload("Group").First(&conv, "id = ?", id).Error if err != nil { return nil, err } return &conv, nil } // GetOrCreatePrivateConversation 获取或创建私聊会话 // 使用参与者关系表来管理会话 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) { var conv model.Conversation // 查找两个用户共同参与的私聊会话 err := r.db.Table("conversations c"). Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID). Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID). Where("c.type = ?", model.ConversationTypePrivate). First(&conv).Error if err == nil { _ = r.db.Model(&model.ConversationParticipant{}). Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}). Update("hidden_at", nil).Error return &conv, nil } if err != gorm.ErrRecordNotFound { return nil, err } // 没找到会话,创建新会话 conv = model.Conversation{ Type: model.ConversationTypePrivate, } // 使用事务创建会话和参与者 err = r.db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&conv).Error; err != nil { return err } // 创建参与者记录 - UserID 存储为 string (UUID) participants := []model.ConversationParticipant{ {ConversationID: conv.ID, UserID: user1ID}, {ConversationID: conv.ID, UserID: user2ID}, } if err := tx.Create(&participants).Error; err != nil { return err } return nil }) return &conv, err } // GetConversations 获取用户会话列表 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) { var convs []*model.Conversation var total int64 // 获取总数 r.db.Model(&model.ConversationParticipant{}). Where("user_id = ? AND hidden_at IS NULL", userID). Count(&total) if total == 0 { return convs, total, nil } offset := (page - 1) * pageSize // 查询会话列表并预加载关联数据: // 当前用户维度先按置顶排序,再按更新时间排序 err := r.db.Model(&model.Conversation{}). Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id"). Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID). Preload("Group"). Offset(offset). Limit(pageSize). Order("cp.is_pinned DESC"). Order("conversations.updated_at DESC"). Find(&convs).Error return convs, total, err } // GetMessages 获取会话消息 func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) { var messages []*model.Message var total int64 r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("conversation_id = ?", conversationID). Offset(offset). Limit(pageSize). Order("seq DESC"). Find(&messages).Error return messages, total, err } // GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步) func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) { var messages []*model.Message err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq). Order("seq ASC"). Limit(limit). Find(&messages).Error return messages, err } // GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多) func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) { var messages []*model.Message err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq). Order("seq DESC"). // 降序获取最新消息在前 Limit(limit). Find(&messages).Error // 反转回正序 for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } return messages, err } // GetConversationParticipants 获取会话参与者 func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) { var participants []*model.ConversationParticipant err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error return participants, err } // GetParticipant 获取用户在会话中的参与者信息 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) { var participant model.ConversationParticipant err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error if err != nil { // 如果找不到参与者,尝试添加(修复没有参与者记录的问题) if err == gorm.ErrRecordNotFound { // 检查会话是否存在 var conv model.Conversation if err := r.db.First(&conv, conversationID).Error; err == nil { // 会话存在,添加参与者 participant = model.ConversationParticipant{ ConversationID: conversationID, UserID: userID, } if err := r.db.Create(&participant).Error; err != nil { return nil, err } return &participant, nil } } return nil, err } return &participant, nil } // UpdateLastReadSeq 更新已读位置 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error { result := r.db.Model(&model.ConversationParticipant{}). Where("conversation_id = ? AND user_id = ?", conversationID, userID). Update("last_read_seq", lastReadSeq) if result.Error != nil { return result.Error } // 如果没有更新任何记录,说明参与者记录不存在,需要插入 if result.RowsAffected == 0 { // 尝试插入新记录(跨数据库 upsert) err := r.db.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "conversation_id"}, {Name: "user_id"}, }, DoUpdates: clause.Assignments(map[string]interface{}{ "last_read_seq": lastReadSeq, "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), }), }).Create(&model.ConversationParticipant{ ConversationID: conversationID, UserID: userID, LastReadSeq: lastReadSeq, }).Error if err != nil { return err } } return nil } // UpdatePinned 更新会话置顶状态(用户维度) func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error { result := r.db.Model(&model.ConversationParticipant{}). Where("conversation_id = ? AND user_id = ?", conversationID, userID). Update("is_pinned", isPinned) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return r.db.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "conversation_id"}, {Name: "user_id"}, }, DoUpdates: clause.Assignments(map[string]interface{}{ "is_pinned": isPinned, "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), }), }).Create(&model.ConversationParticipant{ ConversationID: conversationID, UserID: userID, IsPinned: isPinned, }).Error } return nil } // GetUnreadCount 获取未读消息数 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) { var participant model.ConversationParticipant err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error if err != nil { return 0, err } var count int64 err = r.db.Model(&model.Message{}). Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq). Count(&count).Error return count, err } // UpdateConversationLastSeq 更新会话的最后消息seq和时间 func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error { return r.db.Model(&model.Conversation{}). Where("id = ?", conversationID). Updates(map[string]interface{}{ "last_seq": seq, "last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"), }).Error } // GetNextSeq 获取会话的下一个seq值 func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) { var conv model.Conversation err := r.db.Select("last_seq").First(&conv, conversationID).Error if err != nil { return 0, err } return conv.LastSeq + 1, nil } // CreateMessageWithSeq 创建消息并更新seq(事务操作) func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error { return r.db.Transaction(func(tx *gorm.DB) error { // 获取当前seq并+1 var conv model.Conversation if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil { return err } msg.Seq = conv.LastSeq + 1 // 创建消息 if err := tx.Create(msg).Error; err != nil { return err } // 更新会话的last_seq if err := tx.Model(&model.Conversation{}). Where("id = ?", msg.ConversationID). Updates(map[string]interface{}{ "last_seq": msg.Seq, "last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"), }).Error; err != nil { return err } // 新消息到达后,自动恢复被“仅自己删除”的会话 if err := tx.Model(&model.ConversationParticipant{}). Where("conversation_id = ?", msg.ConversationID). Update("hidden_at", nil).Error; err != nil { return err } return nil }) } // GetAllUnreadCount 获取用户所有会话的未读消息总数 // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) { var totalUnread int64 err := r.db.Table("conversation_participants AS cp"). Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID). Where("cp.user_id = ?", userID). Select("COALESCE(COUNT(m.id), 0)"). Scan(&totalUnread).Error return totalUnread, err } // GetMessageByID 根据ID获取消息 func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) { var message model.Message err := r.db.First(&message, "id = ?", messageID).Error if err != nil { return nil, err } return &message, nil } // CountMessagesBySenderInConversation 统计会话中某用户已发送消息数 func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) { var count int64 err := r.db.Model(&model.Message{}). Where("conversation_id = ? AND sender_id = ?", conversationID, senderID). Count(&count).Error return count, err } // UpdateMessageStatus 更新消息状态 func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error { return r.db.Model(&model.Message{}). Where("id = ?", messageID). Update("status", status).Error } // GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录 // 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态 func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) { var participant model.ConversationParticipant err := r.db.Where("conversation_id = ? AND user_id = ?", model.SystemConversationID, userID).First(&participant).Error if err == nil { return &participant, nil } if err != gorm.ErrRecordNotFound { return nil, err } // 自动创建参与者记录 participant = model.ConversationParticipant{ ConversationID: model.SystemConversationID, UserID: userID, LastReadSeq: 0, } if err := r.db.Create(&participant).Error; err != nil { return nil, err } return &participant, nil } // GetSystemMessagesUnreadCount 获取系统消息未读数 func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) { // 获取或创建参与者记录 participant, err := r.GetOrCreateSystemParticipant(userID) if err != nil { return 0, err } // 计算未读数:查询 seq > last_read_seq 的消息 var count int64 err = r.db.Model(&model.Message{}). Where("conversation_id = ? AND seq > ?", model.SystemConversationID, participant.LastReadSeq). Count(&count).Error return count, err } // MarkAllSystemMessagesAsRead 标记所有系统消息已读 func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error { // 获取系统会话的最新 seq var maxSeq int64 err := r.db.Model(&model.Message{}). Where("conversation_id = ?", model.SystemConversationID). Select("COALESCE(MAX(seq), 0)"). Scan(&maxSeq).Error if err != nil { return err } // 使用跨数据库 upsert 方式更新或创建参与者记录 return r.db.Clauses(clause.OnConflict{ Columns: []clause.Column{ {Name: "conversation_id"}, {Name: "user_id"}, }, DoUpdates: clause.Assignments(map[string]interface{}{ "last_read_seq": maxSeq, "updated_at": gorm.Expr("CURRENT_TIMESTAMP"), }), }).Create(&model.ConversationParticipant{ ConversationID: model.SystemConversationID, UserID: userID, LastReadSeq: maxSeq, }).Error } // GetConversationByGroupID 通过群组ID获取会话 func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) { var conv model.Conversation err := r.db.Where("group_id = ?", groupID).First(&conv).Error if err != nil { return nil, err } return &conv, nil } // RemoveParticipant 移除会话参与者 // 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录 func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error { return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID). Delete(&model.ConversationParticipant{}).Error } // AddParticipant 添加会话参与者 // 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录 func (r *MessageRepository) AddParticipant(conversationID string, userID string) error { // 先检查是否已经是参与者 var count int64 err := r.db.Model(&model.ConversationParticipant{}). Where("conversation_id = ? AND user_id = ?", conversationID, userID). Count(&count).Error if err != nil { return err } // 如果已经是参与者,直接返回 if count > 0 { return nil } // 添加参与者 participant := model.ConversationParticipant{ ConversationID: conversationID, UserID: userID, LastReadSeq: 0, } return r.db.Create(&participant).Error } // DeleteConversationByGroupID 删除群组对应的会话及其参与者 // 当解散群组时调用 func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error { // 获取群组对应的会话 conv, err := r.GetConversationByGroupID(groupID) if err != nil { // 如果会话不存在,直接返回 if err == gorm.ErrRecordNotFound { return nil } return err } return r.db.Transaction(func(tx *gorm.DB) error { // 删除会话参与者 if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil { return err } // 删除会话中的消息 if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil { return err } // 删除会话 if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil { return err } return nil }) } // HideConversationForUser 仅对当前用户隐藏会话(私聊删除) func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error { now := time.Now() return r.db.Model(&model.ConversationParticipant{}). Where("conversation_id = ? AND user_id = ?", conversationID, userID). Update("hidden_at", &now).Error }