feat(schedule): add course table screens and navigation

Add complete schedule functionality including:
- Schedule screen with weekly course table view
- Course detail screen with transparent modal presentation
- New ScheduleStack navigator integrated into main tab bar
- Schedule service for API interactions
- Type definitions for course entities

Also includes bug fixes for group invite/request handlers
to include required groupId parameter.
This commit is contained in:
2026-03-12 08:38:14 +08:00
parent 21293644b8
commit 0a0cbacbcc
25 changed files with 3050 additions and 260 deletions

View File

@@ -5,7 +5,10 @@ import (
"encoding/json"
"fmt"
"log"
"math"
"math/rand"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -34,6 +37,38 @@ type Cache interface {
Increment(key string) int64
// IncrementBy 增加指定值
IncrementBy(key string, value int64) int64
// ==================== Hash 操作 ====================
// HSet 设置 Hash 字段
HSet(ctx context.Context, key string, field string, value interface{}) error
// HMSet 批量设置 Hash 字段
HMSet(ctx context.Context, key string, values map[string]interface{}) error
// HGet 获取 Hash 字段值
HGet(ctx context.Context, key string, field string) (string, error)
// HMGet 批量获取 Hash 字段值
HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error)
// HGetAll 获取 Hash 所有字段
HGetAll(ctx context.Context, key string) (map[string]string, error)
// HDel 删除 Hash 字段
HDel(ctx context.Context, key string, fields ...string) error
// ==================== Sorted Set 操作 ====================
// ZAdd 添加 Sorted Set 成员
ZAdd(ctx context.Context, key string, score float64, member string) error
// ZRangeByScore 按分数范围获取成员(升序)
ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error)
// ZRevRangeByScore 按分数范围获取成员(降序)
ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error)
// ZRem 删除 Sorted Set 成员
ZRem(ctx context.Context, key string, members ...interface{}) error
// ZCard 获取 Sorted Set 成员数量
ZCard(ctx context.Context, key string) (int64, error)
// ==================== 计数器操作 ====================
// Incr 原子递增(返回新值)
Incr(ctx context.Context, key string) (int64, error)
// Expire 设置过期时间
Expire(ctx context.Context, key string, ttl time.Duration) error
}
// cacheItem 缓存项(用于内存缓存降级)
@@ -64,16 +99,16 @@ type MetricsSnapshot struct {
}
type Settings struct {
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
}
var settings = Settings{
@@ -327,6 +362,378 @@ func (c *MemoryCache) Stop() {
close(c.stopCleanup)
}
// ==================== MemoryCache Hash 操作 ====================
// hashItem Hash 存储项
type hashItem struct {
fields sync.Map
}
// HSet 设置 Hash 字段
func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
key = normalizeKey(key)
item, _ := c.items.Load(key)
var h *hashItem
if item == nil {
h = &hashItem{}
c.items.Store(key, &cacheItem{value: h, expiration: 0})
} else {
ci := item.(*cacheItem)
if ci.isExpired() {
h = &hashItem{}
c.items.Store(key, &cacheItem{value: h, expiration: 0})
} else {
h = ci.value.(*hashItem)
}
}
h.fields.Store(field, value)
return nil
}
// HMSet 批量设置 Hash 字段
func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
for field, value := range values {
if err := c.HSet(ctx, key, field, value); err != nil {
return err
}
}
return nil
}
// HGet 获取 Hash 字段值
func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return "", fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return "", fmt.Errorf("key not found")
}
h, ok := ci.value.(*hashItem)
if !ok {
return "", fmt.Errorf("key is not a hash")
}
val, ok := h.fields.Load(field)
if !ok {
return "", fmt.Errorf("field not found")
}
switch v := val.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
default:
data, _ := json.Marshal(v)
return string(data), nil
}
}
// HMGet 批量获取 Hash 字段值
func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
result := make([]interface{}, len(fields))
for i, field := range fields {
val, err := c.HGet(ctx, key, field)
if err != nil {
result[i] = nil
} else {
result[i] = val
}
}
return result, nil
}
// HGetAll 获取 Hash 所有字段
func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, fmt.Errorf("key not found")
}
h, ok := ci.value.(*hashItem)
if !ok {
return nil, fmt.Errorf("key is not a hash")
}
result := make(map[string]string)
h.fields.Range(func(k, v interface{}) bool {
keyStr := k.(string)
switch val := v.(type) {
case string:
result[keyStr] = val
case []byte:
result[keyStr] = string(val)
default:
data, _ := json.Marshal(val)
result[keyStr] = string(data)
}
return true
})
return result, nil
}
// HDel 删除 Hash 字段
func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil
}
h, ok := ci.value.(*hashItem)
if !ok {
return nil
}
for _, field := range fields {
h.fields.Delete(field)
}
return nil
}
// ==================== MemoryCache Sorted Set 操作 ====================
// zItem Sorted Set 成员
type zItem struct {
score float64
member string
}
// zsetItem Sorted Set 存储项
type zsetItem struct {
members sync.Map // member -> *zItem
byScore *sortedSlice // 按分数排序的切片
}
// sortedSlice 简单的排序切片实现
type sortedSlice struct {
items []*zItem
mu sync.RWMutex
}
// ZAdd 添加 Sorted Set 成员
func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
key = normalizeKey(key)
item, _ := c.items.Load(key)
var z *zsetItem
if item == nil {
z = &zsetItem{byScore: &sortedSlice{}}
c.items.Store(key, &cacheItem{value: z, expiration: 0})
} else {
ci := item.(*cacheItem)
if ci.isExpired() {
z = &zsetItem{byScore: &sortedSlice{}}
c.items.Store(key, &cacheItem{value: z, expiration: 0})
} else {
z = ci.value.(*zsetItem)
}
}
z.members.Store(member, &zItem{score: score, member: member})
z.byScore.mu.Lock()
// 简单实现:重新构建排序切片
z.byScore.items = nil
z.members.Range(func(k, v interface{}) bool {
z.byScore.items = append(z.byScore.items, v.(*zItem))
return true
})
// 按分数排序
sort.Slice(z.byScore.items, func(i, j int) bool {
return z.byScore.items[i].score < z.byScore.items[j].score
})
z.byScore.mu.Unlock()
return nil
}
// ZRangeByScore 按分数范围获取成员(升序)
func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil, fmt.Errorf("key is not a sorted set")
}
minScore, _ := strconv.ParseFloat(min, 64)
maxScore, _ := strconv.ParseFloat(max, 64)
if min == "-inf" {
minScore = math.Inf(-1)
}
if max == "+inf" {
maxScore = math.Inf(1)
}
z.byScore.mu.RLock()
defer z.byScore.mu.RUnlock()
var result []string
var skipped int64 = 0
for _, item := range z.byScore.items {
if item.score < minScore || item.score > maxScore {
continue
}
if skipped < offset {
skipped++
continue
}
if count > 0 && int64(len(result)) >= count {
break
}
result = append(result, item.member)
}
return result, nil
}
// ZRevRangeByScore 按分数范围获取成员(降序)
func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil, fmt.Errorf("key is not a sorted set")
}
minScore, _ := strconv.ParseFloat(min, 64)
maxScore, _ := strconv.ParseFloat(max, 64)
if min == "-inf" {
minScore = math.Inf(-1)
}
if max == "+inf" {
maxScore = math.Inf(1)
}
z.byScore.mu.RLock()
defer z.byScore.mu.RUnlock()
var result []string
var skipped int64 = 0
// 从后往前遍历
for i := len(z.byScore.items) - 1; i >= 0; i-- {
item := z.byScore.items[i]
if item.score < minScore || item.score > maxScore {
continue
}
if skipped < offset {
skipped++
continue
}
if count > 0 && int64(len(result)) >= count {
break
}
result = append(result, item.member)
}
return result, nil
}
// ZRem 删除 Sorted Set 成员
func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil
}
for _, m := range members {
if member, ok := m.(string); ok {
z.members.Delete(member)
}
}
// 重建排序切片
z.byScore.mu.Lock()
z.byScore.items = nil
z.members.Range(func(k, v interface{}) bool {
z.byScore.items = append(z.byScore.items, v.(*zItem))
return true
})
sort.Slice(z.byScore.items, func(i, j int) bool {
return z.byScore.items[i].score < z.byScore.items[j].score
})
z.byScore.mu.Unlock()
return nil
}
// ZCard 获取 Sorted Set 成员数量
func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return 0, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return 0, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return 0, fmt.Errorf("key is not a sorted set")
}
var count int64 = 0
z.members.Range(func(k, v interface{}) bool {
count++
return true
})
return count, nil
}
// ==================== MemoryCache 计数器操作 ====================
// Incr 原子递增(返回新值)
func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) {
return c.IncrementBy(key, 1), nil
}
// Expire 设置过期时间
func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items.Store(key, &cacheItem{
value: ci.value,
expiration: expiration,
})
return nil
}
// RedisCache Redis缓存实现
type RedisCache struct {
client *redisPkg.Client
@@ -451,6 +858,91 @@ func (c *RedisCache) IncrementBy(key string, value int64) int64 {
return result
}
// ==================== RedisCache Hash 操作 ====================
// HSet 设置 Hash 字段
func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
key = normalizeKey(key)
return c.client.HSet(ctx, key, field, value)
}
// HMSet 批量设置 Hash 字段
func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
key = normalizeKey(key)
return c.client.HMSet(ctx, key, values)
}
// HGet 获取 Hash 字段值
func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
key = normalizeKey(key)
return c.client.HGet(ctx, key, field)
}
// HMGet 批量获取 Hash 字段值
func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
key = normalizeKey(key)
return c.client.HMGet(ctx, key, fields...)
}
// HGetAll 获取 Hash 所有字段
func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
key = normalizeKey(key)
return c.client.HGetAll(ctx, key)
}
// HDel 删除 Hash 字段
func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
key = normalizeKey(key)
return c.client.HDel(ctx, key, fields...)
}
// ==================== RedisCache Sorted Set 操作 ====================
// ZAdd 添加 Sorted Set 成员
func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
key = normalizeKey(key)
return c.client.ZAdd(ctx, key, score, member)
}
// ZRangeByScore 按分数范围获取成员(升序)
func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
return c.client.ZRangeByScore(ctx, key, min, max, offset, count)
}
// ZRevRangeByScore 按分数范围获取成员(降序)
func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count)
}
// ZRem 删除 Sorted Set 成员
func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
key = normalizeKey(key)
return c.client.ZRem(ctx, key, members...)
}
// ZCard 获取 Sorted Set 成员数量
func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
return c.client.ZCard(ctx, key)
}
// ==================== RedisCache 计数器操作 ====================
// Incr 原子递增(返回新值)
func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
return c.client.Incr(ctx, key)
}
// Expire 设置过期时间
func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
key = normalizeKey(key)
_, err := c.client.Expire(ctx, key, ttl)
return err
}
// 全局缓存实例
var globalCache Cache
var once sync.Once

