Files
backend/internal/repository/message_repo.go
lan 4c0177149a Clean backend debug logging and standardize error reporting.
This removes verbose trace output in handlers/services and keeps only actionable error-level logs.
2026-03-09 22:20:44 +08:00

525 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"carrot_bbs/internal/model"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// MessageRepository 消息仓储
type MessageRepository struct {
db *gorm.DB
}
// NewMessageRepository 创建消息仓储
func NewMessageRepository(db *gorm.DB) *MessageRepository {
return &MessageRepository{db: db}
}
// CreateMessage 创建消息
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
return r.db.Create(msg).Error
}
// GetConversation 获取会话
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// GetOrCreatePrivateConversation 获取或创建私聊会话
// 使用参与者关系表来管理会话
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
var conv model.Conversation
// 查找两个用户共同参与的私聊会话
err := r.db.Table("conversations c").
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
Where("c.type = ?", model.ConversationTypePrivate).
First(&conv).Error
if err == nil {
_ = r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
Update("hidden_at", nil).Error
return &conv, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 没找到会话,创建新会话
conv = model.Conversation{
Type: model.ConversationTypePrivate,
}
// 使用事务创建会话和参与者
err = r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&conv).Error; err != nil {
return err
}
// 创建参与者记录 - UserID 存储为 string (UUID)
participants := []model.ConversationParticipant{
{ConversationID: conv.ID, UserID: user1ID},
{ConversationID: conv.ID, UserID: user2ID},
}
if err := tx.Create(&participants).Error; err != nil {
return err
}
return nil
})
return &conv, err
}
// GetConversations 获取用户会话列表
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
var convs []*model.Conversation
var total int64
// 获取总数
r.db.Model(&model.ConversationParticipant{}).
Where("user_id = ? AND hidden_at IS NULL", userID).
Count(&total)
if total == 0 {
return convs, total, nil
}
offset := (page - 1) * pageSize
// 查询会话列表并预加载关联数据:
// 当前用户维度先按置顶排序,再按更新时间排序
err := r.db.Model(&model.Conversation{}).
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
Preload("Group").
Offset(offset).
Limit(pageSize).
Order("cp.is_pinned DESC").
Order("conversations.updated_at DESC").
Find(&convs).Error
return convs, total, err
}
// GetMessages 获取会话消息
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
var messages []*model.Message
var total int64
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("conversation_id = ?", conversationID).
Offset(offset).
Limit(pageSize).
Order("seq DESC").
Find(&messages).Error
return messages, total, err
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
Order("seq ASC").
Limit(limit).
Find(&messages).Error
return messages, err
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
Order("seq DESC"). // 降序获取最新消息在前
Limit(limit).
Find(&messages).Error
// 反转回正序
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, err
}
// GetConversationParticipants 获取会话参与者
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
var participants []*model.ConversationParticipant
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
return participants, err
}
// GetParticipant 获取用户在会话中的参与者信息
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
if err == gorm.ErrRecordNotFound {
// 检查会话是否存在
var conv model.Conversation
if err := r.db.First(&conv, conversationID).Error; err == nil {
// 会话存在,添加参与者
participant = model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
}
return nil, err
}
return &participant, nil
}
// UpdateLastReadSeq 更新已读位置
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("last_read_seq", lastReadSeq)
if result.Error != nil {
return result.Error
}
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
if result.RowsAffected == 0 {
// 尝试插入新记录(跨数据库 upsert
err := r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": lastReadSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: lastReadSeq,
}).Error
if err != nil {
return err
}
}
return nil
}
// UpdatePinned 更新会话置顶状态(用户维度)
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("is_pinned", isPinned)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"is_pinned": isPinned,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
IsPinned: isPinned,
}).Error
}
return nil
}
// GetUnreadCount 获取未读消息数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
return 0, err
}
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
return r.db.Model(&model.Conversation{}).
Where("id = ?", conversationID).
Updates(map[string]interface{}{
"last_seq": seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error
}
// GetNextSeq 获取会话的下一个seq值
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
var conv model.Conversation
err := r.db.Select("last_seq").First(&conv, conversationID).Error
if err != nil {
return 0, err
}
return conv.LastSeq + 1, nil
}
// CreateMessageWithSeq 创建消息并更新seq事务操作
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 获取当前seq并+1
var conv model.Conversation
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
return err
}
msg.Seq = conv.LastSeq + 1
// 创建消息
if err := tx.Create(msg).Error; err != nil {
return err
}
// 更新会话的last_seq
if err := tx.Model(&model.Conversation{}).
Where("id = ?", msg.ConversationID).
Updates(map[string]interface{}{
"last_seq": msg.Seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error; err != nil {
return err
}
// 新消息到达后,自动恢复被“仅自己删除”的会话
if err := tx.Model(&model.ConversationParticipant{}).
Where("conversation_id = ?", msg.ConversationID).
Update("hidden_at", nil).Error; err != nil {
return err
}
return nil
})
}
// GetAllUnreadCount 获取用户所有会话的未读消息总数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
var totalUnread int64
err := r.db.Table("conversation_participants AS cp").
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
Where("cp.user_id = ?", userID).
Select("COALESCE(COUNT(m.id), 0)").
Scan(&totalUnread).Error
return totalUnread, err
}
// GetMessageByID 根据ID获取消息
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
var message model.Message
err := r.db.First(&message, "id = ?", messageID).Error
if err != nil {
return nil, err
}
return &message, nil
}
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
var count int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
Count(&count).Error
return count, err
}
// UpdateMessageStatus 更新消息状态
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
return r.db.Model(&model.Message{}).
Where("id = ?", messageID).
Update("status", status).Error
}
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?",
model.SystemConversationID, userID).First(&participant).Error
if err == nil {
return &participant, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 自动创建参与者记录
participant = model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: 0,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
// GetSystemMessagesUnreadCount 获取系统消息未读数
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
// 获取或创建参与者记录
participant, err := r.GetOrCreateSystemParticipant(userID)
if err != nil {
return 0, err
}
// 计算未读数:查询 seq > last_read_seq 的消息
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND seq > ?",
model.SystemConversationID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
// 获取系统会话的最新 seq
var maxSeq int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ?", model.SystemConversationID).
Select("COALESCE(MAX(seq), 0)").
Scan(&maxSeq).Error
if err != nil {
return err
}
// 使用跨数据库 upsert 方式更新或创建参与者记录
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": maxSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: maxSeq,
}).Error
}
// GetConversationByGroupID 通过群组ID获取会话
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// RemoveParticipant 移除会话参与者
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Delete(&model.ConversationParticipant{}).Error
}
// AddParticipant 添加会话参与者
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
// 先检查是否已经是参与者
var count int64
err := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Count(&count).Error
if err != nil {
return err
}
// 如果已经是参与者,直接返回
if count > 0 {
return nil
}
// 添加参与者
participant := model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: 0,
}
return r.db.Create(&participant).Error
}
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
// 当解散群组时调用
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
// 获取群组对应的会话
conv, err := r.GetConversationByGroupID(groupID)
if err != nil {
// 如果会话不存在,直接返回
if err == gorm.ErrRecordNotFound {
return nil
}
return err
}
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除会话参与者
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
return err
}
// 删除会话中的消息
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
return err
}
// 删除会话
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
return err
}
return nil
})
}
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
now := time.Now()
return r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("hidden_at", &now).Error
}