feat(schedule): add course table screens and navigation

Add complete schedule functionality including:
- Schedule screen with weekly course table view
- Course detail screen with transparent modal presentation
- New ScheduleStack navigator integrated into main tab bar
- Schedule service for API interactions
- Type definitions for course entities

Also includes bug fixes for group invite/request handlers
to include required groupId parameter.
This commit is contained in:
2026-03-12 08:38:14 +08:00
parent 21293644b8
commit 0a0cbacbcc
25 changed files with 3050 additions and 260 deletions

View File

@@ -5,7 +5,10 @@ import (
"encoding/json"
"fmt"
"log"
"math"
"math/rand"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -34,6 +37,38 @@ type Cache interface {
Increment(key string) int64
// IncrementBy 增加指定值
IncrementBy(key string, value int64) int64
// ==================== Hash 操作 ====================
// HSet 设置 Hash 字段
HSet(ctx context.Context, key string, field string, value interface{}) error
// HMSet 批量设置 Hash 字段
HMSet(ctx context.Context, key string, values map[string]interface{}) error
// HGet 获取 Hash 字段值
HGet(ctx context.Context, key string, field string) (string, error)
// HMGet 批量获取 Hash 字段值
HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error)
// HGetAll 获取 Hash 所有字段
HGetAll(ctx context.Context, key string) (map[string]string, error)
// HDel 删除 Hash 字段
HDel(ctx context.Context, key string, fields ...string) error
// ==================== Sorted Set 操作 ====================
// ZAdd 添加 Sorted Set 成员
ZAdd(ctx context.Context, key string, score float64, member string) error
// ZRangeByScore 按分数范围获取成员(升序)
ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error)
// ZRevRangeByScore 按分数范围获取成员(降序)
ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error)
// ZRem 删除 Sorted Set 成员
ZRem(ctx context.Context, key string, members ...interface{}) error
// ZCard 获取 Sorted Set 成员数量
ZCard(ctx context.Context, key string) (int64, error)
// ==================== 计数器操作 ====================
// Incr 原子递增(返回新值)
Incr(ctx context.Context, key string) (int64, error)
// Expire 设置过期时间
Expire(ctx context.Context, key string, ttl time.Duration) error
}
// cacheItem 缓存项(用于内存缓存降级)
@@ -64,16 +99,16 @@ type MetricsSnapshot struct {
}
type Settings struct {
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
}
var settings = Settings{
@@ -327,6 +362,378 @@ func (c *MemoryCache) Stop() {
close(c.stopCleanup)
}
// ==================== MemoryCache Hash 操作 ====================
// hashItem Hash 存储项
type hashItem struct {
fields sync.Map
}
// HSet 设置 Hash 字段
func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
key = normalizeKey(key)
item, _ := c.items.Load(key)
var h *hashItem
if item == nil {
h = &hashItem{}
c.items.Store(key, &cacheItem{value: h, expiration: 0})
} else {
ci := item.(*cacheItem)
if ci.isExpired() {
h = &hashItem{}
c.items.Store(key, &cacheItem{value: h, expiration: 0})
} else {
h = ci.value.(*hashItem)
}
}
h.fields.Store(field, value)
return nil
}
// HMSet 批量设置 Hash 字段
func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
for field, value := range values {
if err := c.HSet(ctx, key, field, value); err != nil {
return err
}
}
return nil
}
// HGet 获取 Hash 字段值
func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return "", fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return "", fmt.Errorf("key not found")
}
h, ok := ci.value.(*hashItem)
if !ok {
return "", fmt.Errorf("key is not a hash")
}
val, ok := h.fields.Load(field)
if !ok {
return "", fmt.Errorf("field not found")
}
switch v := val.(type) {
case string:
return v, nil
case []byte:
return string(v), nil
default:
data, _ := json.Marshal(v)
return string(data), nil
}
}
// HMGet 批量获取 Hash 字段值
func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
result := make([]interface{}, len(fields))
for i, field := range fields {
val, err := c.HGet(ctx, key, field)
if err != nil {
result[i] = nil
} else {
result[i] = val
}
}
return result, nil
}
// HGetAll 获取 Hash 所有字段
func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, fmt.Errorf("key not found")
}
h, ok := ci.value.(*hashItem)
if !ok {
return nil, fmt.Errorf("key is not a hash")
}
result := make(map[string]string)
h.fields.Range(func(k, v interface{}) bool {
keyStr := k.(string)
switch val := v.(type) {
case string:
result[keyStr] = val
case []byte:
result[keyStr] = string(val)
default:
data, _ := json.Marshal(val)
result[keyStr] = string(data)
}
return true
})
return result, nil
}
// HDel 删除 Hash 字段
func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil
}
h, ok := ci.value.(*hashItem)
if !ok {
return nil
}
for _, field := range fields {
h.fields.Delete(field)
}
return nil
}
// ==================== MemoryCache Sorted Set 操作 ====================
// zItem Sorted Set 成员
type zItem struct {
score float64
member string
}
// zsetItem Sorted Set 存储项
type zsetItem struct {
members sync.Map // member -> *zItem
byScore *sortedSlice // 按分数排序的切片
}
// sortedSlice 简单的排序切片实现
type sortedSlice struct {
items []*zItem
mu sync.RWMutex
}
// ZAdd 添加 Sorted Set 成员
func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
key = normalizeKey(key)
item, _ := c.items.Load(key)
var z *zsetItem
if item == nil {
z = &zsetItem{byScore: &sortedSlice{}}
c.items.Store(key, &cacheItem{value: z, expiration: 0})
} else {
ci := item.(*cacheItem)
if ci.isExpired() {
z = &zsetItem{byScore: &sortedSlice{}}
c.items.Store(key, &cacheItem{value: z, expiration: 0})
} else {
z = ci.value.(*zsetItem)
}
}
z.members.Store(member, &zItem{score: score, member: member})
z.byScore.mu.Lock()
// 简单实现:重新构建排序切片
z.byScore.items = nil
z.members.Range(func(k, v interface{}) bool {
z.byScore.items = append(z.byScore.items, v.(*zItem))
return true
})
// 按分数排序
sort.Slice(z.byScore.items, func(i, j int) bool {
return z.byScore.items[i].score < z.byScore.items[j].score
})
z.byScore.mu.Unlock()
return nil
}
// ZRangeByScore 按分数范围获取成员(升序)
func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil, fmt.Errorf("key is not a sorted set")
}
minScore, _ := strconv.ParseFloat(min, 64)
maxScore, _ := strconv.ParseFloat(max, 64)
if min == "-inf" {
minScore = math.Inf(-1)
}
if max == "+inf" {
maxScore = math.Inf(1)
}
z.byScore.mu.RLock()
defer z.byScore.mu.RUnlock()
var result []string
var skipped int64 = 0
for _, item := range z.byScore.items {
if item.score < minScore || item.score > maxScore {
continue
}
if skipped < offset {
skipped++
continue
}
if count > 0 && int64(len(result)) >= count {
break
}
result = append(result, item.member)
}
return result, nil
}
// ZRevRangeByScore 按分数范围获取成员(降序)
func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil, fmt.Errorf("key is not a sorted set")
}
minScore, _ := strconv.ParseFloat(min, 64)
maxScore, _ := strconv.ParseFloat(max, 64)
if min == "-inf" {
minScore = math.Inf(-1)
}
if max == "+inf" {
maxScore = math.Inf(1)
}
z.byScore.mu.RLock()
defer z.byScore.mu.RUnlock()
var result []string
var skipped int64 = 0
// 从后往前遍历
for i := len(z.byScore.items) - 1; i >= 0; i-- {
item := z.byScore.items[i]
if item.score < minScore || item.score > maxScore {
continue
}
if skipped < offset {
skipped++
continue
}
if count > 0 && int64(len(result)) >= count {
break
}
result = append(result, item.member)
}
return result, nil
}
// ZRem 删除 Sorted Set 成员
func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return nil
}
for _, m := range members {
if member, ok := m.(string); ok {
z.members.Delete(member)
}
}
// 重建排序切片
z.byScore.mu.Lock()
z.byScore.items = nil
z.members.Range(func(k, v interface{}) bool {
z.byScore.items = append(z.byScore.items, v.(*zItem))
return true
})
sort.Slice(z.byScore.items, func(i, j int) bool {
return z.byScore.items[i].score < z.byScore.items[j].score
})
z.byScore.mu.Unlock()
return nil
}
// ZCard 获取 Sorted Set 成员数量
func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return 0, nil
}
ci := item.(*cacheItem)
if ci.isExpired() {
c.items.Delete(key)
return 0, nil
}
z, ok := ci.value.(*zsetItem)
if !ok {
return 0, fmt.Errorf("key is not a sorted set")
}
var count int64 = 0
z.members.Range(func(k, v interface{}) bool {
count++
return true
})
return count, nil
}
// ==================== MemoryCache 计数器操作 ====================
// Incr 原子递增(返回新值)
func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) {
return c.IncrementBy(key, 1), nil
}
// Expire 设置过期时间
func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
key = normalizeKey(key)
item, ok := c.items.Load(key)
if !ok {
return fmt.Errorf("key not found")
}
ci := item.(*cacheItem)
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items.Store(key, &cacheItem{
value: ci.value,
expiration: expiration,
})
return nil
}
// RedisCache Redis缓存实现
type RedisCache struct {
client *redisPkg.Client
@@ -451,6 +858,91 @@ func (c *RedisCache) IncrementBy(key string, value int64) int64 {
return result
}
// ==================== RedisCache Hash 操作 ====================
// HSet 设置 Hash 字段
func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
key = normalizeKey(key)
return c.client.HSet(ctx, key, field, value)
}
// HMSet 批量设置 Hash 字段
func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
key = normalizeKey(key)
return c.client.HMSet(ctx, key, values)
}
// HGet 获取 Hash 字段值
func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
key = normalizeKey(key)
return c.client.HGet(ctx, key, field)
}
// HMGet 批量获取 Hash 字段值
func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
key = normalizeKey(key)
return c.client.HMGet(ctx, key, fields...)
}
// HGetAll 获取 Hash 所有字段
func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
key = normalizeKey(key)
return c.client.HGetAll(ctx, key)
}
// HDel 删除 Hash 字段
func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
key = normalizeKey(key)
return c.client.HDel(ctx, key, fields...)
}
// ==================== RedisCache Sorted Set 操作 ====================
// ZAdd 添加 Sorted Set 成员
func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
key = normalizeKey(key)
return c.client.ZAdd(ctx, key, score, member)
}
// ZRangeByScore 按分数范围获取成员(升序)
func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
return c.client.ZRangeByScore(ctx, key, min, max, offset, count)
}
// ZRevRangeByScore 按分数范围获取成员(降序)
func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
key = normalizeKey(key)
return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count)
}
// ZRem 删除 Sorted Set 成员
func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
key = normalizeKey(key)
return c.client.ZRem(ctx, key, members...)
}
// ZCard 获取 Sorted Set 成员数量
func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
return c.client.ZCard(ctx, key)
}
// ==================== RedisCache 计数器操作 ====================
// Incr 原子递增(返回新值)
func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) {
key = normalizeKey(key)
return c.client.Incr(ctx, key)
}
// Expire 设置过期时间
func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
key = normalizeKey(key)
_, err := c.client.Expire(ctx, key, ttl)
return err
}
// 全局缓存实例
var globalCache Cache
var once sync.Once

724
internal/cache/conversation_cache.go vendored Normal file
View File