724
internal/cache/conversation_cache.go vendored Normal file
View File

@@ -0,0 +1,724 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
)
// CachedConversation 带缓存元数据的会话
type CachedConversation struct {
Data *model.Conversation // 实际数据
Version int64 // 版本号CAS 更新用)
UpdatedAt time.Time // 最后更新时间
AccessAt time.Time // 最后访问时间(用于 TTL 延长)
}
// CachedParticipant 带缓存元数据的参与者
type CachedParticipant struct {
Data *model.ConversationParticipant
Version int64
UpdatedAt time.Time
AccessAt time.Time
}
// CachedMessage 带缓存元数据的消息
type CachedMessage struct {
Data *model.Message `json:"data"` // 消息数据
Seq int64 `json:"seq"` // 消息序号
CreatedAt time.Time `json:"created_at"` // 创建时间
}
// MessageCacheData Redis 中存储的消息 Hash 结构
type MessageCacheData struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments json.RawMessage `json:"segments"`
ReplyToID *string `json:"reply_to_id,omitempty"`
Status string `json:"status"`
Category string `json:"category"`
SystemType string `json:"system_type,omitempty"`
ExtraData json.RawMessage `json:"extra_data,omitempty"`
MentionUsers string `json:"mention_users"`
MentionAll bool `json:"mention_all"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PageCache 分页缓存
type PageCache struct {
Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表
Total int64 `json:"total"` // 消息总数
Page int `json:"page"` // 当前页码
PageSize int `json:"page_size"` // 每页大小
HasMore bool `json:"has_more"` // 是否有更多
UpdatedAt time.Time `json:"updated_at"` // 更新时间
}
// ConversationCacheSettings 缓存配置
type ConversationCacheSettings struct {
DetailTTL time.Duration // 会话详情 TTL (5min)
ListTTL time.Duration // 会话列表 TTL (60s)
ParticipantTTL time.Duration // 参与者 TTL (5min)
UnreadTTL time.Duration // 未读数 TTL (30s)
// 消息缓存配置
MessageDetailTTL time.Duration // 单条消息详情缓存 (30min)
MessageListTTL time.Duration // 消息分页列表缓存 (5min)
MessageIndexTTL time.Duration // 消息索引缓存 (30min)
MessageCountTTL time.Duration // 消息计数缓存 (30min)
}
// DefaultConversationCacheSettings 返回默认配置
func DefaultConversationCacheSettings() *ConversationCacheSettings {
return &ConversationCacheSettings{
DetailTTL: 5 * time.Minute,
ListTTL: 60 * time.Second,
ParticipantTTL: 5 * time.Minute,
UnreadTTL: 30 * time.Second,
MessageDetailTTL: 30 * time.Minute,
MessageListTTL: 5 * time.Minute,
MessageIndexTTL: 30 * time.Minute,
MessageCountTTL: 30 * time.Minute,
}
}
// parseSegments 将 JSON RawMessage 解析为 MessageSegments
func parseSegments(data json.RawMessage) model.MessageSegments {
if data == nil {
return nil
}
var segments model.MessageSegments
if err := json.Unmarshal(data, &segments); err != nil {
return nil
}
return segments
}
// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage
func serializeSegments(segments model.MessageSegments) json.RawMessage {
if segments == nil {
return nil
}
data, err := json.Marshal(segments)
if err != nil {
return nil
}
return data
}
// ToModel 将 MessageCacheData 转换为 model.Message
func (m *MessageCacheData) ToModel() *model.Message {
return &model.Message{
ID: m.ID,
ConversationID: m.ConversationID,
SenderID: m.SenderID,
Seq: m.Seq,
Segments: parseSegments(m.Segments),
ReplyToID: m.ReplyToID,
Status: model.MessageStatus(m.Status),
Category: model.MessageCategory(m.Category),
SystemType: model.SystemMessageType(m.SystemType),
ExtraData: parseExtraData(m.ExtraData),
MentionUsers: m.MentionUsers,
MentionAll: m.MentionAll,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData
func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData {
return &MessageCacheData{
ID: msg.ID,
ConversationID: msg.ConversationID,
SenderID: msg.SenderID,
Seq: msg.Seq,
Segments: serializeSegments(msg.Segments),
ReplyToID: msg.ReplyToID,
Status: string(msg.Status),
Category: string(msg.Category),
SystemType: string(msg.SystemType),
ExtraData: serializeExtraData(msg.ExtraData),
MentionUsers: msg.MentionUsers,
MentionAll: msg.MentionAll,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
// parseExtraData 将 JSON RawMessage 解析为 ExtraData
func parseExtraData(data json.RawMessage) *model.ExtraData {
if data == nil {
return nil
}
var extraData model.ExtraData
if err := json.Unmarshal(data, &extraData); err != nil {
return nil
}
return &extraData
}
// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage
func serializeExtraData(extraData *model.ExtraData) json.RawMessage {
if extraData == nil {
return nil
}
data, err := json.Marshal(extraData)
if err != nil {
return nil
}
return data
}
// ============================================================
// 缓存 Key 常量和生成函数
// ============================================================
const (
keyPrefixConv = "conv" // 会话详情
keyPrefixConvPart = "conv_part" // 参与者列表
keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息
)
// ConversationKey 会话详情缓存 key
func ConversationKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConv, convID)
}
// ParticipantListKey 参与者列表缓存 key
func ParticipantListKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID)
}
// ParticipantKey 用户参与者信息缓存 key
func ParticipantKey(convID, userID string) string {
return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID)
}
// ============================================================
// ConversationRepository 接口定义
// ============================================================
// ConversationRepository 会话数据仓库接口(用于依赖注入)
type ConversationRepository interface {
GetConversationByID(convID string) (*model.Conversation, error)
GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetParticipant(convID, userID string) (*model.ConversationParticipant, error)
GetParticipants(convID string) ([]*model.ConversationParticipant, error)
GetUnreadCount(userID, convID string) (int64, error)
}
// MessageRepository 消息数据仓库接口(用于依赖注入)
type MessageRepository interface {
GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error)
}
// ============================================================
// ConversationCache 核心实现
// ============================================================
// ConversationCache 会话缓存管理器
type ConversationCache struct {
cache Cache // 底层缓存
settings *ConversationCacheSettings // 配置
repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源)
msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源)
}
// NewConversationCache 创建会话缓存管理器
func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache {
if settings == nil {
settings = DefaultConversationCacheSettings()
}
return &ConversationCache{
cache: cache,
settings: settings,
repo: repo,
msgRepo: msgRepo,
}
}
// GetConversation 读取会话(带 TTL 滑动延长)
// 1. 尝试从缓存获取
// 2. 如果命中,更新 AccessAt 并延长 TTL
// 3. 如果未命中,从 repo 加载并写入缓存
func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) {
key := ConversationKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversation](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.DetailTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
conv, err := c.repo.GetConversationByID(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedConv := &CachedConversation{
Data: conv,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedConv, c.settings.DetailTTL)
return conv, nil
}
// CachedConversationList 带元数据的会话列表缓存
type CachedConversationList struct {
Conversations []*model.Conversation `json:"conversations"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetConversationList 获取用户会话列表(带 TTL 滑动延长)
func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
key := ConversationListKey(userID, page, pageSize)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversationList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ListTTL)
return cached.Conversations, cached.Total, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, 0, fmt.Errorf("repository not configured")
}
convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedConversationList{
Conversations: convs,
Total: total,
Page: page,
PageSize: pageSize,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ListTTL)
return convs, total, nil
}
// GetParticipant 获取参与者信息(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) {
key := ParticipantKey(convID, userID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipant](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participant, err := c.repo.GetParticipant(convID, userID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedPart := &CachedParticipant{
Data: participant,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedPart, c.settings.ParticipantTTL)
return participant, nil
}
// CachedParticipantList 带元数据的参与者列表缓存
type CachedParticipantList struct {
Participants []*model.ConversationParticipant `json:"participants"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetParticipants 获取会话所有参与者(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) {
key := ParticipantListKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipantList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Participants, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participants, err := c.repo.GetParticipants(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedParticipantList{
Participants: participants,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ParticipantTTL)
return participants, nil
}
// CachedUnreadCount 带元数据的未读数缓存
type CachedUnreadCount struct {
Count int64 `json:"count"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetUnreadCount 获取未读数(带 TTL 滑动延长)
func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) {
key := UnreadDetailKey(userID, convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedUnreadCount](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.UnreadTTL)
return cached.Count, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return 0, fmt.Errorf("repository not configured")
}
count, err := c.repo.GetUnreadCount(userID, convID)
if err != nil {
return 0, err
}
// 写入缓存
now := time.Now()
cachedCount := &CachedUnreadCount{
Count: count,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedCount, c.settings.UnreadTTL)
return count, nil
}
// ============================================================
// 缓存失效方法
// ============================================================
// InvalidateConversation 使会话缓存失效
func (c *ConversationCache) InvalidateConversation(convID string) {
c.cache.Delete(ConversationKey(convID))
}
// InvalidateConversationList 使会话列表缓存失效(按用户)
func (c *ConversationCache) InvalidateConversationList(userID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID))
}
// InvalidateParticipant 使参与者缓存失效
func (c *ConversationCache) InvalidateParticipant(convID, userID string) {
c.cache.Delete(ParticipantKey(convID, userID))
}
// InvalidateParticipantList 使参与者列表缓存失效
func (c *ConversationCache) InvalidateParticipantList(convID string) {
c.cache.Delete(ParticipantListKey(convID))
}
// InvalidateUnreadCount 使未读数缓存失效
func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) {
c.cache.Delete(UnreadDetailKey(userID, convID))
}
// ============================================================
// 消息缓存方法
// ============================================================
// GetMessages 获取会话消息(带缓存)
// 1. 尝试从分页缓存获取
// 2. 如果命中,从 Hash 中批量获取消息详情
// 3. 如果未命中,从数据库加载并写入缓存
func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) {
// 1. 尝试从缓存获取分页数据
pageKey := MessagePageKey(convID, page, pageSize)
cached, ok := GetTyped[*PageCache](c.cache, pageKey)
if ok && cached != nil {
// TTL 滑动延长
cached.UpdatedAt = time.Now()
c.cache.Set(pageKey, cached, c.settings.MessageListTTL)
// 从 Hash 中批量获取消息详情
if len(cached.Seqs) > 0 {
messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs)
if err == nil {
return messages, cached.Total, nil
}
// Hash 获取失败,继续从数据库加载
}
}
// 2. 缓存未命中,从数据库加载
if c.msgRepo == nil {
return nil, 0, fmt.Errorf("message repository not configured")
}
messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 3. 写入缓存
seqs := make([]int64, len(messages))
for i, msg := range messages {
seqs[i] = msg.Seq
// 异步写入消息详情到 Hash
go c.asyncCacheMessage(context.Background(), convID, msg)
}
pageCache := &PageCache{
Seqs: seqs,
Total: total,
Page: page,
PageSize: pageSize,
HasMore: int64(page*pageSize) < total,
UpdatedAt: time.Now(),
}
c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL)
return messages, total, nil
}
// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步)
// 使用 Sorted Set 的 ZRangeByScore 实现
func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表
members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载)
// 使用 Sorted Set 的 ZRevRangeByScore 实现
func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表(降序)
members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// CacheMessage 缓存单条消息(立即写入缓存)
// 写入 Hash、Sorted Set、更新计数
func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error {
hashKey := MessageHashKey(convID)
indexKey := MessageIndexKey(convID)
msgData := MessageCacheDataFromModel(msg)
data, err := json.Marshal(msgData)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// HSET 消息详情
if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil {
return fmt.Errorf("failed to set hash: %w", err)
}
// ZADD 消息索引
if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil {
return fmt.Errorf("failed to add to sorted set: %w", err)
}
// 设置 TTL
c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL)
c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL)
// INCR 消息计数
c.cache.Incr(ctx, MessageCountKey(convID))
return nil
}
// InvalidateMessageCache 使消息缓存失效
func (c *ConversationCache) InvalidateMessageCache(convID string) {
c.cache.Delete(MessageHashKey(convID))
c.cache.Delete(MessageIndexKey(convID))
c.cache.Delete(MessageCountKey(convID))
// 删除所有分页缓存
c.InvalidateMessagePages(convID)
}
// InvalidateMessagePages 仅使消息分页缓存失效
// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。
func (c *ConversationCache) InvalidateMessagePages(convID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID))
}
// ============================================================
// 内部辅助方法
// ============================================================
// getMessagesBySeqs 从 Hash 中批量获取消息
func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) {
if len(seqs) == 0 {
return nil, nil
}
hashKey := MessageHashKey(convID)
fields := make([]string, len(seqs))
for i, seq := range seqs {
fields[i] = fmt.Sprintf("%d", seq)
}
// 批量获取
values, err := c.cache.HMGet(ctx, hashKey, fields...)
if err != nil {
return nil, err
}
messages := make([]*model.Message, 0, len(seqs))
for _, val := range values {
if val == nil {
continue
}
var msgData MessageCacheData
switch v := val.(type) {
case string:
if err := json.Unmarshal([]byte(v), &msgData); err != nil {
continue
}
case []byte:
if err := json.Unmarshal(v, &msgData); err != nil {
continue
}
default:
continue
}
messages = append(messages, msgData.ToModel())
}
return messages, nil
}
// asyncCacheMessage 异步缓存单条消息
func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) {
if err := c.CacheMessage(ctx, convID, msg); err != nil {
log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err)
}
}

