Files
backend/internal/cache/conversation_cache.go

725 lines
22 KiB
Go
Raw Normal View History

package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
)
// CachedConversation 带缓存元数据的会话
type CachedConversation struct {
Data *model.Conversation // 实际数据
Version int64 // 版本号CAS 更新用)
UpdatedAt time.Time // 最后更新时间
AccessAt time.Time // 最后访问时间(用于 TTL 延长)
}
// CachedParticipant 带缓存元数据的参与者
type CachedParticipant struct {
Data *model.ConversationParticipant
Version int64
UpdatedAt time.Time
AccessAt time.Time
}
// CachedMessage 带缓存元数据的消息
type CachedMessage struct {
Data *model.Message `json:"data"` // 消息数据
Seq int64 `json:"seq"` // 消息序号
CreatedAt time.Time `json:"created_at"` // 创建时间
}
// MessageCacheData Redis 中存储的消息 Hash 结构
type MessageCacheData struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments json.RawMessage `json:"segments"`
ReplyToID *string `json:"reply_to_id,omitempty"`
Status string `json:"status"`
Category string `json:"category"`
SystemType string `json:"system_type,omitempty"`
ExtraData json.RawMessage `json:"extra_data,omitempty"`
MentionUsers string `json:"mention_users"`
MentionAll bool `json:"mention_all"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PageCache 分页缓存
type PageCache struct {
Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表
Total int64 `json:"total"` // 消息总数
Page int `json:"page"` // 当前页码
PageSize int `json:"page_size"` // 每页大小
HasMore bool `json:"has_more"` // 是否有更多
UpdatedAt time.Time `json:"updated_at"` // 更新时间
}
// ConversationCacheSettings 缓存配置
type ConversationCacheSettings struct {
DetailTTL time.Duration // 会话详情 TTL (5min)
ListTTL time.Duration // 会话列表 TTL (60s)
ParticipantTTL time.Duration // 参与者 TTL (5min)
UnreadTTL time.Duration // 未读数 TTL (30s)
// 消息缓存配置
MessageDetailTTL time.Duration // 单条消息详情缓存 (30min)
MessageListTTL time.Duration // 消息分页列表缓存 (5min)
MessageIndexTTL time.Duration // 消息索引缓存 (30min)
MessageCountTTL time.Duration // 消息计数缓存 (30min)
}
// DefaultConversationCacheSettings 返回默认配置
func DefaultConversationCacheSettings() *ConversationCacheSettings {
return &ConversationCacheSettings{
DetailTTL: 5 * time.Minute,
ListTTL: 60 * time.Second,
ParticipantTTL: 5 * time.Minute,
UnreadTTL: 30 * time.Second,
MessageDetailTTL: 30 * time.Minute,
MessageListTTL: 5 * time.Minute,
MessageIndexTTL: 30 * time.Minute,
MessageCountTTL: 30 * time.Minute,
}
}
// parseSegments 将 JSON RawMessage 解析为 MessageSegments
func parseSegments(data json.RawMessage) model.MessageSegments {
if data == nil {
return nil
}
var segments model.MessageSegments
if err := json.Unmarshal(data, &segments); err != nil {
return nil
}
return segments
}
// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage
func serializeSegments(segments model.MessageSegments) json.RawMessage {
if segments == nil {
return nil
}
data, err := json.Marshal(segments)
if err != nil {
return nil
}
return data
}
// ToModel 将 MessageCacheData 转换为 model.Message
func (m *MessageCacheData) ToModel() *model.Message {
return &model.Message{
ID: m.ID,
ConversationID: m.ConversationID,
SenderID: m.SenderID,
Seq: m.Seq,
Segments: parseSegments(m.Segments),
ReplyToID: m.ReplyToID,
Status: model.MessageStatus(m.Status),
Category: model.MessageCategory(m.Category),
SystemType: model.SystemMessageType(m.SystemType),
ExtraData: parseExtraData(m.ExtraData),
MentionUsers: m.MentionUsers,
MentionAll: m.MentionAll,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData
func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData {
return &MessageCacheData{
ID: msg.ID,
ConversationID: msg.ConversationID,
SenderID: msg.SenderID,
Seq: msg.Seq,
Segments: serializeSegments(msg.Segments),
ReplyToID: msg.ReplyToID,
Status: string(msg.Status),
Category: string(msg.Category),
SystemType: string(msg.SystemType),
ExtraData: serializeExtraData(msg.ExtraData),
MentionUsers: msg.MentionUsers,
MentionAll: msg.MentionAll,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
// parseExtraData 将 JSON RawMessage 解析为 ExtraData
func parseExtraData(data json.RawMessage) *model.ExtraData {
if data == nil {
return nil
}
var extraData model.ExtraData
if err := json.Unmarshal(data, &extraData); err != nil {
return nil
}
return &extraData
}
// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage
func serializeExtraData(extraData *model.ExtraData) json.RawMessage {
if extraData == nil {
return nil
}
data, err := json.Marshal(extraData)
if err != nil {
return nil
}
return data
}
// ============================================================
// 缓存 Key 常量和生成函数
// ============================================================
const (
keyPrefixConv = "conv" // 会话详情
keyPrefixConvPart = "conv_part" // 参与者列表
keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息
)
// ConversationKey 会话详情缓存 key
func ConversationKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConv, convID)
}
// ParticipantListKey 参与者列表缓存 key
func ParticipantListKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID)
}
// ParticipantKey 用户参与者信息缓存 key
func ParticipantKey(convID, userID string) string {
return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID)
}
// ============================================================
// ConversationRepository 接口定义
// ============================================================
// ConversationRepository 会话数据仓库接口(用于依赖注入)
type ConversationRepository interface {
GetConversationByID(convID string) (*model.Conversation, error)
GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetParticipant(convID, userID string) (*model.ConversationParticipant, error)
GetParticipants(convID string) ([]*model.ConversationParticipant, error)
GetUnreadCount(userID, convID string) (int64, error)
}
// MessageRepository 消息数据仓库接口(用于依赖注入)
type MessageRepository interface {
GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error)
}
// ============================================================
// ConversationCache 核心实现
// ============================================================
// ConversationCache 会话缓存管理器
type ConversationCache struct {
cache Cache // 底层缓存
settings *ConversationCacheSettings // 配置
repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源)
msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源)
}
// NewConversationCache 创建会话缓存管理器
func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache {
if settings == nil {
settings = DefaultConversationCacheSettings()
}
return &ConversationCache{
cache: cache,
settings: settings,
repo: repo,
msgRepo: msgRepo,
}
}
// GetConversation 读取会话(带 TTL 滑动延长)
// 1. 尝试从缓存获取
// 2. 如果命中,更新 AccessAt 并延长 TTL
// 3. 如果未命中,从 repo 加载并写入缓存
func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) {
key := ConversationKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversation](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.DetailTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
conv, err := c.repo.GetConversationByID(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedConv := &CachedConversation{
Data: conv,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedConv, c.settings.DetailTTL)
return conv, nil
}
// CachedConversationList 带元数据的会话列表缓存
type CachedConversationList struct {
Conversations []*model.Conversation `json:"conversations"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetConversationList 获取用户会话列表(带 TTL 滑动延长)
func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
key := ConversationListKey(userID, page, pageSize)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversationList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ListTTL)
return cached.Conversations, cached.Total, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, 0, fmt.Errorf("repository not configured")
}
convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedConversationList{
Conversations: convs,
Total: total,
Page: page,
PageSize: pageSize,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ListTTL)
return convs, total, nil
}
// GetParticipant 获取参与者信息(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) {
key := ParticipantKey(convID, userID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipant](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participant, err := c.repo.GetParticipant(convID, userID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedPart := &CachedParticipant{
Data: participant,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedPart, c.settings.ParticipantTTL)
return participant, nil
}
// CachedParticipantList 带元数据的参与者列表缓存
type CachedParticipantList struct {
Participants []*model.ConversationParticipant `json:"participants"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetParticipants 获取会话所有参与者(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) {
key := ParticipantListKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipantList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Participants, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participants, err := c.repo.GetParticipants(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedParticipantList{
Participants: participants,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ParticipantTTL)
return participants, nil
}
// CachedUnreadCount 带元数据的未读数缓存
type CachedUnreadCount struct {
Count int64 `json:"count"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetUnreadCount 获取未读数(带 TTL 滑动延长)
func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) {
key := UnreadDetailKey(userID, convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedUnreadCount](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.UnreadTTL)
return cached.Count, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return 0, fmt.Errorf("repository not configured")
}
count, err := c.repo.GetUnreadCount(userID, convID)
if err != nil {
return 0, err
}
// 写入缓存
now := time.Now()
cachedCount := &CachedUnreadCount{
Count: count,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedCount, c.settings.UnreadTTL)
return count, nil
}
// ============================================================
// 缓存失效方法
// ============================================================
// InvalidateConversation 使会话缓存失效
func (c *ConversationCache) InvalidateConversation(convID string) {
c.cache.Delete(ConversationKey(convID))
}
// InvalidateConversationList 使会话列表缓存失效(按用户)
func (c *ConversationCache) InvalidateConversationList(userID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID))
}
// InvalidateParticipant 使参与者缓存失效
func (c *ConversationCache) InvalidateParticipant(convID, userID string) {
c.cache.Delete(ParticipantKey(convID, userID))
}
// InvalidateParticipantList 使参与者列表缓存失效
func (c *ConversationCache) InvalidateParticipantList(convID string) {
c.cache.Delete(ParticipantListKey(convID))
}
// InvalidateUnreadCount 使未读数缓存失效
func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) {
c.cache.Delete(UnreadDetailKey(userID, convID))
}
// ============================================================
// 消息缓存方法
// ============================================================
// GetMessages 获取会话消息(带缓存)
// 1. 尝试从分页缓存获取
// 2. 如果命中,从 Hash 中批量获取消息详情
// 3. 如果未命中,从数据库加载并写入缓存
func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) {
// 1. 尝试从缓存获取分页数据
pageKey := MessagePageKey(convID, page, pageSize)
cached, ok := GetTyped[*PageCache](c.cache, pageKey)
if ok && cached != nil {
// TTL 滑动延长
cached.UpdatedAt = time.Now()
c.cache.Set(pageKey, cached, c.settings.MessageListTTL)
// 从 Hash 中批量获取消息详情
if len(cached.Seqs) > 0 {
messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs)
if err == nil {
return messages, cached.Total, nil
}
// Hash 获取失败,继续从数据库加载
}
}
// 2. 缓存未命中,从数据库加载
if c.msgRepo == nil {
return nil, 0, fmt.Errorf("message repository not configured")
}
messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 3. 写入缓存
seqs := make([]int64, len(messages))
for i, msg := range messages {
seqs[i] = msg.Seq
// 异步写入消息详情到 Hash
go c.asyncCacheMessage(context.Background(), convID, msg)
}
pageCache := &PageCache{
Seqs: seqs,
Total: total,
Page: page,
PageSize: pageSize,
HasMore: int64(page*pageSize) < total,
UpdatedAt: time.Now(),
}
c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL)
return messages, total, nil
}
// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步)
// 使用 Sorted Set 的 ZRangeByScore 实现
func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表
members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载)
// 使用 Sorted Set 的 ZRevRangeByScore 实现
func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表(降序)
members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// CacheMessage 缓存单条消息(立即写入缓存)
// 写入 Hash、Sorted Set、更新计数
func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error {
hashKey := MessageHashKey(convID)
indexKey := MessageIndexKey(convID)
msgData := MessageCacheDataFromModel(msg)
data, err := json.Marshal(msgData)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// HSET 消息详情
if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil {
return fmt.Errorf("failed to set hash: %w", err)
}
// ZADD 消息索引
if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil {
return fmt.Errorf("failed to add to sorted set: %w", err)
}
// 设置 TTL
c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL)
c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL)
// INCR 消息计数
c.cache.Incr(ctx, MessageCountKey(convID))
return nil
}
// InvalidateMessageCache 使消息缓存失效
func (c *ConversationCache) InvalidateMessageCache(convID string) {
c.cache.Delete(MessageHashKey(convID))
c.cache.Delete(MessageIndexKey(convID))
c.cache.Delete(MessageCountKey(convID))
// 删除所有分页缓存
c.InvalidateMessagePages(convID)
}
// InvalidateMessagePages 仅使消息分页缓存失效
// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。
func (c *ConversationCache) InvalidateMessagePages(convID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID))
}
// ============================================================
// 内部辅助方法
// ============================================================
// getMessagesBySeqs 从 Hash 中批量获取消息
func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) {
if len(seqs) == 0 {
return nil, nil
}
hashKey := MessageHashKey(convID)
fields := make([]string, len(seqs))
for i, seq := range seqs {
fields[i] = fmt.Sprintf("%d", seq)
}
// 批量获取
values, err := c.cache.HMGet(ctx, hashKey, fields...)
if err != nil {
return nil, err
}
messages := make([]*model.Message, 0, len(seqs))
for _, val := range values {
if val == nil {
continue
}
var msgData MessageCacheData
switch v := val.(type) {
case string:
if err := json.Unmarshal([]byte(v), &msgData); err != nil {
continue
}
case []byte:
if err := json.Unmarshal(v, &msgData); err != nil {
continue
}
default:
continue
}
messages = append(messages, msgData.ToModel())
}
return messages, nil
}
// asyncCacheMessage 异步缓存单条消息
func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) {
if err := c.CacheMessage(ctx, convID, msg); err != nil {
log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err)
}
}