@@ -0,0 +1,724 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"carrot_bbs/internal/model"
)
// CachedConversation 带缓存元数据的会话
type CachedConversation struct {
Data *model.Conversation // 实际数据
Version int64 // 版本号CAS 更新用)
UpdatedAt time.Time // 最后更新时间
AccessAt time.Time // 最后访问时间(用于 TTL 延长)
}
// CachedParticipant 带缓存元数据的参与者
type CachedParticipant struct {
Data *model.ConversationParticipant
Version int64
UpdatedAt time.Time
AccessAt time.Time
}
// CachedMessage 带缓存元数据的消息
type CachedMessage struct {
Data *model.Message `json:"data"` // 消息数据
Seq int64 `json:"seq"` // 消息序号
CreatedAt time.Time `json:"created_at"` // 创建时间
}
// MessageCacheData Redis 中存储的消息 Hash 结构
type MessageCacheData struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
SenderID string `json:"sender_id"`
Seq int64 `json:"seq"`
Segments json.RawMessage `json:"segments"`
ReplyToID *string `json:"reply_to_id,omitempty"`
Status string `json:"status"`
Category string `json:"category"`
SystemType string `json:"system_type,omitempty"`
ExtraData json.RawMessage `json:"extra_data,omitempty"`
MentionUsers string `json:"mention_users"`
MentionAll bool `json:"mention_all"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// PageCache 分页缓存
type PageCache struct {
Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表
Total int64 `json:"total"` // 消息总数
Page int `json:"page"` // 当前页码
PageSize int `json:"page_size"` // 每页大小
HasMore bool `json:"has_more"` // 是否有更多
UpdatedAt time.Time `json:"updated_at"` // 更新时间
}
// ConversationCacheSettings 缓存配置
type ConversationCacheSettings struct {
DetailTTL time.Duration // 会话详情 TTL (5min)
ListTTL time.Duration // 会话列表 TTL (60s)
ParticipantTTL time.Duration // 参与者 TTL (5min)
UnreadTTL time.Duration // 未读数 TTL (30s)
// 消息缓存配置
MessageDetailTTL time.Duration // 单条消息详情缓存 (30min)
MessageListTTL time.Duration // 消息分页列表缓存 (5min)
MessageIndexTTL time.Duration // 消息索引缓存 (30min)
MessageCountTTL time.Duration // 消息计数缓存 (30min)
}
// DefaultConversationCacheSettings 返回默认配置
func DefaultConversationCacheSettings() *ConversationCacheSettings {
return &ConversationCacheSettings{
DetailTTL: 5 * time.Minute,
ListTTL: 60 * time.Second,
ParticipantTTL: 5 * time.Minute,
UnreadTTL: 30 * time.Second,
MessageDetailTTL: 30 * time.Minute,
MessageListTTL: 5 * time.Minute,
MessageIndexTTL: 30 * time.Minute,
MessageCountTTL: 30 * time.Minute,
}
}
// parseSegments 将 JSON RawMessage 解析为 MessageSegments
func parseSegments(data json.RawMessage) model.MessageSegments {
if data == nil {
return nil
}
var segments model.MessageSegments
if err := json.Unmarshal(data, &segments); err != nil {
return nil
}
return segments
}
// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage
func serializeSegments(segments model.MessageSegments) json.RawMessage {
if segments == nil {
return nil
}
data, err := json.Marshal(segments)
if err != nil {
return nil
}
return data
}
// ToModel 将 MessageCacheData 转换为 model.Message
func (m *MessageCacheData) ToModel() *model.Message {
return &model.Message{
ID: m.ID,
ConversationID: m.ConversationID,
SenderID: m.SenderID,
Seq: m.Seq,
Segments: parseSegments(m.Segments),
ReplyToID: m.ReplyToID,
Status: model.MessageStatus(m.Status),
Category: model.MessageCategory(m.Category),
SystemType: model.SystemMessageType(m.SystemType),
ExtraData: parseExtraData(m.ExtraData),
MentionUsers: m.MentionUsers,
MentionAll: m.MentionAll,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData
func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData {
return &MessageCacheData{
ID: msg.ID,
ConversationID: msg.ConversationID,
SenderID: msg.SenderID,
Seq: msg.Seq,
Segments: serializeSegments(msg.Segments),
ReplyToID: msg.ReplyToID,
Status: string(msg.Status),
Category: string(msg.Category),
SystemType: string(msg.SystemType),
ExtraData: serializeExtraData(msg.ExtraData),
MentionUsers: msg.MentionUsers,
MentionAll: msg.MentionAll,
CreatedAt: msg.CreatedAt,
UpdatedAt: msg.UpdatedAt,
}
}
// parseExtraData 将 JSON RawMessage 解析为 ExtraData
func parseExtraData(data json.RawMessage) *model.ExtraData {
if data == nil {
return nil
}
var extraData model.ExtraData
if err := json.Unmarshal(data, &extraData); err != nil {
return nil
}
return &extraData
}
// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage
func serializeExtraData(extraData *model.ExtraData) json.RawMessage {
if extraData == nil {
return nil
}
data, err := json.Marshal(extraData)
if err != nil {
return nil
}
return data
}
// ============================================================
// 缓存 Key 常量和生成函数
// ============================================================
const (
keyPrefixConv = "conv" // 会话详情
keyPrefixConvPart = "conv_part" // 参与者列表
keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息
)
// ConversationKey 会话详情缓存 key
func ConversationKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConv, convID)
}
// ParticipantListKey 参与者列表缓存 key
func ParticipantListKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID)
}
// ParticipantKey 用户参与者信息缓存 key
func ParticipantKey(convID, userID string) string {
return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID)
}
// ============================================================
// ConversationRepository 接口定义
// ============================================================
// ConversationRepository 会话数据仓库接口(用于依赖注入)
type ConversationRepository interface {
GetConversationByID(convID string) (*model.Conversation, error)
GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error)
GetParticipant(convID, userID string) (*model.ConversationParticipant, error)
GetParticipants(convID string) ([]*model.ConversationParticipant, error)
GetUnreadCount(userID, convID string) (int64, error)
}
// MessageRepository 消息数据仓库接口(用于依赖注入)
type MessageRepository interface {
GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error)
GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error)
GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error)
}
// ============================================================
// ConversationCache 核心实现
// ============================================================
// ConversationCache 会话缓存管理器
type ConversationCache struct {
cache Cache // 底层缓存
settings *ConversationCacheSettings // 配置
repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源)
msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源)
}
// NewConversationCache 创建会话缓存管理器
func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache {
if settings == nil {
settings = DefaultConversationCacheSettings()
}
return &ConversationCache{
cache: cache,
settings: settings,
repo: repo,
msgRepo: msgRepo,
}
}
// GetConversation 读取会话(带 TTL 滑动延长)
// 1. 尝试从缓存获取
// 2. 如果命中,更新 AccessAt 并延长 TTL
// 3. 如果未命中,从 repo 加载并写入缓存
func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) {
key := ConversationKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversation](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.DetailTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
conv, err := c.repo.GetConversationByID(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedConv := &CachedConversation{
Data: conv,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedConv, c.settings.DetailTTL)
return conv, nil
}
// CachedConversationList 带元数据的会话列表缓存
type CachedConversationList struct {
Conversations []*model.Conversation `json:"conversations"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetConversationList 获取用户会话列表(带 TTL 滑动延长)
func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
key := ConversationListKey(userID, page, pageSize)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedConversationList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ListTTL)
return cached.Conversations, cached.Total, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, 0, fmt.Errorf("repository not configured")
}
convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedConversationList{
Conversations: convs,
Total: total,
Page: page,
PageSize: pageSize,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ListTTL)
return convs, total, nil
}
// GetParticipant 获取参与者信息(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) {
key := ParticipantKey(convID, userID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipant](c.cache, key)
if ok && cached != nil && cached.Data != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Data, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participant, err := c.repo.GetParticipant(convID, userID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedPart := &CachedParticipant{
Data: participant,
Version: 0,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedPart, c.settings.ParticipantTTL)
return participant, nil
}
// CachedParticipantList 带元数据的参与者列表缓存
type CachedParticipantList struct {
Participants []*model.ConversationParticipant `json:"participants"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetParticipants 获取会话所有参与者(带 TTL 滑动延长)
func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) {
key := ParticipantListKey(convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedParticipantList](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.ParticipantTTL)
return cached.Participants, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return nil, fmt.Errorf("repository not configured")
}
participants, err := c.repo.GetParticipants(convID)
if err != nil {
return nil, err
}
// 写入缓存
now := time.Now()
cachedList := &CachedParticipantList{
Participants: participants,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedList, c.settings.ParticipantTTL)
return participants, nil
}
// CachedUnreadCount 带元数据的未读数缓存
type CachedUnreadCount struct {
Count int64 `json:"count"`
UpdatedAt time.Time `json:"updated_at"`
AccessAt time.Time `json:"access_at"`
}
// GetUnreadCount 获取未读数(带 TTL 滑动延长)
func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) {
key := UnreadDetailKey(userID, convID)
// 1. 尝试从缓存获取
cached, ok := GetTyped[*CachedUnreadCount](c.cache, key)
if ok && cached != nil {
// 2. 命中,更新 AccessAt 并延长 TTL
cached.AccessAt = time.Now()
c.cache.Set(key, cached, c.settings.UnreadTTL)
return cached.Count, nil
}
// 3. 未命中,从 repo 加载
if c.repo == nil {
return 0, fmt.Errorf("repository not configured")
}
count, err := c.repo.GetUnreadCount(userID, convID)
if err != nil {
return 0, err
}
// 写入缓存
now := time.Now()
cachedCount := &CachedUnreadCount{
Count: count,
UpdatedAt: now,
AccessAt: now,
}
c.cache.Set(key, cachedCount, c.settings.UnreadTTL)
return count, nil
}
// ============================================================
// 缓存失效方法
// ============================================================
// InvalidateConversation 使会话缓存失效
func (c *ConversationCache) InvalidateConversation(convID string) {
c.cache.Delete(ConversationKey(convID))
}
// InvalidateConversationList 使会话列表缓存失效(按用户)
func (c *ConversationCache) InvalidateConversationList(userID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID))
}
// InvalidateParticipant 使参与者缓存失效
func (c *ConversationCache) InvalidateParticipant(convID, userID string) {
c.cache.Delete(ParticipantKey(convID, userID))
}
// InvalidateParticipantList 使参与者列表缓存失效
func (c *ConversationCache) InvalidateParticipantList(convID string) {
c.cache.Delete(ParticipantListKey(convID))
}
// InvalidateUnreadCount 使未读数缓存失效
func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) {
c.cache.Delete(UnreadDetailKey(userID, convID))
}
// ============================================================
// 消息缓存方法
// ============================================================
// GetMessages 获取会话消息(带缓存)
// 1. 尝试从分页缓存获取
// 2. 如果命中,从 Hash 中批量获取消息详情
// 3. 如果未命中,从数据库加载并写入缓存
func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) {
// 1. 尝试从缓存获取分页数据
pageKey := MessagePageKey(convID, page, pageSize)
cached, ok := GetTyped[*PageCache](c.cache, pageKey)
if ok && cached != nil {
// TTL 滑动延长
cached.UpdatedAt = time.Now()
c.cache.Set(pageKey, cached, c.settings.MessageListTTL)
// 从 Hash 中批量获取消息详情
if len(cached.Seqs) > 0 {
messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs)
if err == nil {
return messages, cached.Total, nil
}
// Hash 获取失败,继续从数据库加载
}
}
// 2. 缓存未命中,从数据库加载
if c.msgRepo == nil {
return nil, 0, fmt.Errorf("message repository not configured")
}
messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize)
if err != nil {
return nil, 0, err
}
// 3. 写入缓存
seqs := make([]int64, len(messages))
for i, msg := range messages {
seqs[i] = msg.Seq
// 异步写入消息详情到 Hash
go c.asyncCacheMessage(context.Background(), convID, msg)
}
pageCache := &PageCache{
Seqs: seqs,
Total: total,
Page: page,
PageSize: pageSize,
HasMore: int64(page*pageSize) < total,
UpdatedAt: time.Now(),
}
c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL)
return messages, total, nil
}
// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步)
// 使用 Sorted Set 的 ZRangeByScore 实现
func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表
members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载)
// 使用 Sorted Set 的 ZRevRangeByScore 实现
func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
indexKey := MessageIndexKey(convID)
// 1. 尝试从 Sorted Set 获取 seq 列表(降序)
members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit))
if err != nil {
return nil, err
}
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
if len(members) > 0 {
seqs := make([]int64, 0, len(members))
for _, member := range members {
var seq int64
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
seqs = append(seqs, seq)
}
}
return c.getMessagesBySeqs(ctx, convID, seqs)
}
// 3. Sorted Set 未命中,从数据库加载
if c.msgRepo == nil {
return nil, fmt.Errorf("message repository not configured")
}
messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
if err != nil {
return nil, err
}
// 4. 异步写入缓存
for _, msg := range messages {
go c.asyncCacheMessage(context.Background(), convID, msg)
}
return messages, nil
}
// CacheMessage 缓存单条消息(立即写入缓存)
// 写入 Hash、Sorted Set、更新计数
func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error {
hashKey := MessageHashKey(convID)
indexKey := MessageIndexKey(convID)
msgData := MessageCacheDataFromModel(msg)
data, err := json.Marshal(msgData)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
// HSET 消息详情
if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil {
return fmt.Errorf("failed to set hash: %w", err)
}
// ZADD 消息索引
if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil {
return fmt.Errorf("failed to add to sorted set: %w", err)
}
// 设置 TTL
c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL)
c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL)
// INCR 消息计数
c.cache.Incr(ctx, MessageCountKey(convID))
return nil
}
// InvalidateMessageCache 使消息缓存失效
func (c *ConversationCache) InvalidateMessageCache(convID string) {
c.cache.Delete(MessageHashKey(convID))
c.cache.Delete(MessageIndexKey(convID))
c.cache.Delete(MessageCountKey(convID))
// 删除所有分页缓存
c.InvalidateMessagePages(convID)
}
// InvalidateMessagePages 仅使消息分页缓存失效
// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。
func (c *ConversationCache) InvalidateMessagePages(convID string) {
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID))
}
// ============================================================
// 内部辅助方法
// ============================================================
// getMessagesBySeqs 从 Hash 中批量获取消息
func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) {
if len(seqs) == 0 {
return nil, nil
}
hashKey := MessageHashKey(convID)
fields := make([]string, len(seqs))
for i, seq := range seqs {
fields[i] = fmt.Sprintf("%d", seq)
}
// 批量获取
values, err := c.cache.HMGet(ctx, hashKey, fields...)
if err != nil {
return nil, err
}
messages := make([]*model.Message, 0, len(seqs))
for _, val := range values {
if val == nil {
continue
}
var msgData MessageCacheData
switch v := val.(type) {
case string:
if err := json.Unmarshal([]byte(v), &msgData); err != nil {
continue
}
case []byte:
if err := json.Unmarshal(v, &msgData); err != nil {
continue
}
default:
continue
}
messages = append(messages, msgData.ToModel())
}
return messages, nil
}
// asyncCacheMessage 异步缓存单条消息
func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) {
if err := c.CacheMessage(ctx, convID, msg); err != nil {
log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err)
}
}

View File

@@ -26,6 +26,13 @@ const (
// 用户相关
PrefixUserInfo = "users:info"
PrefixUserMe = "users:me"
// 消息缓存相关
keyPrefixMsgHash = "msg_hash" // 消息详情 Hash
keyPrefixMsgIndex = "msg_index" // 消息索引 Sorted Set
keyPrefixMsgCount = "msg_count" // 消息计数
keyPrefixMsgSeq = "msg_seq" // Seq 计数器
keyPrefixMsgPage = "msg_page" // 分页缓存
)
// PostListKey 生成帖子列表缓存键
@@ -145,3 +152,37 @@ func InvalidateUserInfo(cache Cache, userID string) {
cache.Delete(UserInfoKey(userID))
cache.Delete(UserMeKey(userID))
}
// ============================================================
// 消息缓存 Key 生成函数
// ============================================================
// MessageHashKey 消息详情 Hash key
func MessageHashKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgHash, convID)
}
// MessageIndexKey 消息索引 Sorted Set key
func MessageIndexKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgIndex, convID)
}
// MessageCountKey 消息计数 key
func MessageCountKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgCount, convID)
}
// MessageSeqKey Seq 计数器 key
func MessageSeqKey(convID string) string {
return fmt.Sprintf("%s:%s", keyPrefixMsgSeq, convID)
}
// MessagePageKey 分页缓存 key
func MessagePageKey(convID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:%d:%d", keyPrefixMsgPage, convID, page, pageSize)
}
// InvalidateMessagePages 失效会话消息分页缓存
func InvalidateMessagePages(cache Cache, conversationID string) {
cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, conversationID))
}

76
internal/cache/repository_adapter.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
)
// ConversationRepositoryAdapter 适配 MessageRepository 到 ConversationRepository 接口
type ConversationRepositoryAdapter struct {
repo *repository.MessageRepository
}
// NewConversationRepositoryAdapter 创建适配器
func NewConversationRepositoryAdapter(repo *repository.MessageRepository) ConversationRepository {
return &ConversationRepositoryAdapter{repo: repo}
}
// GetConversationByID 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetConversationByID(convID string) (*model.Conversation, error) {
return a.repo.GetConversation(convID)
}
// GetConversationsByUserID 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
return a.repo.GetConversations(userID, page, pageSize)
}
// GetParticipant 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetParticipant(convID, userID string) (*model.ConversationParticipant, error) {
return a.repo.GetParticipant(convID, userID)
}
// GetParticipants 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetParticipants(convID string) ([]*model.ConversationParticipant, error) {
return a.repo.GetConversationParticipants(convID)
}
// GetUnreadCount 实现 ConversationRepository 接口
func (a *ConversationRepositoryAdapter) GetUnreadCount(userID, convID string) (int64, error) {
return a.repo.GetUnreadCount(convID, userID)
}
// MessageRepositoryAdapter 适配 MessageRepository 到 MessageRepository 接口
type MessageRepositoryAdapter struct {
repo *repository.MessageRepository
}
// NewMessageRepositoryAdapter 创建适配器
func NewMessageRepositoryAdapter(repo *repository.MessageRepository) MessageRepository {
return &MessageRepositoryAdapter{repo: repo}
}
// GetMessages 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) {
return a.repo.GetMessages(convID, page, pageSize)
}
// GetMessagesAfterSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) {
return a.repo.GetMessagesAfterSeq(convID, afterSeq, limit)
}
// GetMessagesBeforeSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
return a.repo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
}
// CreateMessage 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) CreateMessage(msg *model.Message) error {
return a.repo.CreateMessage(msg)
}
// UpdateConversationLastSeq 实现 MessageRepository 接口
func (a *MessageRepositoryAdapter) UpdateConversationLastSeq(convID string, seq int64) error {
return a.repo.UpdateConversationLastSeq(convID, seq)
}

View File

@@ -15,18 +15,19 @@ import (
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Cache CacheConfig `mapstructure:"cache"`
S3 S3Config `mapstructure:"s3"`
JWT JWTConfig `mapstructure:"jwt"`
Log LogConfig `mapstructure:"log"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Upload UploadConfig `mapstructure:"upload"`
Gorse GorseConfig `mapstructure:"gorse"`
OpenAI OpenAIConfig `mapstructure:"openai"`
Email EmailConfig `mapstructure:"email"`
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Cache CacheConfig `mapstructure:"cache"`
S3 S3Config `mapstructure:"s3"`
JWT JWTConfig `mapstructure:"jwt"`
Log LogConfig `mapstructure:"log"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
Upload UploadConfig `mapstructure:"upload"`
Gorse GorseConfig `mapstructure:"gorse"`
OpenAI OpenAIConfig `mapstructure:"openai"`
Email EmailConfig `mapstructure:"email"`
ConversationCache ConversationCacheConfig `mapstructure:"conversation_cache"`
}
type ServerConfig struct {
@@ -173,6 +174,73 @@ type EmailConfig struct {
Timeout int `mapstructure:"timeout"`
}
// ConversationCacheConfig 会话缓存配置
type ConversationCacheConfig struct {
// TTL 配置
DetailTTL string `mapstructure:"detail_ttl"`
ListTTL string `mapstructure:"list_ttl"`
ParticipantTTL string `mapstructure:"participant_ttl"`
UnreadTTL string `mapstructure:"unread_ttl"`
// 消息缓存配置
MessageDetailTTL string `mapstructure:"message_detail_ttl"`
MessageListTTL string `mapstructure:"message_list_ttl"`
MessageIndexTTL string `mapstructure:"message_index_ttl"`
MessageCountTTL string `mapstructure:"message_count_ttl"`
// 批量写入配置
BatchInterval string `mapstructure:"batch_interval"`
BatchThreshold int `mapstructure:"batch_threshold"`
BatchMaxSize int `mapstructure:"batch_max_size"`
BufferMaxSize int `mapstructure:"buffer_max_size"`
}
// ConversationCacheSettings 会话缓存运行时配置(用于传递给 cache 包)
type ConversationCacheSettings struct {
DetailTTL time.Duration
ListTTL time.Duration
ParticipantTTL time.Duration
UnreadTTL time.Duration
MessageDetailTTL time.Duration
MessageListTTL time.Duration
MessageIndexTTL time.Duration
MessageCountTTL time.Duration
BatchInterval time.Duration
BatchThreshold int
BatchMaxSize int
BufferMaxSize int
}
// ToSettings 将 ConversationCacheConfig 转换为 ConversationCacheSettings
func (c *ConversationCacheConfig) ToSettings() *ConversationCacheSettings {
return &ConversationCacheSettings{
DetailTTL: parseDuration(c.DetailTTL, 5*time.Minute),
ListTTL: parseDuration(c.ListTTL, 60*time.Second),
ParticipantTTL: parseDuration(c.ParticipantTTL, 5*time.Minute),
UnreadTTL: parseDuration(c.UnreadTTL, 30*time.Second),
MessageDetailTTL: parseDuration(c.MessageDetailTTL, 30*time.Minute),
MessageListTTL: parseDuration(c.MessageListTTL, 5*time.Minute),
MessageIndexTTL: parseDuration(c.MessageIndexTTL, 30*time.Minute),
MessageCountTTL: parseDuration(c.MessageCountTTL, 30*time.Minute),
BatchInterval: parseDuration(c.BatchInterval, 5*time.Second),
BatchThreshold: c.BatchThreshold,
BatchMaxSize: c.BatchMaxSize,
BufferMaxSize: c.BufferMaxSize,
}
}
// parseDuration 解析持续时间字符串,如果解析失败则返回默认值
func parseDuration(s string, defaultVal time.Duration) time.Duration {
if s == "" {
return defaultVal
}
d, err := time.ParseDuration(s)
if err != nil {
return defaultVal
}
return d
}
func Load(configPath string) (*Config, error) {
viper.SetConfigFile(configPath)
viper.SetConfigType("yaml")
@@ -259,6 +327,19 @@ func Load(configPath string) (*Config, error) {
viper.SetDefault("email.use_tls", true)
viper.SetDefault("email.insecure_skip_verify", false)
viper.SetDefault("email.timeout", 15)
// ConversationCache 默认值
viper.SetDefault("conversation_cache.detail_ttl", "5m")
viper.SetDefault("conversation_cache.list_ttl", "60s")
viper.SetDefault("conversation_cache.participant_ttl", "5m")
viper.SetDefault("conversation_cache.unread_ttl", "30s")
viper.SetDefault("conversation_cache.message_detail_ttl", "30m")
viper.SetDefault("conversation_cache.message_list_ttl", "5m")
viper.SetDefault("conversation_cache.message_index_ttl", "30m")
viper.SetDefault("conversation_cache.message_count_ttl", "30m")
viper.SetDefault("conversation_cache.batch_interval", "5s")
viper.SetDefault("conversation_cache.batch_threshold", 100)
viper.SetDefault("conversation_cache.batch_max_size", 500)
viper.SetDefault("conversation_cache.buffer_max_size", 10000)
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)

View File

@@ -0,0 +1,35 @@
package dto
import (
"encoding/json"
"carrot_bbs/internal/model"
)
func ConvertScheduleCourseToResponse(course *model.ScheduleCourse, weeks []int) *ScheduleCourseResponse {
if course == nil {
return nil
}
return &ScheduleCourseResponse{
ID: course.ID,
Name: course.Name,
Teacher: course.Teacher,
Location: course.Location,
DayOfWeek: course.DayOfWeek,
StartSection: course.StartSection,
EndSection: course.EndSection,
Weeks: weeks,
Color: course.Color,
}
}
func ParseWeeksJSON(raw string) []int {
if raw == "" {
return []int{}
}
var weeks []int
if err := json.Unmarshal([]byte(raw), &weeks); err != nil {
return []int{}
}
return weeks
}

View File

@@ -0,0 +1,13 @@
package dto
type ScheduleCourseResponse struct {
ID string `json:"id"`
Name string `json:"name"`
Teacher string `json:"teacher,omitempty"`
Location string `json:"location,omitempty"`
DayOfWeek int `json:"day_of_week"`
StartSection int `json:"start_section"`
EndSection int `json:"end_section"`
Weeks []int `json:"weeks"`
Color string `json:"color,omitempty"`
}

View File

@@ -38,12 +38,12 @@ func parseGroupID(c *gin.Context) string {
// parseUserIDFromPath 从路径参数获取用户IDUUID格式
func parseUserIDFromPath(c *gin.Context) string {
return c.Param("userId")
return c.Param("user_id")
}
// parseAnnouncementID 从路径参数获取公告ID
func parseAnnouncementID(c *gin.Context) string {
return c.Param("announcementId")
return c.Param("announcement_id")
}
// ==================== 群组管理 ====================
@@ -454,7 +454,7 @@ func (h *GroupHandler) GetMembers(c *gin.Context) {
// ==================== RESTful Action 端点 ====================
// HandleCreateGroup 创建群组
// POST /api/v1/groups/create
// POST /api/v1/groups
func (h *GroupHandler) HandleCreateGroup(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -478,7 +478,7 @@ func (h *GroupHandler) HandleCreateGroup(c *gin.Context) {
}
// HandleGetUserGroups 获取用户群组列表
// GET /api/v1/groups/list
// GET /api/v1/groups
func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -499,7 +499,6 @@ func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) {
}
// HandleGetMyMemberInfo 获取我在群组中的成员信息
// GET /api/v1/groups/get_my_info?group_id=xxx
// GET /api/v1/groups/:id/me
func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) {
userID := parseUserID(c)
@@ -551,7 +550,7 @@ func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) {
}
// HandleDissolveGroup 解散群组
// POST /api/v1/groups/dissolve
// DELETE /api/v1/groups/:id
func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -559,18 +558,13 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
return
}
var params dto.DissolveGroupParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if err := h.groupService.DissolveGroup(userID, params.GroupID); err != nil {
if err := h.groupService.DissolveGroup(userID, groupID); err != nil {
if err == service.ErrNotGroupOwner {
response.Forbidden(c, "只有群主可以解散群组")
return
@@ -587,7 +581,7 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
}
// HandleTransferOwner 转让群主
// POST /api/v1/groups/transfer
// POST /api/v1/groups/:id/transfer
func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -595,22 +589,24 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.TransferOwnerParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.NewOwnerID == "" {
response.BadRequest(c, "new_owner_id is required")
return
}
if err := h.groupService.TransferOwner(userID, params.GroupID, params.NewOwnerID); err != nil {
if err := h.groupService.TransferOwner(userID, groupID, params.NewOwnerID); err != nil {
if err == service.ErrNotGroupOwner {
response.Forbidden(c, "只有群主可以转让群主")
return
@@ -631,7 +627,7 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
}
// HandleInviteMembers 邀请成员加入群组
// POST /api/v1/groups/invite_members
// POST /api/v1/groups/:id/invitations
func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -639,18 +635,19 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.InviteMembersParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if err := h.groupService.InviteMembers(userID, params.GroupID, params.MemberIDs); err != nil {
if err := h.groupService.InviteMembers(userID, groupID, params.MemberIDs); err != nil {
if err == service.ErrNotGroupMember {
response.Forbidden(c, "只有群成员可以邀请他人")
return
@@ -675,7 +672,7 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
}
// HandleJoinGroup 加入群组
// POST /api/v1/groups/join
// POST /api/v1/groups/:id/join-requests
func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -683,18 +680,13 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
return
}
var params dto.JoinGroupParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if err := h.groupService.JoinGroup(userID, params.GroupID); err != nil {
if err := h.groupService.JoinGroup(userID, groupID); err != nil {
if err == service.ErrJoinRequestPending {
response.SuccessWithMessage(c, "申请已提交,等待群主/管理员审批", nil)
return
@@ -723,7 +715,7 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
}
// HandleSetNickname 设置群内昵称
// POST /api/v1/groups/set_nickname
// PUT /api/v1/groups/:id/members/me/nickname
func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -731,18 +723,19 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetNicknameParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if err := h.groupService.SetMemberNickname(userID, params.GroupID, params.Nickname); err != nil {
if err := h.groupService.SetMemberNickname(userID, groupID, params.Nickname); err != nil {
if err == service.ErrNotGroupMember {
response.BadRequest(c, "不是群成员")
return
@@ -759,7 +752,7 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
}
// HandleSetJoinType 设置加群方式
// POST /api/v1/groups/set_join_type
// PUT /api/v1/groups/:id/join-type
func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -767,18 +760,19 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetJoinTypeParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if err := h.groupService.SetJoinType(userID, params.GroupID, params.JoinType); err != nil {
if err := h.groupService.SetJoinType(userID, groupID, params.JoinType); err != nil {
if err == service.ErrNotGroupOwner {
response.Forbidden(c, "只有群主可以设置加群方式")
return
@@ -803,7 +797,7 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
}
// HandleCreateAnnouncement 创建群公告
// POST /api/v1/groups/create_announcement
// POST /api/v1/groups/:id/announcements
func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -811,18 +805,19 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.CreateAnnouncementParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
announcement, err := h.groupService.CreateAnnouncement(userID, params.GroupID, params.Content)
announcement, err := h.groupService.CreateAnnouncement(userID, groupID, params.Content)
if err != nil {
if err == service.ErrNotGroupAdmin {
response.Forbidden(c, "只有群主或管理员可以发布公告")
@@ -840,7 +835,6 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
}
// HandleGetAnnouncements 获取群公告列表
// GET /api/v1/groups/get_announcements?group_id=xxx
// GET /api/v1/groups/:id/announcements
func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) {
userID := parseUserID(c)
@@ -872,7 +866,7 @@ func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) {
}
// HandleDeleteAnnouncement 删除群公告
// POST /api/v1/groups/delete_announcement
// DELETE /api/v1/groups/:id/announcements/:announcement_id
func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -880,22 +874,18 @@ func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) {
return
}
var params dto.DeleteAnnouncementParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.AnnouncementID == "" {
announcementID := parseAnnouncementID(c)
if announcementID == "" {
response.BadRequest(c, "announcement_id is required")
return
}
if err := h.groupService.DeleteAnnouncement(userID, params.AnnouncementID); err != nil {
if err := h.groupService.DeleteAnnouncement(userID, announcementID); err != nil {
if err == service.ErrNotGroupAdmin {
response.Forbidden(c, "只有群主或管理员可以删除公告")
return
@@ -1292,7 +1282,7 @@ func (h *GroupHandler) DeleteAnnouncement(c *gin.Context) {
// ==================== RESTful Action 端点 ====================
// HandleSetGroupKick 群组踢人
// POST /api/v1/groups/set_group_kick
// POST /api/v1/groups/:id/members/kick
func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1300,23 +1290,25 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetGroupKickParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.UserID == "" {
response.BadRequest(c, "user_id is required")
return
}
// 使用 RemoveMember 方法
err := h.groupService.RemoveMember(userID, params.GroupID, params.UserID)
err := h.groupService.RemoveMember(userID, groupID, params.UserID)
if err != nil {
if err == service.ErrNotGroupAdmin {
response.Forbidden(c, "只有群主或管理员可以移除成员")
@@ -1342,7 +1334,7 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
}
// HandleSetGroupBan 群组单人禁言
// POST /api/v1/groups/set_group_ban
// POST /api/v1/groups/:id/members/ban
func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1350,16 +1342,18 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetGroupBanParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.UserID == "" {
response.BadRequest(c, "user_id is required")
return
@@ -1367,8 +1361,8 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
// duration > 0 或 duration = -1 表示禁言duration = 0 表示解除禁言
muted := params.Duration != 0
log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, params.GroupID, params.UserID, params.Duration, muted)
err := h.groupService.MuteMember(userID, params.GroupID, params.UserID, muted)
log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, groupID, params.UserID, params.Duration, muted)
err := h.groupService.MuteMember(userID, groupID, params.UserID, muted)
if err != nil {
log.Printf("[HandleSetGroupBan] 禁言操作失败: %v", err)
} else {
@@ -1403,7 +1397,7 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
}
// HandleSetGroupWholeBan 群组全员禁言
// POST /api/v1/groups/set_group_whole_ban
// PUT /api/v1/groups/:id/ban
func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1411,18 +1405,19 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetGroupWholeBanParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
err := h.groupService.SetMuteAll(userID, params.GroupID, params.Enable)
err := h.groupService.SetMuteAll(userID, groupID, params.Enable)
if err != nil {
if err == service.ErrNotGroupOwner {
response.Forbidden(c, "只有群主可以设置全员禁言")
@@ -1444,7 +1439,7 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
}
// HandleSetGroupAdmin 群组设置管理员
// POST /api/v1/groups/set_group_admin
// PUT /api/v1/groups/:id/members/:user_id/admin
func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1452,28 +1447,30 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
targetUserID := parseUserIDFromPath(c)
if targetUserID == "" {
response.BadRequest(c, "user_id is required")
return
}
var params dto.SetGroupAdminParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.UserID == "" {
response.BadRequest(c, "user_id is required")
return
}
// 根据 enable 参数设置角色
role := model.GroupRoleMember
if params.Enable {
role = model.GroupRoleAdmin
}
err := h.groupService.SetMemberRole(userID, params.GroupID, params.UserID, role)
err := h.groupService.SetMemberRole(userID, groupID, targetUserID, role)
if err != nil {
if err == service.ErrNotGroupOwner {
response.Forbidden(c, "只有群主可以设置管理员")
@@ -1499,7 +1496,7 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
}
// HandleSetGroupName 设置群名
// POST /api/v1/groups/set_group_name
// PUT /api/v1/groups/:id/name
func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1507,16 +1504,18 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetGroupNameParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.GroupName == "" {
response.BadRequest(c, "group_name is required")
return
@@ -1526,7 +1525,7 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
"name": params.GroupName,
}
err := h.groupService.UpdateGroup(userID, params.GroupID, updates)
err := h.groupService.UpdateGroup(userID, groupID, updates)
if err != nil {
if err == service.ErrNotGroupAdmin {
response.Forbidden(c, "没有权限修改群组信息")
@@ -1541,12 +1540,12 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
}
// 获取更新后的群组信息
group, _ := h.groupService.GetGroupByID(params.GroupID)
group, _ := h.groupService.GetGroupByID(groupID)
response.Success(c, dto.GroupToResponse(group))
}
// HandleSetGroupAvatar 设置群头像
// POST /api/v1/groups/set_group_avatar
// PUT /api/v1/groups/:id/avatar
func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1554,16 +1553,18 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
return
}
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
var params dto.SetGroupAvatarParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
response.BadRequest(c, "group_id is required")
return
}
if params.Avatar == "" {
response.BadRequest(c, "avatar is required")
return
@@ -1573,7 +1574,7 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
"avatar": params.Avatar,
}
err := h.groupService.UpdateGroup(userID, params.GroupID, updates)
err := h.groupService.UpdateGroup(userID, groupID, updates)
if err != nil {
if err == service.ErrNotGroupAdmin {
response.Forbidden(c, "没有权限修改群组信息")
@@ -1588,12 +1589,12 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
}
// 获取更新后的群组信息
group, _ := h.groupService.GetGroupByID(params.GroupID)
group, _ := h.groupService.GetGroupByID(groupID)
response.Success(c, dto.GroupToResponse(group))
}
// HandleSetGroupLeave 退出群组
// POST /api/v1/groups/set_group_leave
// POST /api/v1/groups/:id/leave
func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1601,18 +1602,13 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
return
}
var params dto.SetGroupLeaveParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.GroupID == "" {
groupID := parseGroupID(c)
if groupID == "" {
response.BadRequest(c, "group_id is required")
return
}
err := h.groupService.LeaveGroup(userID, params.GroupID)
err := h.groupService.LeaveGroup(userID, groupID)
if err != nil {
if err == service.ErrNotGroupMember {
response.BadRequest(c, "不是群成员")
@@ -1630,7 +1626,7 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
}
// HandleSetGroupAddRequest 处理加群审批
// POST /api/v1/groups/set_group_add_request
// POST /api/v1/groups/:id/join-requests/handle
func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1678,7 +1674,7 @@ func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) {
}
// HandleRespondInvite 处理群邀请响应
// POST /api/v1/groups/respond_invite
// POST /api/v1/groups/:id/join-requests/respond
func (h *GroupHandler) HandleRespondInvite(c *gin.Context) {
userID := parseUserID(c)
if userID == "" {
@@ -1725,7 +1721,6 @@ func (h *GroupHandler) HandleRespondInvite(c *gin.Context) {
}
// HandleGetGroupInfo 获取群信息
// GET /api/v1/groups/get?group_id=xxx
// GET /api/v1/groups/:id
func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) {
userID := parseUserID(c)
@@ -1761,7 +1756,6 @@ func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) {
}
// HandleGetGroupMemberList 获取群成员列表
// GET /api/v1/groups/get_members?group_id=xxx
// GET /api/v1/groups/:id/members
func (h *GroupHandler) HandleGetGroupMemberList(c *gin.Context) {
userID := parseUserID(c)

View File

@@ -116,14 +116,14 @@ func (h *MessageHandler) HandleTyping(c *gin.Context) {
response.Unauthorized(c, "")
return
}
var params struct {
ConversationID string `json:"conversation_id" binding:"required"`
}
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
conversationID := getIDParam(c, "id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
h.chatService.SendTyping(c.Request.Context(), userID, conversationID)
response.SuccessWithMessage(c, "typing sent", nil)
}
@@ -397,8 +397,8 @@ func (h *MessageHandler) SendMessage(c *gin.Context) {
}
// HandleSendMessage RESTful 风格的发送消息端点
// POST /api/v1/conversations/send_message
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
// POST /api/v1/conversations/:id/messages
// 请求体格式: {"detail_type": "private", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
@@ -406,6 +406,12 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
return
}
conversationID := getIDParam(c, "id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
var params dto.SendMessageParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
@@ -413,10 +419,6 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
}
// 验证参数
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if params.DetailType == "" {
response.BadRequest(c, "detail_type is required")
return
@@ -427,7 +429,7 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
}
// 发送消息
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, params.Segments, params.ReplyToID)
if err != nil {
response.BadRequest(c, err.Error())
return
@@ -480,7 +482,7 @@ func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
}
// HandleGetConversationList 获取会话列表
// GET /api/v1/conversations/list
// GET /api/v1/conversations
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
@@ -780,7 +782,6 @@ func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
}
// HandleGetConversation 获取会话详情
// GET /api/v1/conversations/get?conversation_id=xxx
// GET /api/v1/conversations/:id
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
userID := c.GetString("user_id")
@@ -825,7 +826,6 @@ func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
}
// HandleGetMessages 获取会话消息
// GET /api/v1/conversations/get_messages?conversation_id=xxx
// GET /api/v1/conversations/:id/messages
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
userID := c.GetString("user_id")
@@ -913,7 +913,7 @@ func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
}
// HandleMarkRead 标记已读
// POST /api/v1/conversations/mark_read
// POST /api/v1/conversations/:id/read
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
@@ -921,18 +921,19 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
return
}
var params dto.MarkReadParams
if err := c.ShouldBindJSON(&params); err != nil {
conversationID := getIDParam(c, "id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
var req dto.MarkReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
err := h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
if err != nil {
response.BadRequest(c, err.Error())
return
@@ -942,7 +943,7 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
}
// HandleSetConversationPinned 设置会话置顶
// POST /api/v1/conversations/set_pinned
// PUT /api/v1/conversations/:id/pinned
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
@@ -950,24 +951,27 @@ func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
return
}
var params dto.SetConversationPinnedParams
if err := c.ShouldBindJSON(&params); err != nil {
conversationID := getIDParam(c, "id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
var req struct {
IsPinned bool `json:"is_pinned"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
if err := h.chatService.SetConversationPinned(c.Request.Context(), conversationID, userID, req.IsPinned); err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
"conversation_id": params.ConversationID,
"is_pinned": params.IsPinned,
"conversation_id": conversationID,
"is_pinned": req.IsPinned,
})
}