View File

@@ -26,6 +26,13 @@ const (
// 用户相关
PrefixUserInfo = "users:info"
PrefixUserMe = "users:me"
// 消息缓存相关
keyPrefixMsgHash = "msg_hash" // 消息详情 Hash
keyPrefixMsgIndex = "msg_index" // 消息索引 Sorted Set
keyPrefixMsgCount = "msg_count" // 消息计数
keyPrefixMsgSeq = "msg_seq" // Seq 计数器
keyPrefixMsgPage = "msg_page" // 分页缓存
)
// PostListKey 生成帖子列表缓存键
@@ -145,3 +152,37 @@ func InvalidateUserInfo(cache Cache, userID string) {
cache.Delete(UserInfoKey(userID))
cache.Delete(UserMeKey(userID))
}
// ============================================================
// 消息缓存 Key 生成函数
// ============================================================
// MessageHashKey 消息详情 Hash key
func MessageHashKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgHash, convID)
}
// MessageIndexKey 消息索引 Sorted Set key
func MessageIndexKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgIndex, convID)
}
// MessageCountKey 消息计数 key
func MessageCountKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgCount, convID)
}
// MessageSeqKey Seq 计数器 key
func MessageSeqKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgSeq, convID)
}
// MessagePageKey 分页缓存 key
func MessagePageKey(convID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:%d:%d", keyPrefixMsgPage, convID, page, pageSize)
}
// InvalidateMessagePages 失效会话消息分页缓存
func InvalidateMessagePages(cache Cache, conversationID string) {
cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, conversationID))
}

