Files
backend/internal/repository/message_repo.go

544 lines
17 KiB
Go
Raw Normal View History

package repository
import (
"carrot_bbs/internal/model"
"fmt"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// MessageRepository 消息仓储
type MessageRepository struct {
db *gorm.DB
}
// NewMessageRepository 创建消息仓储
func NewMessageRepository(db *gorm.DB) *MessageRepository {
return &MessageRepository{db: db}
}
// CreateMessage 创建消息
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
return r.db.Create(msg).Error
}
// GetConversation 获取会话
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// GetOrCreatePrivateConversation 获取或创建私聊会话
// 使用参与者关系表来管理会话
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
var conv model.Conversation
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID)
// 查找两个用户共同参与的私聊会话
err := r.db.Table("conversations c").
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
Where("c.type = ?", model.ConversationTypePrivate).
First(&conv).Error
if err == nil {
_ = r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
Update("hidden_at", nil).Error
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID)
return &conv, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 没找到会话,创建新会话
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n")
conv = model.Conversation{
Type: model.ConversationTypePrivate,
}
// 使用事务创建会话和参与者
err = r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(&conv).Error; err != nil {
return err
}
// 创建参与者记录 - UserID 存储为 string (UUID)
participants := []model.ConversationParticipant{
{ConversationID: conv.ID, UserID: user1ID},
{ConversationID: conv.ID, UserID: user2ID},
}
if err := tx.Create(&participants).Error; err != nil {
return err
}
return nil
})
if err == nil {
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID)
}
return &conv, err
}
// GetConversations 获取用户会话列表
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
var convs []*model.Conversation
var total int64
// 获取总数
r.db.Model(&model.ConversationParticipant{}).
Where("user_id = ? AND hidden_at IS NULL", userID).
Count(&total)
if total == 0 {
return convs, total, nil
}
offset := (page - 1) * pageSize
// 查询会话列表并预加载关联数据:
// 当前用户维度先按置顶排序,再按更新时间排序
err := r.db.Model(&model.Conversation{}).
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
Preload("Group").
Offset(offset).
Limit(pageSize).
Order("cp.is_pinned DESC").
Order("conversations.updated_at DESC").
Find(&convs).Error
return convs, total, err
}
// GetMessages 获取会话消息
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
var messages []*model.Message
var total int64
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("conversation_id = ?", conversationID).
Offset(offset).
Limit(pageSize).
Order("seq DESC").
Find(&messages).Error
return messages, total, err
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
Order("seq ASC").
Limit(limit).
Find(&messages).Error
return messages, err
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
var messages []*model.Message
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit)
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
Order("seq DESC"). // 降序获取最新消息在前
Limit(limit).
Find(&messages).Error
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages))
for i, m := range messages {
if i < 5 || i >= len(messages)-2 {
fmt.Printf("%d ", m.Seq)
} else if i == 5 {
fmt.Printf("... ")
}
}
fmt.Println()
// 反转回正序
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
return messages, err
}
// GetConversationParticipants 获取会话参与者
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
var participants []*model.ConversationParticipant
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
return participants, err
}
// GetParticipant 获取用户在会话中的参与者信息
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
if err == gorm.ErrRecordNotFound {
// 检查会话是否存在
var conv model.Conversation
if err := r.db.First(&conv, conversationID).Error; err == nil {
// 会话存在,添加参与者
participant = model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
}
return nil, err
}
return &participant, nil
}
// UpdateLastReadSeq 更新已读位置
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("last_read_seq", lastReadSeq)
if result.Error != nil {
return result.Error
}
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
if result.RowsAffected == 0 {
// 尝试插入新记录(跨数据库 upsert
err := r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": lastReadSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: lastReadSeq,
}).Error
if err != nil {
return err
}
}
return nil
}
// UpdatePinned 更新会话置顶状态(用户维度)
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
result := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("is_pinned", isPinned)
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"is_pinned": isPinned,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
IsPinned: isPinned,
}).Error
}
return nil
}
// GetUnreadCount 获取未读消息数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
if err != nil {
return 0, err
}
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
return r.db.Model(&model.Conversation{}).
Where("id = ?", conversationID).
Updates(map[string]interface{}{
"last_seq": seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error
}
// GetNextSeq 获取会话的下一个seq值
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
var conv model.Conversation
err := r.db.Select("last_seq").First(&conv, conversationID).Error
if err != nil {
return 0, err
}
return conv.LastSeq + 1, nil
}
// CreateMessageWithSeq 创建消息并更新seq事务操作
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 获取当前seq并+1
var conv model.Conversation
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
return err
}
msg.Seq = conv.LastSeq + 1
// 创建消息
if err := tx.Create(msg).Error; err != nil {
return err
}
// 更新会话的last_seq
if err := tx.Model(&model.Conversation{}).
Where("id = ?", msg.ConversationID).
Updates(map[string]interface{}{
"last_seq": msg.Seq,
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
}).Error; err != nil {
return err
}
// 新消息到达后,自动恢复被“仅自己删除”的会话
if err := tx.Model(&model.ConversationParticipant{}).
Where("conversation_id = ?", msg.ConversationID).
Update("hidden_at", nil).Error; err != nil {
return err
}
return nil
})
}
// GetAllUnreadCount 获取用户所有会话的未读消息总数
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
var totalUnread int64
err := r.db.Table("conversation_participants AS cp").
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
Where("cp.user_id = ?", userID).
Select("COALESCE(COUNT(m.id), 0)").
Scan(&totalUnread).Error
return totalUnread, err
}
// GetMessageByID 根据ID获取消息
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
var message model.Message
err := r.db.First(&message, "id = ?", messageID).Error
if err != nil {
return nil, err
}
return &message, nil
}
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
var count int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
Count(&count).Error
return count, err
}
// UpdateMessageStatus 更新消息状态
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
return r.db.Model(&model.Message{}).
Where("id = ?", messageID).
Update("status", status).Error
}
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
var participant model.ConversationParticipant
err := r.db.Where("conversation_id = ? AND user_id = ?",
model.SystemConversationID, userID).First(&participant).Error
if err == nil {
return &participant, nil
}
if err != gorm.ErrRecordNotFound {
return nil, err
}
// 自动创建参与者记录
participant = model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: 0,
}
if err := r.db.Create(&participant).Error; err != nil {
return nil, err
}
return &participant, nil
}
// GetSystemMessagesUnreadCount 获取系统消息未读数
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
// 获取或创建参与者记录
participant, err := r.GetOrCreateSystemParticipant(userID)
if err != nil {
return 0, err
}
// 计算未读数:查询 seq > last_read_seq 的消息
var count int64
err = r.db.Model(&model.Message{}).
Where("conversation_id = ? AND seq > ?",
model.SystemConversationID, participant.LastReadSeq).
Count(&count).Error
return count, err
}
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
// 获取系统会话的最新 seq
var maxSeq int64
err := r.db.Model(&model.Message{}).
Where("conversation_id = ?", model.SystemConversationID).
Select("COALESCE(MAX(seq), 0)").
Scan(&maxSeq).Error
if err != nil {
return err
}
// 使用跨数据库 upsert 方式更新或创建参与者记录
return r.db.Clauses(clause.OnConflict{
Columns: []clause.Column{
{Name: "conversation_id"},
{Name: "user_id"},
},
DoUpdates: clause.Assignments(map[string]interface{}{
"last_read_seq": maxSeq,
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
}),
}).Create(&model.ConversationParticipant{
ConversationID: model.SystemConversationID,
UserID: userID,
LastReadSeq: maxSeq,
}).Error
}
// GetConversationByGroupID 通过群组ID获取会话
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
var conv model.Conversation
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
if err != nil {
return nil, err
}
return &conv, nil
}
// RemoveParticipant 移除会话参与者
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Delete(&model.ConversationParticipant{}).Error
}
// AddParticipant 添加会话参与者
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
// 先检查是否已经是参与者
var count int64
err := r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Count(&count).Error
if err != nil {
return err
}
// 如果已经是参与者,直接返回
if count > 0 {
return nil
}
// 添加参与者
participant := model.ConversationParticipant{
ConversationID: conversationID,
UserID: userID,
LastReadSeq: 0,
}
return r.db.Create(&participant).Error
}
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
// 当解散群组时调用
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
// 获取群组对应的会话
conv, err := r.GetConversationByGroupID(groupID)
if err != nil {
// 如果会话不存在,直接返回
if err == gorm.ErrRecordNotFound {
return nil
}
return err
}
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除会话参与者
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
return err
}
// 删除会话中的消息
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
return err
}
// 删除会话
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
return err
}
return nil
})
}
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
now := time.Now()
return r.db.Model(&model.ConversationParticipant{}).
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("hidden_at", &now).Error
}