2026-03-09 21:28:58 +08:00
package repository
import (
"carrot_bbs/internal/model"
2026-03-12 08:38:14 +08:00
"context"
"fmt"
"strings"
2026-03-09 21:28:58 +08:00
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// MessageRepository 消息仓储
type MessageRepository struct {
db * gorm . DB
}
// NewMessageRepository 创建消息仓储
func NewMessageRepository ( db * gorm . DB ) * MessageRepository {
return & MessageRepository { db : db }
}
// CreateMessage 创建消息
func ( r * MessageRepository ) CreateMessage ( msg * model . Message ) error {
return r . db . Create ( msg ) . Error
}
// GetConversation 获取会话
func ( r * MessageRepository ) GetConversation ( id string ) ( * model . Conversation , error ) {
var conv model . Conversation
err := r . db . Preload ( "Group" ) . First ( & conv , "id = ?" , id ) . Error
if err != nil {
return nil , err
}
return & conv , nil
}
// GetOrCreatePrivateConversation 获取或创建私聊会话
// 使用参与者关系表来管理会话
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) GetOrCreatePrivateConversation ( user1ID , user2ID string ) ( * model . Conversation , error ) {
var conv model . Conversation
// 查找两个用户共同参与的私聊会话
err := r . db . Table ( "conversations c" ) .
Joins ( "INNER JOIN conversation_participants cp1 ON c.id = cp1.conversation_id AND cp1.user_id = ?" , user1ID ) .
Joins ( "INNER JOIN conversation_participants cp2 ON c.id = cp2.conversation_id AND cp2.user_id = ?" , user2ID ) .
Where ( "c.type = ?" , model . ConversationTypePrivate ) .
First ( & conv ) . Error
if err == nil {
_ = r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ? AND user_id IN ?" , conv . ID , [ ] string { user1ID , user2ID } ) .
Update ( "hidden_at" , nil ) . Error
return & conv , nil
}
if err != gorm . ErrRecordNotFound {
return nil , err
}
// 没找到会话,创建新会话
conv = model . Conversation {
Type : model . ConversationTypePrivate ,
}
// 使用事务创建会话和参与者
err = r . db . Transaction ( func ( tx * gorm . DB ) error {
if err := tx . Create ( & conv ) . Error ; err != nil {
return err
}
// 创建参与者记录 - UserID 存储为 string (UUID)
participants := [ ] model . ConversationParticipant {
{ ConversationID : conv . ID , UserID : user1ID } ,
{ ConversationID : conv . ID , UserID : user2ID } ,
}
if err := tx . Create ( & participants ) . Error ; err != nil {
return err
}
return nil
} )
return & conv , err
}
// GetConversations 获取用户会话列表
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) GetConversations ( userID string , page , pageSize int ) ( [ ] * model . Conversation , int64 , error ) {
var convs [ ] * model . Conversation
var total int64
// 获取总数
r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "user_id = ? AND hidden_at IS NULL" , userID ) .
Count ( & total )
if total == 0 {
return convs , total , nil
}
offset := ( page - 1 ) * pageSize
// 查询会话列表并预加载关联数据:
// 当前用户维度先按置顶排序,再按更新时间排序
err := r . db . Model ( & model . Conversation { } ) .
Joins ( "INNER JOIN conversation_participants cp ON conversations.id = cp.conversation_id" ) .
Where ( "cp.user_id = ? AND cp.hidden_at IS NULL" , userID ) .
Preload ( "Group" ) .
Offset ( offset ) .
Limit ( pageSize ) .
Order ( "cp.is_pinned DESC" ) .
Order ( "conversations.updated_at DESC" ) .
Find ( & convs ) . Error
return convs , total , err
}
// GetMessages 获取会话消息
func ( r * MessageRepository ) GetMessages ( conversationID string , page , pageSize int ) ( [ ] * model . Message , int64 , error ) {
var messages [ ] * model . Message
var total int64
r . db . Model ( & model . Message { } ) . Where ( "conversation_id = ?" , conversationID ) . Count ( & total )
offset := ( page - 1 ) * pageSize
err := r . db . Where ( "conversation_id = ?" , conversationID ) .
Offset ( offset ) .
Limit ( pageSize ) .
Order ( "seq DESC" ) .
Find ( & messages ) . Error
return messages , total , err
}
// GetMessagesAfterSeq 获取指定seq之后的消息( 用于增量同步)
func ( r * MessageRepository ) GetMessagesAfterSeq ( conversationID string , afterSeq int64 , limit int ) ( [ ] * model . Message , error ) {
var messages [ ] * model . Message
err := r . db . Where ( "conversation_id = ? AND seq > ?" , conversationID , afterSeq ) .
Order ( "seq ASC" ) .
Limit ( limit ) .
Find ( & messages ) . Error
return messages , err
}
// GetMessagesBeforeSeq 获取指定seq之前的历史消息( 用于下拉加载更多)
func ( r * MessageRepository ) GetMessagesBeforeSeq ( conversationID string , beforeSeq int64 , limit int ) ( [ ] * model . Message , error ) {
var messages [ ] * model . Message
err := r . db . Where ( "conversation_id = ? AND seq < ?" , conversationID , beforeSeq ) .
Order ( "seq DESC" ) . // 降序获取最新消息在前
Limit ( limit ) .
Find ( & messages ) . Error
// 反转回正序
for i , j := 0 , len ( messages ) - 1 ; i < j ; i , j = i + 1 , j - 1 {
messages [ i ] , messages [ j ] = messages [ j ] , messages [ i ]
}
return messages , err
}
// GetConversationParticipants 获取会话参与者
func ( r * MessageRepository ) GetConversationParticipants ( conversationID string ) ( [ ] * model . ConversationParticipant , error ) {
var participants [ ] * model . ConversationParticipant
err := r . db . Where ( "conversation_id = ?" , conversationID ) . Find ( & participants ) . Error
return participants , err
}
// GetParticipant 获取用户在会话中的参与者信息
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) GetParticipant ( conversationID string , userID string ) ( * model . ConversationParticipant , error ) {
var participant model . ConversationParticipant
err := r . db . Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) . First ( & participant ) . Error
if err != nil {
// 如果找不到参与者,尝试添加(修复没有参与者记录的问题)
if err == gorm . ErrRecordNotFound {
// 检查会话是否存在
var conv model . Conversation
2026-03-12 08:38:14 +08:00
if err := r . db . Where ( "id = ?" , conversationID ) . First ( & conv ) . Error ; err == nil {
2026-03-09 21:28:58 +08:00
// 会话存在,添加参与者
participant = model . ConversationParticipant {
ConversationID : conversationID ,
UserID : userID ,
}
if err := r . db . Create ( & participant ) . Error ; err != nil {
return nil , err
}
return & participant , nil
}
}
return nil , err
}
return & participant , nil
}
// UpdateLastReadSeq 更新已读位置
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) UpdateLastReadSeq ( conversationID string , userID string , lastReadSeq int64 ) error {
result := r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) .
Update ( "last_read_seq" , lastReadSeq )
if result . Error != nil {
return result . Error
}
// 如果没有更新任何记录,说明参与者记录不存在,需要插入
if result . RowsAffected == 0 {
// 尝试插入新记录(跨数据库 upsert)
err := r . db . Clauses ( clause . OnConflict {
Columns : [ ] clause . Column {
{ Name : "conversation_id" } ,
{ Name : "user_id" } ,
} ,
DoUpdates : clause . Assignments ( map [ string ] interface { } {
"last_read_seq" : lastReadSeq ,
"updated_at" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
} ) ,
} ) . Create ( & model . ConversationParticipant {
ConversationID : conversationID ,
UserID : userID ,
LastReadSeq : lastReadSeq ,
} ) . Error
if err != nil {
return err
}
}
return nil
}
// UpdatePinned 更新会话置顶状态(用户维度)
func ( r * MessageRepository ) UpdatePinned ( conversationID string , userID string , isPinned bool ) error {
result := r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) .
Update ( "is_pinned" , isPinned )
if result . Error != nil {
return result . Error
}
if result . RowsAffected == 0 {
return r . db . Clauses ( clause . OnConflict {
Columns : [ ] clause . Column {
{ Name : "conversation_id" } ,
{ Name : "user_id" } ,
} ,
DoUpdates : clause . Assignments ( map [ string ] interface { } {
"is_pinned" : isPinned ,
"updated_at" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
} ) ,
} ) . Create ( & model . ConversationParticipant {
ConversationID : conversationID ,
UserID : userID ,
IsPinned : isPinned ,
} ) . Error
}
return nil
}
// GetUnreadCount 获取未读消息数
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) GetUnreadCount ( conversationID string , userID string ) ( int64 , error ) {
var participant model . ConversationParticipant
err := r . db . Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) . First ( & participant ) . Error
if err != nil {
return 0 , err
}
var count int64
err = r . db . Model ( & model . Message { } ) .
Where ( "conversation_id = ? AND sender_id != ? AND seq > ?" , conversationID , userID , participant . LastReadSeq ) .
Count ( & count ) . Error
return count , err
}
// UpdateConversationLastSeq 更新会话的最后消息seq和时间
func ( r * MessageRepository ) UpdateConversationLastSeq ( conversationID string , seq int64 ) error {
return r . db . Model ( & model . Conversation { } ) .
Where ( "id = ?" , conversationID ) .
Updates ( map [ string ] interface { } {
"last_seq" : seq ,
"last_msg_time" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
} ) . Error
}
// GetNextSeq 获取会话的下一个seq值
func ( r * MessageRepository ) GetNextSeq ( conversationID string ) ( int64 , error ) {
var conv model . Conversation
2026-03-12 08:38:14 +08:00
err := r . db . Select ( "last_seq" ) . Where ( "id = ?" , conversationID ) . First ( & conv ) . Error
2026-03-09 21:28:58 +08:00
if err != nil {
return 0 , err
}
return conv . LastSeq + 1 , nil
}
// CreateMessageWithSeq 创建消息并更新seq( 事务操作)
func ( r * MessageRepository ) CreateMessageWithSeq ( msg * model . Message ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 获取当前seq并+1
var conv model . Conversation
2026-03-12 08:38:14 +08:00
if err := tx . Select ( "last_seq" ) . Where ( "id = ?" , msg . ConversationID ) . First ( & conv ) . Error ; err != nil {
2026-03-09 21:28:58 +08:00
return err
}
msg . Seq = conv . LastSeq + 1
// 创建消息
if err := tx . Create ( msg ) . Error ; err != nil {
return err
}
// 更新会话的last_seq
if err := tx . Model ( & model . Conversation { } ) .
Where ( "id = ?" , msg . ConversationID ) .
Updates ( map [ string ] interface { } {
"last_seq" : msg . Seq ,
"last_msg_time" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
} ) . Error ; err != nil {
return err
}
// 新消息到达后,自动恢复被“仅自己删除”的会话
if err := tx . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ?" , msg . ConversationID ) .
Update ( "hidden_at" , nil ) . Error ; err != nil {
return err
}
return nil
} )
}
// GetAllUnreadCount 获取用户所有会话的未读消息总数
// userID 参数为 string 类型( UUID格式) , 与JWT中user_id保持一致
func ( r * MessageRepository ) GetAllUnreadCount ( userID string ) ( int64 , error ) {
var totalUnread int64
err := r . db . Table ( "conversation_participants AS cp" ) .
Joins ( "LEFT JOIN messages AS m ON m.conversation_id = cp.conversation_id AND m.sender_id <> ? AND m.seq > cp.last_read_seq AND m.deleted_at IS NULL" , userID ) .
Where ( "cp.user_id = ?" , userID ) .
Select ( "COALESCE(COUNT(m.id), 0)" ) .
Scan ( & totalUnread ) . Error
return totalUnread , err
}
// GetMessageByID 根据ID获取消息
func ( r * MessageRepository ) GetMessageByID ( messageID string ) ( * model . Message , error ) {
var message model . Message
err := r . db . First ( & message , "id = ?" , messageID ) . Error
if err != nil {
return nil , err
}
return & message , nil
}
// CountMessagesBySenderInConversation 统计会话中某用户已发送消息数
func ( r * MessageRepository ) CountMessagesBySenderInConversation ( conversationID , senderID string ) ( int64 , error ) {
var count int64
err := r . db . Model ( & model . Message { } ) .
Where ( "conversation_id = ? AND sender_id = ?" , conversationID , senderID ) .
Count ( & count ) . Error
return count , err
}
// UpdateMessageStatus 更新消息状态
func ( r * MessageRepository ) UpdateMessageStatus ( messageID int64 , status model . MessageStatus ) error {
return r . db . Model ( & model . Message { } ) .
Where ( "id = ?" , messageID ) .
Update ( "status" , status ) . Error
}
// GetOrCreateSystemParticipant 获取或创建用户在系统会话中的参与者记录
// 系统会话是虚拟会话,但需要参与者记录来跟踪已读状态
func ( r * MessageRepository ) GetOrCreateSystemParticipant ( userID string ) ( * model . ConversationParticipant , error ) {
var participant model . ConversationParticipant
err := r . db . Where ( "conversation_id = ? AND user_id = ?" ,
model . SystemConversationID , userID ) . First ( & participant ) . Error
if err == nil {
return & participant , nil
}
if err != gorm . ErrRecordNotFound {
return nil , err
}
// 自动创建参与者记录
participant = model . ConversationParticipant {
ConversationID : model . SystemConversationID ,
UserID : userID ,
LastReadSeq : 0 ,
}
if err := r . db . Create ( & participant ) . Error ; err != nil {
return nil , err
}
return & participant , nil
}
// GetSystemMessagesUnreadCount 获取系统消息未读数
func ( r * MessageRepository ) GetSystemMessagesUnreadCount ( userID string ) ( int64 , error ) {
// 获取或创建参与者记录
participant , err := r . GetOrCreateSystemParticipant ( userID )
if err != nil {
return 0 , err
}
// 计算未读数:查询 seq > last_read_seq 的消息
var count int64
err = r . db . Model ( & model . Message { } ) .
Where ( "conversation_id = ? AND seq > ?" ,
model . SystemConversationID , participant . LastReadSeq ) .
Count ( & count ) . Error
return count , err
}
// MarkAllSystemMessagesAsRead 标记所有系统消息已读
func ( r * MessageRepository ) MarkAllSystemMessagesAsRead ( userID string ) error {
// 获取系统会话的最新 seq
var maxSeq int64
err := r . db . Model ( & model . Message { } ) .
Where ( "conversation_id = ?" , model . SystemConversationID ) .
Select ( "COALESCE(MAX(seq), 0)" ) .
Scan ( & maxSeq ) . Error
if err != nil {
return err
}
// 使用跨数据库 upsert 方式更新或创建参与者记录
return r . db . Clauses ( clause . OnConflict {
Columns : [ ] clause . Column {
{ Name : "conversation_id" } ,
{ Name : "user_id" } ,
} ,
DoUpdates : clause . Assignments ( map [ string ] interface { } {
"last_read_seq" : maxSeq ,
"updated_at" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
} ) ,
} ) . Create ( & model . ConversationParticipant {
ConversationID : model . SystemConversationID ,
UserID : userID ,
LastReadSeq : maxSeq ,
} ) . Error
}
// GetConversationByGroupID 通过群组ID获取会话
func ( r * MessageRepository ) GetConversationByGroupID ( groupID string ) ( * model . Conversation , error ) {
var conv model . Conversation
err := r . db . Where ( "group_id = ?" , groupID ) . First ( & conv ) . Error
if err != nil {
return nil , err
}
return & conv , nil
}
// RemoveParticipant 移除会话参与者
// 当用户退出群聊时,需要同时移除其在对应会话中的参与者记录
func ( r * MessageRepository ) RemoveParticipant ( conversationID string , userID string ) error {
return r . db . Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) .
Delete ( & model . ConversationParticipant { } ) . Error
}
// AddParticipant 添加会话参与者
// 当用户加入群聊时,需要同时将其添加到对应会话的参与者记录
func ( r * MessageRepository ) AddParticipant ( conversationID string , userID string ) error {
// 先检查是否已经是参与者
var count int64
err := r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) .
Count ( & count ) . Error
if err != nil {
return err
}
// 如果已经是参与者,直接返回
if count > 0 {
return nil
}
// 添加参与者
participant := model . ConversationParticipant {
ConversationID : conversationID ,
UserID : userID ,
LastReadSeq : 0 ,
}
return r . db . Create ( & participant ) . Error
}
// DeleteConversationByGroupID 删除群组对应的会话及其参与者
// 当解散群组时调用
func ( r * MessageRepository ) DeleteConversationByGroupID ( groupID string ) error {
// 获取群组对应的会话
conv , err := r . GetConversationByGroupID ( groupID )
if err != nil {
// 如果会话不存在,直接返回
if err == gorm . ErrRecordNotFound {
return nil
}
return err
}
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 删除会话参与者
if err := tx . Where ( "conversation_id = ?" , conv . ID ) . Delete ( & model . ConversationParticipant { } ) . Error ; err != nil {
return err
}
// 删除会话中的消息
if err := tx . Where ( "conversation_id = ?" , conv . ID ) . Delete ( & model . Message { } ) . Error ; err != nil {
return err
}
// 删除会话
if err := tx . Delete ( & model . Conversation { } , "id = ?" , conv . ID ) . Error ; err != nil {
return err
}
return nil
} )
}
// HideConversationForUser 仅对当前用户隐藏会话(私聊删除)
func ( r * MessageRepository ) HideConversationForUser ( conversationID , userID string ) error {
now := time . Now ( )
return r . db . Model ( & model . ConversationParticipant { } ) .
Where ( "conversation_id = ? AND user_id = ?" , conversationID , userID ) .
Update ( "hidden_at" , & now ) . Error
}
2026-03-12 08:38:14 +08:00
// ParticipantUpdate 参与者更新数据
type ParticipantUpdate struct {
ConversationID string
UserID string
LastReadSeq int64
}
// BatchWriteMessages 批量写入消息
// 使用 GORM 的 CreateInBatches 实现高效批量插入
func ( r * MessageRepository ) BatchWriteMessages ( ctx context . Context , messages [ ] * model . Message ) error {
if len ( messages ) == 0 {
return nil
}
return r . db . WithContext ( ctx ) . CreateInBatches ( messages , 100 ) . Error
}
// BatchUpdateParticipants 批量更新参与者(使用 CASE WHEN 优化)
// 使用单条 SQL 更新多条记录,避免循环执行 UPDATE
func ( r * MessageRepository ) BatchUpdateParticipants ( ctx context . Context , updates [ ] ParticipantUpdate ) error {
if len ( updates ) == 0 {
return nil
}
// 构建 CASE WHEN 批量更新 SQL
// UPDATE conversation_participants
// SET last_read_seq = CASE
// WHEN (conversation_id = '1' AND user_id = 'a') THEN 10
// WHEN (conversation_id = '2' AND user_id = 'b') THEN 20
// END,
// updated_at = ?
// WHERE (conversation_id = '1' AND user_id = 'a')
// OR (conversation_id = '2' AND user_id = 'b')
var cases [ ] string
var whereClauses [ ] string
var args [ ] interface { }
for _ , u := range updates {
cases = append ( cases , "WHEN (conversation_id = ? AND user_id = ?) THEN ?" )
whereClauses = append ( whereClauses , "(conversation_id = ? AND user_id = ?)" )
args = append ( args , u . ConversationID , u . UserID , u . LastReadSeq , u . ConversationID , u . UserID )
}
sql := fmt . Sprintf ( `
UPDATE conversation_participants
SET last_read_seq = CASE % s END ,
updated_at = ?
WHERE % s
` , strings . Join ( cases , " " ) , strings . Join ( whereClauses , " OR " ) )
args = append ( args , time . Now ( ) )
return r . db . WithContext ( ctx ) . Exec ( sql , args ... ) . Error
}
// UpdateConversationLastSeqWithContext 更新会话最后消息序号
func ( r * MessageRepository ) UpdateConversationLastSeqWithContext ( ctx context . Context , convID string , lastSeq int64 , lastMsgTime time . Time ) error {
return r . db . WithContext ( ctx ) .
Model ( & model . Conversation { } ) .
Where ( "id = ?" , convID ) .
Updates ( map [ string ] interface { } {
"last_seq" : lastSeq ,
"last_msg_time" : lastMsgTime ,
"updated_at" : time . Now ( ) ,
} ) . Error
}
// BatchWriteMessagesWithTx 在事务中批量写入消息
func ( r * MessageRepository ) BatchWriteMessagesWithTx ( tx * gorm . DB , messages [ ] * model . Message ) error {
if len ( messages ) == 0 {
return nil
}
return tx . CreateInBatches ( messages , 100 ) . Error
}
// BatchUpdateParticipantsWithTx 在事务中批量更新参与者
func ( r * MessageRepository ) BatchUpdateParticipantsWithTx ( tx * gorm . DB , updates [ ] ParticipantUpdate ) error {
if len ( updates ) == 0 {
return nil
}
var cases [ ] string
var whereClauses [ ] string
var args [ ] interface { }
for _ , u := range updates {
cases = append ( cases , "WHEN (conversation_id = ? AND user_id = ?) THEN ?" )
whereClauses = append ( whereClauses , "(conversation_id = ? AND user_id = ?)" )
args = append ( args , u . ConversationID , u . UserID , u . LastReadSeq , u . ConversationID , u . UserID )
}
sql := fmt . Sprintf ( `
UPDATE conversation_participants
SET last_read_seq = CASE % s END ,
updated_at = ?
WHERE % s
` , strings . Join ( cases , " " ) , strings . Join ( whereClauses , " OR " ) )
args = append ( args , time . Now ( ) )
return tx . Exec ( sql , args ... ) . Error
}
// UpdateConversationLastSeqWithTx 在事务中更新会话最后消息序号
func ( r * MessageRepository ) UpdateConversationLastSeqWithTx ( tx * gorm . DB , convID string , lastSeq int64 , lastMsgTime time . Time ) error {
return tx . Model ( & model . Conversation { } ) .
Where ( "id = ?" , convID ) .
Updates ( map [ string ] interface { } {
"last_seq" : lastSeq ,
"last_msg_time" : lastMsgTime ,
"updated_at" : time . Now ( ) ,
} ) . Error
}