Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
544 lines
17 KiB
Go
544 lines
17 KiB
Go
package repository
|
||
|
||
import (
|
||
"carrot_bbs/internal/model"
|
||
"fmt"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
)
|
||
|
||
// MessageRepository 消息仓储
|
||
type MessageRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewMessageRepository 创建消息仓储
|
||
func NewMessageRepository(db *gorm.DB) *MessageRepository {
|
||
return &MessageRepository{db: db}
|
||
}
|
||
|
||
// CreateMessage 创建消息
|
||
func (r *MessageRepository) CreateMessage(msg *model.Message) error {
|
||
return r.db.Create(msg).Error
|
||
}
|
||
|
||
// GetConversation 获取会话
|
||
func (r *MessageRepository) GetConversation(id string) (*model.Conversation, error) {
|
||
var conv model.Conversation
|
||
err := r.db.Preload("Group").First(&conv, "id = ?", id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &conv, nil
|
||
}
|
||
|
||
// GetOrCreatePrivateConversation 获取或创建私聊会话
|
||
// 使用参与者关系表来管理会话
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) GetOrCreatePrivateConversation(user1ID, user2ID string) (*model.Conversation, error) {
|
||
var conv model.Conversation
|
||
|
||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: user1ID=%s, user2ID=%s\n", user1ID, user2ID)
|
||
|
||
// 查找两个用户共同参与的私聊会话
|
||
err := r.db.Table("conversations c").
|
||
Joins("INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?", user1ID).
|
||
Joins("INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?", user2ID).
|
||
Where("c.type = ?", model.ConversationTypePrivate).
|
||
First(&conv).Error
|
||
|
||
if err == nil {
|
||
_ = r.db.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ? AND user_id IN ?", conv.ID, []string{user1ID, user2ID}).
|
||
Update("hidden_at", nil).Error
|
||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: found existing conversation, ID=%s\n", conv.ID)
|
||
return &conv, nil
|
||
}
|
||
|
||
if err != gorm.ErrRecordNotFound {
|
||
return nil, err
|
||
}
|
||
|
||
// 没找到会话,创建新会话
|
||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: no existing conversation found, creating new one\n")
|
||
conv = model.Conversation{
|
||
Type: model.ConversationTypePrivate,
|
||
}
|
||
|
||
// 使用事务创建会话和参与者
|
||
err = r.db.Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Create(&conv).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 创建参与者记录 - UserID 存储为 string (UUID)
|
||
participants := []model.ConversationParticipant{
|
||
{ConversationID: conv.ID, UserID: user1ID},
|
||
{ConversationID: conv.ID, UserID: user2ID},
|
||
}
|
||
if err := tx.Create(&participants).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
|
||
if err == nil {
|
||
fmt.Printf("[DEBUG] GetOrCreatePrivateConversation: created new conversation, ID=%s\n", conv.ID)
|
||
}
|
||
|
||
return &conv, err
|
||
}
|
||
|
||
// GetConversations 获取用户会话列表
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) GetConversations(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||
var convs []*model.Conversation
|
||
var total int64
|
||
|
||
// 获取总数
|
||
r.db.Model(&model.ConversationParticipant{}).
|
||
Where("user_id = ? AND hidden_at IS NULL", userID).
|
||
Count(&total)
|
||
|
||
if total == 0 {
|
||
return convs, total, nil
|
||
}
|
||
|
||
offset := (page - 1) * pageSize
|
||
// 查询会话列表并预加载关联数据:
|
||
// 当前用户维度先按置顶排序,再按更新时间排序
|
||
err := r.db.Model(&model.Conversation{}).
|
||
Joins("INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id").
|
||
Where("cp.user_id = ? AND cp.hidden_at IS NULL", userID).
|
||
Preload("Group").
|
||
Offset(offset).
|
||
Limit(pageSize).
|
||
Order("cp.is_pinned DESC").
|
||
Order("conversations.updated_at DESC").
|
||
Find(&convs).Error
|
||
|
||
return convs, total, err
|
||
}
|
||
|
||
// GetMessages 获取会话消息
|
||
func (r *MessageRepository) GetMessages(conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||
var messages []*model.Message
|
||
var total int64
|
||
|
||
r.db.Model(&model.Message{}).Where("conversation_id = ?", conversationID).Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := r.db.Where("conversation_id = ?", conversationID).
|
||
Offset(offset).
|
||
Limit(pageSize).
|
||
Order("seq DESC").
|
||
Find(&messages).Error
|
||
|
||
return messages, total, err
|
||
}
|
||
|
||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||
func (r *MessageRepository) GetMessagesAfterSeq(conversationID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||
var messages []*model.Message
|
||
err := r.db.Where("conversation_id = ? AND seq > ?", conversationID, afterSeq).
|
||
Order("seq ASC").
|
||
Limit(limit).
|
||
Find(&messages).Error
|
||
return messages, err
|
||
}
|
||
|
||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||
func (r *MessageRepository) GetMessagesBeforeSeq(conversationID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||
var messages []*model.Message
|
||
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: conversationID=%s, beforeSeq=%d, limit=%d\n", conversationID, beforeSeq, limit)
|
||
err := r.db.Where("conversation_id = ? AND seq < ?", conversationID, beforeSeq).
|
||
Order("seq DESC"). // 降序获取最新消息在前
|
||
Limit(limit).
|
||
Find(&messages).Error
|
||
fmt.Printf("[DEBUG] GetMessagesBeforeSeq: found %d messages, seq range: ", len(messages))
|
||
for i, m := range messages {
|
||
if i < 5 || i >= len(messages)-2 {
|
||
fmt.Printf("%d ", m.Seq)
|
||
} else if i == 5 {
|
||
fmt.Printf("... ")
|
||
}
|
||
}
|
||
fmt.Println()
|
||
// 反转回正序
|
||
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
||
messages[i], messages[j] = messages[j], messages[i]
|
||
}
|
||
return messages, err
|
||
}
|
||
|
||
// GetConversationParticipants 获取会话参与者
|
||
func (r *MessageRepository) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||
var participants []*model.ConversationParticipant
|
||
err := r.db.Where("conversation_id = ?", conversationID).Find(&participants).Error
|
||
return participants, err
|
||
}
|
||
|
||
// GetParticipant 获取用户在会话中的参与者信息
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) GetParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||
var participant model.ConversationParticipant
|
||
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||
if err != nil {
|
||
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
|
||
if err == gorm.ErrRecordNotFound {
|
||
// 检查会话是否存在
|
||
var conv model.Conversation
|
||
if err := r.db.First(&conv, conversationID).Error; err == nil {
|
||
// 会话存在,添加参与者
|
||
participant = model.ConversationParticipant{
|
||
ConversationID: conversationID,
|
||
UserID: userID,
|
||
}
|
||
if err := r.db.Create(&participant).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return &participant, nil
|
||
}
|
||
}
|
||
return nil, err
|
||
}
|
||
return &participant, nil
|
||
}
|
||
|
||
// UpdateLastReadSeq 更新已读位置
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) UpdateLastReadSeq(conversationID string, userID string, lastReadSeq int64) error {
|
||
result := r.db.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||
Update("last_read_seq", lastReadSeq)
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
|
||
if result.RowsAffected == 0 {
|
||
// 尝试插入新记录(跨数据库 upsert)
|
||
err := r.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{
|
||
{Name: "conversation_id"},
|
||
{Name: "user_id"},
|
||
},
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"last_read_seq": lastReadSeq,
|
||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}),
|
||
}).Create(&model.ConversationParticipant{
|
||
ConversationID: conversationID,
|
||
UserID: userID,
|
||
LastReadSeq: lastReadSeq,
|
||
}).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdatePinned 更新会话置顶状态(用户维度)
|
||
func (r *MessageRepository) UpdatePinned(conversationID string, userID string, isPinned bool) error {
|
||
result := r.db.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||
Update("is_pinned", isPinned)
|
||
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
|
||
if result.RowsAffected == 0 {
|
||
return r.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{
|
||
{Name: "conversation_id"},
|
||
{Name: "user_id"},
|
||
},
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"is_pinned": isPinned,
|
||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}),
|
||
}).Create(&model.ConversationParticipant{
|
||
ConversationID: conversationID,
|
||
UserID: userID,
|
||
IsPinned: isPinned,
|
||
}).Error
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetUnreadCount 获取未读消息数
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) GetUnreadCount(conversationID string, userID string) (int64, error) {
|
||
var participant model.ConversationParticipant
|
||
err := r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).First(&participant).Error
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
var count int64
|
||
err = r.db.Model(&model.Message{}).
|
||
Where("conversation_id = ? AND sender_id != ? AND seq > ?", conversationID, userID, participant.LastReadSeq).
|
||
Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
|
||
func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq int64) error {
|
||
return r.db.Model(&model.Conversation{}).
|
||
Where("id = ?", conversationID).
|
||
Updates(map[string]interface{}{
|
||
"last_seq": seq,
|
||
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}).Error
|
||
}
|
||
|
||
// GetNextSeq 获取会话的下一个seq值
|
||
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
|
||
var conv model.Conversation
|
||
err := r.db.Select("last_seq").First(&conv, conversationID).Error
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return conv.LastSeq + 1, nil
|
||
}
|
||
|
||
// CreateMessageWithSeq 创建消息并更新seq(事务操作)
|
||
func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
// 获取当前seq并+1
|
||
var conv model.Conversation
|
||
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
msg.Seq = conv.LastSeq + 1
|
||
|
||
// 创建消息
|
||
if err := tx.Create(msg).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新会话的last_seq
|
||
if err := tx.Model(&model.Conversation{}).
|
||
Where("id = ?", msg.ConversationID).
|
||
Updates(map[string]interface{}{
|
||
"last_seq": msg.Seq,
|
||
"last_msg_time": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 新消息到达后,自动恢复被“仅自己删除”的会话
|
||
if err := tx.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ?", msg.ConversationID).
|
||
Update("hidden_at", nil).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// GetAllUnreadCount 获取用户所有会话的未读消息总数
|
||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||
func (r *MessageRepository) GetAllUnreadCount(userID string) (int64, error) {
|
||
var totalUnread int64
|
||
err := r.db.Table("conversation_participants AS cp").
|
||
Joins("LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL", userID).
|
||
Where("cp.user_id = ?", userID).
|
||
Select("COALESCE(COUNT(m.id), 0)").
|
||
Scan(&totalUnread).Error
|
||
return totalUnread, err
|
||
}
|
||
|
||
// GetMessageByID 根据ID获取消息
|
||
func (r *MessageRepository) GetMessageByID(messageID string) (*model.Message, error) {
|
||
var message model.Message
|
||
err := r.db.First(&message, "id = ?", messageID).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &message, nil
|
||
}
|
||
|
||
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
|
||
func (r *MessageRepository) CountMessagesBySenderInConversation(conversationID, senderID string) (int64, error) {
|
||
var count int64
|
||
err := r.db.Model(&model.Message{}).
|
||
Where("conversation_id = ? AND sender_id = ?", conversationID, senderID).
|
||
Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
// UpdateMessageStatus 更新消息状态
|
||
func (r *MessageRepository) UpdateMessageStatus(messageID int64, status model.MessageStatus) error {
|
||
return r.db.Model(&model.Message{}).
|
||
Where("id = ?", messageID).
|
||
Update("status", status).Error
|
||
}
|
||
|
||
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
|
||
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
|
||
func (r *MessageRepository) GetOrCreateSystemParticipant(userID string) (*model.ConversationParticipant, error) {
|
||
var participant model.ConversationParticipant
|
||
err := r.db.Where("conversation_id = ? AND user_id = ?",
|
||
model.SystemConversationID, userID).First(&participant).Error
|
||
|
||
if err == nil {
|
||
return &participant, nil
|
||
}
|
||
|
||
if err != gorm.ErrRecordNotFound {
|
||
return nil, err
|
||
}
|
||
|
||
// 自动创建参与者记录
|
||
participant = model.ConversationParticipant{
|
||
ConversationID: model.SystemConversationID,
|
||
UserID: userID,
|
||
LastReadSeq: 0,
|
||
}
|
||
|
||
if err := r.db.Create(&participant).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &participant, nil
|
||
}
|
||
|
||
// GetSystemMessagesUnreadCount 获取系统消息未读数
|
||
func (r *MessageRepository) GetSystemMessagesUnreadCount(userID string) (int64, error) {
|
||
// 获取或创建参与者记录
|
||
participant, err := r.GetOrCreateSystemParticipant(userID)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// 计算未读数:查询 seq > last_read_seq 的消息
|
||
var count int64
|
||
err = r.db.Model(&model.Message{}).
|
||
Where("conversation_id = ? AND seq > ?",
|
||
model.SystemConversationID, participant.LastReadSeq).
|
||
Count(&count).Error
|
||
|
||
return count, err
|
||
}
|
||
|
||
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
|
||
func (r *MessageRepository) MarkAllSystemMessagesAsRead(userID string) error {
|
||
// 获取系统会话的最新 seq
|
||
var maxSeq int64
|
||
err := r.db.Model(&model.Message{}).
|
||
Where("conversation_id = ?", model.SystemConversationID).
|
||
Select("COALESCE(MAX(seq), 0)").
|
||
Scan(&maxSeq).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 使用跨数据库 upsert 方式更新或创建参与者记录
|
||
return r.db.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{
|
||
{Name: "conversation_id"},
|
||
{Name: "user_id"},
|
||
},
|
||
DoUpdates: clause.Assignments(map[string]interface{}{
|
||
"last_read_seq": maxSeq,
|
||
"updated_at": gorm.Expr("CURRENT_TIMESTAMP"),
|
||
}),
|
||
}).Create(&model.ConversationParticipant{
|
||
ConversationID: model.SystemConversationID,
|
||
UserID: userID,
|
||
LastReadSeq: maxSeq,
|
||
}).Error
|
||
}
|
||
|
||
// GetConversationByGroupID 通过群组ID获取会话
|
||
func (r *MessageRepository) GetConversationByGroupID(groupID string) (*model.Conversation, error) {
|
||
var conv model.Conversation
|
||
err := r.db.Where("group_id = ?", groupID).First(&conv).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &conv, nil
|
||
}
|
||
|
||
// RemoveParticipant 移除会话参与者
|
||
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
|
||
func (r *MessageRepository) RemoveParticipant(conversationID string, userID string) error {
|
||
return r.db.Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||
Delete(&model.ConversationParticipant{}).Error
|
||
}
|
||
|
||
// AddParticipant 添加会话参与者
|
||
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
|
||
func (r *MessageRepository) AddParticipant(conversationID string, userID string) error {
|
||
// 先检查是否已经是参与者
|
||
var count int64
|
||
err := r.db.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果已经是参与者,直接返回
|
||
if count > 0 {
|
||
return nil
|
||
}
|
||
|
||
// 添加参与者
|
||
participant := model.ConversationParticipant{
|
||
ConversationID: conversationID,
|
||
UserID: userID,
|
||
LastReadSeq: 0,
|
||
}
|
||
return r.db.Create(&participant).Error
|
||
}
|
||
|
||
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
|
||
// 当解散群组时调用
|
||
func (r *MessageRepository) DeleteConversationByGroupID(groupID string) error {
|
||
// 获取群组对应的会话
|
||
conv, err := r.GetConversationByGroupID(groupID)
|
||
if err != nil {
|
||
// 如果会话不存在,直接返回
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
// 删除会话参与者
|
||
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.ConversationParticipant{}).Error; err != nil {
|
||
return err
|
||
}
|
||
// 删除会话中的消息
|
||
if err := tx.Where("conversation_id = ?", conv.ID).Delete(&model.Message{}).Error; err != nil {
|
||
return err
|
||
}
|
||
// 删除会话
|
||
if err := tx.Delete(&model.Conversation{}, "id = ?", conv.ID).Error; err != nil {
|
||
return err
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
|
||
func (r *MessageRepository) HideConversationForUser(conversationID, userID string) error {
|
||
now := time.Now()
|
||
return r.db.Model(&model.ConversationParticipant{}).
|
||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||
Update("hidden_at", &now).Error
|
||
}
|