View File

@@ -0,0 +1,140 @@
package handler
import (
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
type ScheduleHandler struct {
scheduleService service.ScheduleService
}
func NewScheduleHandler(scheduleService service.ScheduleService) *ScheduleHandler {
return &ScheduleHandler{scheduleService: scheduleService}
}
type createScheduleCourseRequest struct {
Name string `json:"name" binding:"required"`
Teacher string `json:"teacher"`
Location string `json:"location"`
DayOfWeek int `json:"day_of_week" binding:"required"`
StartSection int `json:"start_section" binding:"required"`
EndSection int `json:"end_section" binding:"required"`
Weeks []int `json:"weeks" binding:"required,min=1"`
Color string `json:"color"`
}
type updateScheduleCourseRequest = createScheduleCourseRequest
func (h *ScheduleHandler) ListCourses(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
week := 0
if rawWeek := c.Query("week"); rawWeek != "" {
parsed, err := strconv.Atoi(rawWeek)
if err != nil {
response.BadRequest(c, "invalid week")
return
}
week = parsed
}
list, err := h.scheduleService.ListCourses(userID, week)
if err != nil {
response.HandleError(c, err, "failed to list schedule courses")
return
}
response.Success(c, gin.H{"list": list})
}
func (h *ScheduleHandler) CreateCourse(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req createScheduleCourseRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
created, err := h.scheduleService.CreateCourse(userID, service.CreateScheduleCourseInput{
Name: req.Name,
Teacher: req.Teacher,
Location: req.Location,
DayOfWeek: req.DayOfWeek,
StartSection: req.StartSection,
EndSection: req.EndSection,
Weeks: req.Weeks,
Color: req.Color,
})
if err != nil {
response.HandleError(c, err, "failed to create schedule course")
return
}
response.SuccessWithMessage(c, "course created", gin.H{"course": created})
}
func (h *ScheduleHandler) UpdateCourse(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
courseID := c.Param("id")
if courseID == "" {
response.BadRequest(c, "invalid course id")
return
}
var req updateScheduleCourseRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
updated, err := h.scheduleService.UpdateCourse(userID, courseID, service.CreateScheduleCourseInput{
Name: req.Name,
Teacher: req.Teacher,
Location: req.Location,
DayOfWeek: req.DayOfWeek,
StartSection: req.StartSection,
EndSection: req.EndSection,
Weeks: req.Weeks,
Color: req.Color,
})
if err != nil {
response.HandleError(c, err, "failed to update schedule course")
return
}
response.SuccessWithMessage(c, "course updated", gin.H{"course": updated})
}
func (h *ScheduleHandler) DeleteCourse(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
courseID := c.Param("id")
if courseID == "" {
response.BadRequest(c, "invalid course id")
return
}
if err := h.scheduleService.DeleteCourse(userID, courseID); err != nil {
response.HandleError(c, err, "failed to delete schedule course")
return
}
response.SuccessWithMessage(c, "course deleted", nil)
}

View File

@@ -143,6 +143,9 @@ func autoMigrate(db *gorm.DB) error {
// 自定义表情
&UserSticker{},
// 课表
&ScheduleCourse{},
)
if err != nil {
return err

View File

@@ -0,0 +1,35 @@
package model
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// ScheduleCourse 用户课表课程
type ScheduleCourse struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
Name string `json:"name" gorm:"type:varchar(120);not null"`
Teacher string `json:"teacher" gorm:"type:varchar(80)"`
Location string `json:"location" gorm:"type:varchar(120)"`
DayOfWeek int `json:"day_of_week" gorm:"index;not null"` // 0=周一, 6=周日
StartSection int `json:"start_section" gorm:"not null"`
EndSection int `json:"end_section" gorm:"not null"`
Weeks string `json:"weeks" gorm:"type:text;not null"` // JSON 数组字符串
Color string `json:"color" gorm:"type:varchar(20)"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (s *ScheduleCourse) BeforeCreate(tx *gorm.DB) error {
if s.ID == "" {
s.ID = uuid.New().String()
}
return nil
}
func (ScheduleCourse) TableName() string {
return "schedule_courses"
}

View File

@@ -164,10 +164,17 @@ func (c *clientImpl) moderateSingleBatch(
}
type chatCompletionsRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen3.5思考模式控制
ThinkingBudget *int `json:"thinking_budget,omitempty"` // 思考过程最大token数
ResponseFormat *responseFormatConfig `json:"response_format,omitempty"` // 响应格式
}
type responseFormatConfig struct {
Type string `json:"type"` // "text" or "json_object"
}
type chatMessage struct {
@@ -227,6 +234,13 @@ func (c *clientImpl) chatCompletion(
Temperature: temperature,
MaxTokens: maxTokens,
}
// 禁用qwen3.5的思考模式避免产生大量不必要的token消耗
falseVal := false
reqBody.EnableThinking = &falseVal
zero := 0
reqBody.ThinkingBudget = &zero
// 使用JSON输出格式
reqBody.ResponseFormat = &responseFormatConfig{Type: "json_object"}
data, err := json.Marshal(reqBody)
if err != nil {

View File

@@ -117,3 +117,117 @@ func (c *Client) Close() error {
func (c *Client) IsMiniRedis() bool {
return c.isMiniRedis
}
// ==================== Hash 操作 ====================
// HSet 设置 Hash 字段
func (c *Client) HSet(ctx context.Context, key string, field string, value interface{}) error {
return c.rdb.HSet(ctx, key, field, value).Err()
}
// HMSet 批量设置 Hash 字段
func (c *Client) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
return c.rdb.HMSet(ctx, key, values).Err()
}
// HGet 获取 Hash 字段值
func (c *Client) HGet(ctx context.Context, key string, field string) (string, error) {
return c.rdb.HGet(ctx, key, field).Result()
}
// HMGet 批量获取 Hash 字段值
func (c *Client) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
return c.rdb.HMGet(ctx, key, fields...).Result()
}
// HGetAll 获取 Hash 所有字段
func (c *Client) HGetAll(ctx context.Context, key string) (map[string]string, error) {
return c.rdb.HGetAll(ctx, key).Result()
}
// HDel 删除 Hash 字段
func (c *Client) HDel(ctx context.Context, key string, fields ...string) error {
return c.rdb.HDel(ctx, key, fields...).Err()
}
// HExists 检查 Hash 字段是否存在
func (c *Client) HExists(ctx context.Context, key string, field string) (bool, error) {
return c.rdb.HExists(ctx, key, field).Result()
}
// HLen 获取 Hash 字段数量
func (c *Client) HLen(ctx context.Context, key string) (int64, error) {
return c.rdb.HLen(ctx, key).Result()
}
// ==================== Sorted Set 操作 ====================
// ZAdd 添加 Sorted Set 成员
func (c *Client) ZAdd(ctx context.Context, key string, score float64, member string) error {
return c.rdb.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err()
}
// ZAddArgs 批量添加 Sorted Set 成员
func (c *Client) ZAddArgs(ctx context.Context, key string, members ...redis.Z) error {
return c.rdb.ZAdd(ctx, key, members...).Err()
}
// ZRangeByScore 按分数范围获取成员(升序)
func (c *Client) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
return c.rdb.ZRangeByScore(ctx, key, &redis.ZRangeBy{
Min: min,
Max: max,
Offset: offset,
Count: count,
}).Result()
}
// ZRevRangeByScore 按分数范围获取成员(降序)
func (c *Client) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
return c.rdb.ZRevRangeByScore(ctx, key, &redis.ZRangeBy{
Min: min,
Max: max,
Offset: offset,
Count: count,
}).Result()
}
// ZRange 获取指定范围的成员(升序)
func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return c.rdb.ZRange(ctx, key, start, stop).Result()
}
// ZRevRange 获取指定范围的成员(降序)
func (c *Client) ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
return c.rdb.ZRevRange(ctx, key, start, stop).Result()
}
// ZRem 删除 Sorted Set 成员
func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) error {
return c.rdb.ZRem(ctx, key, members...).Err()
}
// ZScore 获取成员分数
func (c *Client) ZScore(ctx context.Context, key string, member string) (float64, error) {
return c.rdb.ZScore(ctx, key, member).Result()
}
// ZCard 获取 Sorted Set 成员数量
func (c *Client) ZCard(ctx context.Context, key string) (int64, error) {
return c.rdb.ZCard(ctx, key).Result()
}
// ZCount 统计分数范围内的成员数量
func (c *Client) ZCount(ctx context.Context, key string, min, max string) (int64, error) {
return c.rdb.ZCount(ctx, key, min, max).Result()
}
// ==================== Pipeline 操作 ====================
// Pipeliner Pipeline 接口(使用 redis 库原生接口)
type Pipeliner = redis.Pipeliner
// Pipeline 创建 Pipeline
func (c *Client) Pipeline() Pipeliner {
return c.rdb.Pipeline()
}

View File

@@ -2,6 +2,9 @@ package repository
import (
"carrot_bbs/internal/model"
"context"
"fmt"
"strings"
"time"
"gorm.io/gorm"
@@ -172,7 +175,7 @@ func (r *MessageRepository) GetParticipant(conversationID string, userID string)
if err == gorm.ErrRecordNotFound {
// 检查会话是否存在
var conv model.Conversation
if err := r.db.First(&conv, conversationID).Error; err == nil {
if err := r.db.Where("id = ?", conversationID).First(&conv).Error; err == nil {
// 会话存在,添加参与者
participant = model.ConversationParticipant{
ConversationID: conversationID,
@@ -284,7 +287,7 @@ func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq
// GetNextSeq 获取会话的下一个seq值
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
var conv model.Conversation
err := r.db.Select("last_seq").First(&conv, conversationID).Error
err := r.db.Select("last_seq").Where("id = ?", conversationID).First(&conv).Error
if err != nil {
return 0, err
}
@@ -296,7 +299,7 @@ func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 获取当前seq并+1
var conv model.Conversation
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
if err := tx.Select("last_seq").Where("id = ?", msg.ConversationID).First(&conv).Error; err != nil {
return err
}
@@ -522,3 +525,117 @@ func (r *MessageRepository) HideConversationForUser(conversationID, userID strin
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
Update("hidden_at", &now).Error
}
// ParticipantUpdate 参与者更新数据
type ParticipantUpdate struct {
ConversationID string
UserID string
LastReadSeq int64
}
// BatchWriteMessages 批量写入消息
// 使用 GORM 的 CreateInBatches 实现高效批量插入
func (r *MessageRepository) BatchWriteMessages(ctx context.Context, messages []*model.Message) error {
if len(messages) == 0 {
return nil
}
return r.db.WithContext(ctx).CreateInBatches(messages, 100).Error
}
// BatchUpdateParticipants 批量更新参与者(使用 CASE WHEN 优化)
// 使用单条 SQL 更新多条记录,避免循环执行 UPDATE
func (r *MessageRepository) BatchUpdateParticipants(ctx context.Context, updates []ParticipantUpdate) error {
if len(updates) == 0 {
return nil
}
// 构建 CASE WHEN 批量更新 SQL
// UPDATE conversation_participants
// SET last_read_seq = CASE
// WHEN (conversation_id = '1' AND user_id = 'a') THEN 10
// WHEN (conversation_id = '2' AND user_id = 'b') THEN 20
// END,
// updated_at = ?
// WHERE (conversation_id = '1' AND user_id = 'a')
// OR (conversation_id = '2' AND user_id = 'b')
var cases []string
var whereClauses []string
var args []interface{}
for _, u := range updates {
cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?")
whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)")
args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID)
}
sql := fmt.Sprintf(`
UPDATE conversation_participants
SET last_read_seq = CASE %s END,
updated_at = ?
WHERE %s
`, strings.Join(cases, " "), strings.Join(whereClauses, " OR "))
args = append(args, time.Now())
return r.db.WithContext(ctx).Exec(sql, args...).Error
}
// UpdateConversationLastSeqWithContext 更新会话最后消息序号
func (r *MessageRepository) UpdateConversationLastSeqWithContext(ctx context.Context, convID string, lastSeq int64, lastMsgTime time.Time) error {
return r.db.WithContext(ctx).
Model(&model.Conversation{}).
Where("id = ?", convID).
Updates(map[string]interface{}{
"last_seq": lastSeq,
"last_msg_time": lastMsgTime,
"updated_at": time.Now(),
}).Error
}
// BatchWriteMessagesWithTx 在事务中批量写入消息
func (r *MessageRepository) BatchWriteMessagesWithTx(tx *gorm.DB, messages []*model.Message) error {
if len(messages) == 0 {
return nil
}
return tx.CreateInBatches(messages, 100).Error
}
// BatchUpdateParticipantsWithTx 在事务中批量更新参与者
func (r *MessageRepository) BatchUpdateParticipantsWithTx(tx *gorm.DB, updates []ParticipantUpdate) error {
if len(updates) == 0 {
return nil
}
var cases []string
var whereClauses []string
var args []interface{}
for _, u := range updates {
cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?")
whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)")
args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID)
}
sql := fmt.Sprintf(`
UPDATE conversation_participants
SET last_read_seq = CASE %s END,
updated_at = ?
WHERE %s
`, strings.Join(cases, " "), strings.Join(whereClauses, " OR "))
args = append(args, time.Now())
return tx.Exec(sql, args...).Error
}
// UpdateConversationLastSeqWithTx 在事务中更新会话最后消息序号
func (r *MessageRepository) UpdateConversationLastSeqWithTx(tx *gorm.DB, convID string, lastSeq int64, lastMsgTime time.Time) error {
return tx.Model(&model.Conversation{}).
Where("id = ?", convID).
Updates(map[string]interface{}{
"last_seq": lastSeq,
"last_msg_time": lastMsgTime,
"updated_at": time.Now(),
}).Error
}

View File

@@ -0,0 +1,66 @@
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
type ScheduleRepository interface {
ListByUserID(userID string) ([]*model.ScheduleCourse, error)
GetByID(id string) (*model.ScheduleCourse, error)
Create(course *model.ScheduleCourse) error
Update(course *model.ScheduleCourse) error
DeleteByID(id string) error
ExistsColorByUser(userID, color, excludeID string) (bool, error)
}
type scheduleRepository struct {
db *gorm.DB
}
func NewScheduleRepository(db *gorm.DB) ScheduleRepository {
return &scheduleRepository{db: db}
}
func (r *scheduleRepository) ListByUserID(userID string) ([]*model.ScheduleCourse, error) {
var courses []*model.ScheduleCourse
err := r.db.
Where("user_id = ?", userID).
Order("day_of_week ASC, start_section ASC, created_at ASC").
Find(&courses).Error
return courses, err
}
func (r *scheduleRepository) Create(course *model.ScheduleCourse) error {
return r.db.Create(course).Error
}
func (r *scheduleRepository) GetByID(id string) (*model.ScheduleCourse, error) {
var course model.ScheduleCourse
if err := r.db.Where("id = ?", id).First(&course).Error; err != nil {
return nil, err
}
return &course, nil
}
func (r *scheduleRepository) Update(course *model.ScheduleCourse) error {
return r.db.Save(course).Error
}
func (r *scheduleRepository) DeleteByID(id string) error {
return r.db.Delete(&model.ScheduleCourse{}, "id = ?", id).Error
}
func (r *scheduleRepository) ExistsColorByUser(userID, color, excludeID string) (bool, error) {
var count int64
query := r.db.Model(&model.ScheduleCourse{}).
Where("user_id = ? AND LOWER(color) = LOWER(?)", userID, color)
if excludeID != "" {
query = query.Where("id <> ?", excludeID)
}
if err := query.Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}

View File

@@ -23,6 +23,7 @@ type Router struct {
stickerHandler *handler.StickerHandler
gorseHandler *handler.GorseHandler
voteHandler *handler.VoteHandler
scheduleHandler *handler.ScheduleHandler
jwtService *service.JWTService
}
@@ -41,6 +42,7 @@ func New(
stickerHandler *handler.StickerHandler,
gorseHandler *handler.GorseHandler,
voteHandler *handler.VoteHandler,
scheduleHandler *handler.ScheduleHandler,
) *Router {
// 设置JWT服务
userHandler.SetJWTService(jwtService)
@@ -59,6 +61,7 @@ func New(
stickerHandler: stickerHandler,
gorseHandler: gorseHandler,
voteHandler: voteHandler,
scheduleHandler: scheduleHandler,
jwtService: jwtService,
}
@@ -160,6 +163,18 @@ func (r *Router) setupRoutes() {
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
}
// 课表路由
if r.scheduleHandler != nil {
schedule := v1.Group("/schedule")
schedule.Use(authMiddleware)
{
schedule.GET("/courses", r.scheduleHandler.ListCourses)
schedule.POST("/courses", r.scheduleHandler.CreateCourse)
schedule.PUT("/courses/:id", r.scheduleHandler.UpdateCourse)
schedule.DELETE("/courses/:id", r.scheduleHandler.DeleteCourse)
}
}
// 投票选项路由
voteOptions := v1.Group("/vote-options")
voteOptions.Use(authMiddleware)

View File

@@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"log"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/sse"
@@ -58,6 +60,9 @@ type chatServiceImpl struct {
userRepo *repository.UserRepository
sensitive SensitiveService
sseHub *sse.Hub
// 缓存相关字段
conversationCache *cache.ConversationCache
}
// NewChatService 创建聊天服务
@@ -68,12 +73,25 @@ func NewChatService(
sensitive SensitiveService,
sseHub *sse.Hub,
) ChatService {
// 创建适配器
convRepoAdapter := cache.NewConversationRepositoryAdapter(repo)
msgRepoAdapter := cache.NewMessageRepositoryAdapter(repo)
// 创建会话缓存
conversationCache := cache.NewConversationCache(
cache.GetCache(),
convRepoAdapter,
msgRepoAdapter,
cache.DefaultConversationCacheSettings(),
)
return &chatServiceImpl{
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
sseHub: sseHub,
db: db,
repo: repo,
userRepo: userRepo,
sensitive: sensitive,
sseHub: sseHub,
conversationCache: conversationCache,
}
}
@@ -86,18 +104,33 @@ func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payl
// GetOrCreateConversation 获取或创建私聊会话
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
conv, err := s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
if err != nil {
return nil, err
}
// 失效会话列表缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversationList(user1ID)
s.conversationCache.InvalidateConversationList(user2ID)
}
return conv, nil
}
// GetConversationList 获取用户的会话列表
// GetConversationList 获取用户的会话列表(带缓存)
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetConversationList(ctx, userID, page, pageSize)
}
return s.repo.GetConversations(userID, page, pageSize)
}
// GetConversationByID 获取会话详情
// GetConversationByID 获取会话详情(带缓存)
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, userID)
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
@@ -105,21 +138,33 @@ func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationI
return nil, fmt.Errorf("failed to get participant: %w", err)
}
// 获取会话信息
conv, err := s.repo.GetConversation(conversationID)
// 获取会话信息(优先使用缓存)
var conv *model.Conversation
if s.conversationCache != nil {
conv, err = s.conversationCache.GetConversation(ctx, conversationID)
} else {
conv, err = s.repo.GetConversation(conversationID)
}
if err != nil {
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 填充用户的已读位置信息
_ = participant // 可以用于返回已读位置等信息
return conv, nil
}
// getParticipant 获取参与者信息(优先使用缓存)
func (s *chatServiceImpl) getParticipant(ctx context.Context, conversationID, userID string) (*model.ConversationParticipant, error) {
if s.conversationCache != nil {
return s.conversationCache.GetParticipant(ctx, conversationID, userID)
}
return s.repo.GetParticipant(conversationID, userID)
}
// DeleteConversationForSelf 仅自己删除会话
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
@@ -133,12 +178,18 @@ func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, convers
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
return fmt.Errorf("failed to hide conversation: %w", err)
}
// 失效会话列表缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversationList(userID)
}
return nil
}
// SetConversationPinned 设置会话置顶(用户维度)
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
participant, err := s.repo.GetParticipant(conversationID, userID)
participant, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
@@ -152,13 +203,20 @@ func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversatio
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
return fmt.Errorf("failed to update pinned status: %w", err)
}
// 失效缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateParticipant(conversationID, userID)
s.conversationCache.InvalidateConversationList(userID)
}
return nil
}
// SendMessage 发送消息(使用 segments
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 首先验证会话是否存在
conv, err := s.repo.GetConversation(conversationID)
conv, err := s.getConversation(ctx, conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
@@ -166,9 +224,9 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
return nil, fmt.Errorf("failed to get conversation: %w", err)
}
// 拉黑限制:仅拦截被拉黑方 -> 拉黑人方向
// 拉黑限制:仅拦截"被拉黑方 -> 拉黑人"方向
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
participants, pErr := s.repo.GetConversationParticipants(conversationID)
participants, pErr := s.getParticipants(ctx, conversationID)
if pErr != nil {
return nil, fmt.Errorf("failed to get participants: %w", pErr)
}
@@ -209,7 +267,7 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
}
// 验证用户是否是会话参与者
participant, err := s.repo.GetParticipant(conversationID, senderID)
participant, err := s.getParticipant(ctx, conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
@@ -231,11 +289,27 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
if s.conversationCache != nil {
s.conversationCache.InvalidateMessagePages(conversationID)
}
// 异步写入缓存
go func() {
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
}
}()
// 获取会话中的参与者并发送 SSE
participants, err := s.repo.GetConversationParticipants(conversationID)
participants, err := s.getParticipants(ctx, conversationID)
if err == nil {
targetIDs := make([]string, 0, len(participants))
for _, p := range participants {
// 私聊场景下,发送者已经从 HTTP 响应拿到消息,避免再通过 SSE 回推导致本端重复展示。
if conv.Type == model.ConversationTypePrivate && p.UserID == senderID {
continue
}
targetIDs = append(targetIDs, p.UserID)
}
detailType := "private"
@@ -250,6 +324,10 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
if p.UserID == senderID {
continue
}
// 失效未读数缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateUnreadCount(p.UserID, conversationID)
}
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
"conversation_id": conversationID,
@@ -259,11 +337,46 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
}
}
// 失效会话列表缓存
if s.conversationCache != nil {
for _, p := range participants {
s.conversationCache.InvalidateConversationList(p.UserID)
}
}
_ = participant // 避免未使用变量警告
return message, nil
}
// getConversation 获取会话(优先使用缓存)
func (s *chatServiceImpl) getConversation(ctx context.Context, conversationID string) (*model.Conversation, error) {
if s.conversationCache != nil {
return s.conversationCache.GetConversation(ctx, conversationID)
}
return s.repo.GetConversation(conversationID)
}
// getParticipants 获取会话参与者(优先使用缓存)
func (s *chatServiceImpl) getParticipants(ctx context.Context, conversationID string) ([]*model.ConversationParticipant, error) {
if s.conversationCache != nil {
return s.conversationCache.GetParticipants(ctx, conversationID)
}
return s.repo.GetConversationParticipants(conversationID)
}
// cacheMessage 缓存消息(内部方法)
func (s *chatServiceImpl) cacheMessage(ctx context.Context, convID string, msg *model.Message) error {
if s.conversationCache == nil {
return nil
}
asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return s.conversationCache.CacheMessage(asyncCtx, convID, msg)
}
func containsImageSegment(segments model.MessageSegments) bool {
for _, seg := range segments {
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
@@ -273,10 +386,10 @@ func containsImageSegment(segments model.MessageSegments) bool {
return false
}
// GetMessages 获取消息历史(分页)
// GetMessages 获取消息历史(分页,带缓存
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, errors.New("conversation not found or no permission")
@@ -284,13 +397,18 @@ func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
}
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize)
}
return s.repo.GetMessages(conversationID, page, pageSize)
}
// GetMessagesAfterSeq 获取指定seq之后的消息用于增量同步
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
@@ -308,7 +426,7 @@ func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationI
// GetMessagesBeforeSeq 获取指定seq之前的历史消息用于下拉加载更多
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("conversation not found or no permission")
@@ -326,7 +444,7 @@ func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversation
// MarkAsRead 标记已读
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("conversation not found or no permission")
@@ -334,17 +452,27 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
return fmt.Errorf("failed to get participant: %w", err)
}
// 更新参与者的已读位置
// 1. 先写入DB保证数据一致性DB是唯一数据源
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
if err != nil {
return fmt.Errorf("failed to update last read seq: %w", err)
}
participants, pErr := s.repo.GetConversationParticipants(conversationID)
// 2. DB 写入成功后失效缓存Cache-Aside 模式)
if s.conversationCache != nil {
// 失效参与者缓存,下次读取时会从 DB 加载最新数据
s.conversationCache.InvalidateParticipant(conversationID, userID)
// 失效未读数缓存
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
// 失效会话列表缓存
s.conversationCache.InvalidateConversationList(userID)
}
participants, pErr := s.getParticipants(ctx, conversationID)
if pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
@@ -372,10 +500,10 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
return nil
}
// GetUnreadCount 获取指定会话的未读消息数
// GetUnreadCount 获取指定会话的未读消息数(带缓存)
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, userID)
_, err := s.getParticipant(ctx, conversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return 0, errors.New("conversation not found or no permission")
@@ -383,6 +511,11 @@ func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID str
return 0, fmt.Errorf("failed to get participant: %w", err)
}
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetUnreadCount(ctx, userID, conversationID)
}
return s.repo.GetUnreadCount(conversationID, userID)
}
@@ -427,10 +560,15 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u
return fmt.Errorf("failed to recall message: %w", err)
}
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
// 失效消息缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversation(message.ConversationID)
}
if participants, pErr := s.getParticipants(ctx, message.ConversationID); pErr == nil {
detailType := "private"
groupID := ""
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
if conv, convErr := s.getConversation(ctx, message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
if conv.GroupID != nil {
groupID = *conv.GroupID
@@ -465,7 +603,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(message.ConversationID, userID)
_, err = s.getParticipant(ctx, message.ConversationID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("no permission to delete this message")
@@ -485,6 +623,11 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
return fmt.Errorf("failed to delete message: %w", err)
}
// 失效消息缓存
if s.conversationCache != nil {
s.conversationCache.InvalidateConversation(message.ConversationID)
}
return nil
}
@@ -495,19 +638,19 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve
}
// 验证用户是否是会话参与者
_, err := s.repo.GetParticipant(conversationID, senderID)
_, err := s.getParticipant(ctx, conversationID, senderID)
if err != nil {
return
}
// 获取会话中的其他参与者
participants, err := s.repo.GetConversationParticipants(conversationID)
participants, err := s.getParticipants(ctx, conversationID)
if err != nil {
return
}
detailType := "private"
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
detailType = "group"
}
for _, p := range participants {
@@ -537,7 +680,7 @@ func (s *chatServiceImpl) IsUserOnline(userID string) bool {
// 适用于群聊等由调用方自行负责推送的场景
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
// 验证会话是否存在
_, err := s.repo.GetConversation(conversationID)
_, err := s.getConversation(ctx, conversationID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("会话不存在,请重新创建会话")
@@ -546,7 +689,7 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv
}
// 验证用户是否是会话参与者
_, err = s.repo.GetParticipant(conversationID, senderID)
_, err = s.getParticipant(ctx, conversationID, senderID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("您不是该会话的参与者")
@@ -566,5 +709,17 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv
return nil, fmt.Errorf("failed to save message: %w", err)
}
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
if s.conversationCache != nil {
s.conversationCache.InvalidateMessagePages(conversationID)
}
// 异步写入缓存
go func() {
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
}
}()
return message, nil
}

View File

@@ -145,6 +145,45 @@ func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMess
}
}
// invalidateConversationCachesAfterSystemMessage 系统消息写入后失效相关缓存
func (s *groupService) invalidateConversationCachesAfterSystemMessage(conversationID string) {
if conversationID == "" || s.messageRepo == nil {
return
}
// 新系统消息会影响消息分页列表
cache.InvalidateMessagePages(s.cache, conversationID)
// 参与者列表可能发生变化(加群/退群)后,这里统一清理一次
s.cache.Delete(cache.ParticipantListKey(conversationID))
participants, err := s.messageRepo.GetConversationParticipants(conversationID)
if err != nil {
return
}
for _, p := range participants {
if p == nil || p.UserID == "" {
continue
}
// 会话最后消息、未读数会变化,清理用户维度缓存
cache.InvalidateConversationList(s.cache, p.UserID)
cache.InvalidateUnreadConversation(s.cache, p.UserID)
cache.InvalidateUnreadDetail(s.cache, p.UserID, conversationID)
}
}
// invalidateConversationCachesAfterMembershipChange 成员变更后失效相关缓存
func (s *groupService) invalidateConversationCachesAfterMembershipChange(conversationID, userID string) {
if conversationID == "" {
return
}
s.cache.Delete(cache.ParticipantListKey(conversationID))
if userID != "" {
s.cache.Delete(cache.ParticipantKey(conversationID, userID))
cache.InvalidateConversationList(s.cache, userID)
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
}
}
// ==================== 群组管理 ====================
// CreateGroup 创建群组
@@ -444,6 +483,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
log.Printf("[broadcastMemberJoinNotice] 保存入群提示消息失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
} else {
savedMessage = msg
s.invalidateConversationCachesAfterSystemMessage(conv.ID)
}
} else {
log.Printf("[broadcastMemberJoinNotice] 获取群组会话失败: groupID=%s, err=%v", groupID, err)
@@ -502,6 +542,7 @@ func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userI
if err := s.messageRepo.AddParticipant(conv.ID, userID); err != nil {
log.Printf("[addMemberToGroupAndConversation] 添加会话参与者失败: groupID=%s, userID=%s, err=%v", group.ID, userID, err)
}
s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID)
}
}
cache.InvalidateGroupMembers(s.cache, group.ID)
@@ -1036,6 +1077,7 @@ func (s *groupService) LeaveGroup(userID string, groupID string) error {
// 如果移除参与者失败,记录日志但不阻塞退出群流程
fmt.Printf("[WARN] LeaveGroup: failed to remove participant %s from conversation %s, error: %v\n", userID, conv.ID, err)
}
s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID)
}
// 失效群组成员缓存
@@ -1092,6 +1134,7 @@ func (s *groupService) RemoveMember(userID string, groupID string, targetUserID
if err := s.messageRepo.RemoveParticipant(conv.ID, targetUserID); err != nil {
log.Printf("[RemoveMember] 移除会话参与者失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
}
s.invalidateConversationCachesAfterMembershipChange(conv.ID, targetUserID)
}
}
@@ -1290,6 +1333,7 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st
} else {
savedMessage = msg
log.Printf("[MuteMember] 禁言消息已保存, ID=%s, Seq=%d", msg.ID, msg.Seq)
s.invalidateConversationCachesAfterSystemMessage(conv.ID)
}
} else {
log.Printf("[MuteMember] 获取群组会话失败: %v", err)

View File

@@ -2,11 +2,14 @@ package service
import (
"context"
"log"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
// 缓存TTL常量
@@ -21,15 +24,37 @@ const (
// MessageService 消息服务
type MessageService struct {
db *gorm.DB
// 基础仓储
messageRepo *repository.MessageRepository
cache cache.Cache
// 缓存相关字段
conversationCache *cache.ConversationCache
// 基础缓存(用于简单缓存操作)
baseCache cache.Cache
}
// NewMessageService 创建消息服务
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
func NewMessageService(db *gorm.DB, messageRepo *repository.MessageRepository) *MessageService {
// 创建适配器
convRepoAdapter := cache.NewConversationRepositoryAdapter(messageRepo)
msgRepoAdapter := cache.NewMessageRepositoryAdapter(messageRepo)
// 创建会话缓存
conversationCache := cache.NewConversationCache(
cache.GetCache(),
convRepoAdapter,
msgRepoAdapter,
cache.DefaultConversationCacheSettings(),
)
return &MessageService{
messageRepo: messageRepo,
cache: cache.GetCache(),
db: db,
messageRepo: messageRepo,
conversationCache: conversationCache,
baseCache: cache.GetCache(),
}
}
@@ -61,20 +86,50 @@ func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID s
return nil, err
}
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
if s.conversationCache != nil {
s.conversationCache.InvalidateMessagePages(conv.ID)
}
// 异步写入缓存
go func() {
if err := s.cacheMessage(context.Background(), conv.ID, msg); err != nil {
log.Printf("[MessageService] async cache message failed, convID=%s, msgID=%s, err=%v", conv.ID, msg.ID, err)
}
}()
// 失效会话列表缓存(发送者和接收者)
cache.InvalidateConversationList(s.cache, senderID)
cache.InvalidateConversationList(s.cache, receiverID)
s.conversationCache.InvalidateConversationList(senderID)
s.conversationCache.InvalidateConversationList(receiverID)
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, receiverID)
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
cache.InvalidateUnreadConversation(s.baseCache, receiverID)
s.conversationCache.InvalidateUnreadCount(receiverID, conv.ID)
return msg, nil
}
// cacheMessage 缓存消息(内部方法)
func (s *MessageService) cacheMessage(ctx context.Context, convID string, msg *model.Message) error {
if s.conversationCache == nil {
return nil
}
asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
return s.conversationCache.CacheMessage(asyncCtx, convID, msg)
}
// GetConversations 获取会话列表(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
// 优先使用 ConversationCache
if s.conversationCache != nil {
return s.conversationCache.GetConversationList(ctx, userID, page, pageSize)
}
// 降级到基础缓存
cacheSettings := cache.GetSettings()
conversationTTL := cacheSettings.ConversationTTL
if conversationTTL <= 0 {
@@ -92,7 +147,7 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa
// 生成缓存键
cacheKey := cache.ConversationListKey(userID, page, pageSize)
result, err := cache.GetOrLoadTyped[*ConversationListResult](
s.cache,
s.baseCache,
cacheKey,
conversationTTL,
jitter,
@@ -117,8 +172,14 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa
return result.Conversations, result.Total, nil
}
// GetMessages 获取消息列表
// GetMessages 获取消息列表(带缓存)
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
// 优先使用 ConversationCache
if s.conversationCache != nil {
return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize)
}
// 降级到直接访问数据库
return s.messageRepo.GetMessages(conversationID, page, pageSize)
}
@@ -127,20 +188,25 @@ func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
}
// MarkAsRead 标记为已读
// MarkAsRead 标记为已读(使用 Cache-Aside 模式)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
// 1. 先写入DB保证数据一致性DB是唯一数据源
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
if err != nil {
return err
}
// 失效未读数缓存
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, userID)
// 2. DB 写入成功后失效缓存Cache-Aside 模式)
if s.conversationCache != nil {
// 失效参与者缓存,下次读取时会从 DB 加载最新数据
s.conversationCache.InvalidateParticipant(conversationID, userID)
// 失效未读数缓存
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
// 失效会话列表缓存
s.conversationCache.InvalidateConversationList(userID)
}
cache.InvalidateUnreadConversation(s.baseCache, userID)
return nil
}
@@ -148,6 +214,12 @@ func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string,
// GetUnreadCount 获取未读消息数(带缓存)
// userID 参数为 string 类型UUID格式与JWT中user_id保持一致
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
// 优先使用 ConversationCache
if s.conversationCache != nil {
return s.conversationCache.GetUnreadCount(ctx, userID, conversationID)
}
// 降级到基础缓存
cacheSettings := cache.GetSettings()
unreadTTL := cacheSettings.UnreadCountTTL
if unreadTTL <= 0 {
@@ -166,7 +238,7 @@ func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID stri
cacheKey := cache.UnreadDetailKey(userID, conversationID)
return cache.GetOrLoadTyped[int64](
s.cache,
s.baseCache,
cacheKey,
unreadTTL,
jitter,
@@ -186,14 +258,18 @@ func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, u
}
// 失效会话列表缓存
cache.InvalidateConversationList(s.cache, user1ID)
cache.InvalidateConversationList(s.cache, user2ID)
s.conversationCache.InvalidateConversationList(user1ID)
s.conversationCache.InvalidateConversationList(user2ID)
return conv, nil
}
// GetConversationParticipants 获取会话参与者列表
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
// 优先使用缓存
if s.conversationCache != nil {
return s.conversationCache.GetParticipants(context.Background(), conversationID)
}
return s.messageRepo.GetConversationParticipants(conversationID)
}
@@ -204,12 +280,12 @@ func ParseConversationID(idStr string) (string, error) {
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
func (s *MessageService) InvalidateUserConversationCache(userID string) {
cache.InvalidateConversationList(s.cache, userID)
cache.InvalidateUnreadConversation(s.cache, userID)
s.conversationCache.InvalidateConversationList(userID)
cache.InvalidateUnreadConversation(s.baseCache, userID)
}
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
cache.InvalidateUnreadConversation(s.cache, userID)
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
cache.InvalidateUnreadConversation(s.baseCache, userID)
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
}

View File

@@ -73,9 +73,20 @@ func (s *PostService) Create(ctx context.Context, userID, title, content string,
}
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
defer func() {
if r := recover(); r != nil {
log.Printf("[ERROR] Panic in post moderation async flow, fallback publish post=%s panic=%v", postID, r)
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish post %s after panic recovery: %v", postID, err)
return
}
s.invalidatePostCaches(postID)
}
}()
// 未启用AI时直接发布
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
} else {
s.invalidatePostCaches(postID)
@@ -87,7 +98,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
if err != nil {
var rejectedErr *PostModerationRejectedError
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
} else {
s.invalidatePostCaches(postID)
@@ -97,7 +108,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
}
// 规则审核不可用时降级为发布避免长时间pending
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
} else {
s.invalidatePostCaches(postID)
@@ -106,7 +117,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
return
}
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil {
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
return
}
@@ -127,6 +138,26 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
}
}
func (s *PostService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
const maxAttempts = 3
const retryDelay = 200 * time.Millisecond
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil {
lastErr = err
if attempt < maxAttempts {
log.Printf("[WARN] UpdateModerationStatus failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err)
time.Sleep(time.Duration(attempt) * retryDelay)
continue
}
} else {
return nil
}
}
return lastErr
}
func (s *PostService) invalidatePostCaches(postID string) {
cache.InvalidatePostDetail(s.cache, postID)
cache.InvalidatePostList(s.cache)

View File

@@ -0,0 +1,207 @@
package service
import (
"encoding/json"
"errors"
"regexp"
"sort"
"strings"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/repository"
"gorm.io/gorm"
)
var (
ErrInvalidSchedulePayload = &ServiceError{Code: 400, Message: "invalid schedule payload"}
ErrScheduleCourseNotFound = &ServiceError{Code: 404, Message: "schedule course not found"}
ErrScheduleForbidden = &ServiceError{Code: 403, Message: "forbidden schedule operation"}
ErrScheduleColorDuplicated = &ServiceError{Code: 400, Message: "course color already used"}
)
var hexColorRegex = regexp.MustCompile(`^#[0-9A-F]{6}$`)
type CreateScheduleCourseInput struct {
Name string
Teacher string
Location string
DayOfWeek int
StartSection int
EndSection int
Weeks []int
Color string
}
type ScheduleService interface {
ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error)
CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error)
UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error)
DeleteCourse(userID, courseID string) error
}
type scheduleService struct {
repo repository.ScheduleRepository
}
func NewScheduleService(repo repository.ScheduleRepository) ScheduleService {
return &scheduleService{repo: repo}
}
func (s *scheduleService) ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error) {
courses, err := s.repo.ListByUserID(userID)
if err != nil {
return nil, err
}
result := make([]*dto.ScheduleCourseResponse, 0, len(courses))
for _, item := range courses {
weeks := dto.ParseWeeksJSON(item.Weeks)
if week > 0 && !containsWeek(weeks, week) {
continue
}
result = append(result, dto.ConvertScheduleCourseToResponse(item, weeks))
}
return result, nil
}
func (s *scheduleService) CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) {
entity, weeks, err := buildScheduleEntity(userID, input, nil)
if err != nil {
return nil, err
}
if err := s.ensureUniqueColor(userID, entity.Color, ""); err != nil {
return nil, err
}
if err := s.repo.Create(entity); err != nil {
return nil, err
}
return dto.ConvertScheduleCourseToResponse(entity, weeks), nil
}
func (s *scheduleService) UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) {
existing, err := s.repo.GetByID(courseID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrScheduleCourseNotFound
}
return nil, err
}
if existing.UserID != userID {
return nil, ErrScheduleForbidden
}
entity, weeks, err := buildScheduleEntity(userID, input, existing)
if err != nil {
return nil, err
}
if err := s.ensureUniqueColor(userID, entity.Color, entity.ID); err != nil {
return nil, err
}
if err := s.repo.Update(entity); err != nil {
return nil, err
}
return dto.ConvertScheduleCourseToResponse(entity, weeks), nil
}
func (s *scheduleService) DeleteCourse(userID, courseID string) error {
existing, err := s.repo.GetByID(courseID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrScheduleCourseNotFound
}
return err
}
if existing.UserID != userID {
return ErrScheduleForbidden
}
return s.repo.DeleteByID(courseID)
}
func buildScheduleEntity(userID string, input CreateScheduleCourseInput, target *model.ScheduleCourse) (*model.ScheduleCourse, []int, error) {
name := strings.TrimSpace(input.Name)
if name == "" || input.DayOfWeek < 0 || input.DayOfWeek > 6 || input.StartSection < 1 || input.EndSection < input.StartSection {
return nil, nil, ErrInvalidSchedulePayload
}
weeks := normalizeWeeks(input.Weeks)
if len(weeks) == 0 {
return nil, nil, ErrInvalidSchedulePayload
}
weeksJSON, err := json.Marshal(weeks)
if err != nil {
return nil, nil, err
}
entity := target
if entity == nil {
entity = &model.ScheduleCourse{
UserID: userID,
}
}
normalizedColor := normalizeHexColor(input.Color)
if normalizedColor == "" || !hexColorRegex.MatchString(normalizedColor) {
return nil, nil, ErrInvalidSchedulePayload
}
entity.Name = name
entity.Teacher = strings.TrimSpace(input.Teacher)
entity.Location = strings.TrimSpace(input.Location)
entity.DayOfWeek = input.DayOfWeek
entity.StartSection = input.StartSection
entity.EndSection = input.EndSection
entity.Weeks = string(weeksJSON)
entity.Color = normalizedColor
return entity, weeks, nil
}
func (s *scheduleService) ensureUniqueColor(userID, color, excludeID string) error {
exists, err := s.repo.ExistsColorByUser(userID, color, excludeID)
if err != nil {
return err
}
if exists {
return ErrScheduleColorDuplicated
}
return nil
}
func normalizeWeeks(source []int) []int {
unique := make(map[int]struct{}, len(source))
result := make([]int, 0, len(source))
for _, w := range source {
if w < 1 || w > 30 {
continue
}
if _, exists := unique[w]; exists {
continue
}
unique[w] = struct{}{}
result = append(result, w)
}
sort.Ints(result)
return result
}
func containsWeek(weeks []int, target int) bool {
for _, week := range weeks {
if week == target {
return true
}
}
return false
}
func normalizeHexColor(color string) string {
trimmed := strings.TrimSpace(color)
if trimmed == "" {
return ""
}
if strings.HasPrefix(trimmed, "#") {
return strings.ToUpper(trimmed)
}
return "#" + strings.ToUpper(trimmed)
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"log"
"strings"
"time"
"carrot_bbs/internal/cache"
"carrot_bbs/internal/dto"
@@ -84,8 +85,17 @@ func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dt
}
func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) {
defer func() {
if r := recover(); r != nil {
log.Printf("[ERROR] Panic in vote post moderation async flow, fallback publish post=%s panic=%v", postID, r)
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish vote post %s after panic recovery: %v", postID, err)
}
}
}()
if s.postAIService == nil || !s.postAIService.IsEnabled() {
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err)
}
return
@@ -95,24 +105,44 @@ func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string,
if err != nil {
var rejectedErr *PostModerationRejectedError
if errors.As(err, &rejectedErr) {
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr)
}
s.notifyModerationRejected(userID, rejectedErr.Reason)
return
}
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr)
}
return
}
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil {
log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err)
}
}
func (s *VoteService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
const maxAttempts = 3
const retryDelay = 200 * time.Millisecond
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil {
lastErr = err
if attempt < maxAttempts {
log.Printf("[WARN] UpdateModerationStatus for vote post failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err)
time.Sleep(time.Duration(attempt) * retryDelay)
continue
}
} else {
return nil
}
}
return lastErr
}
func (s *VoteService) notifyModerationRejected(userID, reason string) {
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
return