diff --git a/configs/config.yaml b/configs/config.yaml index 6105f30..0b68bc2 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -170,3 +170,23 @@ email: use_tls: true insecure_skip_verify: false timeout: 15 + +# 会话缓存配置 +conversation_cache: + # TTL 配置 + detail_ttl: 5m # 会话详情缓存时间 + list_ttl: 60s # 会话列表缓存时间 + participant_ttl: 5m # 参与者缓存时间 + unread_ttl: 30s # 未读数缓存时间 + + # 消息缓存配置 + message_detail_ttl: 30m # 单条消息详情缓存 + message_list_ttl: 5m # 消息分页列表缓存 + message_index_ttl: 30m # 消息索引缓存 + message_count_ttl: 30m # 消息计数缓存 + + # 批量写入配置 + batch_interval: 5s # 写入间隔 + batch_threshold: 100 # 条数阈值 + batch_max_size: 500 # 单次最大批量 + buffer_max_size: 10000 # 写缓冲最大条数 diff --git a/internal/cache/cache.go b/internal/cache/cache.go index cbb225d..07e1aef 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -5,7 +5,10 @@ import ( "encoding/json" "fmt" "log" + "math" "math/rand" + "sort" + "strconv" "strings" "sync" "sync/atomic" @@ -34,6 +37,38 @@ type Cache interface { Increment(key string) int64 // IncrementBy 增加指定值 IncrementBy(key string, value int64) int64 + + // ==================== Hash 操作 ==================== + // HSet 设置 Hash 字段 + HSet(ctx context.Context, key string, field string, value interface{}) error + // HMSet 批量设置 Hash 字段 + HMSet(ctx context.Context, key string, values map[string]interface{}) error + // HGet 获取 Hash 字段值 + HGet(ctx context.Context, key string, field string) (string, error) + // HMGet 批量获取 Hash 字段值 + HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) + // HGetAll 获取 Hash 所有字段 + HGetAll(ctx context.Context, key string) (map[string]string, error) + // HDel 删除 Hash 字段 + HDel(ctx context.Context, key string, fields ...string) error + + // ==================== Sorted Set 操作 ==================== + // ZAdd 添加 Sorted Set 成员 + ZAdd(ctx context.Context, key string, score float64, member string) error + // ZRangeByScore 按分数范围获取成员(升序) + ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) + // ZRevRangeByScore 按分数范围获取成员(降序) + ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) + // ZRem 删除 Sorted Set 成员 + ZRem(ctx context.Context, key string, members ...interface{}) error + // ZCard 获取 Sorted Set 成员数量 + ZCard(ctx context.Context, key string) (int64, error) + + // ==================== 计数器操作 ==================== + // Incr 原子递增(返回新值) + Incr(ctx context.Context, key string) (int64, error) + // Expire 设置过期时间 + Expire(ctx context.Context, key string, ttl time.Duration) error } // cacheItem 缓存项(用于内存缓存降级) @@ -64,16 +99,16 @@ type MetricsSnapshot struct { } type Settings struct { - Enabled bool - KeyPrefix string - DefaultTTL time.Duration - NullTTL time.Duration - JitterRatio float64 - PostListTTL time.Duration - ConversationTTL time.Duration - UnreadCountTTL time.Duration - GroupMembersTTL time.Duration - DisableFlushDB bool + Enabled bool + KeyPrefix string + DefaultTTL time.Duration + NullTTL time.Duration + JitterRatio float64 + PostListTTL time.Duration + ConversationTTL time.Duration + UnreadCountTTL time.Duration + GroupMembersTTL time.Duration + DisableFlushDB bool } var settings = Settings{ @@ -327,6 +362,378 @@ func (c *MemoryCache) Stop() { close(c.stopCleanup) } +// ==================== MemoryCache Hash 操作 ==================== + +// hashItem Hash 存储项 +type hashItem struct { + fields sync.Map +} + +// HSet 设置 Hash 字段 +func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error { + key = normalizeKey(key) + item, _ := c.items.Load(key) + var h *hashItem + if item == nil { + h = &hashItem{} + c.items.Store(key, &cacheItem{value: h, expiration: 0}) + } else { + ci := item.(*cacheItem) + if ci.isExpired() { + h = &hashItem{} + c.items.Store(key, &cacheItem{value: h, expiration: 0}) + } else { + h = ci.value.(*hashItem) + } + } + h.fields.Store(field, value) + return nil +} + +// HMSet 批量设置 Hash 字段 +func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error { + for field, value := range values { + if err := c.HSet(ctx, key, field, value); err != nil { + return err + } + } + return nil +} + +// HGet 获取 Hash 字段值 +func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return "", fmt.Errorf("key not found") + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return "", fmt.Errorf("key not found") + } + h, ok := ci.value.(*hashItem) + if !ok { + return "", fmt.Errorf("key is not a hash") + } + val, ok := h.fields.Load(field) + if !ok { + return "", fmt.Errorf("field not found") + } + switch v := val.(type) { + case string: + return v, nil + case []byte: + return string(v), nil + default: + data, _ := json.Marshal(v) + return string(data), nil + } +} + +// HMGet 批量获取 Hash 字段值 +func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) { + result := make([]interface{}, len(fields)) + for i, field := range fields { + val, err := c.HGet(ctx, key, field) + if err != nil { + result[i] = nil + } else { + result[i] = val + } + } + return result, nil +} + +// HGetAll 获取 Hash 所有字段 +func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return nil, fmt.Errorf("key not found") + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return nil, fmt.Errorf("key not found") + } + h, ok := ci.value.(*hashItem) + if !ok { + return nil, fmt.Errorf("key is not a hash") + } + result := make(map[string]string) + h.fields.Range(func(k, v interface{}) bool { + keyStr := k.(string) + switch val := v.(type) { + case string: + result[keyStr] = val + case []byte: + result[keyStr] = string(val) + default: + data, _ := json.Marshal(val) + result[keyStr] = string(data) + } + return true + }) + return result, nil +} + +// HDel 删除 Hash 字段 +func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return nil + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return nil + } + h, ok := ci.value.(*hashItem) + if !ok { + return nil + } + for _, field := range fields { + h.fields.Delete(field) + } + return nil +} + +// ==================== MemoryCache Sorted Set 操作 ==================== + +// zItem Sorted Set 成员 +type zItem struct { + score float64 + member string +} + +// zsetItem Sorted Set 存储项 +type zsetItem struct { + members sync.Map // member -> *zItem + byScore *sortedSlice // 按分数排序的切片 +} + +// sortedSlice 简单的排序切片实现 +type sortedSlice struct { + items []*zItem + mu sync.RWMutex +} + +// ZAdd 添加 Sorted Set 成员 +func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error { + key = normalizeKey(key) + item, _ := c.items.Load(key) + var z *zsetItem + if item == nil { + z = &zsetItem{byScore: &sortedSlice{}} + c.items.Store(key, &cacheItem{value: z, expiration: 0}) + } else { + ci := item.(*cacheItem) + if ci.isExpired() { + z = &zsetItem{byScore: &sortedSlice{}} + c.items.Store(key, &cacheItem{value: z, expiration: 0}) + } else { + z = ci.value.(*zsetItem) + } + } + z.members.Store(member, &zItem{score: score, member: member}) + z.byScore.mu.Lock() + // 简单实现:重新构建排序切片 + z.byScore.items = nil + z.members.Range(func(k, v interface{}) bool { + z.byScore.items = append(z.byScore.items, v.(*zItem)) + return true + }) + // 按分数排序 + sort.Slice(z.byScore.items, func(i, j int) bool { + return z.byScore.items[i].score < z.byScore.items[j].score + }) + z.byScore.mu.Unlock() + return nil +} + +// ZRangeByScore 按分数范围获取成员(升序) +func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return nil, nil + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return nil, nil + } + z, ok := ci.value.(*zsetItem) + if !ok { + return nil, fmt.Errorf("key is not a sorted set") + } + + minScore, _ := strconv.ParseFloat(min, 64) + maxScore, _ := strconv.ParseFloat(max, 64) + if min == "-inf" { + minScore = math.Inf(-1) + } + if max == "+inf" { + maxScore = math.Inf(1) + } + + z.byScore.mu.RLock() + defer z.byScore.mu.RUnlock() + + var result []string + var skipped int64 = 0 + for _, item := range z.byScore.items { + if item.score < minScore || item.score > maxScore { + continue + } + if skipped < offset { + skipped++ + continue + } + if count > 0 && int64(len(result)) >= count { + break + } + result = append(result, item.member) + } + return result, nil +} + +// ZRevRangeByScore 按分数范围获取成员(降序) +func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return nil, nil + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return nil, nil + } + z, ok := ci.value.(*zsetItem) + if !ok { + return nil, fmt.Errorf("key is not a sorted set") + } + + minScore, _ := strconv.ParseFloat(min, 64) + maxScore, _ := strconv.ParseFloat(max, 64) + if min == "-inf" { + minScore = math.Inf(-1) + } + if max == "+inf" { + maxScore = math.Inf(1) + } + + z.byScore.mu.RLock() + defer z.byScore.mu.RUnlock() + + var result []string + var skipped int64 = 0 + // 从后往前遍历 + for i := len(z.byScore.items) - 1; i >= 0; i-- { + item := z.byScore.items[i] + if item.score < minScore || item.score > maxScore { + continue + } + if skipped < offset { + skipped++ + continue + } + if count > 0 && int64(len(result)) >= count { + break + } + result = append(result, item.member) + } + return result, nil +} + +// ZRem 删除 Sorted Set 成员 +func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return nil + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return nil + } + z, ok := ci.value.(*zsetItem) + if !ok { + return nil + } + for _, m := range members { + if member, ok := m.(string); ok { + z.members.Delete(member) + } + } + // 重建排序切片 + z.byScore.mu.Lock() + z.byScore.items = nil + z.members.Range(func(k, v interface{}) bool { + z.byScore.items = append(z.byScore.items, v.(*zItem)) + return true + }) + sort.Slice(z.byScore.items, func(i, j int) bool { + return z.byScore.items[i].score < z.byScore.items[j].score + }) + z.byScore.mu.Unlock() + return nil +} + +// ZCard 获取 Sorted Set 成员数量 +func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return 0, nil + } + ci := item.(*cacheItem) + if ci.isExpired() { + c.items.Delete(key) + return 0, nil + } + z, ok := ci.value.(*zsetItem) + if !ok { + return 0, fmt.Errorf("key is not a sorted set") + } + var count int64 = 0 + z.members.Range(func(k, v interface{}) bool { + count++ + return true + }) + return count, nil +} + +// ==================== MemoryCache 计数器操作 ==================== + +// Incr 原子递增(返回新值) +func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) { + return c.IncrementBy(key, 1), nil +} + +// Expire 设置过期时间 +func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error { + key = normalizeKey(key) + item, ok := c.items.Load(key) + if !ok { + return fmt.Errorf("key not found") + } + ci := item.(*cacheItem) + var expiration int64 + if ttl > 0 { + expiration = time.Now().Add(ttl).UnixNano() + } + c.items.Store(key, &cacheItem{ + value: ci.value, + expiration: expiration, + }) + return nil +} + // RedisCache Redis缓存实现 type RedisCache struct { client *redisPkg.Client @@ -451,6 +858,91 @@ func (c *RedisCache) IncrementBy(key string, value int64) int64 { return result } +// ==================== RedisCache Hash 操作 ==================== + +// HSet 设置 Hash 字段 +func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error { + key = normalizeKey(key) + return c.client.HSet(ctx, key, field, value) +} + +// HMSet 批量设置 Hash 字段 +func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error { + key = normalizeKey(key) + return c.client.HMSet(ctx, key, values) +} + +// HGet 获取 Hash 字段值 +func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) { + key = normalizeKey(key) + return c.client.HGet(ctx, key, field) +} + +// HMGet 批量获取 Hash 字段值 +func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) { + key = normalizeKey(key) + return c.client.HMGet(ctx, key, fields...) +} + +// HGetAll 获取 Hash 所有字段 +func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) { + key = normalizeKey(key) + return c.client.HGetAll(ctx, key) +} + +// HDel 删除 Hash 字段 +func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error { + key = normalizeKey(key) + return c.client.HDel(ctx, key, fields...) +} + +// ==================== RedisCache Sorted Set 操作 ==================== + +// ZAdd 添加 Sorted Set 成员 +func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error { + key = normalizeKey(key) + return c.client.ZAdd(ctx, key, score, member) +} + +// ZRangeByScore 按分数范围获取成员(升序) +func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) { + key = normalizeKey(key) + return c.client.ZRangeByScore(ctx, key, min, max, offset, count) +} + +// ZRevRangeByScore 按分数范围获取成员(降序) +func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) { + key = normalizeKey(key) + return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count) +} + +// ZRem 删除 Sorted Set 成员 +func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error { + key = normalizeKey(key) + return c.client.ZRem(ctx, key, members...) +} + +// ZCard 获取 Sorted Set 成员数量 +func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) { + key = normalizeKey(key) + return c.client.ZCard(ctx, key) +} + +// ==================== RedisCache 计数器操作 ==================== + +// Incr 原子递增(返回新值) +func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) { + key = normalizeKey(key) + return c.client.Incr(ctx, key) +} + +// Expire 设置过期时间 +func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error { + key = normalizeKey(key) + _, err := c.client.Expire(ctx, key, ttl) + return err +} + // 全局缓存实例 var globalCache Cache var once sync.Once diff --git a/internal/cache/conversation_cache.go b/internal/cache/conversation_cache.go new file mode 100644 index 0000000..9114c7e --- /dev/null +++ b/internal/cache/conversation_cache.go @@ -0,0 +1,724 @@ +package cache + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + "carrot_bbs/internal/model" +) + +// CachedConversation 带缓存元数据的会话 +type CachedConversation struct { + Data *model.Conversation // 实际数据 + Version int64 // 版本号(CAS 更新用) + UpdatedAt time.Time // 最后更新时间 + AccessAt time.Time // 最后访问时间(用于 TTL 延长) +} + +// CachedParticipant 带缓存元数据的参与者 +type CachedParticipant struct { + Data *model.ConversationParticipant + Version int64 + UpdatedAt time.Time + AccessAt time.Time +} + +// CachedMessage 带缓存元数据的消息 +type CachedMessage struct { + Data *model.Message `json:"data"` // 消息数据 + Seq int64 `json:"seq"` // 消息序号 + CreatedAt time.Time `json:"created_at"` // 创建时间 +} + +// MessageCacheData Redis 中存储的消息 Hash 结构 +type MessageCacheData struct { + ID string `json:"id"` + ConversationID string `json:"conversation_id"` + SenderID string `json:"sender_id"` + Seq int64 `json:"seq"` + Segments json.RawMessage `json:"segments"` + ReplyToID *string `json:"reply_to_id,omitempty"` + Status string `json:"status"` + Category string `json:"category"` + SystemType string `json:"system_type,omitempty"` + ExtraData json.RawMessage `json:"extra_data,omitempty"` + MentionUsers string `json:"mention_users"` + MentionAll bool `json:"mention_all"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// PageCache 分页缓存 +type PageCache struct { + Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表 + Total int64 `json:"total"` // 消息总数 + Page int `json:"page"` // 当前页码 + PageSize int `json:"page_size"` // 每页大小 + HasMore bool `json:"has_more"` // 是否有更多 + UpdatedAt time.Time `json:"updated_at"` // 更新时间 +} + +// ConversationCacheSettings 缓存配置 +type ConversationCacheSettings struct { + DetailTTL time.Duration // 会话详情 TTL (5min) + ListTTL time.Duration // 会话列表 TTL (60s) + ParticipantTTL time.Duration // 参与者 TTL (5min) + UnreadTTL time.Duration // 未读数 TTL (30s) + + // 消息缓存配置 + MessageDetailTTL time.Duration // 单条消息详情缓存 (30min) + MessageListTTL time.Duration // 消息分页列表缓存 (5min) + MessageIndexTTL time.Duration // 消息索引缓存 (30min) + MessageCountTTL time.Duration // 消息计数缓存 (30min) +} + +// DefaultConversationCacheSettings 返回默认配置 +func DefaultConversationCacheSettings() *ConversationCacheSettings { + return &ConversationCacheSettings{ + DetailTTL: 5 * time.Minute, + ListTTL: 60 * time.Second, + ParticipantTTL: 5 * time.Minute, + UnreadTTL: 30 * time.Second, + MessageDetailTTL: 30 * time.Minute, + MessageListTTL: 5 * time.Minute, + MessageIndexTTL: 30 * time.Minute, + MessageCountTTL: 30 * time.Minute, + } +} + +// parseSegments 将 JSON RawMessage 解析为 MessageSegments +func parseSegments(data json.RawMessage) model.MessageSegments { + if data == nil { + return nil + } + var segments model.MessageSegments + if err := json.Unmarshal(data, &segments); err != nil { + return nil + } + return segments +} + +// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage +func serializeSegments(segments model.MessageSegments) json.RawMessage { + if segments == nil { + return nil + } + data, err := json.Marshal(segments) + if err != nil { + return nil + } + return data +} + +// ToModel 将 MessageCacheData 转换为 model.Message +func (m *MessageCacheData) ToModel() *model.Message { + return &model.Message{ + ID: m.ID, + ConversationID: m.ConversationID, + SenderID: m.SenderID, + Seq: m.Seq, + Segments: parseSegments(m.Segments), + ReplyToID: m.ReplyToID, + Status: model.MessageStatus(m.Status), + Category: model.MessageCategory(m.Category), + SystemType: model.SystemMessageType(m.SystemType), + ExtraData: parseExtraData(m.ExtraData), + MentionUsers: m.MentionUsers, + MentionAll: m.MentionAll, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + } +} + +// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData +func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData { + return &MessageCacheData{ + ID: msg.ID, + ConversationID: msg.ConversationID, + SenderID: msg.SenderID, + Seq: msg.Seq, + Segments: serializeSegments(msg.Segments), + ReplyToID: msg.ReplyToID, + Status: string(msg.Status), + Category: string(msg.Category), + SystemType: string(msg.SystemType), + ExtraData: serializeExtraData(msg.ExtraData), + MentionUsers: msg.MentionUsers, + MentionAll: msg.MentionAll, + CreatedAt: msg.CreatedAt, + UpdatedAt: msg.UpdatedAt, + } +} + +// parseExtraData 将 JSON RawMessage 解析为 ExtraData +func parseExtraData(data json.RawMessage) *model.ExtraData { + if data == nil { + return nil + } + var extraData model.ExtraData + if err := json.Unmarshal(data, &extraData); err != nil { + return nil + } + return &extraData +} + +// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage +func serializeExtraData(extraData *model.ExtraData) json.RawMessage { + if extraData == nil { + return nil + } + data, err := json.Marshal(extraData) + if err != nil { + return nil + } + return data +} + +// ============================================================ +// 缓存 Key 常量和生成函数 +// ============================================================ + +const ( + keyPrefixConv = "conv" // 会话详情 + keyPrefixConvPart = "conv_part" // 参与者列表 + keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息 +) + +// ConversationKey 会话详情缓存 key +func ConversationKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixConv, convID) +} + +// ParticipantListKey 参与者列表缓存 key +func ParticipantListKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID) +} + +// ParticipantKey 用户参与者信息缓存 key +func ParticipantKey(convID, userID string) string { + return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID) +} + +// ============================================================ +// ConversationRepository 接口定义 +// ============================================================ + +// ConversationRepository 会话数据仓库接口(用于依赖注入) +type ConversationRepository interface { + GetConversationByID(convID string) (*model.Conversation, error) + GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) + GetParticipant(convID, userID string) (*model.ConversationParticipant, error) + GetParticipants(convID string) ([]*model.ConversationParticipant, error) + GetUnreadCount(userID, convID string) (int64, error) +} + +// MessageRepository 消息数据仓库接口(用于依赖注入) +type MessageRepository interface { + GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) + GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) + GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) +} + +// ============================================================ +// ConversationCache 核心实现 +// ============================================================ + +// ConversationCache 会话缓存管理器 +type ConversationCache struct { + cache Cache // 底层缓存 + settings *ConversationCacheSettings // 配置 + repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源) + msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源) +} + +// NewConversationCache 创建会话缓存管理器 +func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache { + if settings == nil { + settings = DefaultConversationCacheSettings() + } + return &ConversationCache{ + cache: cache, + settings: settings, + repo: repo, + msgRepo: msgRepo, + } +} + +// GetConversation 读取会话(带 TTL 滑动延长) +// 1. 尝试从缓存获取 +// 2. 如果命中,更新 AccessAt 并延长 TTL +// 3. 如果未命中,从 repo 加载并写入缓存 +func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) { + key := ConversationKey(convID) + + // 1. 尝试从缓存获取 + cached, ok := GetTyped[*CachedConversation](c.cache, key) + if ok && cached != nil && cached.Data != nil { + // 2. 命中,更新 AccessAt 并延长 TTL + cached.AccessAt = time.Now() + c.cache.Set(key, cached, c.settings.DetailTTL) + return cached.Data, nil + } + + // 3. 未命中,从 repo 加载 + if c.repo == nil { + return nil, fmt.Errorf("repository not configured") + } + + conv, err := c.repo.GetConversationByID(convID) + if err != nil { + return nil, err + } + + // 写入缓存 + now := time.Now() + cachedConv := &CachedConversation{ + Data: conv, + Version: 0, + UpdatedAt: now, + AccessAt: now, + } + c.cache.Set(key, cachedConv, c.settings.DetailTTL) + + return conv, nil +} + +// CachedConversationList 带元数据的会话列表缓存 +type CachedConversationList struct { + Conversations []*model.Conversation `json:"conversations"` + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + UpdatedAt time.Time `json:"updated_at"` + AccessAt time.Time `json:"access_at"` +} + +// GetConversationList 获取用户会话列表(带 TTL 滑动延长) +func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + key := ConversationListKey(userID, page, pageSize) + + // 1. 尝试从缓存获取 + cached, ok := GetTyped[*CachedConversationList](c.cache, key) + if ok && cached != nil { + // 2. 命中,更新 AccessAt 并延长 TTL + cached.AccessAt = time.Now() + c.cache.Set(key, cached, c.settings.ListTTL) + return cached.Conversations, cached.Total, nil + } + + // 3. 未命中,从 repo 加载 + if c.repo == nil { + return nil, 0, fmt.Errorf("repository not configured") + } + + convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize) + if err != nil { + return nil, 0, err + } + + // 写入缓存 + now := time.Now() + cachedList := &CachedConversationList{ + Conversations: convs, + Total: total, + Page: page, + PageSize: pageSize, + UpdatedAt: now, + AccessAt: now, + } + c.cache.Set(key, cachedList, c.settings.ListTTL) + + return convs, total, nil +} + +// GetParticipant 获取参与者信息(带 TTL 滑动延长) +func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) { + key := ParticipantKey(convID, userID) + + // 1. 尝试从缓存获取 + cached, ok := GetTyped[*CachedParticipant](c.cache, key) + if ok && cached != nil && cached.Data != nil { + // 2. 命中,更新 AccessAt 并延长 TTL + cached.AccessAt = time.Now() + c.cache.Set(key, cached, c.settings.ParticipantTTL) + return cached.Data, nil + } + + // 3. 未命中,从 repo 加载 + if c.repo == nil { + return nil, fmt.Errorf("repository not configured") + } + + participant, err := c.repo.GetParticipant(convID, userID) + if err != nil { + return nil, err + } + + // 写入缓存 + now := time.Now() + cachedPart := &CachedParticipant{ + Data: participant, + Version: 0, + UpdatedAt: now, + AccessAt: now, + } + c.cache.Set(key, cachedPart, c.settings.ParticipantTTL) + + return participant, nil +} + +// CachedParticipantList 带元数据的参与者列表缓存 +type CachedParticipantList struct { + Participants []*model.ConversationParticipant `json:"participants"` + UpdatedAt time.Time `json:"updated_at"` + AccessAt time.Time `json:"access_at"` +} + +// GetParticipants 获取会话所有参与者(带 TTL 滑动延长) +func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) { + key := ParticipantListKey(convID) + + // 1. 尝试从缓存获取 + cached, ok := GetTyped[*CachedParticipantList](c.cache, key) + if ok && cached != nil { + // 2. 命中,更新 AccessAt 并延长 TTL + cached.AccessAt = time.Now() + c.cache.Set(key, cached, c.settings.ParticipantTTL) + return cached.Participants, nil + } + + // 3. 未命中,从 repo 加载 + if c.repo == nil { + return nil, fmt.Errorf("repository not configured") + } + + participants, err := c.repo.GetParticipants(convID) + if err != nil { + return nil, err + } + + // 写入缓存 + now := time.Now() + cachedList := &CachedParticipantList{ + Participants: participants, + UpdatedAt: now, + AccessAt: now, + } + c.cache.Set(key, cachedList, c.settings.ParticipantTTL) + + return participants, nil +} + +// CachedUnreadCount 带元数据的未读数缓存 +type CachedUnreadCount struct { + Count int64 `json:"count"` + UpdatedAt time.Time `json:"updated_at"` + AccessAt time.Time `json:"access_at"` +} + +// GetUnreadCount 获取未读数(带 TTL 滑动延长) +func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) { + key := UnreadDetailKey(userID, convID) + + // 1. 尝试从缓存获取 + cached, ok := GetTyped[*CachedUnreadCount](c.cache, key) + if ok && cached != nil { + // 2. 命中,更新 AccessAt 并延长 TTL + cached.AccessAt = time.Now() + c.cache.Set(key, cached, c.settings.UnreadTTL) + return cached.Count, nil + } + + // 3. 未命中,从 repo 加载 + if c.repo == nil { + return 0, fmt.Errorf("repository not configured") + } + + count, err := c.repo.GetUnreadCount(userID, convID) + if err != nil { + return 0, err + } + + // 写入缓存 + now := time.Now() + cachedCount := &CachedUnreadCount{ + Count: count, + UpdatedAt: now, + AccessAt: now, + } + c.cache.Set(key, cachedCount, c.settings.UnreadTTL) + + return count, nil +} + +// ============================================================ +// 缓存失效方法 +// ============================================================ + +// InvalidateConversation 使会话缓存失效 +func (c *ConversationCache) InvalidateConversation(convID string) { + c.cache.Delete(ConversationKey(convID)) +} + +// InvalidateConversationList 使会话列表缓存失效(按用户) +func (c *ConversationCache) InvalidateConversationList(userID string) { + c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID)) +} + +// InvalidateParticipant 使参与者缓存失效 +func (c *ConversationCache) InvalidateParticipant(convID, userID string) { + c.cache.Delete(ParticipantKey(convID, userID)) +} + +// InvalidateParticipantList 使参与者列表缓存失效 +func (c *ConversationCache) InvalidateParticipantList(convID string) { + c.cache.Delete(ParticipantListKey(convID)) +} + +// InvalidateUnreadCount 使未读数缓存失效 +func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) { + c.cache.Delete(UnreadDetailKey(userID, convID)) +} + +// ============================================================ +// 消息缓存方法 +// ============================================================ + +// GetMessages 获取会话消息(带缓存) +// 1. 尝试从分页缓存获取 +// 2. 如果命中,从 Hash 中批量获取消息详情 +// 3. 如果未命中,从数据库加载并写入缓存 +func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) { + // 1. 尝试从缓存获取分页数据 + pageKey := MessagePageKey(convID, page, pageSize) + cached, ok := GetTyped[*PageCache](c.cache, pageKey) + if ok && cached != nil { + // TTL 滑动延长 + cached.UpdatedAt = time.Now() + c.cache.Set(pageKey, cached, c.settings.MessageListTTL) + + // 从 Hash 中批量获取消息详情 + if len(cached.Seqs) > 0 { + messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs) + if err == nil { + return messages, cached.Total, nil + } + // Hash 获取失败,继续从数据库加载 + } + } + + // 2. 缓存未命中,从数据库加载 + if c.msgRepo == nil { + return nil, 0, fmt.Errorf("message repository not configured") + } + + messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize) + if err != nil { + return nil, 0, err + } + + // 3. 写入缓存 + seqs := make([]int64, len(messages)) + for i, msg := range messages { + seqs[i] = msg.Seq + // 异步写入消息详情到 Hash + go c.asyncCacheMessage(context.Background(), convID, msg) + } + + pageCache := &PageCache{ + Seqs: seqs, + Total: total, + Page: page, + PageSize: pageSize, + HasMore: int64(page*pageSize) < total, + UpdatedAt: time.Now(), + } + c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL) + + return messages, total, nil +} + +// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步) +// 使用 Sorted Set 的 ZRangeByScore 实现 +func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) { + indexKey := MessageIndexKey(convID) + + // 1. 尝试从 Sorted Set 获取 seq 列表 + members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit)) + if err != nil { + return nil, err + } + + // 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情 + if len(members) > 0 { + seqs := make([]int64, 0, len(members)) + for _, member := range members { + var seq int64 + if _, err := fmt.Sscanf(member, "%d", &seq); err == nil { + seqs = append(seqs, seq) + } + } + return c.getMessagesBySeqs(ctx, convID, seqs) + } + + // 3. Sorted Set 未命中,从数据库加载 + if c.msgRepo == nil { + return nil, fmt.Errorf("message repository not configured") + } + + messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit) + if err != nil { + return nil, err + } + + // 4. 异步写入缓存 + for _, msg := range messages { + go c.asyncCacheMessage(context.Background(), convID, msg) + } + + return messages, nil +} + +// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载) +// 使用 Sorted Set 的 ZRevRangeByScore 实现 +func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) { + indexKey := MessageIndexKey(convID) + + // 1. 尝试从 Sorted Set 获取 seq 列表(降序) + members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit)) + if err != nil { + return nil, err + } + + // 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情 + if len(members) > 0 { + seqs := make([]int64, 0, len(members)) + for _, member := range members { + var seq int64 + if _, err := fmt.Sscanf(member, "%d", &seq); err == nil { + seqs = append(seqs, seq) + } + } + return c.getMessagesBySeqs(ctx, convID, seqs) + } + + // 3. Sorted Set 未命中,从数据库加载 + if c.msgRepo == nil { + return nil, fmt.Errorf("message repository not configured") + } + + messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit) + if err != nil { + return nil, err + } + + // 4. 异步写入缓存 + for _, msg := range messages { + go c.asyncCacheMessage(context.Background(), convID, msg) + } + + return messages, nil +} + +// CacheMessage 缓存单条消息(立即写入缓存) +// 写入 Hash、Sorted Set、更新计数 +func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error { + hashKey := MessageHashKey(convID) + indexKey := MessageIndexKey(convID) + + msgData := MessageCacheDataFromModel(msg) + data, err := json.Marshal(msgData) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // HSET 消息详情 + if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil { + return fmt.Errorf("failed to set hash: %w", err) + } + + // ZADD 消息索引 + if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil { + return fmt.Errorf("failed to add to sorted set: %w", err) + } + + // 设置 TTL + c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL) + c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL) + + // INCR 消息计数 + c.cache.Incr(ctx, MessageCountKey(convID)) + + return nil +} + +// InvalidateMessageCache 使消息缓存失效 +func (c *ConversationCache) InvalidateMessageCache(convID string) { + c.cache.Delete(MessageHashKey(convID)) + c.cache.Delete(MessageIndexKey(convID)) + c.cache.Delete(MessageCountKey(convID)) + // 删除所有分页缓存 + c.InvalidateMessagePages(convID) +} + +// InvalidateMessagePages 仅使消息分页缓存失效 +// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。 +func (c *ConversationCache) InvalidateMessagePages(convID string) { + c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID)) +} + +// ============================================================ +// 内部辅助方法 +// ============================================================ + +// getMessagesBySeqs 从 Hash 中批量获取消息 +func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) { + if len(seqs) == 0 { + return nil, nil + } + + hashKey := MessageHashKey(convID) + fields := make([]string, len(seqs)) + for i, seq := range seqs { + fields[i] = fmt.Sprintf("%d", seq) + } + + // 批量获取 + values, err := c.cache.HMGet(ctx, hashKey, fields...) + if err != nil { + return nil, err + } + + messages := make([]*model.Message, 0, len(seqs)) + for _, val := range values { + if val == nil { + continue + } + var msgData MessageCacheData + switch v := val.(type) { + case string: + if err := json.Unmarshal([]byte(v), &msgData); err != nil { + continue + } + case []byte: + if err := json.Unmarshal(v, &msgData); err != nil { + continue + } + default: + continue + } + messages = append(messages, msgData.ToModel()) + } + + return messages, nil +} + +// asyncCacheMessage 异步缓存单条消息 +func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) { + if err := c.CacheMessage(ctx, convID, msg); err != nil { + log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err) + } +} diff --git a/internal/cache/keys.go b/internal/cache/keys.go index 0cfd8a0..7fc08e5 100644 --- a/internal/cache/keys.go +++ b/internal/cache/keys.go @@ -26,6 +26,13 @@ const ( // 用户相关 PrefixUserInfo = "users:info" PrefixUserMe = "users:me" + + // 消息缓存相关 + keyPrefixMsgHash = "msg_hash" // 消息详情 Hash + keyPrefixMsgIndex = "msg_index" // 消息索引 Sorted Set + keyPrefixMsgCount = "msg_count" // 消息计数 + keyPrefixMsgSeq = "msg_seq" // Seq 计数器 + keyPrefixMsgPage = "msg_page" // 分页缓存 ) // PostListKey 生成帖子列表缓存键 @@ -145,3 +152,37 @@ func InvalidateUserInfo(cache Cache, userID string) { cache.Delete(UserInfoKey(userID)) cache.Delete(UserMeKey(userID)) } + +// ============================================================ +// 消息缓存 Key 生成函数 +// ============================================================ + +// MessageHashKey 消息详情 Hash key +func MessageHashKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixMsgHash, convID) +} + +// MessageIndexKey 消息索引 Sorted Set key +func MessageIndexKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixMsgIndex, convID) +} + +// MessageCountKey 消息计数 key +func MessageCountKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixMsgCount, convID) +} + +// MessageSeqKey Seq 计数器 key +func MessageSeqKey(convID string) string { + return fmt.Sprintf("%s:%s", keyPrefixMsgSeq, convID) +} + +// MessagePageKey 分页缓存 key +func MessagePageKey(convID string, page, pageSize int) string { + return fmt.Sprintf("%s:%s:%d:%d", keyPrefixMsgPage, convID, page, pageSize) +} + +// InvalidateMessagePages 失效会话消息分页缓存 +func InvalidateMessagePages(cache Cache, conversationID string) { + cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, conversationID)) +} diff --git a/internal/cache/repository_adapter.go b/internal/cache/repository_adapter.go new file mode 100644 index 0000000..b8d94f9 --- /dev/null +++ b/internal/cache/repository_adapter.go @@ -0,0 +1,76 @@ +package cache + +import ( + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" +) + +// ConversationRepositoryAdapter 适配 MessageRepository 到 ConversationRepository 接口 +type ConversationRepositoryAdapter struct { + repo *repository.MessageRepository +} + +// NewConversationRepositoryAdapter 创建适配器 +func NewConversationRepositoryAdapter(repo *repository.MessageRepository) ConversationRepository { + return &ConversationRepositoryAdapter{repo: repo} +} + +// GetConversationByID 实现 ConversationRepository 接口 +func (a *ConversationRepositoryAdapter) GetConversationByID(convID string) (*model.Conversation, error) { + return a.repo.GetConversation(convID) +} + +// GetConversationsByUserID 实现 ConversationRepository 接口 +func (a *ConversationRepositoryAdapter) GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + return a.repo.GetConversations(userID, page, pageSize) +} + +// GetParticipant 实现 ConversationRepository 接口 +func (a *ConversationRepositoryAdapter) GetParticipant(convID, userID string) (*model.ConversationParticipant, error) { + return a.repo.GetParticipant(convID, userID) +} + +// GetParticipants 实现 ConversationRepository 接口 +func (a *ConversationRepositoryAdapter) GetParticipants(convID string) ([]*model.ConversationParticipant, error) { + return a.repo.GetConversationParticipants(convID) +} + +// GetUnreadCount 实现 ConversationRepository 接口 +func (a *ConversationRepositoryAdapter) GetUnreadCount(userID, convID string) (int64, error) { + return a.repo.GetUnreadCount(convID, userID) +} + +// MessageRepositoryAdapter 适配 MessageRepository 到 MessageRepository 接口 +type MessageRepositoryAdapter struct { + repo *repository.MessageRepository +} + +// NewMessageRepositoryAdapter 创建适配器 +func NewMessageRepositoryAdapter(repo *repository.MessageRepository) MessageRepository { + return &MessageRepositoryAdapter{repo: repo} +} + +// GetMessages 实现 MessageRepository 接口 +func (a *MessageRepositoryAdapter) GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) { + return a.repo.GetMessages(convID, page, pageSize) +} + +// GetMessagesAfterSeq 实现 MessageRepository 接口 +func (a *MessageRepositoryAdapter) GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) { + return a.repo.GetMessagesAfterSeq(convID, afterSeq, limit) +} + +// GetMessagesBeforeSeq 实现 MessageRepository 接口 +func (a *MessageRepositoryAdapter) GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) { + return a.repo.GetMessagesBeforeSeq(convID, beforeSeq, limit) +} + +// CreateMessage 实现 MessageRepository 接口 +func (a *MessageRepositoryAdapter) CreateMessage(msg *model.Message) error { + return a.repo.CreateMessage(msg) +} + +// UpdateConversationLastSeq 实现 MessageRepository 接口 +func (a *MessageRepositoryAdapter) UpdateConversationLastSeq(convID string, seq int64) error { + return a.repo.UpdateConversationLastSeq(convID, seq) +} diff --git a/internal/config/config.go b/internal/config/config.go index 056991b..b0d3f2f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,18 +15,19 @@ import ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Cache CacheConfig `mapstructure:"cache"` - S3 S3Config `mapstructure:"s3"` - JWT JWTConfig `mapstructure:"jwt"` - Log LogConfig `mapstructure:"log"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Upload UploadConfig `mapstructure:"upload"` - Gorse GorseConfig `mapstructure:"gorse"` - OpenAI OpenAIConfig `mapstructure:"openai"` - Email EmailConfig `mapstructure:"email"` + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Cache CacheConfig `mapstructure:"cache"` + S3 S3Config `mapstructure:"s3"` + JWT JWTConfig `mapstructure:"jwt"` + Log LogConfig `mapstructure:"log"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Upload UploadConfig `mapstructure:"upload"` + Gorse GorseConfig `mapstructure:"gorse"` + OpenAI OpenAIConfig `mapstructure:"openai"` + Email EmailConfig `mapstructure:"email"` + ConversationCache ConversationCacheConfig `mapstructure:"conversation_cache"` } type ServerConfig struct { @@ -173,6 +174,73 @@ type EmailConfig struct { Timeout int `mapstructure:"timeout"` } +// ConversationCacheConfig 会话缓存配置 +type ConversationCacheConfig struct { + // TTL 配置 + DetailTTL string `mapstructure:"detail_ttl"` + ListTTL string `mapstructure:"list_ttl"` + ParticipantTTL string `mapstructure:"participant_ttl"` + UnreadTTL string `mapstructure:"unread_ttl"` + + // 消息缓存配置 + MessageDetailTTL string `mapstructure:"message_detail_ttl"` + MessageListTTL string `mapstructure:"message_list_ttl"` + MessageIndexTTL string `mapstructure:"message_index_ttl"` + MessageCountTTL string `mapstructure:"message_count_ttl"` + + // 批量写入配置 + BatchInterval string `mapstructure:"batch_interval"` + BatchThreshold int `mapstructure:"batch_threshold"` + BatchMaxSize int `mapstructure:"batch_max_size"` + BufferMaxSize int `mapstructure:"buffer_max_size"` +} + +// ConversationCacheSettings 会话缓存运行时配置(用于传递给 cache 包) +type ConversationCacheSettings struct { + DetailTTL time.Duration + ListTTL time.Duration + ParticipantTTL time.Duration + UnreadTTL time.Duration + MessageDetailTTL time.Duration + MessageListTTL time.Duration + MessageIndexTTL time.Duration + MessageCountTTL time.Duration + BatchInterval time.Duration + BatchThreshold int + BatchMaxSize int + BufferMaxSize int +} + +// ToSettings 将 ConversationCacheConfig 转换为 ConversationCacheSettings +func (c *ConversationCacheConfig) ToSettings() *ConversationCacheSettings { + return &ConversationCacheSettings{ + DetailTTL: parseDuration(c.DetailTTL, 5*time.Minute), + ListTTL: parseDuration(c.ListTTL, 60*time.Second), + ParticipantTTL: parseDuration(c.ParticipantTTL, 5*time.Minute), + UnreadTTL: parseDuration(c.UnreadTTL, 30*time.Second), + MessageDetailTTL: parseDuration(c.MessageDetailTTL, 30*time.Minute), + MessageListTTL: parseDuration(c.MessageListTTL, 5*time.Minute), + MessageIndexTTL: parseDuration(c.MessageIndexTTL, 30*time.Minute), + MessageCountTTL: parseDuration(c.MessageCountTTL, 30*time.Minute), + BatchInterval: parseDuration(c.BatchInterval, 5*time.Second), + BatchThreshold: c.BatchThreshold, + BatchMaxSize: c.BatchMaxSize, + BufferMaxSize: c.BufferMaxSize, + } +} + +// parseDuration 解析持续时间字符串,如果解析失败则返回默认值 +func parseDuration(s string, defaultVal time.Duration) time.Duration { + if s == "" { + return defaultVal + } + d, err := time.ParseDuration(s) + if err != nil { + return defaultVal + } + return d +} + func Load(configPath string) (*Config, error) { viper.SetConfigFile(configPath) viper.SetConfigType("yaml") @@ -259,6 +327,19 @@ func Load(configPath string) (*Config, error) { viper.SetDefault("email.use_tls", true) viper.SetDefault("email.insecure_skip_verify", false) viper.SetDefault("email.timeout", 15) + // ConversationCache 默认值 + viper.SetDefault("conversation_cache.detail_ttl", "5m") + viper.SetDefault("conversation_cache.list_ttl", "60s") + viper.SetDefault("conversation_cache.participant_ttl", "5m") + viper.SetDefault("conversation_cache.unread_ttl", "30s") + viper.SetDefault("conversation_cache.message_detail_ttl", "30m") + viper.SetDefault("conversation_cache.message_list_ttl", "5m") + viper.SetDefault("conversation_cache.message_index_ttl", "30m") + viper.SetDefault("conversation_cache.message_count_ttl", "30m") + viper.SetDefault("conversation_cache.batch_interval", "5s") + viper.SetDefault("conversation_cache.batch_threshold", 100) + viper.SetDefault("conversation_cache.batch_max_size", 500) + viper.SetDefault("conversation_cache.buffer_max_size", 10000) if err := viper.ReadInConfig(); err != nil { return nil, fmt.Errorf("failed to read config: %w", err) diff --git a/internal/dto/schedule_converter.go b/internal/dto/schedule_converter.go new file mode 100644 index 0000000..55621c0 --- /dev/null +++ b/internal/dto/schedule_converter.go @@ -0,0 +1,35 @@ +package dto + +import ( + "encoding/json" + + "carrot_bbs/internal/model" +) + +func ConvertScheduleCourseToResponse(course *model.ScheduleCourse, weeks []int) *ScheduleCourseResponse { + if course == nil { + return nil + } + return &ScheduleCourseResponse{ + ID: course.ID, + Name: course.Name, + Teacher: course.Teacher, + Location: course.Location, + DayOfWeek: course.DayOfWeek, + StartSection: course.StartSection, + EndSection: course.EndSection, + Weeks: weeks, + Color: course.Color, + } +} + +func ParseWeeksJSON(raw string) []int { + if raw == "" { + return []int{} + } + var weeks []int + if err := json.Unmarshal([]byte(raw), &weeks); err != nil { + return []int{} + } + return weeks +} diff --git a/internal/dto/schedule_dto.go b/internal/dto/schedule_dto.go new file mode 100644 index 0000000..c5f0da5 --- /dev/null +++ b/internal/dto/schedule_dto.go @@ -0,0 +1,13 @@ +package dto + +type ScheduleCourseResponse struct { + ID string `json:"id"` + Name string `json:"name"` + Teacher string `json:"teacher,omitempty"` + Location string `json:"location,omitempty"` + DayOfWeek int `json:"day_of_week"` + StartSection int `json:"start_section"` + EndSection int `json:"end_section"` + Weeks []int `json:"weeks"` + Color string `json:"color,omitempty"` +} diff --git a/internal/handler/group_handler.go b/internal/handler/group_handler.go index 49c057b..0e8cd1f 100644 --- a/internal/handler/group_handler.go +++ b/internal/handler/group_handler.go @@ -38,12 +38,12 @@ func parseGroupID(c *gin.Context) string { // parseUserIDFromPath 从路径参数获取用户ID(UUID格式) func parseUserIDFromPath(c *gin.Context) string { - return c.Param("userId") + return c.Param("user_id") } // parseAnnouncementID 从路径参数获取公告ID func parseAnnouncementID(c *gin.Context) string { - return c.Param("announcementId") + return c.Param("announcement_id") } // ==================== 群组管理 ==================== @@ -454,7 +454,7 @@ func (h *GroupHandler) GetMembers(c *gin.Context) { // ==================== RESTful Action 端点 ==================== // HandleCreateGroup 创建群组 -// POST /api/v1/groups/create +// POST /api/v1/groups func (h *GroupHandler) HandleCreateGroup(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -478,7 +478,7 @@ func (h *GroupHandler) HandleCreateGroup(c *gin.Context) { } // HandleGetUserGroups 获取用户群组列表 -// GET /api/v1/groups/list +// GET /api/v1/groups func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -499,7 +499,6 @@ func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) { } // HandleGetMyMemberInfo 获取我在群组中的成员信息 -// GET /api/v1/groups/get_my_info?group_id=xxx // GET /api/v1/groups/:id/me func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) { userID := parseUserID(c) @@ -551,7 +550,7 @@ func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) { } // HandleDissolveGroup 解散群组 -// POST /api/v1/groups/dissolve +// DELETE /api/v1/groups/:id func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -559,18 +558,13 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) { return } - var params dto.DissolveGroupParams - if err := c.ShouldBindJSON(¶ms); err != nil { - response.BadRequest(c, err.Error()) - return - } - - if params.GroupID == "" { + groupID := parseGroupID(c) + if groupID == "" { response.BadRequest(c, "group_id is required") return } - if err := h.groupService.DissolveGroup(userID, params.GroupID); err != nil { + if err := h.groupService.DissolveGroup(userID, groupID); err != nil { if err == service.ErrNotGroupOwner { response.Forbidden(c, "只有群主可以解散群组") return @@ -587,7 +581,7 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) { } // HandleTransferOwner 转让群主 -// POST /api/v1/groups/transfer +// POST /api/v1/groups/:id/transfer func (h *GroupHandler) HandleTransferOwner(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -595,22 +589,24 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.TransferOwnerParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } if params.NewOwnerID == "" { response.BadRequest(c, "new_owner_id is required") return } - if err := h.groupService.TransferOwner(userID, params.GroupID, params.NewOwnerID); err != nil { + if err := h.groupService.TransferOwner(userID, groupID, params.NewOwnerID); err != nil { if err == service.ErrNotGroupOwner { response.Forbidden(c, "只有群主可以转让群主") return @@ -631,7 +627,7 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) { } // HandleInviteMembers 邀请成员加入群组 -// POST /api/v1/groups/invite_members +// POST /api/v1/groups/:id/invitations func (h *GroupHandler) HandleInviteMembers(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -639,18 +635,19 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.InviteMembersParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - - if err := h.groupService.InviteMembers(userID, params.GroupID, params.MemberIDs); err != nil { + if err := h.groupService.InviteMembers(userID, groupID, params.MemberIDs); err != nil { if err == service.ErrNotGroupMember { response.Forbidden(c, "只有群成员可以邀请他人") return @@ -675,7 +672,7 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) { } // HandleJoinGroup 加入群组 -// POST /api/v1/groups/join +// POST /api/v1/groups/:id/join-requests func (h *GroupHandler) HandleJoinGroup(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -683,18 +680,13 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) { return } - var params dto.JoinGroupParams - if err := c.ShouldBindJSON(¶ms); err != nil { - response.BadRequest(c, err.Error()) - return - } - - if params.GroupID == "" { + groupID := parseGroupID(c) + if groupID == "" { response.BadRequest(c, "group_id is required") return } - if err := h.groupService.JoinGroup(userID, params.GroupID); err != nil { + if err := h.groupService.JoinGroup(userID, groupID); err != nil { if err == service.ErrJoinRequestPending { response.SuccessWithMessage(c, "申请已提交,等待群主/管理员审批", nil) return @@ -723,7 +715,7 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) { } // HandleSetNickname 设置群内昵称 -// POST /api/v1/groups/set_nickname +// PUT /api/v1/groups/:id/members/me/nickname func (h *GroupHandler) HandleSetNickname(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -731,18 +723,19 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetNicknameParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - - if err := h.groupService.SetMemberNickname(userID, params.GroupID, params.Nickname); err != nil { + if err := h.groupService.SetMemberNickname(userID, groupID, params.Nickname); err != nil { if err == service.ErrNotGroupMember { response.BadRequest(c, "不是群成员") return @@ -759,7 +752,7 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) { } // HandleSetJoinType 设置加群方式 -// POST /api/v1/groups/set_join_type +// PUT /api/v1/groups/:id/join-type func (h *GroupHandler) HandleSetJoinType(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -767,18 +760,19 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetJoinTypeParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - - if err := h.groupService.SetJoinType(userID, params.GroupID, params.JoinType); err != nil { + if err := h.groupService.SetJoinType(userID, groupID, params.JoinType); err != nil { if err == service.ErrNotGroupOwner { response.Forbidden(c, "只有群主可以设置加群方式") return @@ -803,7 +797,7 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) { } // HandleCreateAnnouncement 创建群公告 -// POST /api/v1/groups/create_announcement +// POST /api/v1/groups/:id/announcements func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -811,18 +805,19 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.CreateAnnouncementParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - - announcement, err := h.groupService.CreateAnnouncement(userID, params.GroupID, params.Content) + announcement, err := h.groupService.CreateAnnouncement(userID, groupID, params.Content) if err != nil { if err == service.ErrNotGroupAdmin { response.Forbidden(c, "只有群主或管理员可以发布公告") @@ -840,7 +835,6 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) { } // HandleGetAnnouncements 获取群公告列表 -// GET /api/v1/groups/get_announcements?group_id=xxx // GET /api/v1/groups/:id/announcements func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) { userID := parseUserID(c) @@ -872,7 +866,7 @@ func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) { } // HandleDeleteAnnouncement 删除群公告 -// POST /api/v1/groups/delete_announcement +// DELETE /api/v1/groups/:id/announcements/:announcement_id func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -880,22 +874,18 @@ func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) { return } - var params dto.DeleteAnnouncementParams - if err := c.ShouldBindJSON(¶ms); err != nil { - response.BadRequest(c, err.Error()) - return - } - - if params.GroupID == "" { + groupID := parseGroupID(c) + if groupID == "" { response.BadRequest(c, "group_id is required") return } - if params.AnnouncementID == "" { + announcementID := parseAnnouncementID(c) + if announcementID == "" { response.BadRequest(c, "announcement_id is required") return } - if err := h.groupService.DeleteAnnouncement(userID, params.AnnouncementID); err != nil { + if err := h.groupService.DeleteAnnouncement(userID, announcementID); err != nil { if err == service.ErrNotGroupAdmin { response.Forbidden(c, "只有群主或管理员可以删除公告") return @@ -1292,7 +1282,7 @@ func (h *GroupHandler) DeleteAnnouncement(c *gin.Context) { // ==================== RESTful Action 端点 ==================== // HandleSetGroupKick 群组踢人 -// POST /api/v1/groups/set_group_kick +// POST /api/v1/groups/:id/members/kick func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1300,23 +1290,25 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetGroupKickParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } if params.UserID == "" { response.BadRequest(c, "user_id is required") return } // 使用 RemoveMember 方法 - err := h.groupService.RemoveMember(userID, params.GroupID, params.UserID) + err := h.groupService.RemoveMember(userID, groupID, params.UserID) if err != nil { if err == service.ErrNotGroupAdmin { response.Forbidden(c, "只有群主或管理员可以移除成员") @@ -1342,7 +1334,7 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) { } // HandleSetGroupBan 群组单人禁言 -// POST /api/v1/groups/set_group_ban +// POST /api/v1/groups/:id/members/ban func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1350,16 +1342,18 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetGroupBanParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } if params.UserID == "" { response.BadRequest(c, "user_id is required") return @@ -1367,8 +1361,8 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) { // duration > 0 或 duration = -1 表示禁言,duration = 0 表示解除禁言 muted := params.Duration != 0 - log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, params.GroupID, params.UserID, params.Duration, muted) - err := h.groupService.MuteMember(userID, params.GroupID, params.UserID, muted) + log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, groupID, params.UserID, params.Duration, muted) + err := h.groupService.MuteMember(userID, groupID, params.UserID, muted) if err != nil { log.Printf("[HandleSetGroupBan] 禁言操作失败: %v", err) } else { @@ -1403,7 +1397,7 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) { } // HandleSetGroupWholeBan 群组全员禁言 -// POST /api/v1/groups/set_group_whole_ban +// PUT /api/v1/groups/:id/ban func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1411,18 +1405,19 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetGroupWholeBanParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - - err := h.groupService.SetMuteAll(userID, params.GroupID, params.Enable) + err := h.groupService.SetMuteAll(userID, groupID, params.Enable) if err != nil { if err == service.ErrNotGroupOwner { response.Forbidden(c, "只有群主可以设置全员禁言") @@ -1444,7 +1439,7 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) { } // HandleSetGroupAdmin 群组设置管理员 -// POST /api/v1/groups/set_group_admin +// PUT /api/v1/groups/:id/members/:user_id/admin func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1452,28 +1447,30 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + targetUserID := parseUserIDFromPath(c) + if targetUserID == "" { + response.BadRequest(c, "user_id is required") + return + } + var params dto.SetGroupAdminParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } - if params.UserID == "" { - response.BadRequest(c, "user_id is required") - return - } - // 根据 enable 参数设置角色 role := model.GroupRoleMember if params.Enable { role = model.GroupRoleAdmin } - err := h.groupService.SetMemberRole(userID, params.GroupID, params.UserID, role) + err := h.groupService.SetMemberRole(userID, groupID, targetUserID, role) if err != nil { if err == service.ErrNotGroupOwner { response.Forbidden(c, "只有群主可以设置管理员") @@ -1499,7 +1496,7 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) { } // HandleSetGroupName 设置群名 -// POST /api/v1/groups/set_group_name +// PUT /api/v1/groups/:id/name func (h *GroupHandler) HandleSetGroupName(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1507,16 +1504,18 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetGroupNameParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } if params.GroupName == "" { response.BadRequest(c, "group_name is required") return @@ -1526,7 +1525,7 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) { "name": params.GroupName, } - err := h.groupService.UpdateGroup(userID, params.GroupID, updates) + err := h.groupService.UpdateGroup(userID, groupID, updates) if err != nil { if err == service.ErrNotGroupAdmin { response.Forbidden(c, "没有权限修改群组信息") @@ -1541,12 +1540,12 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) { } // 获取更新后的群组信息 - group, _ := h.groupService.GetGroupByID(params.GroupID) + group, _ := h.groupService.GetGroupByID(groupID) response.Success(c, dto.GroupToResponse(group)) } // HandleSetGroupAvatar 设置群头像 -// POST /api/v1/groups/set_group_avatar +// PUT /api/v1/groups/:id/avatar func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1554,16 +1553,18 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) { return } + groupID := parseGroupID(c) + if groupID == "" { + response.BadRequest(c, "group_id is required") + return + } + var params dto.SetGroupAvatarParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } - if params.GroupID == "" { - response.BadRequest(c, "group_id is required") - return - } if params.Avatar == "" { response.BadRequest(c, "avatar is required") return @@ -1573,7 +1574,7 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) { "avatar": params.Avatar, } - err := h.groupService.UpdateGroup(userID, params.GroupID, updates) + err := h.groupService.UpdateGroup(userID, groupID, updates) if err != nil { if err == service.ErrNotGroupAdmin { response.Forbidden(c, "没有权限修改群组信息") @@ -1588,12 +1589,12 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) { } // 获取更新后的群组信息 - group, _ := h.groupService.GetGroupByID(params.GroupID) + group, _ := h.groupService.GetGroupByID(groupID) response.Success(c, dto.GroupToResponse(group)) } // HandleSetGroupLeave 退出群组 -// POST /api/v1/groups/set_group_leave +// POST /api/v1/groups/:id/leave func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1601,18 +1602,13 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) { return } - var params dto.SetGroupLeaveParams - if err := c.ShouldBindJSON(¶ms); err != nil { - response.BadRequest(c, err.Error()) - return - } - - if params.GroupID == "" { + groupID := parseGroupID(c) + if groupID == "" { response.BadRequest(c, "group_id is required") return } - err := h.groupService.LeaveGroup(userID, params.GroupID) + err := h.groupService.LeaveGroup(userID, groupID) if err != nil { if err == service.ErrNotGroupMember { response.BadRequest(c, "不是群成员") @@ -1630,7 +1626,7 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) { } // HandleSetGroupAddRequest 处理加群审批 -// POST /api/v1/groups/set_group_add_request +// POST /api/v1/groups/:id/join-requests/handle func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1678,7 +1674,7 @@ func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) { } // HandleRespondInvite 处理群邀请响应 -// POST /api/v1/groups/respond_invite +// POST /api/v1/groups/:id/join-requests/respond func (h *GroupHandler) HandleRespondInvite(c *gin.Context) { userID := parseUserID(c) if userID == "" { @@ -1725,7 +1721,6 @@ func (h *GroupHandler) HandleRespondInvite(c *gin.Context) { } // HandleGetGroupInfo 获取群信息 -// GET /api/v1/groups/get?group_id=xxx // GET /api/v1/groups/:id func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) { userID := parseUserID(c) @@ -1761,7 +1756,6 @@ func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) { } // HandleGetGroupMemberList 获取群成员列表 -// GET /api/v1/groups/get_members?group_id=xxx // GET /api/v1/groups/:id/members func (h *GroupHandler) HandleGetGroupMemberList(c *gin.Context) { userID := parseUserID(c) diff --git a/internal/handler/message_handler.go b/internal/handler/message_handler.go index fbf29d7..7e8df93 100644 --- a/internal/handler/message_handler.go +++ b/internal/handler/message_handler.go @@ -116,14 +116,14 @@ func (h *MessageHandler) HandleTyping(c *gin.Context) { response.Unauthorized(c, "") return } - var params struct { - ConversationID string `json:"conversation_id" binding:"required"` - } - if err := c.ShouldBindJSON(¶ms); err != nil { - response.BadRequest(c, err.Error()) + + conversationID := getIDParam(c, "id") + if conversationID == "" { + response.BadRequest(c, "conversation id is required") return } - h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID) + + h.chatService.SendTyping(c.Request.Context(), userID, conversationID) response.SuccessWithMessage(c, "typing sent", nil) } @@ -397,8 +397,8 @@ func (h *MessageHandler) SendMessage(c *gin.Context) { } // HandleSendMessage RESTful 风格的发送消息端点 -// POST /api/v1/conversations/send_message -// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]} +// POST /api/v1/conversations/:id/messages +// 请求体格式: {"detail_type": "private", "segments": [{"type": "text", "data": {"text": "嗨~"}}]} func (h *MessageHandler) HandleSendMessage(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { @@ -406,6 +406,12 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) { return } + conversationID := getIDParam(c, "id") + if conversationID == "" { + response.BadRequest(c, "conversation id is required") + return + } + var params dto.SendMessageParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) @@ -413,10 +419,6 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) { } // 验证参数 - if params.ConversationID == "" { - response.BadRequest(c, "conversation_id is required") - return - } if params.DetailType == "" { response.BadRequest(c, "detail_type is required") return @@ -427,7 +429,7 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) { } // 发送消息 - msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID) + msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, params.Segments, params.ReplyToID) if err != nil { response.BadRequest(c, err.Error()) return @@ -480,7 +482,7 @@ func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) { } // HandleGetConversationList 获取会话列表 -// GET /api/v1/conversations/list +// GET /api/v1/conversations func (h *MessageHandler) HandleGetConversationList(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { @@ -780,7 +782,6 @@ func (h *MessageHandler) HandleCreateConversation(c *gin.Context) { } // HandleGetConversation 获取会话详情 -// GET /api/v1/conversations/get?conversation_id=xxx // GET /api/v1/conversations/:id func (h *MessageHandler) HandleGetConversation(c *gin.Context) { userID := c.GetString("user_id") @@ -825,7 +826,6 @@ func (h *MessageHandler) HandleGetConversation(c *gin.Context) { } // HandleGetMessages 获取会话消息 -// GET /api/v1/conversations/get_messages?conversation_id=xxx // GET /api/v1/conversations/:id/messages func (h *MessageHandler) HandleGetMessages(c *gin.Context) { userID := c.GetString("user_id") @@ -913,7 +913,7 @@ func (h *MessageHandler) HandleGetMessages(c *gin.Context) { } // HandleMarkRead 标记已读 -// POST /api/v1/conversations/mark_read +// POST /api/v1/conversations/:id/read func (h *MessageHandler) HandleMarkRead(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { @@ -921,18 +921,19 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) { return } - var params dto.MarkReadParams - if err := c.ShouldBindJSON(¶ms); err != nil { + conversationID := getIDParam(c, "id") + if conversationID == "" { + response.BadRequest(c, "conversation id is required") + return + } + + var req dto.MarkReadRequest + if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, err.Error()) return } - if params.ConversationID == "" { - response.BadRequest(c, "conversation_id is required") - return - } - - err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq) + err := h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq) if err != nil { response.BadRequest(c, err.Error()) return @@ -942,7 +943,7 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) { } // HandleSetConversationPinned 设置会话置顶 -// POST /api/v1/conversations/set_pinned +// PUT /api/v1/conversations/:id/pinned func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { @@ -950,24 +951,27 @@ func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) { return } - var params dto.SetConversationPinnedParams - if err := c.ShouldBindJSON(¶ms); err != nil { + conversationID := getIDParam(c, "id") + if conversationID == "" { + response.BadRequest(c, "conversation id is required") + return + } + + var req struct { + IsPinned bool `json:"is_pinned"` + } + if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, err.Error()) return } - if params.ConversationID == "" { - response.BadRequest(c, "conversation_id is required") - return - } - - if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil { + if err := h.chatService.SetConversationPinned(c.Request.Context(), conversationID, userID, req.IsPinned); err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{ - "conversation_id": params.ConversationID, - "is_pinned": params.IsPinned, + "conversation_id": conversationID, + "is_pinned": req.IsPinned, }) } diff --git a/internal/handler/schedule_handler.go b/internal/handler/schedule_handler.go new file mode 100644 index 0000000..a459482 --- /dev/null +++ b/internal/handler/schedule_handler.go @@ -0,0 +1,140 @@ +package handler + +import ( + "strconv" + + "github.com/gin-gonic/gin" + + "carrot_bbs/internal/pkg/response" + "carrot_bbs/internal/service" +) + +type ScheduleHandler struct { + scheduleService service.ScheduleService +} + +func NewScheduleHandler(scheduleService service.ScheduleService) *ScheduleHandler { + return &ScheduleHandler{scheduleService: scheduleService} +} + +type createScheduleCourseRequest struct { + Name string `json:"name" binding:"required"` + Teacher string `json:"teacher"` + Location string `json:"location"` + DayOfWeek int `json:"day_of_week" binding:"required"` + StartSection int `json:"start_section" binding:"required"` + EndSection int `json:"end_section" binding:"required"` + Weeks []int `json:"weeks" binding:"required,min=1"` + Color string `json:"color"` +} + +type updateScheduleCourseRequest = createScheduleCourseRequest + +func (h *ScheduleHandler) ListCourses(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + week := 0 + if rawWeek := c.Query("week"); rawWeek != "" { + parsed, err := strconv.Atoi(rawWeek) + if err != nil { + response.BadRequest(c, "invalid week") + return + } + week = parsed + } + + list, err := h.scheduleService.ListCourses(userID, week) + if err != nil { + response.HandleError(c, err, "failed to list schedule courses") + return + } + response.Success(c, gin.H{"list": list}) +} + +func (h *ScheduleHandler) CreateCourse(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + + var req createScheduleCourseRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + created, err := h.scheduleService.CreateCourse(userID, service.CreateScheduleCourseInput{ + Name: req.Name, + Teacher: req.Teacher, + Location: req.Location, + DayOfWeek: req.DayOfWeek, + StartSection: req.StartSection, + EndSection: req.EndSection, + Weeks: req.Weeks, + Color: req.Color, + }) + if err != nil { + response.HandleError(c, err, "failed to create schedule course") + return + } + response.SuccessWithMessage(c, "course created", gin.H{"course": created}) +} + +func (h *ScheduleHandler) UpdateCourse(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + courseID := c.Param("id") + if courseID == "" { + response.BadRequest(c, "invalid course id") + return + } + + var req updateScheduleCourseRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updated, err := h.scheduleService.UpdateCourse(userID, courseID, service.CreateScheduleCourseInput{ + Name: req.Name, + Teacher: req.Teacher, + Location: req.Location, + DayOfWeek: req.DayOfWeek, + StartSection: req.StartSection, + EndSection: req.EndSection, + Weeks: req.Weeks, + Color: req.Color, + }) + if err != nil { + response.HandleError(c, err, "failed to update schedule course") + return + } + response.SuccessWithMessage(c, "course updated", gin.H{"course": updated}) +} + +func (h *ScheduleHandler) DeleteCourse(c *gin.Context) { + userID := c.GetString("user_id") + if userID == "" { + response.Unauthorized(c, "") + return + } + courseID := c.Param("id") + if courseID == "" { + response.BadRequest(c, "invalid course id") + return + } + + if err := h.scheduleService.DeleteCourse(userID, courseID); err != nil { + response.HandleError(c, err, "failed to delete schedule course") + return + } + response.SuccessWithMessage(c, "course deleted", nil) +} diff --git a/internal/model/init.go b/internal/model/init.go index 78cd78e..4e81504 100644 --- a/internal/model/init.go +++ b/internal/model/init.go @@ -143,6 +143,9 @@ func autoMigrate(db *gorm.DB) error { // 自定义表情 &UserSticker{}, + + // 课表 + &ScheduleCourse{}, ) if err != nil { return err diff --git a/internal/model/schedule_course.go b/internal/model/schedule_course.go new file mode 100644 index 0000000..1bf82cf --- /dev/null +++ b/internal/model/schedule_course.go @@ -0,0 +1,35 @@ +package model + +import ( + "time" + + "github.com/google/uuid" + "gorm.io/gorm" +) + +// ScheduleCourse 用户课表课程 +type ScheduleCourse struct { + ID string `json:"id" gorm:"type:varchar(36);primaryKey"` + UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"` + Name string `json:"name" gorm:"type:varchar(120);not null"` + Teacher string `json:"teacher" gorm:"type:varchar(80)"` + Location string `json:"location" gorm:"type:varchar(120)"` + DayOfWeek int `json:"day_of_week" gorm:"index;not null"` // 0=周一, 6=周日 + StartSection int `json:"start_section" gorm:"not null"` + EndSection int `json:"end_section" gorm:"not null"` + Weeks string `json:"weeks" gorm:"type:text;not null"` // JSON 数组字符串 + Color string `json:"color" gorm:"type:varchar(20)"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func (s *ScheduleCourse) BeforeCreate(tx *gorm.DB) error { + if s.ID == "" { + s.ID = uuid.New().String() + } + return nil +} + +func (ScheduleCourse) TableName() string { + return "schedule_courses" +} diff --git a/internal/pkg/openai/client.go b/internal/pkg/openai/client.go index 06ac12e..266c2f3 100644 --- a/internal/pkg/openai/client.go +++ b/internal/pkg/openai/client.go @@ -164,10 +164,17 @@ func (c *clientImpl) moderateSingleBatch( } type chatCompletionsRequest struct { - Model string `json:"model"` - Messages []chatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen3.5思考模式控制 + ThinkingBudget *int `json:"thinking_budget,omitempty"` // 思考过程最大token数 + ResponseFormat *responseFormatConfig `json:"response_format,omitempty"` // 响应格式 +} + +type responseFormatConfig struct { + Type string `json:"type"` // "text" or "json_object" } type chatMessage struct { @@ -227,6 +234,13 @@ func (c *clientImpl) chatCompletion( Temperature: temperature, MaxTokens: maxTokens, } + // 禁用qwen3.5的思考模式,避免产生大量不必要的token消耗 + falseVal := false + reqBody.EnableThinking = &falseVal + zero := 0 + reqBody.ThinkingBudget = &zero + // 使用JSON输出格式 + reqBody.ResponseFormat = &responseFormatConfig{Type: "json_object"} data, err := json.Marshal(reqBody) if err != nil { diff --git a/internal/pkg/redis/redis.go b/internal/pkg/redis/redis.go index 8d3e926..1b7ebc2 100644 --- a/internal/pkg/redis/redis.go +++ b/internal/pkg/redis/redis.go @@ -117,3 +117,117 @@ func (c *Client) Close() error { func (c *Client) IsMiniRedis() bool { return c.isMiniRedis } + +// ==================== Hash 操作 ==================== + +// HSet 设置 Hash 字段 +func (c *Client) HSet(ctx context.Context, key string, field string, value interface{}) error { + return c.rdb.HSet(ctx, key, field, value).Err() +} + +// HMSet 批量设置 Hash 字段 +func (c *Client) HMSet(ctx context.Context, key string, values map[string]interface{}) error { + return c.rdb.HMSet(ctx, key, values).Err() +} + +// HGet 获取 Hash 字段值 +func (c *Client) HGet(ctx context.Context, key string, field string) (string, error) { + return c.rdb.HGet(ctx, key, field).Result() +} + +// HMGet 批量获取 Hash 字段值 +func (c *Client) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) { + return c.rdb.HMGet(ctx, key, fields...).Result() +} + +// HGetAll 获取 Hash 所有字段 +func (c *Client) HGetAll(ctx context.Context, key string) (map[string]string, error) { + return c.rdb.HGetAll(ctx, key).Result() +} + +// HDel 删除 Hash 字段 +func (c *Client) HDel(ctx context.Context, key string, fields ...string) error { + return c.rdb.HDel(ctx, key, fields...).Err() +} + +// HExists 检查 Hash 字段是否存在 +func (c *Client) HExists(ctx context.Context, key string, field string) (bool, error) { + return c.rdb.HExists(ctx, key, field).Result() +} + +// HLen 获取 Hash 字段数量 +func (c *Client) HLen(ctx context.Context, key string) (int64, error) { + return c.rdb.HLen(ctx, key).Result() +} + +// ==================== Sorted Set 操作 ==================== + +// ZAdd 添加 Sorted Set 成员 +func (c *Client) ZAdd(ctx context.Context, key string, score float64, member string) error { + return c.rdb.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err() +} + +// ZAddArgs 批量添加 Sorted Set 成员 +func (c *Client) ZAddArgs(ctx context.Context, key string, members ...redis.Z) error { + return c.rdb.ZAdd(ctx, key, members...).Err() +} + +// ZRangeByScore 按分数范围获取成员(升序) +func (c *Client) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) { + return c.rdb.ZRangeByScore(ctx, key, &redis.ZRangeBy{ + Min: min, + Max: max, + Offset: offset, + Count: count, + }).Result() +} + +// ZRevRangeByScore 按分数范围获取成员(降序) +func (c *Client) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) { + return c.rdb.ZRevRangeByScore(ctx, key, &redis.ZRangeBy{ + Min: min, + Max: max, + Offset: offset, + Count: count, + }).Result() +} + +// ZRange 获取指定范围的成员(升序) +func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + return c.rdb.ZRange(ctx, key, start, stop).Result() +} + +// ZRevRange 获取指定范围的成员(降序) +func (c *Client) ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + return c.rdb.ZRevRange(ctx, key, start, stop).Result() +} + +// ZRem 删除 Sorted Set 成员 +func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) error { + return c.rdb.ZRem(ctx, key, members...).Err() +} + +// ZScore 获取成员分数 +func (c *Client) ZScore(ctx context.Context, key string, member string) (float64, error) { + return c.rdb.ZScore(ctx, key, member).Result() +} + +// ZCard 获取 Sorted Set 成员数量 +func (c *Client) ZCard(ctx context.Context, key string) (int64, error) { + return c.rdb.ZCard(ctx, key).Result() +} + +// ZCount 统计分数范围内的成员数量 +func (c *Client) ZCount(ctx context.Context, key string, min, max string) (int64, error) { + return c.rdb.ZCount(ctx, key, min, max).Result() +} + +// ==================== Pipeline 操作 ==================== + +// Pipeliner Pipeline 接口(使用 redis 库原生接口) +type Pipeliner = redis.Pipeliner + +// Pipeline 创建 Pipeline +func (c *Client) Pipeline() Pipeliner { + return c.rdb.Pipeline() +} diff --git a/internal/repository/message_repo.go b/internal/repository/message_repo.go index 8c6e948..aefc332 100644 --- a/internal/repository/message_repo.go +++ b/internal/repository/message_repo.go @@ -2,6 +2,9 @@ package repository import ( "carrot_bbs/internal/model" + "context" + "fmt" + "strings" "time" "gorm.io/gorm" @@ -172,7 +175,7 @@ func (r *MessageRepository) GetParticipant(conversationID string, userID string) if err == gorm.ErrRecordNotFound { // 检查会话是否存在 var conv model.Conversation - if err := r.db.First(&conv, conversationID).Error; err == nil { + if err := r.db.Where("id = ?", conversationID).First(&conv).Error; err == nil { // 会话存在,添加参与者 participant = model.ConversationParticipant{ ConversationID: conversationID, @@ -284,7 +287,7 @@ func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq // GetNextSeq 获取会话的下一个seq值 func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) { var conv model.Conversation - err := r.db.Select("last_seq").First(&conv, conversationID).Error + err := r.db.Select("last_seq").Where("id = ?", conversationID).First(&conv).Error if err != nil { return 0, err } @@ -296,7 +299,7 @@ func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error { return r.db.Transaction(func(tx *gorm.DB) error { // 获取当前seq并+1 var conv model.Conversation - if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil { + if err := tx.Select("last_seq").Where("id = ?", msg.ConversationID).First(&conv).Error; err != nil { return err } @@ -522,3 +525,117 @@ func (r *MessageRepository) HideConversationForUser(conversationID, userID strin Where("conversation_id = ? AND user_id = ?", conversationID, userID). Update("hidden_at", &now).Error } + +// ParticipantUpdate 参与者更新数据 +type ParticipantUpdate struct { + ConversationID string + UserID string + LastReadSeq int64 +} + +// BatchWriteMessages 批量写入消息 +// 使用 GORM 的 CreateInBatches 实现高效批量插入 +func (r *MessageRepository) BatchWriteMessages(ctx context.Context, messages []*model.Message) error { + if len(messages) == 0 { + return nil + } + return r.db.WithContext(ctx).CreateInBatches(messages, 100).Error +} + +// BatchUpdateParticipants 批量更新参与者(使用 CASE WHEN 优化) +// 使用单条 SQL 更新多条记录,避免循环执行 UPDATE +func (r *MessageRepository) BatchUpdateParticipants(ctx context.Context, updates []ParticipantUpdate) error { + if len(updates) == 0 { + return nil + } + + // 构建 CASE WHEN 批量更新 SQL + // UPDATE conversation_participants + // SET last_read_seq = CASE + // WHEN (conversation_id = '1' AND user_id = 'a') THEN 10 + // WHEN (conversation_id = '2' AND user_id = 'b') THEN 20 + // END, + // updated_at = ? + // WHERE (conversation_id = '1' AND user_id = 'a') + // OR (conversation_id = '2' AND user_id = 'b') + + var cases []string + var whereClauses []string + var args []interface{} + + for _, u := range updates { + cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?") + whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)") + args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID) + } + + sql := fmt.Sprintf(` + UPDATE conversation_participants + SET last_read_seq = CASE %s END, + updated_at = ? + WHERE %s + `, strings.Join(cases, " "), strings.Join(whereClauses, " OR ")) + + args = append(args, time.Now()) + + return r.db.WithContext(ctx).Exec(sql, args...).Error +} + +// UpdateConversationLastSeqWithContext 更新会话最后消息序号 +func (r *MessageRepository) UpdateConversationLastSeqWithContext(ctx context.Context, convID string, lastSeq int64, lastMsgTime time.Time) error { + return r.db.WithContext(ctx). + Model(&model.Conversation{}). + Where("id = ?", convID). + Updates(map[string]interface{}{ + "last_seq": lastSeq, + "last_msg_time": lastMsgTime, + "updated_at": time.Now(), + }).Error +} + +// BatchWriteMessagesWithTx 在事务中批量写入消息 +func (r *MessageRepository) BatchWriteMessagesWithTx(tx *gorm.DB, messages []*model.Message) error { + if len(messages) == 0 { + return nil + } + return tx.CreateInBatches(messages, 100).Error +} + +// BatchUpdateParticipantsWithTx 在事务中批量更新参与者 +func (r *MessageRepository) BatchUpdateParticipantsWithTx(tx *gorm.DB, updates []ParticipantUpdate) error { + if len(updates) == 0 { + return nil + } + + var cases []string + var whereClauses []string + var args []interface{} + + for _, u := range updates { + cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?") + whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)") + args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID) + } + + sql := fmt.Sprintf(` + UPDATE conversation_participants + SET last_read_seq = CASE %s END, + updated_at = ? + WHERE %s + `, strings.Join(cases, " "), strings.Join(whereClauses, " OR ")) + + args = append(args, time.Now()) + + return tx.Exec(sql, args...).Error +} + +// UpdateConversationLastSeqWithTx 在事务中更新会话最后消息序号 +func (r *MessageRepository) UpdateConversationLastSeqWithTx(tx *gorm.DB, convID string, lastSeq int64, lastMsgTime time.Time) error { + return tx.Model(&model.Conversation{}). + Where("id = ?", convID). + Updates(map[string]interface{}{ + "last_seq": lastSeq, + "last_msg_time": lastMsgTime, + "updated_at": time.Now(), + }).Error +} diff --git a/internal/repository/schedule_repo.go b/internal/repository/schedule_repo.go new file mode 100644 index 0000000..1de19f7 --- /dev/null +++ b/internal/repository/schedule_repo.go @@ -0,0 +1,66 @@ +package repository + +import ( + "carrot_bbs/internal/model" + + "gorm.io/gorm" +) + +type ScheduleRepository interface { + ListByUserID(userID string) ([]*model.ScheduleCourse, error) + GetByID(id string) (*model.ScheduleCourse, error) + Create(course *model.ScheduleCourse) error + Update(course *model.ScheduleCourse) error + DeleteByID(id string) error + ExistsColorByUser(userID, color, excludeID string) (bool, error) +} + +type scheduleRepository struct { + db *gorm.DB +} + +func NewScheduleRepository(db *gorm.DB) ScheduleRepository { + return &scheduleRepository{db: db} +} + +func (r *scheduleRepository) ListByUserID(userID string) ([]*model.ScheduleCourse, error) { + var courses []*model.ScheduleCourse + err := r.db. + Where("user_id = ?", userID). + Order("day_of_week ASC, start_section ASC, created_at ASC"). + Find(&courses).Error + return courses, err +} + +func (r *scheduleRepository) Create(course *model.ScheduleCourse) error { + return r.db.Create(course).Error +} + +func (r *scheduleRepository) GetByID(id string) (*model.ScheduleCourse, error) { + var course model.ScheduleCourse + if err := r.db.Where("id = ?", id).First(&course).Error; err != nil { + return nil, err + } + return &course, nil +} + +func (r *scheduleRepository) Update(course *model.ScheduleCourse) error { + return r.db.Save(course).Error +} + +func (r *scheduleRepository) DeleteByID(id string) error { + return r.db.Delete(&model.ScheduleCourse{}, "id = ?", id).Error +} + +func (r *scheduleRepository) ExistsColorByUser(userID, color, excludeID string) (bool, error) { + var count int64 + query := r.db.Model(&model.ScheduleCourse{}). + Where("user_id = ? AND LOWER(color) = LOWER(?)", userID, color) + if excludeID != "" { + query = query.Where("id <> ?", excludeID) + } + if err := query.Count(&count).Error; err != nil { + return false, err + } + return count > 0, nil +} diff --git a/internal/router/router.go b/internal/router/router.go index b7297bd..b7017ad 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -23,6 +23,7 @@ type Router struct { stickerHandler *handler.StickerHandler gorseHandler *handler.GorseHandler voteHandler *handler.VoteHandler + scheduleHandler *handler.ScheduleHandler jwtService *service.JWTService } @@ -41,6 +42,7 @@ func New( stickerHandler *handler.StickerHandler, gorseHandler *handler.GorseHandler, voteHandler *handler.VoteHandler, + scheduleHandler *handler.ScheduleHandler, ) *Router { // 设置JWT服务 userHandler.SetJWTService(jwtService) @@ -59,6 +61,7 @@ func New( stickerHandler: stickerHandler, gorseHandler: gorseHandler, voteHandler: voteHandler, + scheduleHandler: scheduleHandler, jwtService: jwtService, } @@ -160,6 +163,18 @@ func (r *Router) setupRoutes() { posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票 } + // 课表路由 + if r.scheduleHandler != nil { + schedule := v1.Group("/schedule") + schedule.Use(authMiddleware) + { + schedule.GET("/courses", r.scheduleHandler.ListCourses) + schedule.POST("/courses", r.scheduleHandler.CreateCourse) + schedule.PUT("/courses/:id", r.scheduleHandler.UpdateCourse) + schedule.DELETE("/courses/:id", r.scheduleHandler.DeleteCourse) + } + } + // 投票选项路由 voteOptions := v1.Group("/vote-options") voteOptions.Use(authMiddleware) diff --git a/internal/service/chat_service.go b/internal/service/chat_service.go index 776a74c..380b0a3 100644 --- a/internal/service/chat_service.go +++ b/internal/service/chat_service.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "log" "time" + "carrot_bbs/internal/cache" "carrot_bbs/internal/dto" "carrot_bbs/internal/model" "carrot_bbs/internal/pkg/sse" @@ -58,6 +60,9 @@ type chatServiceImpl struct { userRepo *repository.UserRepository sensitive SensitiveService sseHub *sse.Hub + + // 缓存相关字段 + conversationCache *cache.ConversationCache } // NewChatService 创建聊天服务 @@ -68,12 +73,25 @@ func NewChatService( sensitive SensitiveService, sseHub *sse.Hub, ) ChatService { + // 创建适配器 + convRepoAdapter := cache.NewConversationRepositoryAdapter(repo) + msgRepoAdapter := cache.NewMessageRepositoryAdapter(repo) + + // 创建会话缓存 + conversationCache := cache.NewConversationCache( + cache.GetCache(), + convRepoAdapter, + msgRepoAdapter, + cache.DefaultConversationCacheSettings(), + ) + return &chatServiceImpl{ - db: db, - repo: repo, - userRepo: userRepo, - sensitive: sensitive, - sseHub: sseHub, + db: db, + repo: repo, + userRepo: userRepo, + sensitive: sensitive, + sseHub: sseHub, + conversationCache: conversationCache, } } @@ -86,18 +104,33 @@ func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payl // GetOrCreateConversation 获取或创建私聊会话 func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) { - return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID) + conv, err := s.repo.GetOrCreatePrivateConversation(user1ID, user2ID) + if err != nil { + return nil, err + } + + // 失效会话列表缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateConversationList(user1ID) + s.conversationCache.InvalidateConversationList(user2ID) + } + + return conv, nil } -// GetConversationList 获取用户的会话列表 +// GetConversationList 获取用户的会话列表(带缓存) func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + // 优先使用缓存 + if s.conversationCache != nil { + return s.conversationCache.GetConversationList(ctx, userID, page, pageSize) + } return s.repo.GetConversations(userID, page, pageSize) } -// GetConversationByID 获取会话详情 +// GetConversationByID 获取会话详情(带缓存) func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) { // 验证用户是否是会话参与者 - participant, err := s.repo.GetParticipant(conversationID, userID) + participant, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") @@ -105,21 +138,33 @@ func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationI return nil, fmt.Errorf("failed to get participant: %w", err) } - // 获取会话信息 - conv, err := s.repo.GetConversation(conversationID) + // 获取会话信息(优先使用缓存) + var conv *model.Conversation + if s.conversationCache != nil { + conv, err = s.conversationCache.GetConversation(ctx, conversationID) + } else { + conv, err = s.repo.GetConversation(conversationID) + } if err != nil { return nil, fmt.Errorf("failed to get conversation: %w", err) } - // 填充用户的已读位置信息 _ = participant // 可以用于返回已读位置等信息 return conv, nil } +// getParticipant 获取参与者信息(优先使用缓存) +func (s *chatServiceImpl) getParticipant(ctx context.Context, conversationID, userID string) (*model.ConversationParticipant, error) { + if s.conversationCache != nil { + return s.conversationCache.GetParticipant(ctx, conversationID, userID) + } + return s.repo.GetParticipant(conversationID, userID) +} + // DeleteConversationForSelf 仅自己删除会话 func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error { - participant, err := s.repo.GetParticipant(conversationID, userID) + participant, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") @@ -133,12 +178,18 @@ func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, convers if err := s.repo.HideConversationForUser(conversationID, userID); err != nil { return fmt.Errorf("failed to hide conversation: %w", err) } + + // 失效会话列表缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateConversationList(userID) + } + return nil } // SetConversationPinned 设置会话置顶(用户维度) func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error { - participant, err := s.repo.GetParticipant(conversationID, userID) + participant, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") @@ -152,13 +203,20 @@ func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversatio if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil { return fmt.Errorf("failed to update pinned status: %w", err) } + + // 失效缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateParticipant(conversationID, userID) + s.conversationCache.InvalidateConversationList(userID) + } + return nil } // SendMessage 发送消息(使用 segments) func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { // 首先验证会话是否存在 - conv, err := s.repo.GetConversation(conversationID) + conv, err := s.getConversation(ctx, conversationID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("会话不存在,请重新创建会话") @@ -166,9 +224,9 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv return nil, fmt.Errorf("failed to get conversation: %w", err) } - // 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向 + // 拉黑限制:仅拦截"被拉黑方 -> 拉黑人"方向 if conv.Type == model.ConversationTypePrivate && s.userRepo != nil { - participants, pErr := s.repo.GetConversationParticipants(conversationID) + participants, pErr := s.getParticipants(ctx, conversationID) if pErr != nil { return nil, fmt.Errorf("failed to get participants: %w", pErr) } @@ -209,7 +267,7 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv } // 验证用户是否是会话参与者 - participant, err := s.repo.GetParticipant(conversationID, senderID) + participant, err := s.getParticipant(ctx, conversationID, senderID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("您不是该会话的参与者") @@ -231,11 +289,27 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv return nil, fmt.Errorf("failed to save message: %w", err) } + // 新消息会改变分页结果,先失效分页缓存,避免读到旧列表 + if s.conversationCache != nil { + s.conversationCache.InvalidateMessagePages(conversationID) + } + + // 异步写入缓存 + go func() { + if err := s.cacheMessage(context.Background(), conversationID, message); err != nil { + log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err) + } + }() + // 获取会话中的参与者并发送 SSE - participants, err := s.repo.GetConversationParticipants(conversationID) + participants, err := s.getParticipants(ctx, conversationID) if err == nil { targetIDs := make([]string, 0, len(participants)) for _, p := range participants { + // 私聊场景下,发送者已经从 HTTP 响应拿到消息,避免再通过 SSE 回推导致本端重复展示。 + if conv.Type == model.ConversationTypePrivate && p.UserID == senderID { + continue + } targetIDs = append(targetIDs, p.UserID) } detailType := "private" @@ -250,6 +324,10 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv if p.UserID == senderID { continue } + // 失效未读数缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateUnreadCount(p.UserID, conversationID) + } if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil { s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{ "conversation_id": conversationID, @@ -259,11 +337,46 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv } } + // 失效会话列表缓存 + if s.conversationCache != nil { + for _, p := range participants { + s.conversationCache.InvalidateConversationList(p.UserID) + } + } + _ = participant // 避免未使用变量警告 return message, nil } +// getConversation 获取会话(优先使用缓存) +func (s *chatServiceImpl) getConversation(ctx context.Context, conversationID string) (*model.Conversation, error) { + if s.conversationCache != nil { + return s.conversationCache.GetConversation(ctx, conversationID) + } + return s.repo.GetConversation(conversationID) +} + +// getParticipants 获取会话参与者(优先使用缓存) +func (s *chatServiceImpl) getParticipants(ctx context.Context, conversationID string) ([]*model.ConversationParticipant, error) { + if s.conversationCache != nil { + return s.conversationCache.GetParticipants(ctx, conversationID) + } + return s.repo.GetConversationParticipants(conversationID) +} + +// cacheMessage 缓存消息(内部方法) +func (s *chatServiceImpl) cacheMessage(ctx context.Context, convID string, msg *model.Message) error { + if s.conversationCache == nil { + return nil + } + + asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + return s.conversationCache.CacheMessage(asyncCtx, convID, msg) +} + func containsImageSegment(segments model.MessageSegments) bool { for _, seg := range segments { if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" { @@ -273,10 +386,10 @@ func containsImageSegment(segments model.MessageSegments) bool { return false } -// GetMessages 获取消息历史(分页) +// GetMessages 获取消息历史(分页,带缓存) func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) { // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, userID) + _, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, 0, errors.New("conversation not found or no permission") @@ -284,13 +397,18 @@ func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string return nil, 0, fmt.Errorf("failed to get participant: %w", err) } + // 优先使用缓存 + if s.conversationCache != nil { + return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize) + } + return s.repo.GetMessages(conversationID, page, pageSize) } // GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步) func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) { // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, userID) + _, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") @@ -308,7 +426,7 @@ func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationI // GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多) func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) { // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, userID) + _, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("conversation not found or no permission") @@ -326,7 +444,7 @@ func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversation // MarkAsRead 标记已读 func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error { // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, userID) + _, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("conversation not found or no permission") @@ -334,17 +452,27 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, return fmt.Errorf("failed to get participant: %w", err) } - // 更新参与者的已读位置 + // 1. 先写入DB(保证数据一致性,DB是唯一数据源) err = s.repo.UpdateLastReadSeq(conversationID, userID, seq) if err != nil { return fmt.Errorf("failed to update last read seq: %w", err) } - participants, pErr := s.repo.GetConversationParticipants(conversationID) + // 2. DB 写入成功后,失效缓存(Cache-Aside 模式) + if s.conversationCache != nil { + // 失效参与者缓存,下次读取时会从 DB 加载最新数据 + s.conversationCache.InvalidateParticipant(conversationID, userID) + // 失效未读数缓存 + s.conversationCache.InvalidateUnreadCount(userID, conversationID) + // 失效会话列表缓存 + s.conversationCache.InvalidateConversationList(userID) + } + + participants, pErr := s.getParticipants(ctx, conversationID) if pErr == nil { detailType := "private" groupID := "" - if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" if conv.GroupID != nil { groupID = *conv.GroupID @@ -372,10 +500,10 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, return nil } -// GetUnreadCount 获取指定会话的未读消息数 +// GetUnreadCount 获取指定会话的未读消息数(带缓存) func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, userID) + _, err := s.getParticipant(ctx, conversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return 0, errors.New("conversation not found or no permission") @@ -383,6 +511,11 @@ func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID str return 0, fmt.Errorf("failed to get participant: %w", err) } + // 优先使用缓存 + if s.conversationCache != nil { + return s.conversationCache.GetUnreadCount(ctx, userID, conversationID) + } + return s.repo.GetUnreadCount(conversationID, userID) } @@ -427,10 +560,15 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u return fmt.Errorf("failed to recall message: %w", err) } - if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil { + // 失效消息缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateConversation(message.ConversationID) + } + + if participants, pErr := s.getParticipants(ctx, message.ConversationID); pErr == nil { detailType := "private" groupID := "" - if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + if conv, convErr := s.getConversation(ctx, message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" if conv.GroupID != nil { groupID = *conv.GroupID @@ -465,7 +603,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u } // 验证用户是否是会话参与者 - _, err = s.repo.GetParticipant(message.ConversationID, userID) + _, err = s.getParticipant(ctx, message.ConversationID, userID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("no permission to delete this message") @@ -485,6 +623,11 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u return fmt.Errorf("failed to delete message: %w", err) } + // 失效消息缓存 + if s.conversationCache != nil { + s.conversationCache.InvalidateConversation(message.ConversationID) + } + return nil } @@ -495,19 +638,19 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve } // 验证用户是否是会话参与者 - _, err := s.repo.GetParticipant(conversationID, senderID) + _, err := s.getParticipant(ctx, conversationID, senderID) if err != nil { return } // 获取会话中的其他参与者 - participants, err := s.repo.GetConversationParticipants(conversationID) + participants, err := s.getParticipants(ctx, conversationID) if err != nil { return } detailType := "private" - if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { + if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup { detailType = "group" } for _, p := range participants { @@ -537,7 +680,7 @@ func (s *chatServiceImpl) IsUserOnline(userID string) bool { // 适用于群聊等由调用方自行负责推送的场景 func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) { // 验证会话是否存在 - _, err := s.repo.GetConversation(conversationID) + _, err := s.getConversation(ctx, conversationID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("会话不存在,请重新创建会话") @@ -546,7 +689,7 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv } // 验证用户是否是会话参与者 - _, err = s.repo.GetParticipant(conversationID, senderID) + _, err = s.getParticipant(ctx, conversationID, senderID) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("您不是该会话的参与者") @@ -566,5 +709,17 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv return nil, fmt.Errorf("failed to save message: %w", err) } + // 新消息会改变分页结果,先失效分页缓存,避免读到旧列表 + if s.conversationCache != nil { + s.conversationCache.InvalidateMessagePages(conversationID) + } + + // 异步写入缓存 + go func() { + if err := s.cacheMessage(context.Background(), conversationID, message); err != nil { + log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err) + } + }() + return message, nil } diff --git a/internal/service/group_service.go b/internal/service/group_service.go index 1c5d962..8fe7a1c 100644 --- a/internal/service/group_service.go +++ b/internal/service/group_service.go @@ -145,6 +145,45 @@ func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMess } } +// invalidateConversationCachesAfterSystemMessage 系统消息写入后失效相关缓存 +func (s *groupService) invalidateConversationCachesAfterSystemMessage(conversationID string) { + if conversationID == "" || s.messageRepo == nil { + return + } + // 新系统消息会影响消息分页列表 + cache.InvalidateMessagePages(s.cache, conversationID) + // 参与者列表可能发生变化(加群/退群)后,这里统一清理一次 + s.cache.Delete(cache.ParticipantListKey(conversationID)) + + participants, err := s.messageRepo.GetConversationParticipants(conversationID) + if err != nil { + return + } + for _, p := range participants { + if p == nil || p.UserID == "" { + continue + } + // 会话最后消息、未读数会变化,清理用户维度缓存 + cache.InvalidateConversationList(s.cache, p.UserID) + cache.InvalidateUnreadConversation(s.cache, p.UserID) + cache.InvalidateUnreadDetail(s.cache, p.UserID, conversationID) + } +} + +// invalidateConversationCachesAfterMembershipChange 成员变更后失效相关缓存 +func (s *groupService) invalidateConversationCachesAfterMembershipChange(conversationID, userID string) { + if conversationID == "" { + return + } + s.cache.Delete(cache.ParticipantListKey(conversationID)) + if userID != "" { + s.cache.Delete(cache.ParticipantKey(conversationID, userID)) + cache.InvalidateConversationList(s.cache, userID) + cache.InvalidateUnreadConversation(s.cache, userID) + cache.InvalidateUnreadDetail(s.cache, userID, conversationID) + } +} + // ==================== 群组管理 ==================== // CreateGroup 创建群组 @@ -444,6 +483,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st log.Printf("[broadcastMemberJoinNotice] 保存入群提示消息失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err) } else { savedMessage = msg + s.invalidateConversationCachesAfterSystemMessage(conv.ID) } } else { log.Printf("[broadcastMemberJoinNotice] 获取群组会话失败: groupID=%s, err=%v", groupID, err) @@ -502,6 +542,7 @@ func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userI if err := s.messageRepo.AddParticipant(conv.ID, userID); err != nil { log.Printf("[addMemberToGroupAndConversation] 添加会话参与者失败: groupID=%s, userID=%s, err=%v", group.ID, userID, err) } + s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID) } } cache.InvalidateGroupMembers(s.cache, group.ID) @@ -1036,6 +1077,7 @@ func (s *groupService) LeaveGroup(userID string, groupID string) error { // 如果移除参与者失败,记录日志但不阻塞退出群流程 fmt.Printf("[WARN] LeaveGroup: failed to remove participant %s from conversation %s, error: %v\n", userID, conv.ID, err) } + s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID) } // 失效群组成员缓存 @@ -1092,6 +1134,7 @@ func (s *groupService) RemoveMember(userID string, groupID string, targetUserID if err := s.messageRepo.RemoveParticipant(conv.ID, targetUserID); err != nil { log.Printf("[RemoveMember] 移除会话参与者失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err) } + s.invalidateConversationCachesAfterMembershipChange(conv.ID, targetUserID) } } @@ -1290,6 +1333,7 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st } else { savedMessage = msg log.Printf("[MuteMember] 禁言消息已保存, ID=%s, Seq=%d", msg.ID, msg.Seq) + s.invalidateConversationCachesAfterSystemMessage(conv.ID) } } else { log.Printf("[MuteMember] 获取群组会话失败: %v", err) diff --git a/internal/service/message_service.go b/internal/service/message_service.go index f8e7482..ed0f772 100644 --- a/internal/service/message_service.go +++ b/internal/service/message_service.go @@ -2,11 +2,14 @@ package service import ( "context" + "log" "time" "carrot_bbs/internal/cache" "carrot_bbs/internal/model" "carrot_bbs/internal/repository" + + "gorm.io/gorm" ) // 缓存TTL常量 @@ -21,15 +24,37 @@ const ( // MessageService 消息服务 type MessageService struct { + db *gorm.DB + + // 基础仓储 messageRepo *repository.MessageRepository - cache cache.Cache + + // 缓存相关字段 + conversationCache *cache.ConversationCache + + // 基础缓存(用于简单缓存操作) + baseCache cache.Cache } // NewMessageService 创建消息服务 -func NewMessageService(messageRepo *repository.MessageRepository) *MessageService { +func NewMessageService(db *gorm.DB, messageRepo *repository.MessageRepository) *MessageService { + // 创建适配器 + convRepoAdapter := cache.NewConversationRepositoryAdapter(messageRepo) + msgRepoAdapter := cache.NewMessageRepositoryAdapter(messageRepo) + + // 创建会话缓存 + conversationCache := cache.NewConversationCache( + cache.GetCache(), + convRepoAdapter, + msgRepoAdapter, + cache.DefaultConversationCacheSettings(), + ) + return &MessageService{ - messageRepo: messageRepo, - cache: cache.GetCache(), + db: db, + messageRepo: messageRepo, + conversationCache: conversationCache, + baseCache: cache.GetCache(), } } @@ -61,20 +86,50 @@ func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID s return nil, err } + // 新消息会改变分页结果,先失效分页缓存,避免读到旧列表 + if s.conversationCache != nil { + s.conversationCache.InvalidateMessagePages(conv.ID) + } + + // 异步写入缓存 + go func() { + if err := s.cacheMessage(context.Background(), conv.ID, msg); err != nil { + log.Printf("[MessageService] async cache message failed, convID=%s, msgID=%s, err=%v", conv.ID, msg.ID, err) + } + }() + // 失效会话列表缓存(发送者和接收者) - cache.InvalidateConversationList(s.cache, senderID) - cache.InvalidateConversationList(s.cache, receiverID) + s.conversationCache.InvalidateConversationList(senderID) + s.conversationCache.InvalidateConversationList(receiverID) // 失效未读数缓存 - cache.InvalidateUnreadConversation(s.cache, receiverID) - cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID) + cache.InvalidateUnreadConversation(s.baseCache, receiverID) + s.conversationCache.InvalidateUnreadCount(receiverID, conv.ID) return msg, nil } +// cacheMessage 缓存消息(内部方法) +func (s *MessageService) cacheMessage(ctx context.Context, convID string, msg *model.Message) error { + if s.conversationCache == nil { + return nil + } + + asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + return s.conversationCache.CacheMessage(asyncCtx, convID, msg) +} + // GetConversations 获取会话列表(带缓存) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) { + // 优先使用 ConversationCache + if s.conversationCache != nil { + return s.conversationCache.GetConversationList(ctx, userID, page, pageSize) + } + + // 降级到基础缓存 cacheSettings := cache.GetSettings() conversationTTL := cacheSettings.ConversationTTL if conversationTTL <= 0 { @@ -92,7 +147,7 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa // 生成缓存键 cacheKey := cache.ConversationListKey(userID, page, pageSize) result, err := cache.GetOrLoadTyped[*ConversationListResult]( - s.cache, + s.baseCache, cacheKey, conversationTTL, jitter, @@ -117,8 +172,14 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa return result.Conversations, result.Total, nil } -// GetMessages 获取消息列表 +// GetMessages 获取消息列表(带缓存) func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) { + // 优先使用 ConversationCache + if s.conversationCache != nil { + return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize) + } + + // 降级到直接访问数据库 return s.messageRepo.GetMessages(conversationID, page, pageSize) } @@ -127,20 +188,25 @@ func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit) } -// MarkAsRead 标记为已读 +// MarkAsRead 标记为已读(使用 Cache-Aside 模式) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error { + // 1. 先写入DB(保证数据一致性,DB是唯一数据源) err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq) if err != nil { return err } - // 失效未读数缓存 - cache.InvalidateUnreadConversation(s.cache, userID) - cache.InvalidateUnreadDetail(s.cache, userID, conversationID) - - // 失效会话列表缓存 - cache.InvalidateConversationList(s.cache, userID) + // 2. DB 写入成功后,失效缓存(Cache-Aside 模式) + if s.conversationCache != nil { + // 失效参与者缓存,下次读取时会从 DB 加载最新数据 + s.conversationCache.InvalidateParticipant(conversationID, userID) + // 失效未读数缓存 + s.conversationCache.InvalidateUnreadCount(userID, conversationID) + // 失效会话列表缓存 + s.conversationCache.InvalidateConversationList(userID) + } + cache.InvalidateUnreadConversation(s.baseCache, userID) return nil } @@ -148,6 +214,12 @@ func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, // GetUnreadCount 获取未读消息数(带缓存) // userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致 func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) { + // 优先使用 ConversationCache + if s.conversationCache != nil { + return s.conversationCache.GetUnreadCount(ctx, userID, conversationID) + } + + // 降级到基础缓存 cacheSettings := cache.GetSettings() unreadTTL := cacheSettings.UnreadCountTTL if unreadTTL <= 0 { @@ -166,7 +238,7 @@ func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID stri cacheKey := cache.UnreadDetailKey(userID, conversationID) return cache.GetOrLoadTyped[int64]( - s.cache, + s.baseCache, cacheKey, unreadTTL, jitter, @@ -186,14 +258,18 @@ func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, u } // 失效会话列表缓存 - cache.InvalidateConversationList(s.cache, user1ID) - cache.InvalidateConversationList(s.cache, user2ID) + s.conversationCache.InvalidateConversationList(user1ID) + s.conversationCache.InvalidateConversationList(user2ID) return conv, nil } // GetConversationParticipants 获取会话参与者列表 func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) { + // 优先使用缓存 + if s.conversationCache != nil { + return s.conversationCache.GetParticipants(context.Background(), conversationID) + } return s.messageRepo.GetConversationParticipants(conversationID) } @@ -204,12 +280,12 @@ func ParseConversationID(idStr string) (string, error) { // InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用) func (s *MessageService) InvalidateUserConversationCache(userID string) { - cache.InvalidateConversationList(s.cache, userID) - cache.InvalidateUnreadConversation(s.cache, userID) + s.conversationCache.InvalidateConversationList(userID) + cache.InvalidateUnreadConversation(s.baseCache, userID) } // InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用) func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) { - cache.InvalidateUnreadConversation(s.cache, userID) - cache.InvalidateUnreadDetail(s.cache, userID, conversationID) + cache.InvalidateUnreadConversation(s.baseCache, userID) + s.conversationCache.InvalidateUnreadCount(userID, conversationID) } diff --git a/internal/service/post_service.go b/internal/service/post_service.go index 99272fe..b4b7049 100644 --- a/internal/service/post_service.go +++ b/internal/service/post_service.go @@ -73,9 +73,20 @@ func (s *PostService) Create(ctx context.Context, userID, title, content string, } func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) { + defer func() { + if r := recover(); r != nil { + log.Printf("[ERROR] Panic in post moderation async flow, fallback publish post=%s panic=%v", postID, r) + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil { + log.Printf("[WARN] Failed to publish post %s after panic recovery: %v", postID, err) + return + } + s.invalidatePostCaches(postID) + } + }() + // 未启用AI时,直接发布 if s.postAIService == nil || !s.postAIService.IsEnabled() { - if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil { + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil { log.Printf("[WARN] Failed to publish post without AI moderation: %v", err) } else { s.invalidatePostCaches(postID) @@ -87,7 +98,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima if err != nil { var rejectedErr *PostModerationRejectedError if errors.As(err, &rejectedErr) { - if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { + if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr) } else { s.invalidatePostCaches(postID) @@ -97,7 +108,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima } // 规则审核不可用时,降级为发布,避免长时间pending - if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil { + if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil { log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr) } else { s.invalidatePostCaches(postID) @@ -106,7 +117,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima return } - if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil { + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil { log.Printf("[WARN] Failed to publish post %s: %v", postID, err) return } @@ -127,6 +138,26 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima } } +func (s *PostService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error { + const maxAttempts = 3 + const retryDelay = 200 * time.Millisecond + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil { + lastErr = err + if attempt < maxAttempts { + log.Printf("[WARN] UpdateModerationStatus failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err) + time.Sleep(time.Duration(attempt) * retryDelay) + continue + } + } else { + return nil + } + } + return lastErr +} + func (s *PostService) invalidatePostCaches(postID string) { cache.InvalidatePostDetail(s.cache, postID) cache.InvalidatePostList(s.cache) diff --git a/internal/service/schedule_service.go b/internal/service/schedule_service.go new file mode 100644 index 0000000..4f5e44c --- /dev/null +++ b/internal/service/schedule_service.go @@ -0,0 +1,207 @@ +package service + +import ( + "encoding/json" + "errors" + "regexp" + "sort" + "strings" + + "carrot_bbs/internal/dto" + "carrot_bbs/internal/model" + "carrot_bbs/internal/repository" + + "gorm.io/gorm" +) + +var ( + ErrInvalidSchedulePayload = &ServiceError{Code: 400, Message: "invalid schedule payload"} + ErrScheduleCourseNotFound = &ServiceError{Code: 404, Message: "schedule course not found"} + ErrScheduleForbidden = &ServiceError{Code: 403, Message: "forbidden schedule operation"} + ErrScheduleColorDuplicated = &ServiceError{Code: 400, Message: "course color already used"} +) + +var hexColorRegex = regexp.MustCompile(`^#[0-9A-F]{6}$`) + +type CreateScheduleCourseInput struct { + Name string + Teacher string + Location string + DayOfWeek int + StartSection int + EndSection int + Weeks []int + Color string +} + +type ScheduleService interface { + ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error) + CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) + UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) + DeleteCourse(userID, courseID string) error +} + +type scheduleService struct { + repo repository.ScheduleRepository +} + +func NewScheduleService(repo repository.ScheduleRepository) ScheduleService { + return &scheduleService{repo: repo} +} + +func (s *scheduleService) ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error) { + courses, err := s.repo.ListByUserID(userID) + if err != nil { + return nil, err + } + + result := make([]*dto.ScheduleCourseResponse, 0, len(courses)) + for _, item := range courses { + weeks := dto.ParseWeeksJSON(item.Weeks) + if week > 0 && !containsWeek(weeks, week) { + continue + } + result = append(result, dto.ConvertScheduleCourseToResponse(item, weeks)) + } + return result, nil +} + +func (s *scheduleService) CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) { + entity, weeks, err := buildScheduleEntity(userID, input, nil) + if err != nil { + return nil, err + } + if err := s.ensureUniqueColor(userID, entity.Color, ""); err != nil { + return nil, err + } + if err := s.repo.Create(entity); err != nil { + return nil, err + } + return dto.ConvertScheduleCourseToResponse(entity, weeks), nil +} + +func (s *scheduleService) UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) { + existing, err := s.repo.GetByID(courseID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrScheduleCourseNotFound + } + return nil, err + } + if existing.UserID != userID { + return nil, ErrScheduleForbidden + } + + entity, weeks, err := buildScheduleEntity(userID, input, existing) + if err != nil { + return nil, err + } + if err := s.ensureUniqueColor(userID, entity.Color, entity.ID); err != nil { + return nil, err + } + if err := s.repo.Update(entity); err != nil { + return nil, err + } + return dto.ConvertScheduleCourseToResponse(entity, weeks), nil +} + +func (s *scheduleService) DeleteCourse(userID, courseID string) error { + existing, err := s.repo.GetByID(courseID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrScheduleCourseNotFound + } + return err + } + if existing.UserID != userID { + return ErrScheduleForbidden + } + return s.repo.DeleteByID(courseID) +} + +func buildScheduleEntity(userID string, input CreateScheduleCourseInput, target *model.ScheduleCourse) (*model.ScheduleCourse, []int, error) { + name := strings.TrimSpace(input.Name) + if name == "" || input.DayOfWeek < 0 || input.DayOfWeek > 6 || input.StartSection < 1 || input.EndSection < input.StartSection { + return nil, nil, ErrInvalidSchedulePayload + } + + weeks := normalizeWeeks(input.Weeks) + if len(weeks) == 0 { + return nil, nil, ErrInvalidSchedulePayload + } + weeksJSON, err := json.Marshal(weeks) + if err != nil { + return nil, nil, err + } + + entity := target + if entity == nil { + entity = &model.ScheduleCourse{ + UserID: userID, + } + } + + normalizedColor := normalizeHexColor(input.Color) + if normalizedColor == "" || !hexColorRegex.MatchString(normalizedColor) { + return nil, nil, ErrInvalidSchedulePayload + } + + entity.Name = name + entity.Teacher = strings.TrimSpace(input.Teacher) + entity.Location = strings.TrimSpace(input.Location) + entity.DayOfWeek = input.DayOfWeek + entity.StartSection = input.StartSection + entity.EndSection = input.EndSection + entity.Weeks = string(weeksJSON) + entity.Color = normalizedColor + + return entity, weeks, nil +} + +func (s *scheduleService) ensureUniqueColor(userID, color, excludeID string) error { + exists, err := s.repo.ExistsColorByUser(userID, color, excludeID) + if err != nil { + return err + } + if exists { + return ErrScheduleColorDuplicated + } + return nil +} + +func normalizeWeeks(source []int) []int { + unique := make(map[int]struct{}, len(source)) + result := make([]int, 0, len(source)) + for _, w := range source { + if w < 1 || w > 30 { + continue + } + if _, exists := unique[w]; exists { + continue + } + unique[w] = struct{}{} + result = append(result, w) + } + sort.Ints(result) + return result +} + +func containsWeek(weeks []int, target int) bool { + for _, week := range weeks { + if week == target { + return true + } + } + return false +} + +func normalizeHexColor(color string) string { + trimmed := strings.TrimSpace(color) + if trimmed == "" { + return "" + } + if strings.HasPrefix(trimmed, "#") { + return strings.ToUpper(trimmed) + } + return "#" + strings.ToUpper(trimmed) +} diff --git a/internal/service/vote_service.go b/internal/service/vote_service.go index 46482ff..10c798f 100644 --- a/internal/service/vote_service.go +++ b/internal/service/vote_service.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "strings" + "time" "carrot_bbs/internal/cache" "carrot_bbs/internal/dto" @@ -84,8 +85,17 @@ func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dt } func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) { + defer func() { + if r := recover(); r != nil { + log.Printf("[ERROR] Panic in vote post moderation async flow, fallback publish post=%s panic=%v", postID, r) + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil { + log.Printf("[WARN] Failed to publish vote post %s after panic recovery: %v", postID, err) + } + } + }() + if s.postAIService == nil || !s.postAIService.IsEnabled() { - if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil { + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil { log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err) } return @@ -95,24 +105,44 @@ func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, if err != nil { var rejectedErr *PostModerationRejectedError if errors.As(err, &rejectedErr) { - if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { + if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil { log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr) } s.notifyModerationRejected(userID, rejectedErr.Reason) return } - if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil { + if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil { log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr) } return } - if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil { + if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil { log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err) } } +func (s *VoteService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error { + const maxAttempts = 3 + const retryDelay = 200 * time.Millisecond + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil { + lastErr = err + if attempt < maxAttempts { + log.Printf("[WARN] UpdateModerationStatus for vote post failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err) + time.Sleep(time.Duration(attempt) * retryDelay) + continue + } + } else { + return nil + } + } + return lastErr +} + func (s *VoteService) notifyModerationRejected(userID, reason string) { if s.systemMessageService == nil || strings.TrimSpace(userID) == "" { return diff --git a/scripts/test_moderation.go b/scripts/test_moderation.go new file mode 100644 index 0000000..f06b818 --- /dev/null +++ b/scripts/test_moderation.go @@ -0,0 +1,263 @@ +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const moderationSystemPrompt = `你是中文社区的内容审核助手,负责对"帖子标题、正文、配图"做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。 + +审核流程: +1) 先判断是否命中硬性违规; +2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论); +3) 做文图交叉判断(文本+图片合并理解); +4) 给出 approved 与简短 reason。 + +硬性违规(命中任一项必须 approved=false): +A. 宣传对立与煽动撕裂: +- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。 +B. 严重人身攻击与网暴引导: +- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。 +C. 开盒/人肉/隐私暴露: +- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等); +- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。 +D. 其他高危违规: +- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。 + +放行规则(以下通常 approved=true): +- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽; +- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露); +- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。 + +边界判定: +- 若只是"梗文化表达"且不指向现实伤害,优先通过; +- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝; +- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。 + +reason 要求: +- approved=false 时:中文10-30字,说明核心违规点; +- approved=true 时:reason 为空字符串。 + +输出格式(严格): +仅输出一行JSON对象,不要Markdown,不要额外解释: +{"approved": true/false, "reason": "..."}` + +type chatMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +type contentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +type chatCompletionsRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen3.5思考模式控制 + ThinkingBudget *int `json:"thinking_budget,omitempty"` // 思考过程最大token数 + ResponseFormat *responseFormatConfig `json:"response_format,omitempty"` // 响应格式 +} + +type responseFormatConfig struct { + Type string `json:"type"` // "text" or "json_object" +} + +type chatCompletionsResponse struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +func main() { + baseURL := flag.String("url", "https://api.littlelan.cn/", "API base URL") + apiKey := flag.String("key", "", "API key") + model := flag.String("model", "qwen3.5-plus", "Model name") + maxTokens := flag.Int("max-tokens", 220, "Max tokens for completion") + enableThinking := flag.Bool("enable-thinking", false, "Enable thinking mode for qwen3.5") + flag.Parse() + + if *apiKey == "" { + fmt.Println("Error: API key is required. Use -key flag") + return + } + + // 测试用例 + testCases := []struct { + name string + content string + }{ + { + name: "简单正常内容", + content: "帖子标题:今天天气真好\n帖子内容:出门散步,心情愉快!", + }, + { + name: "中等长度内容", + content: "帖子标题:分享我的学习经验\n帖子内容:最近在学习Go语言,发现这门语言真的很适合后端开发。并发处理特别方便,goroutine和channel的设计非常优雅。有一起学习的小伙伴吗?", + }, + { + name: "较长内容", + content: "帖子标题:关于校园生活的一些思考\n帖子内容:大学四年转眼就过去了,回想起来有很多感慨。刚入学的时候什么都不懂,现在感觉自己成长了很多。在这里想分享一些自己的经验,希望能对学弟学妹们有所帮助。首先是学习方面,一定要认真听课,做好笔记。其次是社交方面,多参加社团活动,结交志同道合的朋友。最后是规划方面,早点想清楚自己想做什么,为之努力。", + }, + } + + client := &http.Client{Timeout: 120 * time.Second} + + fmt.Println("============================================") + fmt.Printf("模型: %s\n", *model) + fmt.Printf("API URL: %s\n", *baseURL) + fmt.Printf("MaxTokens 设置: %d\n", *maxTokens) + fmt.Printf("EnableThinking: %v\n", *enableThinking) + fmt.Println("============================================") + + for _, tc := range testCases { + fmt.Printf("\n========== 测试: %s ==========\n", tc.name) + fmt.Printf("内容长度: %d 字符\n", len(tc.content)) + + userPrompt := fmt.Sprintf("%s\n图片批次:1/1(本次仅提供当前批次图片)", tc.content) + + reqBody := chatCompletionsRequest{ + Model: *model, + Messages: []chatMessage{ + {Role: "system", Content: moderationSystemPrompt}, + {Role: "user", Content: []contentPart{{Type: "text", Text: userPrompt}}}, + }, + Temperature: 0.1, + MaxTokens: *maxTokens, + } + // 设置思考模式 + if !*enableThinking { + reqBody.EnableThinking = enableThinking + // 设置思考预算为0,完全禁用思考 + zero := 0 + reqBody.ThinkingBudget = &zero + } + // 使用JSON输出格式 + reqBody.ResponseFormat = &responseFormatConfig{Type: "json_object"} + + data, err := json.Marshal(reqBody) + if err != nil { + fmt.Printf("Error marshaling request: %v\n", err) + continue + } + + endpoint := strings.TrimRight(*baseURL, "/") + "/v1/chat/completions" + if strings.HasSuffix(strings.TrimRight(*baseURL, "/"), "/v1") { + endpoint = strings.TrimRight(*baseURL, "/") + "/chat/completions" + } + + req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(data)) + if err != nil { + fmt.Printf("Error creating request: %v\n", err) + continue + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+*apiKey) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + fmt.Printf("Error sending request: %v\n", err) + continue + } + + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + fmt.Printf("Error reading response: %v\n", err) + continue + } + + elapsed := time.Since(start) + + if resp.StatusCode >= 400 { + fmt.Printf("API Error: status=%d, body=%s\n", resp.StatusCode, string(body)) + continue + } + + var parsed chatCompletionsResponse + if err := json.Unmarshal(body, &parsed); err != nil { + fmt.Printf("Error parsing response: %v\n", err) + fmt.Printf("Raw response: %s\n", string(body)) + continue + } + + if len(parsed.Choices) == 0 { + fmt.Println("No choices in response") + fmt.Printf("Raw response: %s\n", string(body)) + continue + } + + fmt.Printf("响应时间: %v\n", elapsed) + fmt.Printf("Finish Reason: %s\n", parsed.Choices[0].FinishReason) + fmt.Printf("Token使用情况:\n") + fmt.Printf(" - PromptTokens: %d\n", parsed.Usage.PromptTokens) + fmt.Printf(" - CompletionTokens: %d\n", parsed.Usage.CompletionTokens) + fmt.Printf(" - TotalTokens: %d\n", parsed.Usage.TotalTokens) + + output := parsed.Choices[0].Message.Content + fmt.Printf("输出内容长度: %d 字符\n", len(output)) + + // 检查输出是否符合预期 + if parsed.Usage.CompletionTokens > *maxTokens { + fmt.Printf("\n⚠️ 警告: CompletionTokens (%d) 超过了 max_tokens 设置 (%d)!\n", + parsed.Usage.CompletionTokens, *maxTokens) + } + + if len(output) > 500 { + fmt.Printf("\n⚠️ 警告: 输出内容过长! 长度=%d\n", len(output)) + fmt.Printf("前500字符:\n%s...\n", output[:min(500, len(output))]) + } else { + fmt.Printf("输出内容: %s\n", output) + } + + // 尝试解析JSON + extractJSONObject := func(raw string) string { + text := strings.TrimSpace(raw) + start := strings.Index(text, "{") + end := strings.LastIndex(text, "}") + if start >= 0 && end > start { + return text[start : end+1] + } + return text + } + + jsonStr := extractJSONObject(output) + var result struct { + Approved bool `json:"approved"` + Reason string `json:"reason"` + } + if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { + fmt.Printf("\n⚠️ 警告: 无法解析JSON输出: %v\n", err) + fmt.Printf("提取的JSON: %s\n", jsonStr) + } else { + fmt.Printf("\n✓ 解析成功: approved=%v, reason=\"%s\"\n", result.Approved, result.Reason) + } + } + + fmt.Println("\n========== 测试完成 ==========") +} + +func min(a, b int) int { + if a < b { + return a + } + return b +}