package cache import ( "context" "encoding/json" "fmt" "log" "time" "carrot_bbs/internal/model" ) // CachedConversation 带缓存元数据的会话 type CachedConversation struct { Data *model.Conversation // 实际数据 Version int64 // 版本号(CAS 更新用) UpdatedAt time.Time // 最后更新时间 AccessAt time.Time // 最后访问时间(用于 TTL 延长) } // CachedParticipant 带缓存元数据的参与者 type CachedParticipant struct { Data *model.ConversationParticipant Version int64 UpdatedAt time.Time AccessAt time.Time } // CachedMessage 带缓存元数据的消息 type CachedMessage struct { Data *model.Message `json:"data"` // 消息数据 Seq int64 `json:"seq"` // 消息序号 CreatedAt time.Time `json:"created_at"` // 创建时间 } // MessageCacheData Redis 中存储的消息 Hash 结构 type MessageCacheData struct { ID string `json:"id"` ConversationID string `json:"conversation_id"` SenderID string `json:"sender_id"` Seq int64 `json:"seq"` Segments json.RawMessage `json:"segments"` ReplyToID *string `json:"reply_to_id,omitempty"` Status string `json:"status"` Category string `json:"category"` SystemType string `json:"system_type,omitempty"` ExtraData json.RawMessage `json:"extra_data,omitempty"` MentionUsers string `json:"mention_users"` MentionAll bool `json:"mention_all"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } // PageCache 分页缓存 type PageCache struct { Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表 Total int64 `json:"total"` // 消息总数 Page int `json:"page"` // 当前页码 PageSize int `json:"page_size"` // 每页大小 HasMore bool `json:"has_more"` // 是否有更多 UpdatedAt time.Time `json:"updated_at"` // 更新时间 } // ConversationCacheSettings 缓存配置 type ConversationCacheSettings struct { DetailTTL time.Duration // 会话详情 TTL (5min) ListTTL time.Duration // 会话列表 TTL (60s) ParticipantTTL time.Duration // 参与者 TTL (5min) UnreadTTL time.Duration // 未读数 TTL (30s) // 消息缓存配置 MessageDetailTTL time.Duration // 单条消息详情缓存 (30min) MessageListTTL time.Duration // 消息分页列表缓存 (5min) MessageIndexTTL time.Duration // 消息索引缓存 (30min) MessageCountTTL time.Duration // 消息计数缓存 (30min) } // DefaultConversationCacheSettings 返回默认配置 func DefaultConversationCacheSettings() *ConversationCacheSettings { return &ConversationCacheSettings{ DetailTTL: 5 * time.Minute, ListTTL: 60 * time.Second, ParticipantTTL: 5 * time.Minute, UnreadTTL: 30 * time.Second, MessageDetailTTL: 30 * time.Minute, MessageListTTL: 5 * time.Minute, MessageIndexTTL: 30 * time.Minute, MessageCountTTL: 30 * time.Minute, } } // parseSegments 将 JSON RawMessage 解析为 MessageSegments func parseSegments(data json.RawMessage) model.MessageSegments { if data == nil { return nil } var segments model.MessageSegments if err := json.Unmarshal(data, &segments); err != nil { return nil } return segments } // serializeSegments 将 MessageSegments 序列化为 JSON RawMessage func serializeSegments(segments model.MessageSegments) json.RawMessage { if segments == nil { return nil } data, err := json.Marshal(segments) if err != nil { return nil } return data } // ToModel 将 MessageCacheData 转换为 model.Message func (m *MessageCacheData) ToModel() *model.Message { return &model.Message{ ID: m.ID, ConversationID: m.ConversationID, SenderID: m.SenderID, Seq: m.Seq, Segments: parseSegments(m.Segments), ReplyToID: m.ReplyToID, Status: model.MessageStatus(m.Status), Category: model.MessageCategory(m.Category), SystemType: model.SystemMessageType(m.SystemType), ExtraData: parseExtraData(m.ExtraData), MentionUsers: m.MentionUsers, MentionAll: m.MentionAll, CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, } } // MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData { return &MessageCacheData{ ID: msg.ID, ConversationID: msg.ConversationID, SenderID: msg.SenderID, Seq: msg.Seq, Segments: serializeSegments(msg.Segments), ReplyToID: msg.ReplyToID, Status: string(msg.Status), Category: string(msg.Category), SystemType: string(msg.SystemType), ExtraData: serializeExtraData(msg.ExtraData), MentionUsers: msg.MentionUsers, MentionAll: msg.MentionAll, CreatedAt: msg.CreatedAt, UpdatedAt: msg.UpdatedAt, } } // parseExtraData 将 JSON RawMessage 解析为 ExtraData func parseExtraData(data json.RawMessage) *model.ExtraData { if data == nil { return nil } var extraData model.ExtraData if err := json.Unmarshal(data, &extraData); err != nil { return nil } return &extraData } // serializeExtraData 将 ExtraData 序列化为 JSON RawMessage func serializeExtraData(extraData *model.ExtraData) json.RawMessage { if extraData == nil { return nil } data, err := json.Marshal(extraData) if err != nil { return nil } return data } // ============================================================ // 缓存 Key 常量和生成函数 // ============================================================ const ( keyPrefixConv = "conv" // 会话详情 keyPrefixConvPart = "conv_part" // 参与者列表 keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息 ) // ConversationKey 会话详情缓存 key func ConversationKey(convID string) string { return fmt.Sprintf("%s:%s", keyPrefixConv, convID) } // ParticipantListKey 参与者列表缓存 key func ParticipantListKey(convID string) string { return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID) } // ParticipantKey 用户参与者信息缓存 key func ParticipantKey(convID, userID string) string { return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID) } // ============================================================ // ConversationRepository 接口定义 // ============================================================ // ConversationRepository 会话数据仓库接口(用于依赖注入) type ConversationRepository interface { GetConversationByID(convID string) (*model.Conversation, error) GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) GetParticipant(convID, userID string) (*model.ConversationParticipant, error) GetParticipants(convID string) ([]*model.ConversationParticipant, error) GetUnreadCount(userID, convID string) (int64, error) } // MessageRepository 消息数据仓库接口(用于依赖注入) type MessageRepository interface { GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) } // ============================================================ // ConversationCache 核心实现 // ============================================================ // ConversationCache 会话缓存管理器 type ConversationCache struct { cache Cache // 底层缓存 settings *ConversationCacheSettings // 配置 repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源) msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源) } // NewConversationCache 创建会话缓存管理器 func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache { if settings == nil { settings = DefaultConversationCacheSettings() } return &ConversationCache{ cache: cache, settings: settings, repo: repo, msgRepo: msgRepo, } } // GetConversation 读取会话(带 TTL 滑动延长) // 1. 尝试从缓存获取 // 2. 如果命中,更新 AccessAt 并延长 TTL // 3. 如果未命中,从 repo 加载并写入缓存 func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) { key := ConversationKey(convID) // 1. 尝试从缓存获取 cached, ok := GetTyped[*CachedConversation](c.cache, key) if ok && cached != nil && cached.Data != nil { // 2. 命中,更新 AccessAt 并延长 TTL cached.AccessAt = time.Now() c.cache.Set(key, cached, c.settings.DetailTTL) return cached.Data, nil } // 3. 未命中,从 repo 加载 if c.repo == nil { return nil, fmt.Errorf("repository not configured") } conv, err := c.repo.GetConversationByID(convID) if err != nil { return nil, err } // 写入缓存 now := time.Now() cachedConv := &CachedConversation{ Data: conv, Version: 0, UpdatedAt: now, AccessAt: now, } c.cache.Set(key, cachedConv, c.settings.DetailTTL) return conv, nil } // CachedConversationList 带元数据的会话列表缓存 type CachedConversationList struct { Conversations []*model.Conversation `json:"conversations"` Total int64 `json:"total"` Page int `json:"page"` PageSize int `json:"page_size"` UpdatedAt time.Time `json:"updated_at"` AccessAt time.Time `json:"access_at"` } // GetConversationList 获取用户会话列表(带 TTL 滑动延长) func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { key := ConversationListKey(userID, page, pageSize) // 1. 尝试从缓存获取 cached, ok := GetTyped[*CachedConversationList](c.cache, key) if ok && cached != nil { // 2. 命中,更新 AccessAt 并延长 TTL cached.AccessAt = time.Now() c.cache.Set(key, cached, c.settings.ListTTL) return cached.Conversations, cached.Total, nil } // 3. 未命中,从 repo 加载 if c.repo == nil { return nil, 0, fmt.Errorf("repository not configured") } convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize) if err != nil { return nil, 0, err } // 写入缓存 now := time.Now() cachedList := &CachedConversationList{ Conversations: convs, Total: total, Page: page, PageSize: pageSize, UpdatedAt: now, AccessAt: now, } c.cache.Set(key, cachedList, c.settings.ListTTL) return convs, total, nil } // GetParticipant 获取参与者信息(带 TTL 滑动延长) func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) { key := ParticipantKey(convID, userID) // 1. 尝试从缓存获取 cached, ok := GetTyped[*CachedParticipant](c.cache, key) if ok && cached != nil && cached.Data != nil { // 2. 命中,更新 AccessAt 并延长 TTL cached.AccessAt = time.Now() c.cache.Set(key, cached, c.settings.ParticipantTTL) return cached.Data, nil } // 3. 未命中,从 repo 加载 if c.repo == nil { return nil, fmt.Errorf("repository not configured") } participant, err := c.repo.GetParticipant(convID, userID) if err != nil { return nil, err } // 写入缓存 now := time.Now() cachedPart := &CachedParticipant{ Data: participant, Version: 0, UpdatedAt: now, AccessAt: now, } c.cache.Set(key, cachedPart, c.settings.ParticipantTTL) return participant, nil } // CachedParticipantList 带元数据的参与者列表缓存 type CachedParticipantList struct { Participants []*model.ConversationParticipant `json:"participants"` UpdatedAt time.Time `json:"updated_at"` AccessAt time.Time `json:"access_at"` } // GetParticipants 获取会话所有参与者(带 TTL 滑动延长) func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) { key := ParticipantListKey(convID) // 1. 尝试从缓存获取 cached, ok := GetTyped[*CachedParticipantList](c.cache, key) if ok && cached != nil { // 2. 命中,更新 AccessAt 并延长 TTL cached.AccessAt = time.Now() c.cache.Set(key, cached, c.settings.ParticipantTTL) return cached.Participants, nil } // 3. 未命中,从 repo 加载 if c.repo == nil { return nil, fmt.Errorf("repository not configured") } participants, err := c.repo.GetParticipants(convID) if err != nil { return nil, err } // 写入缓存 now := time.Now() cachedList := &CachedParticipantList{ Participants: participants, UpdatedAt: now, AccessAt: now, } c.cache.Set(key, cachedList, c.settings.ParticipantTTL) return participants, nil } // CachedUnreadCount 带元数据的未读数缓存 type CachedUnreadCount struct { Count int64 `json:"count"` UpdatedAt time.Time `json:"updated_at"` AccessAt time.Time `json:"access_at"` } // GetUnreadCount 获取未读数(带 TTL 滑动延长) func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) { key := UnreadDetailKey(userID, convID) // 1. 尝试从缓存获取 cached, ok := GetTyped[*CachedUnreadCount](c.cache, key) if ok && cached != nil { // 2. 命中,更新 AccessAt 并延长 TTL cached.AccessAt = time.Now() c.cache.Set(key, cached, c.settings.UnreadTTL) return cached.Count, nil } // 3. 未命中,从 repo 加载 if c.repo == nil { return 0, fmt.Errorf("repository not configured") } count, err := c.repo.GetUnreadCount(userID, convID) if err != nil { return 0, err } // 写入缓存 now := time.Now() cachedCount := &CachedUnreadCount{ Count: count, UpdatedAt: now, AccessAt: now, } c.cache.Set(key, cachedCount, c.settings.UnreadTTL) return count, nil } // ============================================================ // 缓存失效方法 // ============================================================ // InvalidateConversation 使会话缓存失效 func (c *ConversationCache) InvalidateConversation(convID string) { c.cache.Delete(ConversationKey(convID)) } // InvalidateConversationList 使会话列表缓存失效(按用户) func (c *ConversationCache) InvalidateConversationList(userID string) { c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID)) } // InvalidateParticipant 使参与者缓存失效 func (c *ConversationCache) InvalidateParticipant(convID, userID string) { c.cache.Delete(ParticipantKey(convID, userID)) } // InvalidateParticipantList 使参与者列表缓存失效 func (c *ConversationCache) InvalidateParticipantList(convID string) { c.cache.Delete(ParticipantListKey(convID)) } // InvalidateUnreadCount 使未读数缓存失效 func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) { c.cache.Delete(UnreadDetailKey(userID, convID)) } // ============================================================ // 消息缓存方法 // ============================================================ // GetMessages 获取会话消息(带缓存) // 1. 尝试从分页缓存获取 // 2. 如果命中,从 Hash 中批量获取消息详情 // 3. 如果未命中,从数据库加载并写入缓存 func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) { // 1. 尝试从缓存获取分页数据 pageKey := MessagePageKey(convID, page, pageSize) cached, ok := GetTyped[*PageCache](c.cache, pageKey) if ok && cached != nil { // TTL 滑动延长 cached.UpdatedAt = time.Now() c.cache.Set(pageKey, cached, c.settings.MessageListTTL) // 从 Hash 中批量获取消息详情 if len(cached.Seqs) > 0 { messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs) if err == nil { return messages, cached.Total, nil } // Hash 获取失败,继续从数据库加载 } } // 2. 缓存未命中,从数据库加载 if c.msgRepo == nil { return nil, 0, fmt.Errorf("message repository not configured") } messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize) if err != nil { return nil, 0, err } // 3. 写入缓存 seqs := make([]int64, len(messages)) for i, msg := range messages { seqs[i] = msg.Seq // 异步写入消息详情到 Hash go c.asyncCacheMessage(context.Background(), convID, msg) } pageCache := &PageCache{ Seqs: seqs, Total: total, Page: page, PageSize: pageSize, HasMore: int64(page*pageSize) < total, UpdatedAt: time.Now(), } c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL) return messages, total, nil } // GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步) // 使用 Sorted Set 的 ZRangeByScore 实现 func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) { indexKey := MessageIndexKey(convID) // 1. 尝试从 Sorted Set 获取 seq 列表 members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit)) if err != nil { return nil, err } // 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情 if len(members) > 0 { seqs := make([]int64, 0, len(members)) for _, member := range members { var seq int64 if _, err := fmt.Sscanf(member, "%d", &seq); err == nil { seqs = append(seqs, seq) } } return c.getMessagesBySeqs(ctx, convID, seqs) } // 3. Sorted Set 未命中,从数据库加载 if c.msgRepo == nil { return nil, fmt.Errorf("message repository not configured") } messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit) if err != nil { return nil, err } // 4. 异步写入缓存 for _, msg := range messages { go c.asyncCacheMessage(context.Background(), convID, msg) } return messages, nil } // GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载) // 使用 Sorted Set 的 ZRevRangeByScore 实现 func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) { indexKey := MessageIndexKey(convID) // 1. 尝试从 Sorted Set 获取 seq 列表(降序) members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit)) if err != nil { return nil, err } // 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情 if len(members) > 0 { seqs := make([]int64, 0, len(members)) for _, member := range members { var seq int64 if _, err := fmt.Sscanf(member, "%d", &seq); err == nil { seqs = append(seqs, seq) } } return c.getMessagesBySeqs(ctx, convID, seqs) } // 3. Sorted Set 未命中,从数据库加载 if c.msgRepo == nil { return nil, fmt.Errorf("message repository not configured") } messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit) if err != nil { return nil, err } // 4. 异步写入缓存 for _, msg := range messages { go c.asyncCacheMessage(context.Background(), convID, msg) } return messages, nil } // CacheMessage 缓存单条消息(立即写入缓存) // 写入 Hash、Sorted Set、更新计数 func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error { hashKey := MessageHashKey(convID) indexKey := MessageIndexKey(convID) msgData := MessageCacheDataFromModel(msg) data, err := json.Marshal(msgData) if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } // HSET 消息详情 if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil { return fmt.Errorf("failed to set hash: %w", err) } // ZADD 消息索引 if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil { return fmt.Errorf("failed to add to sorted set: %w", err) } // 设置 TTL c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL) c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL) // INCR 消息计数 c.cache.Incr(ctx, MessageCountKey(convID)) return nil } // InvalidateMessageCache 使消息缓存失效 func (c *ConversationCache) InvalidateMessageCache(convID string) { c.cache.Delete(MessageHashKey(convID)) c.cache.Delete(MessageIndexKey(convID)) c.cache.Delete(MessageCountKey(convID)) // 删除所有分页缓存 c.InvalidateMessagePages(convID) } // InvalidateMessagePages 仅使消息分页缓存失效 // 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。 func (c *ConversationCache) InvalidateMessagePages(convID string) { c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID)) } // ============================================================ // 内部辅助方法 // ============================================================ // getMessagesBySeqs 从 Hash 中批量获取消息 func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) { if len(seqs) == 0 { return nil, nil } hashKey := MessageHashKey(convID) fields := make([]string, len(seqs)) for i, seq := range seqs { fields[i] = fmt.Sprintf("%d", seq) } // 批量获取 values, err := c.cache.HMGet(ctx, hashKey, fields...) if err != nil { return nil, err } messages := make([]*model.Message, 0, len(seqs)) for _, val := range values { if val == nil { continue } var msgData MessageCacheData switch v := val.(type) { case string: if err := json.Unmarshal([]byte(v), &msgData); err != nil { continue } case []byte: if err := json.Unmarshal(v, &msgData); err != nil { continue } default: continue } messages = append(messages, msgData.ToModel()) } return messages, nil } // asyncCacheMessage 异步缓存单条消息 func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) { if err := c.CacheMessage(ctx, convID, msg); err != nil { log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err) } }