package service import ( "context" "log" "time" "carrot_bbs/internal/cache" "carrot_bbs/internal/model" "carrot_bbs/internal/repository" "gorm.io/gorm" ) // 缓存TTL常量 const ( ConversationListTTL = 60 * time.Second // 会话列表缓存60秒 ConversationDetailTTL = 60 * time.Second // 会话详情缓存60秒 UnreadCountTTL = 30 * time.Second // 未读数缓存30秒 ConversationNullTTL = 5 * time.Second UnreadNullTTL = 5 * time.Second CacheJitterRatio = 0.1 ) // MessageService 消息服务 type MessageService struct { db *gorm.DB // 基础仓储 messageRepo *repository.MessageRepository // 缓存相关字段 conversationCache *cache.ConversationCache // 基础缓存(用于简单缓存操作) baseCache cache.Cache } // NewMessageService 创建消息服务 func NewMessageService(db *gorm.DB, messageRepo *repository.MessageRepository) *MessageService { // 创建适配器 convRepoAdapter := cache.NewConversationRepositoryAdapter(messageRepo) msgRepoAdapter := cache.NewMessageRepositoryAdapter(messageRepo) // 创建会话缓存 conversationCache := cache.NewConversationCache( cache.GetCache(), convRepoAdapter, msgRepoAdapter, cache.DefaultConversationCacheSettings(), ) return &MessageService{ db: db, messageRepo: messageRepo, conversationCache: conversationCache, baseCache: cache.GetCache(), } } // ConversationListResult 会话列表缓存结果 type ConversationListResult struct { Conversations []*model.Conversation Total int64 } // SendMessage 发送消息(使用 segments) // senderID 和 receiverID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID string, segments model.MessageSegments) (*model.Message, error) { // 获取或创建会话 conv, err := s.messageRepo.GetOrCreatePrivateConversation(senderID, receiverID) if err != nil { return nil, err } msg := &model.Message{ ConversationID: conv.ID, SenderID: senderID, Segments: segments, Status: model.MessageStatusNormal, } // 使用事务创建消息并更新seq err = s.messageRepo.CreateMessageWithSeq(msg) if err != nil { return nil, err } // 新消息会改变分页结果,先失效分页缓存,避免读到旧列表 if s.conversationCache != nil { s.conversationCache.InvalidateMessagePages(conv.ID) } // 异步写入缓存 go func() { if err := s.cacheMessage(context.Background(), conv.ID, msg); err != nil { log.Printf("[MessageService] async cache message failed, convID=%s, msgID=%s, err=%v", conv.ID, msg.ID, err) } }() // 失效会话列表缓存(发送者和接收者) s.conversationCache.InvalidateConversationList(senderID) s.conversationCache.InvalidateConversationList(receiverID) // 失效未读数缓存 cache.InvalidateUnreadConversation(s.baseCache, receiverID) s.conversationCache.InvalidateUnreadCount(receiverID, conv.ID) return msg, nil } // cacheMessage 缓存消息(内部方法) func (s *MessageService) cacheMessage(ctx context.Context, convID string, msg *model.Message) error { if s.conversationCache == nil { return nil } asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() return s.conversationCache.CacheMessage(asyncCtx, convID, msg) } // GetConversations 获取会话列表(带缓存) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { // 优先使用 ConversationCache if s.conversationCache != nil { return s.conversationCache.GetConversationList(ctx, userID, page, pageSize) } // 降级到基础缓存 cacheSettings := cache.GetSettings() conversationTTL := cacheSettings.ConversationTTL if conversationTTL <= 0 { conversationTTL = ConversationListTTL } nullTTL := cacheSettings.NullTTL if nullTTL <= 0 { nullTTL = ConversationNullTTL } jitter := cacheSettings.JitterRatio if jitter <= 0 { jitter = CacheJitterRatio } // 生成缓存键 cacheKey := cache.ConversationListKey(userID, page, pageSize) result, err := cache.GetOrLoadTyped[*ConversationListResult]( s.baseCache, cacheKey, conversationTTL, jitter, nullTTL, func() (*ConversationListResult, error) { conversations, total, err := s.messageRepo.GetConversations(userID, page, pageSize) if err != nil { return nil, err } return &ConversationListResult{ Conversations: conversations, Total: total, }, nil }, ) if err != nil { return nil, 0, err } if result == nil { return []*model.Conversation{}, 0, nil } return result.Conversations, result.Total, nil } // GetMessages 获取消息列表(带缓存) func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) { // 优先使用 ConversationCache if s.conversationCache != nil { return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize) } // 降级到直接访问数据库 return s.messageRepo.GetMessages(conversationID, page, pageSize) } // GetMessagesAfterSeq 获取指定seq之后的消息(增量同步) func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID string, afterSeq int64, limit int) ([]*model.Message, error) { return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit) } // MarkAsRead 标记为已读(使用 Cache-Aside 模式) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error { // 1. 先写入DB(保证数据一致性,DB是唯一数据源) err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq) if err != nil { return err } // 2. DB 写入成功后,失效缓存(Cache-Aside 模式) if s.conversationCache != nil { // 失效参与者缓存,下次读取时会从 DB 加载最新数据 s.conversationCache.InvalidateParticipant(conversationID, userID) // 失效未读数缓存 s.conversationCache.InvalidateUnreadCount(userID, conversationID) // 失效会话列表缓存 s.conversationCache.InvalidateConversationList(userID) } cache.InvalidateUnreadConversation(s.baseCache, userID) return nil } // GetUnreadCount 获取未读消息数(带缓存) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { // 优先使用 ConversationCache if s.conversationCache != nil { return s.conversationCache.GetUnreadCount(ctx, userID, conversationID) } // 降级到基础缓存 cacheSettings := cache.GetSettings() unreadTTL := cacheSettings.UnreadCountTTL if unreadTTL <= 0 { unreadTTL = UnreadCountTTL } nullTTL := cacheSettings.NullTTL if nullTTL <= 0 { nullTTL = UnreadNullTTL } jitter := cacheSettings.JitterRatio if jitter <= 0 { jitter = CacheJitterRatio } // 生成缓存键 cacheKey := cache.UnreadDetailKey(userID, conversationID) return cache.GetOrLoadTyped[int64]( s.baseCache, cacheKey, unreadTTL, jitter, nullTTL, func() (int64, error) { return s.messageRepo.GetUnreadCount(conversationID, userID) }, ) } // GetOrCreateConversation 获取或创建私聊会话 // user1ID 和 user2ID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { conv, err := s.messageRepo.GetOrCreatePrivateConversation(user1ID, user2ID) if err != nil { return nil, err } // 失效会话列表缓存 s.conversationCache.InvalidateConversationList(user1ID) s.conversationCache.InvalidateConversationList(user2ID) return conv, nil } // GetConversationParticipants 获取会话参与者列表 func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) { // 优先使用缓存 if s.conversationCache != nil { return s.conversationCache.GetParticipants(context.Background(), conversationID) } return s.messageRepo.GetConversationParticipants(conversationID) } // ParseConversationID 辅助函数:直接返回字符串ID(已经是string类型) func ParseConversationID(idStr string) (string, error) { return idStr, nil } // InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用) func (s *MessageService) InvalidateUserConversationCache(userID string) { s.conversationCache.InvalidateConversationList(userID) cache.InvalidateUnreadConversation(s.baseCache, userID) } // InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用) func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) { cache.InvalidateUnreadConversation(s.baseCache, userID) s.conversationCache.InvalidateUnreadCount(userID, conversationID) }