Files
backend/internal/cache/conversation_cache.go
lan 0a0cbacbcc feat(schedule): add course table screens and navigation
Add complete schedule functionality including:
- Schedule screen with weekly course table view
- Course detail screen with transparent modal presentation
- New ScheduleStack navigator integrated into main tab bar
- Schedule service for API interactions
- Type definitions for course entities

Also includes bug fixes for group invite/request handlers
to include required groupId parameter.
2026-03-12 08:38:14 +08:00

725 lines
22 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
)
// CachedConversation 带缓存元数据的会话
type CachedConversation struct {
Data *model.Conversation // 实际数据
Version int64 // 版本号CAS 更新用)
UpdatedAt time.Time // 最后更新时间
AccessAt time.Time // 最后访问时间(用于 TTL 延长)
}
// CachedParticipant 带缓存元数据的参与者
type CachedParticipant struct {
Data *model.ConversationParticipant
Version int64
UpdatedAt time.Time
AccessAt time.Time
}
// CachedMessage 带缓存元数据的消息
type CachedMessage struct {
Data *model.Message `json:"data"` // 消息数据
Seq int64 `json:"seq"` // 消息序号
CreatedAt time.Time `json:"created_at"` // 创建时间
}
// MessageCacheData Redis 中存储的消息 Hash 结构
type MessageCacheData struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments json.RawMessage `json:"segments"`
ReplyToID *string `json:"reply_to_id,omitempty"`
Status string `json:"status"`
Category string `json:"category"`
SystemType string `json:"system_type,omitempty"`
ExtraData json.RawMessage `json:"extra_data,omitempty"`
MentionUsers string `json:"mention_users"`
MentionAll bool `json:"mention_all"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PageCache 分页缓存
type PageCache struct {
Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表
Total int64 `json:"total"` // 消息总数
Page int `json:"page"` // 当前页码
PageSize int `json:"page_size"` // 每页大小
HasMore bool `json:"has_more"` // 是否有更多
UpdatedAt time.Time `json:"updated_at"` // 更新时间
}
// ConversationCacheSettings 缓存配置
type ConversationCacheSettings struct {
DetailTTL time.Duration // 会话详情 TTL (5min)
ListTTL time.Duration // 会话列表 TTL (60s)
ParticipantTTL time.Duration // 参与者 TTL (5min)
UnreadTTL time.Duration // 未读数 TTL (30s)
// 消息缓存配置
MessageDetailTTL time.Duration // 单条消息详情缓存 (30min)
MessageListTTL time.Duration // 消息分页列表缓存 (5min)
MessageIndexTTL time.Duration // 消息索引缓存 (30min)
MessageCountTTL time.Duration // 消息计数缓存 (30min)
}
// DefaultConversationCacheSettings 返回默认配置
func DefaultConversationCacheSettings() *ConversationCacheSettings {
return &ConversationCacheSettings{
DetailTTL: 5 * time.Minute,
ListTTL: 60 * time.Second,
ParticipantTTL: 5 * time.Minute,
UnreadTTL: 30 * time.Second,
MessageDetailTTL: 30 * time.Minute,
MessageListTTL: 5 * time.Minute,
MessageIndexTTL: 30 * time.Minute,
MessageCountTTL: 30 * time.Minute,
}
}
// parseSegments 将 JSON RawMessage 解析为 MessageSegments
func parseSegments(data json.RawMessage) model.MessageSegments {
if data == nil {
return nil
}
var segments model.MessageSegments
if err := json.Unmarshal(data, &segments); err != nil {
return nil
}
return segments
}
// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage
func serializeSegments(segments model.MessageSegments) json.RawMessage {
if segments == nil {
return nil
}
data, err := json.Marshal(segments)
if err != nil {
return nil
}
return data
}
// ToModel 将 MessageCacheData 转换为 model.Message
func (m *MessageCacheData) ToModel() *model.Message {
return &model.Message{
ID: m.ID,
ConversationID: m.ConversationID,
SenderID: m.SenderID,
Seq: m.Seq,
Segments: parseSegments(m.Segments),
ReplyToID: m.ReplyToID,
Status: model.MessageStatus(m.Status),
Category: model.MessageCategory(m.Category),
SystemType: model.SystemMessageType(m.SystemType),
ExtraData: parseExtraData(m.ExtraData),
MentionUsers: m.MentionUsers,
MentionAll: m.MentionAll,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData
func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData {
return &MessageCacheData{
ID: msg.ID,
ConversationID: msg.ConversationID,
SenderID: msg.SenderID,
Seq: msg.Seq,
Segments: serializeSegments(msg.Segments),
ReplyToID: msg.ReplyToID,
Status: string(msg.Status),
Category: string(msg.Category),
SystemType: string(msg.SystemType),
ExtraData: serializeExtraData(msg.ExtraData),
MentionUsers: msg.MentionUsers,
MentionAll: msg.MentionAll,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
// parseExtraData 将 JSON RawMessage 解析为 ExtraData
func parseExtraData(data json.RawMessage) *model.ExtraData {
if data == nil {
return nil
}
var extraData model.ExtraData
if err := json.Unmarshal(data, &extraData); err != nil {
return nil
}
return &extraData
}
// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage
func serializeExtraData(extraData *model.ExtraData) json.RawMessage {
if extraData == nil {
return nil
}
data, err := json.Marshal(extraData)
if err != nil {
return nil
}
return data
}
// ============================================================
// 缓存 Key 常量和生成函数
// ============================================================
const (
keyPrefixConv = "conv" // 会话详情
keyPrefixConvPart = "conv_part" // 参与者列表
keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息
)
// ConversationKey 会话详情缓存 key
func ConversationKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConv, convID)
}
// ParticipantListKey 参与者列表缓存 key
func ParticipantListKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID)
}
// ParticipantKey 用户参与者信息缓存 key
func ParticipantKey(convID, userID string) string {
return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID)
}
// ============================================================
// ConversationRepository 接口定义
// ============================================================
// ConversationRepository 会话数据仓库接口(用于依赖注入)
type ConversationRepository interface {
GetConversationByID(convID string) (*model.Conversation, error)
GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetParticipant(convID, userID string) (*model.ConversationParticipant, error)
GetParticipants(convID string) ([]*model.ConversationParticipant, error)
GetUnreadCount(userID, convID string) (int64, error)
}
// MessageRepository 消息数据仓库接口(用于依赖注入)
type MessageRepository interface {
GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error)
}
// ============================================================
// ConversationCache 核心实现
// ============================================================
// ConversationCache 会话缓存管理器
type ConversationCache struct {
cache Cache // 底层缓存
settings *ConversationCacheSettings // 配置
repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源)
msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源)
}
// NewConversationCache 创建会话缓存管理器
func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache {
if settings == nil {
settings = DefaultConversationCacheSettings()
}
return &ConversationCache{
cache: cache,
settings: settings,
repo: repo,
msgRepo: msgRepo,
}
}
// GetConversation 读取会话(带 TTL 滑动延长)
// 1. 尝试从缓存获取
// 2. 如果命中,更新 AccessAt 并延长 TTL
// 3. 如果未命中,从 repo 加载并写入缓存
func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) {
key := ConversationKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversation](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.DetailTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
conv, err := c.repo.GetConversationByID(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedConv := &CachedConversation{
Data: conv,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedConv, c.settings.DetailTTL)
return conv, nil
}
// CachedConversationList 带元数据的会话列表缓存
type CachedConversationList struct {
Conversations []*model.Conversation `json:"conversations"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetConversationList 获取用户会话列表(带 TTL 滑动延长)
func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
key := ConversationListKey(userID, page, pageSize)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversationList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ListTTL)
return cached.Conversations, cached.Total, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, 0, fmt.Errorf("repository not configured")
}
convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedConversationList{
Conversations: convs,
Total: total,
Page: page,
PageSize: pageSize,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ListTTL)
return convs, total, nil
}
// GetParticipant 获取参与者信息(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) {
key := ParticipantKey(convID, userID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipant](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participant, err := c.repo.GetParticipant(convID, userID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedPart := &CachedParticipant{
Data: participant,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedPart, c.settings.ParticipantTTL)
return participant, nil
}
// CachedParticipantList 带元数据的参与者列表缓存
type CachedParticipantList struct {
Participants []*model.ConversationParticipant `json:"participants"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetParticipants 获取会话所有参与者(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) {
key := ParticipantListKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipantList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Participants, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participants, err := c.repo.GetParticipants(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedParticipantList{
Participants: participants,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ParticipantTTL)
return participants, nil
}
// CachedUnreadCount 带元数据的未读数缓存
type CachedUnreadCount struct {
Count int64 `json:"count"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetUnreadCount 获取未读数(带 TTL 滑动延长)
func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) {
key := UnreadDetailKey(userID, convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedUnreadCount](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.UnreadTTL)
return cached.Count, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return 0, fmt.Errorf("repository not configured")
}
count, err := c.repo.GetUnreadCount(userID, convID)
if err != nil {
return 0, err
}
// 写入缓存
now := time.Now()
cachedCount := &CachedUnreadCount{
Count: count,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedCount, c.settings.UnreadTTL)
return count, nil
}
// ============================================================
// 缓存失效方法
// ============================================================
// InvalidateConversation 使会话缓存失效
func (c *ConversationCache) InvalidateConversation(convID string) {
c.cache.Delete(ConversationKey(convID))
}
// InvalidateConversationList 使会话列表缓存失效(按用户)
func (c *ConversationCache) InvalidateConversationList(userID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID))
}
// InvalidateParticipant 使参与者缓存失效
func (c *ConversationCache) InvalidateParticipant(convID, userID string) {
c.cache.Delete(ParticipantKey(convID, userID))
}
// InvalidateParticipantList 使参与者列表缓存失效
func (c *ConversationCache) InvalidateParticipantList(convID string) {
c.cache.Delete(ParticipantListKey(convID))
}
// InvalidateUnreadCount 使未读数缓存失效
func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) {
c.cache.Delete(UnreadDetailKey(userID, convID))
}
// ============================================================
// 消息缓存方法
// ============================================================
// GetMessages 获取会话消息(带缓存)
// 1. 尝试从分页缓存获取
// 2. 如果命中,从 Hash 中批量获取消息详情
// 3. 如果未命中,从数据库加载并写入缓存
func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) {
// 1. 尝试从缓存获取分页数据
pageKey := MessagePageKey(convID, page, pageSize)
cached, ok := GetTyped[*PageCache](c.cache, pageKey)
if ok && cached != nil {
// TTL 滑动延长
cached.UpdatedAt = time.Now()
c.cache.Set(pageKey, cached, c.settings.MessageListTTL)
// 从 Hash 中批量获取消息详情
if len(cached.Seqs) > 0 {
messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs)
if err == nil {
return messages, cached.Total, nil
}
// Hash 获取失败,继续从数据库加载
}
}
// 2. 缓存未命中,从数据库加载
if c.msgRepo == nil {
return nil, 0, fmt.Errorf("message repository not configured")
}
messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 3. 写入缓存
seqs := make([]int64, len(messages))
for i, msg := range messages {
seqs[i] = msg.Seq
// 异步写入消息详情到 Hash
go c.asyncCacheMessage(context.Background(), convID, msg)
}
pageCache := &PageCache{
Seqs: seqs,
Total: total,
Page: page,
PageSize: pageSize,
HasMore: int64(page*pageSize) < total,
UpdatedAt: time.Now(),
}
c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL)
return messages, total, nil
}
// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步)
// 使用 Sorted Set 的 ZRangeByScore 实现
func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表
members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载)
// 使用 Sorted Set 的 ZRevRangeByScore 实现
func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表(降序)
members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// CacheMessage 缓存单条消息(立即写入缓存)
// 写入 Hash、Sorted Set、更新计数
func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error {
hashKey := MessageHashKey(convID)
indexKey := MessageIndexKey(convID)
msgData := MessageCacheDataFromModel(msg)
data, err := json.Marshal(msgData)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// HSET 消息详情
if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil {
return fmt.Errorf("failed to set hash: %w", err)
}
// ZADD 消息索引
if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil {
return fmt.Errorf("failed to add to sorted set: %w", err)
}
// 设置 TTL
c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL)
c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL)
// INCR 消息计数
c.cache.Incr(ctx, MessageCountKey(convID))
return nil
}
// InvalidateMessageCache 使消息缓存失效
func (c *ConversationCache) InvalidateMessageCache(convID string) {
c.cache.Delete(MessageHashKey(convID))
c.cache.Delete(MessageIndexKey(convID))
c.cache.Delete(MessageCountKey(convID))
// 删除所有分页缓存
c.InvalidateMessagePages(convID)
}
// InvalidateMessagePages 仅使消息分页缓存失效
// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。
func (c *ConversationCache) InvalidateMessagePages(convID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID))
}
// ============================================================
// 内部辅助方法
// ============================================================
// getMessagesBySeqs 从 Hash 中批量获取消息
func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) {
if len(seqs) == 0 {
return nil, nil
}
hashKey := MessageHashKey(convID)
fields := make([]string, len(seqs))
for i, seq := range seqs {
fields[i] = fmt.Sprintf("%d", seq)
}
// 批量获取
values, err := c.cache.HMGet(ctx, hashKey, fields...)
if err != nil {
return nil, err
}
messages := make([]*model.Message, 0, len(seqs))
for _, val := range values {
if val == nil {
continue
}
var msgData MessageCacheData
switch v := val.(type) {
case string:
if err := json.Unmarshal([]byte(v), &msgData); err != nil {
continue
}
case []byte:
if err := json.Unmarshal(v, &msgData); err != nil {
continue
}
default:
continue
}
messages = append(messages, msgData.ToModel())
}
return messages, nil
}
// asyncCacheMessage 异步缓存单条消息
func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) {
if err := c.CacheMessage(ctx, convID, msg); err != nil {
log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err)
}
}