76
internal/cache/repository_adapter.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// ConversationRepositoryAdapter 适配 MessageRepository 到 ConversationRepository 接口
type ConversationRepositoryAdapter struct {
repo *repository.MessageRepository
}
// NewConversationRepositoryAdapter 创建适配器
func NewConversationRepositoryAdapter(repo *repository.MessageRepository) ConversationRepository {
return &ConversationRepositoryAdapter{repo: repo}
}
// GetConversationByID 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetConversationByID(convID string) (*model.Conversation, error) {
return a.repo.GetConversation(convID)
}
// GetConversationsByUserID 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return a.repo.GetConversations(userID, page, pageSize)
}
// GetParticipant 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetParticipant(convID, userID string) (*model.ConversationParticipant, error) {
return a.repo.GetParticipant(convID, userID)
}
// GetParticipants 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetParticipants(convID string) ([]*model.ConversationParticipant, error) {
return a.repo.GetConversationParticipants(convID)
}
// GetUnreadCount 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetUnreadCount(userID, convID string) (int64, error) {
return a.repo.GetUnreadCount(convID, userID)
}
// MessageRepositoryAdapter 适配 MessageRepository 到 MessageRepository 接口
type MessageRepositoryAdapter struct {
repo *repository.MessageRepository
}
// NewMessageRepositoryAdapter 创建适配器
func NewMessageRepositoryAdapter(repo *repository.MessageRepository) MessageRepository {
return &MessageRepositoryAdapter{repo: repo}
}
// GetMessages 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) {
return a.repo.GetMessages(convID, page, pageSize)
}
// GetMessagesAfterSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) {
return a.repo.GetMessagesAfterSeq(convID, afterSeq, limit)
}
// GetMessagesBeforeSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
return a.repo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
}
// CreateMessage 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) CreateMessage(msg *model.Message) error {
return a.repo.CreateMessage(msg)
}
// UpdateConversationLastSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) UpdateConversationLastSeq(convID string, seq int64) error {
return a.repo.UpdateConversationLastSeq(convID, seq)
}