feat(schedule): add course table screens and navigation
Add complete schedule functionality including: - Schedule screen with weekly course table view - Course detail screen with transparent modal presentation - New ScheduleStack navigator integrated into main tab bar - Schedule service for API interactions - Type definitions for course entities Also includes bug fixes for group invite/request handlers to include required groupId parameter.
This commit is contained in:
@@ -170,3 +170,23 @@ email:
|
|||||||
use_tls: true
|
use_tls: true
|
||||||
insecure_skip_verify: false
|
insecure_skip_verify: false
|
||||||
timeout: 15
|
timeout: 15
|
||||||
|
|
||||||
|
# 会话缓存配置
|
||||||
|
conversation_cache:
|
||||||
|
# TTL 配置
|
||||||
|
detail_ttl: 5m # 会话详情缓存时间
|
||||||
|
list_ttl: 60s # 会话列表缓存时间
|
||||||
|
participant_ttl: 5m # 参与者缓存时间
|
||||||
|
unread_ttl: 30s # 未读数缓存时间
|
||||||
|
|
||||||
|
# 消息缓存配置
|
||||||
|
message_detail_ttl: 30m # 单条消息详情缓存
|
||||||
|
message_list_ttl: 5m # 消息分页列表缓存
|
||||||
|
message_index_ttl: 30m # 消息索引缓存
|
||||||
|
message_count_ttl: 30m # 消息计数缓存
|
||||||
|
|
||||||
|
# 批量写入配置
|
||||||
|
batch_interval: 5s # 写入间隔
|
||||||
|
batch_threshold: 100 # 条数阈值
|
||||||
|
batch_max_size: 500 # 单次最大批量
|
||||||
|
buffer_max_size: 10000 # 写缓冲最大条数
|
||||||
|
|||||||
512
internal/cache/cache.go
vendored
512
internal/cache/cache.go
vendored
@@ -5,7 +5,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -34,6 +37,38 @@ type Cache interface {
|
|||||||
Increment(key string) int64
|
Increment(key string) int64
|
||||||
// IncrementBy 增加指定值
|
// IncrementBy 增加指定值
|
||||||
IncrementBy(key string, value int64) int64
|
IncrementBy(key string, value int64) int64
|
||||||
|
|
||||||
|
// ==================== Hash 操作 ====================
|
||||||
|
// HSet 设置 Hash 字段
|
||||||
|
HSet(ctx context.Context, key string, field string, value interface{}) error
|
||||||
|
// HMSet 批量设置 Hash 字段
|
||||||
|
HMSet(ctx context.Context, key string, values map[string]interface{}) error
|
||||||
|
// HGet 获取 Hash 字段值
|
||||||
|
HGet(ctx context.Context, key string, field string) (string, error)
|
||||||
|
// HMGet 批量获取 Hash 字段值
|
||||||
|
HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error)
|
||||||
|
// HGetAll 获取 Hash 所有字段
|
||||||
|
HGetAll(ctx context.Context, key string) (map[string]string, error)
|
||||||
|
// HDel 删除 Hash 字段
|
||||||
|
HDel(ctx context.Context, key string, fields ...string) error
|
||||||
|
|
||||||
|
// ==================== Sorted Set 操作 ====================
|
||||||
|
// ZAdd 添加 Sorted Set 成员
|
||||||
|
ZAdd(ctx context.Context, key string, score float64, member string) error
|
||||||
|
// ZRangeByScore 按分数范围获取成员(升序)
|
||||||
|
ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error)
|
||||||
|
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||||||
|
ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error)
|
||||||
|
// ZRem 删除 Sorted Set 成员
|
||||||
|
ZRem(ctx context.Context, key string, members ...interface{}) error
|
||||||
|
// ZCard 获取 Sorted Set 成员数量
|
||||||
|
ZCard(ctx context.Context, key string) (int64, error)
|
||||||
|
|
||||||
|
// ==================== 计数器操作 ====================
|
||||||
|
// Incr 原子递增(返回新值)
|
||||||
|
Incr(ctx context.Context, key string) (int64, error)
|
||||||
|
// Expire 设置过期时间
|
||||||
|
Expire(ctx context.Context, key string, ttl time.Duration) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheItem 缓存项(用于内存缓存降级)
|
// cacheItem 缓存项(用于内存缓存降级)
|
||||||
@@ -64,16 +99,16 @@ type MetricsSnapshot struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Settings struct {
|
type Settings struct {
|
||||||
Enabled bool
|
Enabled bool
|
||||||
KeyPrefix string
|
KeyPrefix string
|
||||||
DefaultTTL time.Duration
|
DefaultTTL time.Duration
|
||||||
NullTTL time.Duration
|
NullTTL time.Duration
|
||||||
JitterRatio float64
|
JitterRatio float64
|
||||||
PostListTTL time.Duration
|
PostListTTL time.Duration
|
||||||
ConversationTTL time.Duration
|
ConversationTTL time.Duration
|
||||||
UnreadCountTTL time.Duration
|
UnreadCountTTL time.Duration
|
||||||
GroupMembersTTL time.Duration
|
GroupMembersTTL time.Duration
|
||||||
DisableFlushDB bool
|
DisableFlushDB bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var settings = Settings{
|
var settings = Settings{
|
||||||
@@ -327,6 +362,378 @@ func (c *MemoryCache) Stop() {
|
|||||||
close(c.stopCleanup)
|
close(c.stopCleanup)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== MemoryCache Hash 操作 ====================
|
||||||
|
|
||||||
|
// hashItem Hash 存储项
|
||||||
|
type hashItem struct {
|
||||||
|
fields sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// HSet 设置 Hash 字段
|
||||||
|
func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, _ := c.items.Load(key)
|
||||||
|
var h *hashItem
|
||||||
|
if item == nil {
|
||||||
|
h = &hashItem{}
|
||||||
|
c.items.Store(key, &cacheItem{value: h, expiration: 0})
|
||||||
|
} else {
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
h = &hashItem{}
|
||||||
|
c.items.Store(key, &cacheItem{value: h, expiration: 0})
|
||||||
|
} else {
|
||||||
|
h = ci.value.(*hashItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.fields.Store(field, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMSet 批量设置 Hash 字段
|
||||||
|
func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
|
||||||
|
for field, value := range values {
|
||||||
|
if err := c.HSet(ctx, key, field, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGet 获取 Hash 字段值
|
||||||
|
func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return "", fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
h, ok := ci.value.(*hashItem)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("key is not a hash")
|
||||||
|
}
|
||||||
|
val, ok := h.fields.Load(field)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("field not found")
|
||||||
|
}
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
return v, nil
|
||||||
|
case []byte:
|
||||||
|
return string(v), nil
|
||||||
|
default:
|
||||||
|
data, _ := json.Marshal(v)
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMGet 批量获取 Hash 字段值
|
||||||
|
func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
|
||||||
|
result := make([]interface{}, len(fields))
|
||||||
|
for i, field := range fields {
|
||||||
|
val, err := c.HGet(ctx, key, field)
|
||||||
|
if err != nil {
|
||||||
|
result[i] = nil
|
||||||
|
} else {
|
||||||
|
result[i] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGetAll 获取 Hash 所有字段
|
||||||
|
func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil, fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
h, ok := ci.value.(*hashItem)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key is not a hash")
|
||||||
|
}
|
||||||
|
result := make(map[string]string)
|
||||||
|
h.fields.Range(func(k, v interface{}) bool {
|
||||||
|
keyStr := k.(string)
|
||||||
|
switch val := v.(type) {
|
||||||
|
case string:
|
||||||
|
result[keyStr] = val
|
||||||
|
case []byte:
|
||||||
|
result[keyStr] = string(val)
|
||||||
|
default:
|
||||||
|
data, _ := json.Marshal(val)
|
||||||
|
result[keyStr] = string(data)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HDel 删除 Hash 字段
|
||||||
|
func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h, ok := ci.value.(*hashItem)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, field := range fields {
|
||||||
|
h.fields.Delete(field)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== MemoryCache Sorted Set 操作 ====================
|
||||||
|
|
||||||
|
// zItem Sorted Set 成员
|
||||||
|
type zItem struct {
|
||||||
|
score float64
|
||||||
|
member string
|
||||||
|
}
|
||||||
|
|
||||||
|
// zsetItem Sorted Set 存储项
|
||||||
|
type zsetItem struct {
|
||||||
|
members sync.Map // member -> *zItem
|
||||||
|
byScore *sortedSlice // 按分数排序的切片
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortedSlice 简单的排序切片实现
|
||||||
|
type sortedSlice struct {
|
||||||
|
items []*zItem
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZAdd 添加 Sorted Set 成员
|
||||||
|
func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, _ := c.items.Load(key)
|
||||||
|
var z *zsetItem
|
||||||
|
if item == nil {
|
||||||
|
z = &zsetItem{byScore: &sortedSlice{}}
|
||||||
|
c.items.Store(key, &cacheItem{value: z, expiration: 0})
|
||||||
|
} else {
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
z = &zsetItem{byScore: &sortedSlice{}}
|
||||||
|
c.items.Store(key, &cacheItem{value: z, expiration: 0})
|
||||||
|
} else {
|
||||||
|
z = ci.value.(*zsetItem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
z.members.Store(member, &zItem{score: score, member: member})
|
||||||
|
z.byScore.mu.Lock()
|
||||||
|
// 简单实现:重新构建排序切片
|
||||||
|
z.byScore.items = nil
|
||||||
|
z.members.Range(func(k, v interface{}) bool {
|
||||||
|
z.byScore.items = append(z.byScore.items, v.(*zItem))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
// 按分数排序
|
||||||
|
sort.Slice(z.byScore.items, func(i, j int) bool {
|
||||||
|
return z.byScore.items[i].score < z.byScore.items[j].score
|
||||||
|
})
|
||||||
|
z.byScore.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRangeByScore 按分数范围获取成员(升序)
|
||||||
|
func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
z, ok := ci.value.(*zsetItem)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key is not a sorted set")
|
||||||
|
}
|
||||||
|
|
||||||
|
minScore, _ := strconv.ParseFloat(min, 64)
|
||||||
|
maxScore, _ := strconv.ParseFloat(max, 64)
|
||||||
|
if min == "-inf" {
|
||||||
|
minScore = math.Inf(-1)
|
||||||
|
}
|
||||||
|
if max == "+inf" {
|
||||||
|
maxScore = math.Inf(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
z.byScore.mu.RLock()
|
||||||
|
defer z.byScore.mu.RUnlock()
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
var skipped int64 = 0
|
||||||
|
for _, item := range z.byScore.items {
|
||||||
|
if item.score < minScore || item.score > maxScore {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if skipped < offset {
|
||||||
|
skipped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if count > 0 && int64(len(result)) >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, item.member)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||||||
|
func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
z, ok := ci.value.(*zsetItem)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("key is not a sorted set")
|
||||||
|
}
|
||||||
|
|
||||||
|
minScore, _ := strconv.ParseFloat(min, 64)
|
||||||
|
maxScore, _ := strconv.ParseFloat(max, 64)
|
||||||
|
if min == "-inf" {
|
||||||
|
minScore = math.Inf(-1)
|
||||||
|
}
|
||||||
|
if max == "+inf" {
|
||||||
|
maxScore = math.Inf(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
z.byScore.mu.RLock()
|
||||||
|
defer z.byScore.mu.RUnlock()
|
||||||
|
|
||||||
|
var result []string
|
||||||
|
var skipped int64 = 0
|
||||||
|
// 从后往前遍历
|
||||||
|
for i := len(z.byScore.items) - 1; i >= 0; i-- {
|
||||||
|
item := z.byScore.items[i]
|
||||||
|
if item.score < minScore || item.score > maxScore {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if skipped < offset {
|
||||||
|
skipped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if count > 0 && int64(len(result)) >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, item.member)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRem 删除 Sorted Set 成员
|
||||||
|
func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
z, ok := ci.value.(*zsetItem)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, m := range members {
|
||||||
|
if member, ok := m.(string); ok {
|
||||||
|
z.members.Delete(member)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 重建排序切片
|
||||||
|
z.byScore.mu.Lock()
|
||||||
|
z.byScore.items = nil
|
||||||
|
z.members.Range(func(k, v interface{}) bool {
|
||||||
|
z.byScore.items = append(z.byScore.items, v.(*zItem))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
sort.Slice(z.byScore.items, func(i, j int) bool {
|
||||||
|
return z.byScore.items[i].score < z.byScore.items[j].score
|
||||||
|
})
|
||||||
|
z.byScore.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZCard 获取 Sorted Set 成员数量
|
||||||
|
func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
if ci.isExpired() {
|
||||||
|
c.items.Delete(key)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
z, ok := ci.value.(*zsetItem)
|
||||||
|
if !ok {
|
||||||
|
return 0, fmt.Errorf("key is not a sorted set")
|
||||||
|
}
|
||||||
|
var count int64 = 0
|
||||||
|
z.members.Range(func(k, v interface{}) bool {
|
||||||
|
count++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== MemoryCache 计数器操作 ====================
|
||||||
|
|
||||||
|
// Incr 原子递增(返回新值)
|
||||||
|
func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) {
|
||||||
|
return c.IncrementBy(key, 1), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire 设置过期时间
|
||||||
|
func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
item, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
ci := item.(*cacheItem)
|
||||||
|
var expiration int64
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl).UnixNano()
|
||||||
|
}
|
||||||
|
c.items.Store(key, &cacheItem{
|
||||||
|
value: ci.value,
|
||||||
|
expiration: expiration,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RedisCache Redis缓存实现
|
// RedisCache Redis缓存实现
|
||||||
type RedisCache struct {
|
type RedisCache struct {
|
||||||
client *redisPkg.Client
|
client *redisPkg.Client
|
||||||
@@ -451,6 +858,91 @@ func (c *RedisCache) IncrementBy(key string, value int64) int64 {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== RedisCache Hash 操作 ====================
|
||||||
|
|
||||||
|
// HSet 设置 Hash 字段
|
||||||
|
func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HSet(ctx, key, field, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMSet 批量设置 Hash 字段
|
||||||
|
func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HMSet(ctx, key, values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGet 获取 Hash 字段值
|
||||||
|
func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HGet(ctx, key, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMGet 批量获取 Hash 字段值
|
||||||
|
func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HMGet(ctx, key, fields...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGetAll 获取 Hash 所有字段
|
||||||
|
func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HGetAll(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HDel 删除 Hash 字段
|
||||||
|
func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.HDel(ctx, key, fields...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== RedisCache Sorted Set 操作 ====================
|
||||||
|
|
||||||
|
// ZAdd 添加 Sorted Set 成员
|
||||||
|
func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.ZAdd(ctx, key, score, member)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRangeByScore 按分数范围获取成员(升序)
|
||||||
|
func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.ZRangeByScore(ctx, key, min, max, offset, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||||||
|
func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRem 删除 Sorted Set 成员
|
||||||
|
func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.ZRem(ctx, key, members...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZCard 获取 Sorted Set 成员数量
|
||||||
|
func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.ZCard(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== RedisCache 计数器操作 ====================
|
||||||
|
|
||||||
|
// Incr 原子递增(返回新值)
|
||||||
|
func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
return c.client.Incr(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expire 设置过期时间
|
||||||
|
func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||||
|
key = normalizeKey(key)
|
||||||
|
_, err := c.client.Expire(ctx, key, ttl)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// 全局缓存实例
|
// 全局缓存实例
|
||||||
var globalCache Cache
|
var globalCache Cache
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
|
|||||||
724
internal/cache/conversation_cache.go
vendored
Normal file
724
internal/cache/conversation_cache.go
vendored
Normal file
@@ -0,0 +1,724 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CachedConversation 带缓存元数据的会话
|
||||||
|
type CachedConversation struct {
|
||||||
|
Data *model.Conversation // 实际数据
|
||||||
|
Version int64 // 版本号(CAS 更新用)
|
||||||
|
UpdatedAt time.Time // 最后更新时间
|
||||||
|
AccessAt time.Time // 最后访问时间(用于 TTL 延长)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedParticipant 带缓存元数据的参与者
|
||||||
|
type CachedParticipant struct {
|
||||||
|
Data *model.ConversationParticipant
|
||||||
|
Version int64
|
||||||
|
UpdatedAt time.Time
|
||||||
|
AccessAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedMessage 带缓存元数据的消息
|
||||||
|
type CachedMessage struct {
|
||||||
|
Data *model.Message `json:"data"` // 消息数据
|
||||||
|
Seq int64 `json:"seq"` // 消息序号
|
||||||
|
CreatedAt time.Time `json:"created_at"` // 创建时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageCacheData Redis 中存储的消息 Hash 结构
|
||||||
|
type MessageCacheData struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ConversationID string `json:"conversation_id"`
|
||||||
|
SenderID string `json:"sender_id"`
|
||||||
|
Seq int64 `json:"seq"`
|
||||||
|
Segments json.RawMessage `json:"segments"`
|
||||||
|
ReplyToID *string `json:"reply_to_id,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
SystemType string `json:"system_type,omitempty"`
|
||||||
|
ExtraData json.RawMessage `json:"extra_data,omitempty"`
|
||||||
|
MentionUsers string `json:"mention_users"`
|
||||||
|
MentionAll bool `json:"mention_all"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PageCache 分页缓存
|
||||||
|
type PageCache struct {
|
||||||
|
Seqs []int64 `json:"seqs"` // 当前页的消息 seq 列表
|
||||||
|
Total int64 `json:"total"` // 消息总数
|
||||||
|
Page int `json:"page"` // 当前页码
|
||||||
|
PageSize int `json:"page_size"` // 每页大小
|
||||||
|
HasMore bool `json:"has_more"` // 是否有更多
|
||||||
|
UpdatedAt time.Time `json:"updated_at"` // 更新时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationCacheSettings 缓存配置
|
||||||
|
type ConversationCacheSettings struct {
|
||||||
|
DetailTTL time.Duration // 会话详情 TTL (5min)
|
||||||
|
ListTTL time.Duration // 会话列表 TTL (60s)
|
||||||
|
ParticipantTTL time.Duration // 参与者 TTL (5min)
|
||||||
|
UnreadTTL time.Duration // 未读数 TTL (30s)
|
||||||
|
|
||||||
|
// 消息缓存配置
|
||||||
|
MessageDetailTTL time.Duration // 单条消息详情缓存 (30min)
|
||||||
|
MessageListTTL time.Duration // 消息分页列表缓存 (5min)
|
||||||
|
MessageIndexTTL time.Duration // 消息索引缓存 (30min)
|
||||||
|
MessageCountTTL time.Duration // 消息计数缓存 (30min)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConversationCacheSettings 返回默认配置
|
||||||
|
func DefaultConversationCacheSettings() *ConversationCacheSettings {
|
||||||
|
return &ConversationCacheSettings{
|
||||||
|
DetailTTL: 5 * time.Minute,
|
||||||
|
ListTTL: 60 * time.Second,
|
||||||
|
ParticipantTTL: 5 * time.Minute,
|
||||||
|
UnreadTTL: 30 * time.Second,
|
||||||
|
MessageDetailTTL: 30 * time.Minute,
|
||||||
|
MessageListTTL: 5 * time.Minute,
|
||||||
|
MessageIndexTTL: 30 * time.Minute,
|
||||||
|
MessageCountTTL: 30 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseSegments 将 JSON RawMessage 解析为 MessageSegments
|
||||||
|
func parseSegments(data json.RawMessage) model.MessageSegments {
|
||||||
|
if data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var segments model.MessageSegments
|
||||||
|
if err := json.Unmarshal(data, &segments); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return segments
|
||||||
|
}
|
||||||
|
|
||||||
|
// serializeSegments 将 MessageSegments 序列化为 JSON RawMessage
|
||||||
|
func serializeSegments(segments model.MessageSegments) json.RawMessage {
|
||||||
|
if segments == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(segments)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToModel 将 MessageCacheData 转换为 model.Message
|
||||||
|
func (m *MessageCacheData) ToModel() *model.Message {
|
||||||
|
return &model.Message{
|
||||||
|
ID: m.ID,
|
||||||
|
ConversationID: m.ConversationID,
|
||||||
|
SenderID: m.SenderID,
|
||||||
|
Seq: m.Seq,
|
||||||
|
Segments: parseSegments(m.Segments),
|
||||||
|
ReplyToID: m.ReplyToID,
|
||||||
|
Status: model.MessageStatus(m.Status),
|
||||||
|
Category: model.MessageCategory(m.Category),
|
||||||
|
SystemType: model.SystemMessageType(m.SystemType),
|
||||||
|
ExtraData: parseExtraData(m.ExtraData),
|
||||||
|
MentionUsers: m.MentionUsers,
|
||||||
|
MentionAll: m.MentionAll,
|
||||||
|
CreatedAt: m.CreatedAt,
|
||||||
|
UpdatedAt: m.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageCacheDataFromModel 从 model.Message 创建 MessageCacheData
|
||||||
|
func MessageCacheDataFromModel(msg *model.Message) *MessageCacheData {
|
||||||
|
return &MessageCacheData{
|
||||||
|
ID: msg.ID,
|
||||||
|
ConversationID: msg.ConversationID,
|
||||||
|
SenderID: msg.SenderID,
|
||||||
|
Seq: msg.Seq,
|
||||||
|
Segments: serializeSegments(msg.Segments),
|
||||||
|
ReplyToID: msg.ReplyToID,
|
||||||
|
Status: string(msg.Status),
|
||||||
|
Category: string(msg.Category),
|
||||||
|
SystemType: string(msg.SystemType),
|
||||||
|
ExtraData: serializeExtraData(msg.ExtraData),
|
||||||
|
MentionUsers: msg.MentionUsers,
|
||||||
|
MentionAll: msg.MentionAll,
|
||||||
|
CreatedAt: msg.CreatedAt,
|
||||||
|
UpdatedAt: msg.UpdatedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseExtraData 将 JSON RawMessage 解析为 ExtraData
|
||||||
|
func parseExtraData(data json.RawMessage) *model.ExtraData {
|
||||||
|
if data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var extraData model.ExtraData
|
||||||
|
if err := json.Unmarshal(data, &extraData); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &extraData
|
||||||
|
}
|
||||||
|
|
||||||
|
// serializeExtraData 将 ExtraData 序列化为 JSON RawMessage
|
||||||
|
func serializeExtraData(extraData *model.ExtraData) json.RawMessage {
|
||||||
|
if extraData == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(extraData)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// 缓存 Key 常量和生成函数
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
const (
|
||||||
|
keyPrefixConv = "conv" // 会话详情
|
||||||
|
keyPrefixConvPart = "conv_part" // 参与者列表
|
||||||
|
keyPrefixConvPartUser = "conv_part_user" // 用户参与者信息
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationKey 会话详情缓存 key
|
||||||
|
func ConversationKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixConv, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantListKey 参与者列表缓存 key
|
||||||
|
func ParticipantListKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixConvPart, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantKey 用户参与者信息缓存 key
|
||||||
|
func ParticipantKey(convID, userID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s:%s", keyPrefixConvPartUser, convID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// ConversationRepository 接口定义
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// ConversationRepository 会话数据仓库接口(用于依赖注入)
|
||||||
|
type ConversationRepository interface {
|
||||||
|
GetConversationByID(convID string) (*model.Conversation, error)
|
||||||
|
GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error)
|
||||||
|
GetParticipant(convID, userID string) (*model.ConversationParticipant, error)
|
||||||
|
GetParticipants(convID string) ([]*model.ConversationParticipant, error)
|
||||||
|
GetUnreadCount(userID, convID string) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageRepository 消息数据仓库接口(用于依赖注入)
|
||||||
|
type MessageRepository interface {
|
||||||
|
GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error)
|
||||||
|
GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error)
|
||||||
|
GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// ConversationCache 核心实现
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// ConversationCache 会话缓存管理器
|
||||||
|
type ConversationCache struct {
|
||||||
|
cache Cache // 底层缓存
|
||||||
|
settings *ConversationCacheSettings // 配置
|
||||||
|
repo ConversationRepository // 数据仓库接口(用于 cache-aside 回源)
|
||||||
|
msgRepo MessageRepository // 消息数据仓库接口(用于消息缓存回源)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConversationCache 创建会话缓存管理器
|
||||||
|
func NewConversationCache(cache Cache, repo ConversationRepository, msgRepo MessageRepository, settings *ConversationCacheSettings) *ConversationCache {
|
||||||
|
if settings == nil {
|
||||||
|
settings = DefaultConversationCacheSettings()
|
||||||
|
}
|
||||||
|
return &ConversationCache{
|
||||||
|
cache: cache,
|
||||||
|
settings: settings,
|
||||||
|
repo: repo,
|
||||||
|
msgRepo: msgRepo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversation 读取会话(带 TTL 滑动延长)
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
// 2. 如果命中,更新 AccessAt 并延长 TTL
|
||||||
|
// 3. 如果未命中,从 repo 加载并写入缓存
|
||||||
|
func (c *ConversationCache) GetConversation(ctx context.Context, convID string) (*model.Conversation, error) {
|
||||||
|
key := ConversationKey(convID)
|
||||||
|
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
cached, ok := GetTyped[*CachedConversation](c.cache, key)
|
||||||
|
if ok && cached != nil && cached.Data != nil {
|
||||||
|
// 2. 命中,更新 AccessAt 并延长 TTL
|
||||||
|
cached.AccessAt = time.Now()
|
||||||
|
c.cache.Set(key, cached, c.settings.DetailTTL)
|
||||||
|
return cached.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 未命中,从 repo 加载
|
||||||
|
if c.repo == nil {
|
||||||
|
return nil, fmt.Errorf("repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := c.repo.GetConversationByID(convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入缓存
|
||||||
|
now := time.Now()
|
||||||
|
cachedConv := &CachedConversation{
|
||||||
|
Data: conv,
|
||||||
|
Version: 0,
|
||||||
|
UpdatedAt: now,
|
||||||
|
AccessAt: now,
|
||||||
|
}
|
||||||
|
c.cache.Set(key, cachedConv, c.settings.DetailTTL)
|
||||||
|
|
||||||
|
return conv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedConversationList 带元数据的会话列表缓存
|
||||||
|
type CachedConversationList struct {
|
||||||
|
Conversations []*model.Conversation `json:"conversations"`
|
||||||
|
Total int64 `json:"total"`
|
||||||
|
Page int `json:"page"`
|
||||||
|
PageSize int `json:"page_size"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
AccessAt time.Time `json:"access_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationList 获取用户会话列表(带 TTL 滑动延长)
|
||||||
|
func (c *ConversationCache) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
key := ConversationListKey(userID, page, pageSize)
|
||||||
|
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
cached, ok := GetTyped[*CachedConversationList](c.cache, key)
|
||||||
|
if ok && cached != nil {
|
||||||
|
// 2. 命中,更新 AccessAt 并延长 TTL
|
||||||
|
cached.AccessAt = time.Now()
|
||||||
|
c.cache.Set(key, cached, c.settings.ListTTL)
|
||||||
|
return cached.Conversations, cached.Total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 未命中,从 repo 加载
|
||||||
|
if c.repo == nil {
|
||||||
|
return nil, 0, fmt.Errorf("repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
convs, total, err := c.repo.GetConversationsByUserID(userID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入缓存
|
||||||
|
now := time.Now()
|
||||||
|
cachedList := &CachedConversationList{
|
||||||
|
Conversations: convs,
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
UpdatedAt: now,
|
||||||
|
AccessAt: now,
|
||||||
|
}
|
||||||
|
c.cache.Set(key, cachedList, c.settings.ListTTL)
|
||||||
|
|
||||||
|
return convs, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipant 获取参与者信息(带 TTL 滑动延长)
|
||||||
|
func (c *ConversationCache) GetParticipant(ctx context.Context, convID, userID string) (*model.ConversationParticipant, error) {
|
||||||
|
key := ParticipantKey(convID, userID)
|
||||||
|
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
cached, ok := GetTyped[*CachedParticipant](c.cache, key)
|
||||||
|
if ok && cached != nil && cached.Data != nil {
|
||||||
|
// 2. 命中,更新 AccessAt 并延长 TTL
|
||||||
|
cached.AccessAt = time.Now()
|
||||||
|
c.cache.Set(key, cached, c.settings.ParticipantTTL)
|
||||||
|
return cached.Data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 未命中,从 repo 加载
|
||||||
|
if c.repo == nil {
|
||||||
|
return nil, fmt.Errorf("repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
participant, err := c.repo.GetParticipant(convID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入缓存
|
||||||
|
now := time.Now()
|
||||||
|
cachedPart := &CachedParticipant{
|
||||||
|
Data: participant,
|
||||||
|
Version: 0,
|
||||||
|
UpdatedAt: now,
|
||||||
|
AccessAt: now,
|
||||||
|
}
|
||||||
|
c.cache.Set(key, cachedPart, c.settings.ParticipantTTL)
|
||||||
|
|
||||||
|
return participant, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedParticipantList 带元数据的参与者列表缓存
|
||||||
|
type CachedParticipantList struct {
|
||||||
|
Participants []*model.ConversationParticipant `json:"participants"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
AccessAt time.Time `json:"access_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipants 获取会话所有参与者(带 TTL 滑动延长)
|
||||||
|
func (c *ConversationCache) GetParticipants(ctx context.Context, convID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
key := ParticipantListKey(convID)
|
||||||
|
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
cached, ok := GetTyped[*CachedParticipantList](c.cache, key)
|
||||||
|
if ok && cached != nil {
|
||||||
|
// 2. 命中,更新 AccessAt 并延长 TTL
|
||||||
|
cached.AccessAt = time.Now()
|
||||||
|
c.cache.Set(key, cached, c.settings.ParticipantTTL)
|
||||||
|
return cached.Participants, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 未命中,从 repo 加载
|
||||||
|
if c.repo == nil {
|
||||||
|
return nil, fmt.Errorf("repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
participants, err := c.repo.GetParticipants(convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入缓存
|
||||||
|
now := time.Now()
|
||||||
|
cachedList := &CachedParticipantList{
|
||||||
|
Participants: participants,
|
||||||
|
UpdatedAt: now,
|
||||||
|
AccessAt: now,
|
||||||
|
}
|
||||||
|
c.cache.Set(key, cachedList, c.settings.ParticipantTTL)
|
||||||
|
|
||||||
|
return participants, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CachedUnreadCount 带元数据的未读数缓存
|
||||||
|
type CachedUnreadCount struct {
|
||||||
|
Count int64 `json:"count"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
AccessAt time.Time `json:"access_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 获取未读数(带 TTL 滑动延长)
|
||||||
|
func (c *ConversationCache) GetUnreadCount(ctx context.Context, userID, convID string) (int64, error) {
|
||||||
|
key := UnreadDetailKey(userID, convID)
|
||||||
|
|
||||||
|
// 1. 尝试从缓存获取
|
||||||
|
cached, ok := GetTyped[*CachedUnreadCount](c.cache, key)
|
||||||
|
if ok && cached != nil {
|
||||||
|
// 2. 命中,更新 AccessAt 并延长 TTL
|
||||||
|
cached.AccessAt = time.Now()
|
||||||
|
c.cache.Set(key, cached, c.settings.UnreadTTL)
|
||||||
|
return cached.Count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 未命中,从 repo 加载
|
||||||
|
if c.repo == nil {
|
||||||
|
return 0, fmt.Errorf("repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := c.repo.GetUnreadCount(userID, convID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 写入缓存
|
||||||
|
now := time.Now()
|
||||||
|
cachedCount := &CachedUnreadCount{
|
||||||
|
Count: count,
|
||||||
|
UpdatedAt: now,
|
||||||
|
AccessAt: now,
|
||||||
|
}
|
||||||
|
c.cache.Set(key, cachedCount, c.settings.UnreadTTL)
|
||||||
|
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// 缓存失效方法
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// InvalidateConversation 使会话缓存失效
|
||||||
|
func (c *ConversationCache) InvalidateConversation(convID string) {
|
||||||
|
c.cache.Delete(ConversationKey(convID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateConversationList 使会话列表缓存失效(按用户)
|
||||||
|
func (c *ConversationCache) InvalidateConversationList(userID string) {
|
||||||
|
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", PrefixConversationList, userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateParticipant 使参与者缓存失效
|
||||||
|
func (c *ConversationCache) InvalidateParticipant(convID, userID string) {
|
||||||
|
c.cache.Delete(ParticipantKey(convID, userID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateParticipantList 使参与者列表缓存失效
|
||||||
|
func (c *ConversationCache) InvalidateParticipantList(convID string) {
|
||||||
|
c.cache.Delete(ParticipantListKey(convID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUnreadCount 使未读数缓存失效
|
||||||
|
func (c *ConversationCache) InvalidateUnreadCount(userID, convID string) {
|
||||||
|
c.cache.Delete(UnreadDetailKey(userID, convID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// 消息缓存方法
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// GetMessages 获取会话消息(带缓存)
|
||||||
|
// 1. 尝试从分页缓存获取
|
||||||
|
// 2. 如果命中,从 Hash 中批量获取消息详情
|
||||||
|
// 3. 如果未命中,从数据库加载并写入缓存
|
||||||
|
func (c *ConversationCache) GetMessages(ctx context.Context, convID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
// 1. 尝试从缓存获取分页数据
|
||||||
|
pageKey := MessagePageKey(convID, page, pageSize)
|
||||||
|
cached, ok := GetTyped[*PageCache](c.cache, pageKey)
|
||||||
|
if ok && cached != nil {
|
||||||
|
// TTL 滑动延长
|
||||||
|
cached.UpdatedAt = time.Now()
|
||||||
|
c.cache.Set(pageKey, cached, c.settings.MessageListTTL)
|
||||||
|
|
||||||
|
// 从 Hash 中批量获取消息详情
|
||||||
|
if len(cached.Seqs) > 0 {
|
||||||
|
messages, err := c.getMessagesBySeqs(ctx, convID, cached.Seqs)
|
||||||
|
if err == nil {
|
||||||
|
return messages, cached.Total, nil
|
||||||
|
}
|
||||||
|
// Hash 获取失败,继续从数据库加载
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 缓存未命中,从数据库加载
|
||||||
|
if c.msgRepo == nil {
|
||||||
|
return nil, 0, fmt.Errorf("message repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, total, err := c.msgRepo.GetMessages(convID, page, pageSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 写入缓存
|
||||||
|
seqs := make([]int64, len(messages))
|
||||||
|
for i, msg := range messages {
|
||||||
|
seqs[i] = msg.Seq
|
||||||
|
// 异步写入消息详情到 Hash
|
||||||
|
go c.asyncCacheMessage(context.Background(), convID, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
pageCache := &PageCache{
|
||||||
|
Seqs: seqs,
|
||||||
|
Total: total,
|
||||||
|
Page: page,
|
||||||
|
PageSize: pageSize,
|
||||||
|
HasMore: int64(page*pageSize) < total,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
c.cache.Set(pageKey, pageCache, c.settings.MessageListTTL)
|
||||||
|
|
||||||
|
return messages, total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesAfterSeq 获取指定 seq 之后的消息(增量同步)
|
||||||
|
// 使用 Sorted Set 的 ZRangeByScore 实现
|
||||||
|
func (c *ConversationCache) GetMessagesAfterSeq(ctx context.Context, convID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
indexKey := MessageIndexKey(convID)
|
||||||
|
|
||||||
|
// 1. 尝试从 Sorted Set 获取 seq 列表
|
||||||
|
members, err := c.cache.ZRangeByScore(ctx, indexKey, fmt.Sprintf("%d", afterSeq+1), "+inf", 0, int64(limit))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
|
||||||
|
if len(members) > 0 {
|
||||||
|
seqs := make([]int64, 0, len(members))
|
||||||
|
for _, member := range members {
|
||||||
|
var seq int64
|
||||||
|
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
|
||||||
|
seqs = append(seqs, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.getMessagesBySeqs(ctx, convID, seqs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Sorted Set 未命中,从数据库加载
|
||||||
|
if c.msgRepo == nil {
|
||||||
|
return nil, fmt.Errorf("message repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := c.msgRepo.GetMessagesAfterSeq(convID, afterSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 异步写入缓存
|
||||||
|
for _, msg := range messages {
|
||||||
|
go c.asyncCacheMessage(context.Background(), convID, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesBeforeSeq 获取指定 seq 之前的历史消息(下拉加载)
|
||||||
|
// 使用 Sorted Set 的 ZRevRangeByScore 实现
|
||||||
|
func (c *ConversationCache) GetMessagesBeforeSeq(ctx context.Context, convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
indexKey := MessageIndexKey(convID)
|
||||||
|
|
||||||
|
// 1. 尝试从 Sorted Set 获取 seq 列表(降序)
|
||||||
|
members, err := c.cache.ZRevRangeByScore(ctx, indexKey, fmt.Sprintf("%d", beforeSeq-1), "-inf", 0, int64(limit))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 如果 Sorted Set 有数据,从 Hash 获取消息详情
|
||||||
|
if len(members) > 0 {
|
||||||
|
seqs := make([]int64, 0, len(members))
|
||||||
|
for _, member := range members {
|
||||||
|
var seq int64
|
||||||
|
if _, err := fmt.Sscanf(member, "%d", &seq); err == nil {
|
||||||
|
seqs = append(seqs, seq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.getMessagesBySeqs(ctx, convID, seqs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Sorted Set 未命中,从数据库加载
|
||||||
|
if c.msgRepo == nil {
|
||||||
|
return nil, fmt.Errorf("message repository not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := c.msgRepo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 异步写入缓存
|
||||||
|
for _, msg := range messages {
|
||||||
|
go c.asyncCacheMessage(context.Background(), convID, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheMessage 缓存单条消息(立即写入缓存)
|
||||||
|
// 写入 Hash、Sorted Set、更新计数
|
||||||
|
func (c *ConversationCache) CacheMessage(ctx context.Context, convID string, msg *model.Message) error {
|
||||||
|
hashKey := MessageHashKey(convID)
|
||||||
|
indexKey := MessageIndexKey(convID)
|
||||||
|
|
||||||
|
msgData := MessageCacheDataFromModel(msg)
|
||||||
|
data, err := json.Marshal(msgData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HSET 消息详情
|
||||||
|
if err := c.cache.HSet(ctx, hashKey, fmt.Sprintf("%d", msg.Seq), string(data)); err != nil {
|
||||||
|
return fmt.Errorf("failed to set hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZADD 消息索引
|
||||||
|
if err := c.cache.ZAdd(ctx, indexKey, float64(msg.Seq), fmt.Sprintf("%d", msg.Seq)); err != nil {
|
||||||
|
return fmt.Errorf("failed to add to sorted set: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置 TTL
|
||||||
|
c.cache.Expire(ctx, hashKey, c.settings.MessageDetailTTL)
|
||||||
|
c.cache.Expire(ctx, indexKey, c.settings.MessageIndexTTL)
|
||||||
|
|
||||||
|
// INCR 消息计数
|
||||||
|
c.cache.Incr(ctx, MessageCountKey(convID))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateMessageCache 使消息缓存失效
|
||||||
|
func (c *ConversationCache) InvalidateMessageCache(convID string) {
|
||||||
|
c.cache.Delete(MessageHashKey(convID))
|
||||||
|
c.cache.Delete(MessageIndexKey(convID))
|
||||||
|
c.cache.Delete(MessageCountKey(convID))
|
||||||
|
// 删除所有分页缓存
|
||||||
|
c.InvalidateMessagePages(convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateMessagePages 仅使消息分页缓存失效
|
||||||
|
// 新消息写入后会导致分页内容和总数变化,需要清理该会话所有分页缓存。
|
||||||
|
func (c *ConversationCache) InvalidateMessagePages(convID string) {
|
||||||
|
c.cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, convID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// 内部辅助方法
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// getMessagesBySeqs 从 Hash 中批量获取消息
|
||||||
|
func (c *ConversationCache) getMessagesBySeqs(ctx context.Context, convID string, seqs []int64) ([]*model.Message, error) {
|
||||||
|
if len(seqs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hashKey := MessageHashKey(convID)
|
||||||
|
fields := make([]string, len(seqs))
|
||||||
|
for i, seq := range seqs {
|
||||||
|
fields[i] = fmt.Sprintf("%d", seq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量获取
|
||||||
|
values, err := c.cache.HMGet(ctx, hashKey, fields...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := make([]*model.Message, 0, len(seqs))
|
||||||
|
for _, val := range values {
|
||||||
|
if val == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var msgData MessageCacheData
|
||||||
|
switch v := val.(type) {
|
||||||
|
case string:
|
||||||
|
if err := json.Unmarshal([]byte(v), &msgData); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
if err := json.Unmarshal(v, &msgData); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messages = append(messages, msgData.ToModel())
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncCacheMessage 异步缓存单条消息
|
||||||
|
func (c *ConversationCache) asyncCacheMessage(ctx context.Context, convID string, msg *model.Message) {
|
||||||
|
if err := c.CacheMessage(ctx, convID, msg); err != nil {
|
||||||
|
log.Printf("[ConversationCache] async cache message failed, convID=%s, msgID=%s, err=%v", convID, msg.ID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
41
internal/cache/keys.go
vendored
41
internal/cache/keys.go
vendored
@@ -26,6 +26,13 @@ const (
|
|||||||
// 用户相关
|
// 用户相关
|
||||||
PrefixUserInfo = "users:info"
|
PrefixUserInfo = "users:info"
|
||||||
PrefixUserMe = "users:me"
|
PrefixUserMe = "users:me"
|
||||||
|
|
||||||
|
// 消息缓存相关
|
||||||
|
keyPrefixMsgHash = "msg_hash" // 消息详情 Hash
|
||||||
|
keyPrefixMsgIndex = "msg_index" // 消息索引 Sorted Set
|
||||||
|
keyPrefixMsgCount = "msg_count" // 消息计数
|
||||||
|
keyPrefixMsgSeq = "msg_seq" // Seq 计数器
|
||||||
|
keyPrefixMsgPage = "msg_page" // 分页缓存
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostListKey 生成帖子列表缓存键
|
// PostListKey 生成帖子列表缓存键
|
||||||
@@ -145,3 +152,37 @@ func InvalidateUserInfo(cache Cache, userID string) {
|
|||||||
cache.Delete(UserInfoKey(userID))
|
cache.Delete(UserInfoKey(userID))
|
||||||
cache.Delete(UserMeKey(userID))
|
cache.Delete(UserMeKey(userID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// 消息缓存 Key 生成函数
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// MessageHashKey 消息详情 Hash key
|
||||||
|
func MessageHashKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixMsgHash, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageIndexKey 消息索引 Sorted Set key
|
||||||
|
func MessageIndexKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixMsgIndex, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageCountKey 消息计数 key
|
||||||
|
func MessageCountKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixMsgCount, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageSeqKey Seq 计数器 key
|
||||||
|
func MessageSeqKey(convID string) string {
|
||||||
|
return fmt.Sprintf("%s:%s", keyPrefixMsgSeq, convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessagePageKey 分页缓存 key
|
||||||
|
func MessagePageKey(convID string, page, pageSize int) string {
|
||||||
|
return fmt.Sprintf("%s:%s:%d:%d", keyPrefixMsgPage, convID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateMessagePages 失效会话消息分页缓存
|
||||||
|
func InvalidateMessagePages(cache Cache, conversationID string) {
|
||||||
|
cache.DeleteByPrefix(fmt.Sprintf("%s:%s:", keyPrefixMsgPage, conversationID))
|
||||||
|
}
|
||||||
|
|||||||
76
internal/cache/repository_adapter.go
vendored
Normal file
76
internal/cache/repository_adapter.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConversationRepositoryAdapter 适配 MessageRepository 到 ConversationRepository 接口
|
||||||
|
type ConversationRepositoryAdapter struct {
|
||||||
|
repo *repository.MessageRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConversationRepositoryAdapter 创建适配器
|
||||||
|
func NewConversationRepositoryAdapter(repo *repository.MessageRepository) ConversationRepository {
|
||||||
|
return &ConversationRepositoryAdapter{repo: repo}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationByID 实现 ConversationRepository 接口
|
||||||
|
func (a *ConversationRepositoryAdapter) GetConversationByID(convID string) (*model.Conversation, error) {
|
||||||
|
return a.repo.GetConversation(convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConversationsByUserID 实现 ConversationRepository 接口
|
||||||
|
func (a *ConversationRepositoryAdapter) GetConversationsByUserID(userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
return a.repo.GetConversations(userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipant 实现 ConversationRepository 接口
|
||||||
|
func (a *ConversationRepositoryAdapter) GetParticipant(convID, userID string) (*model.ConversationParticipant, error) {
|
||||||
|
return a.repo.GetParticipant(convID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipants 实现 ConversationRepository 接口
|
||||||
|
func (a *ConversationRepositoryAdapter) GetParticipants(convID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
return a.repo.GetConversationParticipants(convID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUnreadCount 实现 ConversationRepository 接口
|
||||||
|
func (a *ConversationRepositoryAdapter) GetUnreadCount(userID, convID string) (int64, error) {
|
||||||
|
return a.repo.GetUnreadCount(convID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageRepositoryAdapter 适配 MessageRepository 到 MessageRepository 接口
|
||||||
|
type MessageRepositoryAdapter struct {
|
||||||
|
repo *repository.MessageRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMessageRepositoryAdapter 创建适配器
|
||||||
|
func NewMessageRepositoryAdapter(repo *repository.MessageRepository) MessageRepository {
|
||||||
|
return &MessageRepositoryAdapter{repo: repo}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessages 实现 MessageRepository 接口
|
||||||
|
func (a *MessageRepositoryAdapter) GetMessages(convID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
return a.repo.GetMessages(convID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesAfterSeq 实现 MessageRepository 接口
|
||||||
|
func (a *MessageRepositoryAdapter) GetMessagesAfterSeq(convID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
return a.repo.GetMessagesAfterSeq(convID, afterSeq, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMessagesBeforeSeq 实现 MessageRepository 接口
|
||||||
|
func (a *MessageRepositoryAdapter) GetMessagesBeforeSeq(convID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||||
|
return a.repo.GetMessagesBeforeSeq(convID, beforeSeq, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateMessage 实现 MessageRepository 接口
|
||||||
|
func (a *MessageRepositoryAdapter) CreateMessage(msg *model.Message) error {
|
||||||
|
return a.repo.CreateMessage(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConversationLastSeq 实现 MessageRepository 接口
|
||||||
|
func (a *MessageRepositoryAdapter) UpdateConversationLastSeq(convID string, seq int64) error {
|
||||||
|
return a.repo.UpdateConversationLastSeq(convID, seq)
|
||||||
|
}
|
||||||
@@ -15,18 +15,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Server ServerConfig `mapstructure:"server"`
|
Server ServerConfig `mapstructure:"server"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
Redis RedisConfig `mapstructure:"redis"`
|
Redis RedisConfig `mapstructure:"redis"`
|
||||||
Cache CacheConfig `mapstructure:"cache"`
|
Cache CacheConfig `mapstructure:"cache"`
|
||||||
S3 S3Config `mapstructure:"s3"`
|
S3 S3Config `mapstructure:"s3"`
|
||||||
JWT JWTConfig `mapstructure:"jwt"`
|
JWT JWTConfig `mapstructure:"jwt"`
|
||||||
Log LogConfig `mapstructure:"log"`
|
Log LogConfig `mapstructure:"log"`
|
||||||
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
RateLimit RateLimitConfig `mapstructure:"rate_limit"`
|
||||||
Upload UploadConfig `mapstructure:"upload"`
|
Upload UploadConfig `mapstructure:"upload"`
|
||||||
Gorse GorseConfig `mapstructure:"gorse"`
|
Gorse GorseConfig `mapstructure:"gorse"`
|
||||||
OpenAI OpenAIConfig `mapstructure:"openai"`
|
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||||
Email EmailConfig `mapstructure:"email"`
|
Email EmailConfig `mapstructure:"email"`
|
||||||
|
ConversationCache ConversationCacheConfig `mapstructure:"conversation_cache"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
@@ -173,6 +174,73 @@ type EmailConfig struct {
|
|||||||
Timeout int `mapstructure:"timeout"`
|
Timeout int `mapstructure:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConversationCacheConfig 会话缓存配置
|
||||||
|
type ConversationCacheConfig struct {
|
||||||
|
// TTL 配置
|
||||||
|
DetailTTL string `mapstructure:"detail_ttl"`
|
||||||
|
ListTTL string `mapstructure:"list_ttl"`
|
||||||
|
ParticipantTTL string `mapstructure:"participant_ttl"`
|
||||||
|
UnreadTTL string `mapstructure:"unread_ttl"`
|
||||||
|
|
||||||
|
// 消息缓存配置
|
||||||
|
MessageDetailTTL string `mapstructure:"message_detail_ttl"`
|
||||||
|
MessageListTTL string `mapstructure:"message_list_ttl"`
|
||||||
|
MessageIndexTTL string `mapstructure:"message_index_ttl"`
|
||||||
|
MessageCountTTL string `mapstructure:"message_count_ttl"`
|
||||||
|
|
||||||
|
// 批量写入配置
|
||||||
|
BatchInterval string `mapstructure:"batch_interval"`
|
||||||
|
BatchThreshold int `mapstructure:"batch_threshold"`
|
||||||
|
BatchMaxSize int `mapstructure:"batch_max_size"`
|
||||||
|
BufferMaxSize int `mapstructure:"buffer_max_size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationCacheSettings 会话缓存运行时配置(用于传递给 cache 包)
|
||||||
|
type ConversationCacheSettings struct {
|
||||||
|
DetailTTL time.Duration
|
||||||
|
ListTTL time.Duration
|
||||||
|
ParticipantTTL time.Duration
|
||||||
|
UnreadTTL time.Duration
|
||||||
|
MessageDetailTTL time.Duration
|
||||||
|
MessageListTTL time.Duration
|
||||||
|
MessageIndexTTL time.Duration
|
||||||
|
MessageCountTTL time.Duration
|
||||||
|
BatchInterval time.Duration
|
||||||
|
BatchThreshold int
|
||||||
|
BatchMaxSize int
|
||||||
|
BufferMaxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToSettings 将 ConversationCacheConfig 转换为 ConversationCacheSettings
|
||||||
|
func (c *ConversationCacheConfig) ToSettings() *ConversationCacheSettings {
|
||||||
|
return &ConversationCacheSettings{
|
||||||
|
DetailTTL: parseDuration(c.DetailTTL, 5*time.Minute),
|
||||||
|
ListTTL: parseDuration(c.ListTTL, 60*time.Second),
|
||||||
|
ParticipantTTL: parseDuration(c.ParticipantTTL, 5*time.Minute),
|
||||||
|
UnreadTTL: parseDuration(c.UnreadTTL, 30*time.Second),
|
||||||
|
MessageDetailTTL: parseDuration(c.MessageDetailTTL, 30*time.Minute),
|
||||||
|
MessageListTTL: parseDuration(c.MessageListTTL, 5*time.Minute),
|
||||||
|
MessageIndexTTL: parseDuration(c.MessageIndexTTL, 30*time.Minute),
|
||||||
|
MessageCountTTL: parseDuration(c.MessageCountTTL, 30*time.Minute),
|
||||||
|
BatchInterval: parseDuration(c.BatchInterval, 5*time.Second),
|
||||||
|
BatchThreshold: c.BatchThreshold,
|
||||||
|
BatchMaxSize: c.BatchMaxSize,
|
||||||
|
BufferMaxSize: c.BufferMaxSize,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDuration 解析持续时间字符串,如果解析失败则返回默认值
|
||||||
|
func parseDuration(s string, defaultVal time.Duration) time.Duration {
|
||||||
|
if s == "" {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
d, err := time.ParseDuration(s)
|
||||||
|
if err != nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
func Load(configPath string) (*Config, error) {
|
func Load(configPath string) (*Config, error) {
|
||||||
viper.SetConfigFile(configPath)
|
viper.SetConfigFile(configPath)
|
||||||
viper.SetConfigType("yaml")
|
viper.SetConfigType("yaml")
|
||||||
@@ -259,6 +327,19 @@ func Load(configPath string) (*Config, error) {
|
|||||||
viper.SetDefault("email.use_tls", true)
|
viper.SetDefault("email.use_tls", true)
|
||||||
viper.SetDefault("email.insecure_skip_verify", false)
|
viper.SetDefault("email.insecure_skip_verify", false)
|
||||||
viper.SetDefault("email.timeout", 15)
|
viper.SetDefault("email.timeout", 15)
|
||||||
|
// ConversationCache 默认值
|
||||||
|
viper.SetDefault("conversation_cache.detail_ttl", "5m")
|
||||||
|
viper.SetDefault("conversation_cache.list_ttl", "60s")
|
||||||
|
viper.SetDefault("conversation_cache.participant_ttl", "5m")
|
||||||
|
viper.SetDefault("conversation_cache.unread_ttl", "30s")
|
||||||
|
viper.SetDefault("conversation_cache.message_detail_ttl", "30m")
|
||||||
|
viper.SetDefault("conversation_cache.message_list_ttl", "5m")
|
||||||
|
viper.SetDefault("conversation_cache.message_index_ttl", "30m")
|
||||||
|
viper.SetDefault("conversation_cache.message_count_ttl", "30m")
|
||||||
|
viper.SetDefault("conversation_cache.batch_interval", "5s")
|
||||||
|
viper.SetDefault("conversation_cache.batch_threshold", 100)
|
||||||
|
viper.SetDefault("conversation_cache.batch_max_size", 500)
|
||||||
|
viper.SetDefault("conversation_cache.buffer_max_size", 10000)
|
||||||
|
|
||||||
if err := viper.ReadInConfig(); err != nil {
|
if err := viper.ReadInConfig(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||||
|
|||||||
35
internal/dto/schedule_converter.go
Normal file
35
internal/dto/schedule_converter.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConvertScheduleCourseToResponse(course *model.ScheduleCourse, weeks []int) *ScheduleCourseResponse {
|
||||||
|
if course == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ScheduleCourseResponse{
|
||||||
|
ID: course.ID,
|
||||||
|
Name: course.Name,
|
||||||
|
Teacher: course.Teacher,
|
||||||
|
Location: course.Location,
|
||||||
|
DayOfWeek: course.DayOfWeek,
|
||||||
|
StartSection: course.StartSection,
|
||||||
|
EndSection: course.EndSection,
|
||||||
|
Weeks: weeks,
|
||||||
|
Color: course.Color,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseWeeksJSON(raw string) []int {
|
||||||
|
if raw == "" {
|
||||||
|
return []int{}
|
||||||
|
}
|
||||||
|
var weeks []int
|
||||||
|
if err := json.Unmarshal([]byte(raw), &weeks); err != nil {
|
||||||
|
return []int{}
|
||||||
|
}
|
||||||
|
return weeks
|
||||||
|
}
|
||||||
13
internal/dto/schedule_dto.go
Normal file
13
internal/dto/schedule_dto.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
type ScheduleCourseResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Teacher string `json:"teacher,omitempty"`
|
||||||
|
Location string `json:"location,omitempty"`
|
||||||
|
DayOfWeek int `json:"day_of_week"`
|
||||||
|
StartSection int `json:"start_section"`
|
||||||
|
EndSection int `json:"end_section"`
|
||||||
|
Weeks []int `json:"weeks"`
|
||||||
|
Color string `json:"color,omitempty"`
|
||||||
|
}
|
||||||
@@ -38,12 +38,12 @@ func parseGroupID(c *gin.Context) string {
|
|||||||
|
|
||||||
// parseUserIDFromPath 从路径参数获取用户ID(UUID格式)
|
// parseUserIDFromPath 从路径参数获取用户ID(UUID格式)
|
||||||
func parseUserIDFromPath(c *gin.Context) string {
|
func parseUserIDFromPath(c *gin.Context) string {
|
||||||
return c.Param("userId")
|
return c.Param("user_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseAnnouncementID 从路径参数获取公告ID
|
// parseAnnouncementID 从路径参数获取公告ID
|
||||||
func parseAnnouncementID(c *gin.Context) string {
|
func parseAnnouncementID(c *gin.Context) string {
|
||||||
return c.Param("announcementId")
|
return c.Param("announcement_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== 群组管理 ====================
|
// ==================== 群组管理 ====================
|
||||||
@@ -454,7 +454,7 @@ func (h *GroupHandler) GetMembers(c *gin.Context) {
|
|||||||
// ==================== RESTful Action 端点 ====================
|
// ==================== RESTful Action 端点 ====================
|
||||||
|
|
||||||
// HandleCreateGroup 创建群组
|
// HandleCreateGroup 创建群组
|
||||||
// POST /api/v1/groups/create
|
// POST /api/v1/groups
|
||||||
func (h *GroupHandler) HandleCreateGroup(c *gin.Context) {
|
func (h *GroupHandler) HandleCreateGroup(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -478,7 +478,7 @@ func (h *GroupHandler) HandleCreateGroup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetUserGroups 获取用户群组列表
|
// HandleGetUserGroups 获取用户群组列表
|
||||||
// GET /api/v1/groups/list
|
// GET /api/v1/groups
|
||||||
func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) {
|
func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -499,7 +499,6 @@ func (h *GroupHandler) HandleGetUserGroups(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetMyMemberInfo 获取我在群组中的成员信息
|
// HandleGetMyMemberInfo 获取我在群组中的成员信息
|
||||||
// GET /api/v1/groups/get_my_info?group_id=xxx
|
|
||||||
// GET /api/v1/groups/:id/me
|
// GET /api/v1/groups/:id/me
|
||||||
func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) {
|
func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
@@ -551,7 +550,7 @@ func (h *GroupHandler) HandleGetMyMemberInfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleDissolveGroup 解散群组
|
// HandleDissolveGroup 解散群组
|
||||||
// POST /api/v1/groups/dissolve
|
// DELETE /api/v1/groups/:id
|
||||||
func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
|
func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -559,18 +558,13 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.DissolveGroupParams
|
groupID := parseGroupID(c)
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if groupID == "" {
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
response.BadRequest(c, "group_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.groupService.DissolveGroup(userID, params.GroupID); err != nil {
|
if err := h.groupService.DissolveGroup(userID, groupID); err != nil {
|
||||||
if err == service.ErrNotGroupOwner {
|
if err == service.ErrNotGroupOwner {
|
||||||
response.Forbidden(c, "只有群主可以解散群组")
|
response.Forbidden(c, "只有群主可以解散群组")
|
||||||
return
|
return
|
||||||
@@ -587,7 +581,7 @@ func (h *GroupHandler) HandleDissolveGroup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleTransferOwner 转让群主
|
// HandleTransferOwner 转让群主
|
||||||
// POST /api/v1/groups/transfer
|
// POST /api/v1/groups/:id/transfer
|
||||||
func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
|
func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -595,22 +589,24 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.TransferOwnerParams
|
var params dto.TransferOwnerParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.NewOwnerID == "" {
|
if params.NewOwnerID == "" {
|
||||||
response.BadRequest(c, "new_owner_id is required")
|
response.BadRequest(c, "new_owner_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.groupService.TransferOwner(userID, params.GroupID, params.NewOwnerID); err != nil {
|
if err := h.groupService.TransferOwner(userID, groupID, params.NewOwnerID); err != nil {
|
||||||
if err == service.ErrNotGroupOwner {
|
if err == service.ErrNotGroupOwner {
|
||||||
response.Forbidden(c, "只有群主可以转让群主")
|
response.Forbidden(c, "只有群主可以转让群主")
|
||||||
return
|
return
|
||||||
@@ -631,7 +627,7 @@ func (h *GroupHandler) HandleTransferOwner(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleInviteMembers 邀请成员加入群组
|
// HandleInviteMembers 邀请成员加入群组
|
||||||
// POST /api/v1/groups/invite_members
|
// POST /api/v1/groups/:id/invitations
|
||||||
func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
|
func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -639,18 +635,19 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.InviteMembersParams
|
var params dto.InviteMembersParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
if err := h.groupService.InviteMembers(userID, groupID, params.MemberIDs); err != nil {
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.groupService.InviteMembers(userID, params.GroupID, params.MemberIDs); err != nil {
|
|
||||||
if err == service.ErrNotGroupMember {
|
if err == service.ErrNotGroupMember {
|
||||||
response.Forbidden(c, "只有群成员可以邀请他人")
|
response.Forbidden(c, "只有群成员可以邀请他人")
|
||||||
return
|
return
|
||||||
@@ -675,7 +672,7 @@ func (h *GroupHandler) HandleInviteMembers(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleJoinGroup 加入群组
|
// HandleJoinGroup 加入群组
|
||||||
// POST /api/v1/groups/join
|
// POST /api/v1/groups/:id/join-requests
|
||||||
func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
|
func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -683,18 +680,13 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.JoinGroupParams
|
groupID := parseGroupID(c)
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if groupID == "" {
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
response.BadRequest(c, "group_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.groupService.JoinGroup(userID, params.GroupID); err != nil {
|
if err := h.groupService.JoinGroup(userID, groupID); err != nil {
|
||||||
if err == service.ErrJoinRequestPending {
|
if err == service.ErrJoinRequestPending {
|
||||||
response.SuccessWithMessage(c, "申请已提交,等待群主/管理员审批", nil)
|
response.SuccessWithMessage(c, "申请已提交,等待群主/管理员审批", nil)
|
||||||
return
|
return
|
||||||
@@ -723,7 +715,7 @@ func (h *GroupHandler) HandleJoinGroup(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetNickname 设置群内昵称
|
// HandleSetNickname 设置群内昵称
|
||||||
// POST /api/v1/groups/set_nickname
|
// PUT /api/v1/groups/:id/members/me/nickname
|
||||||
func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
|
func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -731,18 +723,19 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetNicknameParams
|
var params dto.SetNicknameParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
if err := h.groupService.SetMemberNickname(userID, groupID, params.Nickname); err != nil {
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.groupService.SetMemberNickname(userID, params.GroupID, params.Nickname); err != nil {
|
|
||||||
if err == service.ErrNotGroupMember {
|
if err == service.ErrNotGroupMember {
|
||||||
response.BadRequest(c, "不是群成员")
|
response.BadRequest(c, "不是群成员")
|
||||||
return
|
return
|
||||||
@@ -759,7 +752,7 @@ func (h *GroupHandler) HandleSetNickname(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetJoinType 设置加群方式
|
// HandleSetJoinType 设置加群方式
|
||||||
// POST /api/v1/groups/set_join_type
|
// PUT /api/v1/groups/:id/join-type
|
||||||
func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
|
func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -767,18 +760,19 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetJoinTypeParams
|
var params dto.SetJoinTypeParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
if err := h.groupService.SetJoinType(userID, groupID, params.JoinType); err != nil {
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.groupService.SetJoinType(userID, params.GroupID, params.JoinType); err != nil {
|
|
||||||
if err == service.ErrNotGroupOwner {
|
if err == service.ErrNotGroupOwner {
|
||||||
response.Forbidden(c, "只有群主可以设置加群方式")
|
response.Forbidden(c, "只有群主可以设置加群方式")
|
||||||
return
|
return
|
||||||
@@ -803,7 +797,7 @@ func (h *GroupHandler) HandleSetJoinType(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleCreateAnnouncement 创建群公告
|
// HandleCreateAnnouncement 创建群公告
|
||||||
// POST /api/v1/groups/create_announcement
|
// POST /api/v1/groups/:id/announcements
|
||||||
func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
|
func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -811,18 +805,19 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.CreateAnnouncementParams
|
var params dto.CreateAnnouncementParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
announcement, err := h.groupService.CreateAnnouncement(userID, groupID, params.Content)
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
announcement, err := h.groupService.CreateAnnouncement(userID, params.GroupID, params.Content)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupAdmin {
|
if err == service.ErrNotGroupAdmin {
|
||||||
response.Forbidden(c, "只有群主或管理员可以发布公告")
|
response.Forbidden(c, "只有群主或管理员可以发布公告")
|
||||||
@@ -840,7 +835,6 @@ func (h *GroupHandler) HandleCreateAnnouncement(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetAnnouncements 获取群公告列表
|
// HandleGetAnnouncements 获取群公告列表
|
||||||
// GET /api/v1/groups/get_announcements?group_id=xxx
|
|
||||||
// GET /api/v1/groups/:id/announcements
|
// GET /api/v1/groups/:id/announcements
|
||||||
func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) {
|
func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
@@ -872,7 +866,7 @@ func (h *GroupHandler) HandleGetAnnouncements(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleDeleteAnnouncement 删除群公告
|
// HandleDeleteAnnouncement 删除群公告
|
||||||
// POST /api/v1/groups/delete_announcement
|
// DELETE /api/v1/groups/:id/announcements/:announcement_id
|
||||||
func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) {
|
func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -880,22 +874,18 @@ func (h *GroupHandler) HandleDeleteAnnouncement(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.DeleteAnnouncementParams
|
groupID := parseGroupID(c)
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if groupID == "" {
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
response.BadRequest(c, "group_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if params.AnnouncementID == "" {
|
announcementID := parseAnnouncementID(c)
|
||||||
|
if announcementID == "" {
|
||||||
response.BadRequest(c, "announcement_id is required")
|
response.BadRequest(c, "announcement_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.groupService.DeleteAnnouncement(userID, params.AnnouncementID); err != nil {
|
if err := h.groupService.DeleteAnnouncement(userID, announcementID); err != nil {
|
||||||
if err == service.ErrNotGroupAdmin {
|
if err == service.ErrNotGroupAdmin {
|
||||||
response.Forbidden(c, "只有群主或管理员可以删除公告")
|
response.Forbidden(c, "只有群主或管理员可以删除公告")
|
||||||
return
|
return
|
||||||
@@ -1292,7 +1282,7 @@ func (h *GroupHandler) DeleteAnnouncement(c *gin.Context) {
|
|||||||
// ==================== RESTful Action 端点 ====================
|
// ==================== RESTful Action 端点 ====================
|
||||||
|
|
||||||
// HandleSetGroupKick 群组踢人
|
// HandleSetGroupKick 群组踢人
|
||||||
// POST /api/v1/groups/set_group_kick
|
// POST /api/v1/groups/:id/members/kick
|
||||||
func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1300,23 +1290,25 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupKickParams
|
var params dto.SetGroupKickParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.UserID == "" {
|
if params.UserID == "" {
|
||||||
response.BadRequest(c, "user_id is required")
|
response.BadRequest(c, "user_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 RemoveMember 方法
|
// 使用 RemoveMember 方法
|
||||||
err := h.groupService.RemoveMember(userID, params.GroupID, params.UserID)
|
err := h.groupService.RemoveMember(userID, groupID, params.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupAdmin {
|
if err == service.ErrNotGroupAdmin {
|
||||||
response.Forbidden(c, "只有群主或管理员可以移除成员")
|
response.Forbidden(c, "只有群主或管理员可以移除成员")
|
||||||
@@ -1342,7 +1334,7 @@ func (h *GroupHandler) HandleSetGroupKick(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupBan 群组单人禁言
|
// HandleSetGroupBan 群组单人禁言
|
||||||
// POST /api/v1/groups/set_group_ban
|
// POST /api/v1/groups/:id/members/ban
|
||||||
func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1350,16 +1342,18 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupBanParams
|
var params dto.SetGroupBanParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.UserID == "" {
|
if params.UserID == "" {
|
||||||
response.BadRequest(c, "user_id is required")
|
response.BadRequest(c, "user_id is required")
|
||||||
return
|
return
|
||||||
@@ -1367,8 +1361,8 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
|
|||||||
|
|
||||||
// duration > 0 或 duration = -1 表示禁言,duration = 0 表示解除禁言
|
// duration > 0 或 duration = -1 表示禁言,duration = 0 表示解除禁言
|
||||||
muted := params.Duration != 0
|
muted := params.Duration != 0
|
||||||
log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, params.GroupID, params.UserID, params.Duration, muted)
|
log.Printf("[HandleSetGroupBan] 开始禁言操作: userID=%s, groupID=%s, targetUserID=%s, duration=%d, muted=%v", userID, groupID, params.UserID, params.Duration, muted)
|
||||||
err := h.groupService.MuteMember(userID, params.GroupID, params.UserID, muted)
|
err := h.groupService.MuteMember(userID, groupID, params.UserID, muted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[HandleSetGroupBan] 禁言操作失败: %v", err)
|
log.Printf("[HandleSetGroupBan] 禁言操作失败: %v", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -1403,7 +1397,7 @@ func (h *GroupHandler) HandleSetGroupBan(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupWholeBan 群组全员禁言
|
// HandleSetGroupWholeBan 群组全员禁言
|
||||||
// POST /api/v1/groups/set_group_whole_ban
|
// PUT /api/v1/groups/:id/ban
|
||||||
func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1411,18 +1405,19 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupWholeBanParams
|
var params dto.SetGroupWholeBanParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
err := h.groupService.SetMuteAll(userID, groupID, params.Enable)
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := h.groupService.SetMuteAll(userID, params.GroupID, params.Enable)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupOwner {
|
if err == service.ErrNotGroupOwner {
|
||||||
response.Forbidden(c, "只有群主可以设置全员禁言")
|
response.Forbidden(c, "只有群主可以设置全员禁言")
|
||||||
@@ -1444,7 +1439,7 @@ func (h *GroupHandler) HandleSetGroupWholeBan(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupAdmin 群组设置管理员
|
// HandleSetGroupAdmin 群组设置管理员
|
||||||
// POST /api/v1/groups/set_group_admin
|
// PUT /api/v1/groups/:id/members/:user_id/admin
|
||||||
func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1452,28 +1447,30 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
targetUserID := parseUserIDFromPath(c)
|
||||||
|
if targetUserID == "" {
|
||||||
|
response.BadRequest(c, "user_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupAdminParams
|
var params dto.SetGroupAdminParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.UserID == "" {
|
|
||||||
response.BadRequest(c, "user_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据 enable 参数设置角色
|
// 根据 enable 参数设置角色
|
||||||
role := model.GroupRoleMember
|
role := model.GroupRoleMember
|
||||||
if params.Enable {
|
if params.Enable {
|
||||||
role = model.GroupRoleAdmin
|
role = model.GroupRoleAdmin
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.groupService.SetMemberRole(userID, params.GroupID, params.UserID, role)
|
err := h.groupService.SetMemberRole(userID, groupID, targetUserID, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupOwner {
|
if err == service.ErrNotGroupOwner {
|
||||||
response.Forbidden(c, "只有群主可以设置管理员")
|
response.Forbidden(c, "只有群主可以设置管理员")
|
||||||
@@ -1499,7 +1496,7 @@ func (h *GroupHandler) HandleSetGroupAdmin(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupName 设置群名
|
// HandleSetGroupName 设置群名
|
||||||
// POST /api/v1/groups/set_group_name
|
// PUT /api/v1/groups/:id/name
|
||||||
func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1507,16 +1504,18 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupNameParams
|
var params dto.SetGroupNameParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.GroupName == "" {
|
if params.GroupName == "" {
|
||||||
response.BadRequest(c, "group_name is required")
|
response.BadRequest(c, "group_name is required")
|
||||||
return
|
return
|
||||||
@@ -1526,7 +1525,7 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
|
|||||||
"name": params.GroupName,
|
"name": params.GroupName,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.groupService.UpdateGroup(userID, params.GroupID, updates)
|
err := h.groupService.UpdateGroup(userID, groupID, updates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupAdmin {
|
if err == service.ErrNotGroupAdmin {
|
||||||
response.Forbidden(c, "没有权限修改群组信息")
|
response.Forbidden(c, "没有权限修改群组信息")
|
||||||
@@ -1541,12 +1540,12 @@ func (h *GroupHandler) HandleSetGroupName(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取更新后的群组信息
|
// 获取更新后的群组信息
|
||||||
group, _ := h.groupService.GetGroupByID(params.GroupID)
|
group, _ := h.groupService.GetGroupByID(groupID)
|
||||||
response.Success(c, dto.GroupToResponse(group))
|
response.Success(c, dto.GroupToResponse(group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupAvatar 设置群头像
|
// HandleSetGroupAvatar 设置群头像
|
||||||
// POST /api/v1/groups/set_group_avatar
|
// PUT /api/v1/groups/:id/avatar
|
||||||
func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1554,16 +1553,18 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupID := parseGroupID(c)
|
||||||
|
if groupID == "" {
|
||||||
|
response.BadRequest(c, "group_id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SetGroupAvatarParams
|
var params dto.SetGroupAvatarParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.Avatar == "" {
|
if params.Avatar == "" {
|
||||||
response.BadRequest(c, "avatar is required")
|
response.BadRequest(c, "avatar is required")
|
||||||
return
|
return
|
||||||
@@ -1573,7 +1574,7 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
|
|||||||
"avatar": params.Avatar,
|
"avatar": params.Avatar,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.groupService.UpdateGroup(userID, params.GroupID, updates)
|
err := h.groupService.UpdateGroup(userID, groupID, updates)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupAdmin {
|
if err == service.ErrNotGroupAdmin {
|
||||||
response.Forbidden(c, "没有权限修改群组信息")
|
response.Forbidden(c, "没有权限修改群组信息")
|
||||||
@@ -1588,12 +1589,12 @@ func (h *GroupHandler) HandleSetGroupAvatar(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取更新后的群组信息
|
// 获取更新后的群组信息
|
||||||
group, _ := h.groupService.GetGroupByID(params.GroupID)
|
group, _ := h.groupService.GetGroupByID(groupID)
|
||||||
response.Success(c, dto.GroupToResponse(group))
|
response.Success(c, dto.GroupToResponse(group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupLeave 退出群组
|
// HandleSetGroupLeave 退出群组
|
||||||
// POST /api/v1/groups/set_group_leave
|
// POST /api/v1/groups/:id/leave
|
||||||
func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1601,18 +1602,13 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.SetGroupLeaveParams
|
groupID := parseGroupID(c)
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if groupID == "" {
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.GroupID == "" {
|
|
||||||
response.BadRequest(c, "group_id is required")
|
response.BadRequest(c, "group_id is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := h.groupService.LeaveGroup(userID, params.GroupID)
|
err := h.groupService.LeaveGroup(userID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == service.ErrNotGroupMember {
|
if err == service.ErrNotGroupMember {
|
||||||
response.BadRequest(c, "不是群成员")
|
response.BadRequest(c, "不是群成员")
|
||||||
@@ -1630,7 +1626,7 @@ func (h *GroupHandler) HandleSetGroupLeave(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetGroupAddRequest 处理加群审批
|
// HandleSetGroupAddRequest 处理加群审批
|
||||||
// POST /api/v1/groups/set_group_add_request
|
// POST /api/v1/groups/:id/join-requests/handle
|
||||||
func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) {
|
func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1678,7 +1674,7 @@ func (h *GroupHandler) HandleSetGroupAddRequest(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleRespondInvite 处理群邀请响应
|
// HandleRespondInvite 处理群邀请响应
|
||||||
// POST /api/v1/groups/respond_invite
|
// POST /api/v1/groups/:id/join-requests/respond
|
||||||
func (h *GroupHandler) HandleRespondInvite(c *gin.Context) {
|
func (h *GroupHandler) HandleRespondInvite(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -1725,7 +1721,6 @@ func (h *GroupHandler) HandleRespondInvite(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetGroupInfo 获取群信息
|
// HandleGetGroupInfo 获取群信息
|
||||||
// GET /api/v1/groups/get?group_id=xxx
|
|
||||||
// GET /api/v1/groups/:id
|
// GET /api/v1/groups/:id
|
||||||
func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) {
|
func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
@@ -1761,7 +1756,6 @@ func (h *GroupHandler) HandleGetGroupInfo(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetGroupMemberList 获取群成员列表
|
// HandleGetGroupMemberList 获取群成员列表
|
||||||
// GET /api/v1/groups/get_members?group_id=xxx
|
|
||||||
// GET /api/v1/groups/:id/members
|
// GET /api/v1/groups/:id/members
|
||||||
func (h *GroupHandler) HandleGetGroupMemberList(c *gin.Context) {
|
func (h *GroupHandler) HandleGetGroupMemberList(c *gin.Context) {
|
||||||
userID := parseUserID(c)
|
userID := parseUserID(c)
|
||||||
|
|||||||
@@ -116,14 +116,14 @@ func (h *MessageHandler) HandleTyping(c *gin.Context) {
|
|||||||
response.Unauthorized(c, "")
|
response.Unauthorized(c, "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var params struct {
|
|
||||||
ConversationID string `json:"conversation_id" binding:"required"`
|
conversationID := getIDParam(c, "id")
|
||||||
}
|
if conversationID == "" {
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
response.BadRequest(c, "conversation id is required")
|
||||||
response.BadRequest(c, err.Error())
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
|
|
||||||
|
h.chatService.SendTyping(c.Request.Context(), userID, conversationID)
|
||||||
response.SuccessWithMessage(c, "typing sent", nil)
|
response.SuccessWithMessage(c, "typing sent", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,8 +397,8 @@ func (h *MessageHandler) SendMessage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSendMessage RESTful 风格的发送消息端点
|
// HandleSendMessage RESTful 风格的发送消息端点
|
||||||
// POST /api/v1/conversations/send_message
|
// POST /api/v1/conversations/:id/messages
|
||||||
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
// 请求体格式: {"detail_type": "private", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
||||||
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -406,6 +406,12 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conversationID := getIDParam(c, "id")
|
||||||
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var params dto.SendMessageParams
|
var params dto.SendMessageParams
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
@@ -413,10 +419,6 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证参数
|
// 验证参数
|
||||||
if params.ConversationID == "" {
|
|
||||||
response.BadRequest(c, "conversation_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if params.DetailType == "" {
|
if params.DetailType == "" {
|
||||||
response.BadRequest(c, "detail_type is required")
|
response.BadRequest(c, "detail_type is required")
|
||||||
return
|
return
|
||||||
@@ -427,7 +429,7 @@ func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 发送消息
|
// 发送消息
|
||||||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
|
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, params.Segments, params.ReplyToID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
@@ -480,7 +482,7 @@ func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetConversationList 获取会话列表
|
// HandleGetConversationList 获取会话列表
|
||||||
// GET /api/v1/conversations/list
|
// GET /api/v1/conversations
|
||||||
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -780,7 +782,6 @@ func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetConversation 获取会话详情
|
// HandleGetConversation 获取会话详情
|
||||||
// GET /api/v1/conversations/get?conversation_id=xxx
|
|
||||||
// GET /api/v1/conversations/:id
|
// GET /api/v1/conversations/:id
|
||||||
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
@@ -825,7 +826,6 @@ func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleGetMessages 获取会话消息
|
// HandleGetMessages 获取会话消息
|
||||||
// GET /api/v1/conversations/get_messages?conversation_id=xxx
|
|
||||||
// GET /api/v1/conversations/:id/messages
|
// GET /api/v1/conversations/:id/messages
|
||||||
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
@@ -913,7 +913,7 @@ func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleMarkRead 标记已读
|
// HandleMarkRead 标记已读
|
||||||
// POST /api/v1/conversations/mark_read
|
// POST /api/v1/conversations/:id/read
|
||||||
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -921,18 +921,19 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.MarkReadParams
|
conversationID := getIDParam(c, "id")
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req dto.MarkReadRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.ConversationID == "" {
|
err := h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
|
||||||
response.BadRequest(c, "conversation_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
@@ -942,7 +943,7 @@ func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HandleSetConversationPinned 设置会话置顶
|
// HandleSetConversationPinned 设置会话置顶
|
||||||
// POST /api/v1/conversations/set_pinned
|
// PUT /api/v1/conversations/:id/pinned
|
||||||
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
||||||
userID := c.GetString("user_id")
|
userID := c.GetString("user_id")
|
||||||
if userID == "" {
|
if userID == "" {
|
||||||
@@ -950,24 +951,27 @@ func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var params dto.SetConversationPinnedParams
|
conversationID := getIDParam(c, "id")
|
||||||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
if conversationID == "" {
|
||||||
|
response.BadRequest(c, "conversation id is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
IsPinned bool `json:"is_pinned"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.ConversationID == "" {
|
if err := h.chatService.SetConversationPinned(c.Request.Context(), conversationID, userID, req.IsPinned); err != nil {
|
||||||
response.BadRequest(c, "conversation_id is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
|
|
||||||
response.BadRequest(c, err.Error())
|
response.BadRequest(c, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
||||||
"conversation_id": params.ConversationID,
|
"conversation_id": conversationID,
|
||||||
"is_pinned": params.IsPinned,
|
"is_pinned": req.IsPinned,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
140
internal/handler/schedule_handler.go
Normal file
140
internal/handler/schedule_handler.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/pkg/response"
|
||||||
|
"carrot_bbs/internal/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ScheduleHandler struct {
|
||||||
|
scheduleService service.ScheduleService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScheduleHandler(scheduleService service.ScheduleService) *ScheduleHandler {
|
||||||
|
return &ScheduleHandler{scheduleService: scheduleService}
|
||||||
|
}
|
||||||
|
|
||||||
|
type createScheduleCourseRequest struct {
|
||||||
|
Name string `json:"name" binding:"required"`
|
||||||
|
Teacher string `json:"teacher"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
DayOfWeek int `json:"day_of_week" binding:"required"`
|
||||||
|
StartSection int `json:"start_section" binding:"required"`
|
||||||
|
EndSection int `json:"end_section" binding:"required"`
|
||||||
|
Weeks []int `json:"weeks" binding:"required,min=1"`
|
||||||
|
Color string `json:"color"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateScheduleCourseRequest = createScheduleCourseRequest
|
||||||
|
|
||||||
|
func (h *ScheduleHandler) ListCourses(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
week := 0
|
||||||
|
if rawWeek := c.Query("week"); rawWeek != "" {
|
||||||
|
parsed, err := strconv.Atoi(rawWeek)
|
||||||
|
if err != nil {
|
||||||
|
response.BadRequest(c, "invalid week")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
week = parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := h.scheduleService.ListCourses(userID, week)
|
||||||
|
if err != nil {
|
||||||
|
response.HandleError(c, err, "failed to list schedule courses")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.Success(c, gin.H{"list": list})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ScheduleHandler) CreateCourse(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req createScheduleCourseRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := h.scheduleService.CreateCourse(userID, service.CreateScheduleCourseInput{
|
||||||
|
Name: req.Name,
|
||||||
|
Teacher: req.Teacher,
|
||||||
|
Location: req.Location,
|
||||||
|
DayOfWeek: req.DayOfWeek,
|
||||||
|
StartSection: req.StartSection,
|
||||||
|
EndSection: req.EndSection,
|
||||||
|
Weeks: req.Weeks,
|
||||||
|
Color: req.Color,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.HandleError(c, err, "failed to create schedule course")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.SuccessWithMessage(c, "course created", gin.H{"course": created})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ScheduleHandler) UpdateCourse(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
courseID := c.Param("id")
|
||||||
|
if courseID == "" {
|
||||||
|
response.BadRequest(c, "invalid course id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req updateScheduleCourseRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
response.BadRequest(c, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := h.scheduleService.UpdateCourse(userID, courseID, service.CreateScheduleCourseInput{
|
||||||
|
Name: req.Name,
|
||||||
|
Teacher: req.Teacher,
|
||||||
|
Location: req.Location,
|
||||||
|
DayOfWeek: req.DayOfWeek,
|
||||||
|
StartSection: req.StartSection,
|
||||||
|
EndSection: req.EndSection,
|
||||||
|
Weeks: req.Weeks,
|
||||||
|
Color: req.Color,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
response.HandleError(c, err, "failed to update schedule course")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.SuccessWithMessage(c, "course updated", gin.H{"course": updated})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ScheduleHandler) DeleteCourse(c *gin.Context) {
|
||||||
|
userID := c.GetString("user_id")
|
||||||
|
if userID == "" {
|
||||||
|
response.Unauthorized(c, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
courseID := c.Param("id")
|
||||||
|
if courseID == "" {
|
||||||
|
response.BadRequest(c, "invalid course id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.scheduleService.DeleteCourse(userID, courseID); err != nil {
|
||||||
|
response.HandleError(c, err, "failed to delete schedule course")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
response.SuccessWithMessage(c, "course deleted", nil)
|
||||||
|
}
|
||||||
@@ -143,6 +143,9 @@ func autoMigrate(db *gorm.DB) error {
|
|||||||
|
|
||||||
// 自定义表情
|
// 自定义表情
|
||||||
&UserSticker{},
|
&UserSticker{},
|
||||||
|
|
||||||
|
// 课表
|
||||||
|
&ScheduleCourse{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
35
internal/model/schedule_course.go
Normal file
35
internal/model/schedule_course.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScheduleCourse 用户课表课程
|
||||||
|
type ScheduleCourse struct {
|
||||||
|
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||||
|
UserID string `json:"user_id" gorm:"type:varchar(36);index;not null"`
|
||||||
|
Name string `json:"name" gorm:"type:varchar(120);not null"`
|
||||||
|
Teacher string `json:"teacher" gorm:"type:varchar(80)"`
|
||||||
|
Location string `json:"location" gorm:"type:varchar(120)"`
|
||||||
|
DayOfWeek int `json:"day_of_week" gorm:"index;not null"` // 0=周一, 6=周日
|
||||||
|
StartSection int `json:"start_section" gorm:"not null"`
|
||||||
|
EndSection int `json:"end_section" gorm:"not null"`
|
||||||
|
Weeks string `json:"weeks" gorm:"type:text;not null"` // JSON 数组字符串
|
||||||
|
Color string `json:"color" gorm:"type:varchar(20)"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScheduleCourse) BeforeCreate(tx *gorm.DB) error {
|
||||||
|
if s.ID == "" {
|
||||||
|
s.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ScheduleCourse) TableName() string {
|
||||||
|
return "schedule_courses"
|
||||||
|
}
|
||||||
@@ -164,10 +164,17 @@ func (c *clientImpl) moderateSingleBatch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
type chatCompletionsRequest struct {
|
type chatCompletionsRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []chatMessage `json:"messages"`
|
Messages []chatMessage `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen3.5思考模式控制
|
||||||
|
ThinkingBudget *int `json:"thinking_budget,omitempty"` // 思考过程最大token数
|
||||||
|
ResponseFormat *responseFormatConfig `json:"response_format,omitempty"` // 响应格式
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseFormatConfig struct {
|
||||||
|
Type string `json:"type"` // "text" or "json_object"
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatMessage struct {
|
type chatMessage struct {
|
||||||
@@ -227,6 +234,13 @@ func (c *clientImpl) chatCompletion(
|
|||||||
Temperature: temperature,
|
Temperature: temperature,
|
||||||
MaxTokens: maxTokens,
|
MaxTokens: maxTokens,
|
||||||
}
|
}
|
||||||
|
// 禁用qwen3.5的思考模式,避免产生大量不必要的token消耗
|
||||||
|
falseVal := false
|
||||||
|
reqBody.EnableThinking = &falseVal
|
||||||
|
zero := 0
|
||||||
|
reqBody.ThinkingBudget = &zero
|
||||||
|
// 使用JSON输出格式
|
||||||
|
reqBody.ResponseFormat = &responseFormatConfig{Type: "json_object"}
|
||||||
|
|
||||||
data, err := json.Marshal(reqBody)
|
data, err := json.Marshal(reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -117,3 +117,117 @@ func (c *Client) Close() error {
|
|||||||
func (c *Client) IsMiniRedis() bool {
|
func (c *Client) IsMiniRedis() bool {
|
||||||
return c.isMiniRedis
|
return c.isMiniRedis
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==================== Hash 操作 ====================
|
||||||
|
|
||||||
|
// HSet 设置 Hash 字段
|
||||||
|
func (c *Client) HSet(ctx context.Context, key string, field string, value interface{}) error {
|
||||||
|
return c.rdb.HSet(ctx, key, field, value).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMSet 批量设置 Hash 字段
|
||||||
|
func (c *Client) HMSet(ctx context.Context, key string, values map[string]interface{}) error {
|
||||||
|
return c.rdb.HMSet(ctx, key, values).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGet 获取 Hash 字段值
|
||||||
|
func (c *Client) HGet(ctx context.Context, key string, field string) (string, error) {
|
||||||
|
return c.rdb.HGet(ctx, key, field).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HMGet 批量获取 Hash 字段值
|
||||||
|
func (c *Client) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) {
|
||||||
|
return c.rdb.HMGet(ctx, key, fields...).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HGetAll 获取 Hash 所有字段
|
||||||
|
func (c *Client) HGetAll(ctx context.Context, key string) (map[string]string, error) {
|
||||||
|
return c.rdb.HGetAll(ctx, key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HDel 删除 Hash 字段
|
||||||
|
func (c *Client) HDel(ctx context.Context, key string, fields ...string) error {
|
||||||
|
return c.rdb.HDel(ctx, key, fields...).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HExists 检查 Hash 字段是否存在
|
||||||
|
func (c *Client) HExists(ctx context.Context, key string, field string) (bool, error) {
|
||||||
|
return c.rdb.HExists(ctx, key, field).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HLen 获取 Hash 字段数量
|
||||||
|
func (c *Client) HLen(ctx context.Context, key string) (int64, error) {
|
||||||
|
return c.rdb.HLen(ctx, key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Sorted Set 操作 ====================
|
||||||
|
|
||||||
|
// ZAdd 添加 Sorted Set 成员
|
||||||
|
func (c *Client) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||||
|
return c.rdb.ZAdd(ctx, key, redis.Z{Score: score, Member: member}).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZAddArgs 批量添加 Sorted Set 成员
|
||||||
|
func (c *Client) ZAddArgs(ctx context.Context, key string, members ...redis.Z) error {
|
||||||
|
return c.rdb.ZAdd(ctx, key, members...).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRangeByScore 按分数范围获取成员(升序)
|
||||||
|
func (c *Client) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) {
|
||||||
|
return c.rdb.ZRangeByScore(ctx, key, &redis.ZRangeBy{
|
||||||
|
Min: min,
|
||||||
|
Max: max,
|
||||||
|
Offset: offset,
|
||||||
|
Count: count,
|
||||||
|
}).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRevRangeByScore 按分数范围获取成员(降序)
|
||||||
|
func (c *Client) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) {
|
||||||
|
return c.rdb.ZRevRangeByScore(ctx, key, &redis.ZRangeBy{
|
||||||
|
Min: min,
|
||||||
|
Max: max,
|
||||||
|
Offset: offset,
|
||||||
|
Count: count,
|
||||||
|
}).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRange 获取指定范围的成员(升序)
|
||||||
|
func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||||
|
return c.rdb.ZRange(ctx, key, start, stop).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRevRange 获取指定范围的成员(降序)
|
||||||
|
func (c *Client) ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) {
|
||||||
|
return c.rdb.ZRevRange(ctx, key, start, stop).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRem 删除 Sorted Set 成员
|
||||||
|
func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) error {
|
||||||
|
return c.rdb.ZRem(ctx, key, members...).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZScore 获取成员分数
|
||||||
|
func (c *Client) ZScore(ctx context.Context, key string, member string) (float64, error) {
|
||||||
|
return c.rdb.ZScore(ctx, key, member).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZCard 获取 Sorted Set 成员数量
|
||||||
|
func (c *Client) ZCard(ctx context.Context, key string) (int64, error) {
|
||||||
|
return c.rdb.ZCard(ctx, key).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZCount 统计分数范围内的成员数量
|
||||||
|
func (c *Client) ZCount(ctx context.Context, key string, min, max string) (int64, error) {
|
||||||
|
return c.rdb.ZCount(ctx, key, min, max).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== Pipeline 操作 ====================
|
||||||
|
|
||||||
|
// Pipeliner Pipeline 接口(使用 redis 库原生接口)
|
||||||
|
type Pipeliner = redis.Pipeliner
|
||||||
|
|
||||||
|
// Pipeline 创建 Pipeline
|
||||||
|
func (c *Client) Pipeline() Pipeliner {
|
||||||
|
return c.rdb.Pipeline()
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -172,7 +175,7 @@ func (r *MessageRepository) GetParticipant(conversationID string, userID string)
|
|||||||
if err == gorm.ErrRecordNotFound {
|
if err == gorm.ErrRecordNotFound {
|
||||||
// 检查会话是否存在
|
// 检查会话是否存在
|
||||||
var conv model.Conversation
|
var conv model.Conversation
|
||||||
if err := r.db.First(&conv, conversationID).Error; err == nil {
|
if err := r.db.Where("id = ?", conversationID).First(&conv).Error; err == nil {
|
||||||
// 会话存在,添加参与者
|
// 会话存在,添加参与者
|
||||||
participant = model.ConversationParticipant{
|
participant = model.ConversationParticipant{
|
||||||
ConversationID: conversationID,
|
ConversationID: conversationID,
|
||||||
@@ -284,7 +287,7 @@ func (r *MessageRepository) UpdateConversationLastSeq(conversationID string, seq
|
|||||||
// GetNextSeq 获取会话的下一个seq值
|
// GetNextSeq 获取会话的下一个seq值
|
||||||
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
|
func (r *MessageRepository) GetNextSeq(conversationID string) (int64, error) {
|
||||||
var conv model.Conversation
|
var conv model.Conversation
|
||||||
err := r.db.Select("last_seq").First(&conv, conversationID).Error
|
err := r.db.Select("last_seq").Where("id = ?", conversationID).First(&conv).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -296,7 +299,7 @@ func (r *MessageRepository) CreateMessageWithSeq(msg *model.Message) error {
|
|||||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||||
// 获取当前seq并+1
|
// 获取当前seq并+1
|
||||||
var conv model.Conversation
|
var conv model.Conversation
|
||||||
if err := tx.Select("last_seq").First(&conv, msg.ConversationID).Error; err != nil {
|
if err := tx.Select("last_seq").Where("id = ?", msg.ConversationID).First(&conv).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -522,3 +525,117 @@ func (r *MessageRepository) HideConversationForUser(conversationID, userID strin
|
|||||||
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
Where("conversation_id = ? AND user_id = ?", conversationID, userID).
|
||||||
Update("hidden_at", &now).Error
|
Update("hidden_at", &now).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParticipantUpdate 参与者更新数据
|
||||||
|
type ParticipantUpdate struct {
|
||||||
|
ConversationID string
|
||||||
|
UserID string
|
||||||
|
LastReadSeq int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriteMessages 批量写入消息
|
||||||
|
// 使用 GORM 的 CreateInBatches 实现高效批量插入
|
||||||
|
func (r *MessageRepository) BatchWriteMessages(ctx context.Context, messages []*model.Message) error {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.db.WithContext(ctx).CreateInBatches(messages, 100).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateParticipants 批量更新参与者(使用 CASE WHEN 优化)
|
||||||
|
// 使用单条 SQL 更新多条记录,避免循环执行 UPDATE
|
||||||
|
func (r *MessageRepository) BatchUpdateParticipants(ctx context.Context, updates []ParticipantUpdate) error {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 CASE WHEN 批量更新 SQL
|
||||||
|
// UPDATE conversation_participants
|
||||||
|
// SET last_read_seq = CASE
|
||||||
|
// WHEN (conversation_id = '1' AND user_id = 'a') THEN 10
|
||||||
|
// WHEN (conversation_id = '2' AND user_id = 'b') THEN 20
|
||||||
|
// END,
|
||||||
|
// updated_at = ?
|
||||||
|
// WHERE (conversation_id = '1' AND user_id = 'a')
|
||||||
|
// OR (conversation_id = '2' AND user_id = 'b')
|
||||||
|
|
||||||
|
var cases []string
|
||||||
|
var whereClauses []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
for _, u := range updates {
|
||||||
|
cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?")
|
||||||
|
whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)")
|
||||||
|
args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql := fmt.Sprintf(`
|
||||||
|
UPDATE conversation_participants
|
||||||
|
SET last_read_seq = CASE %s END,
|
||||||
|
updated_at = ?
|
||||||
|
WHERE %s
|
||||||
|
`, strings.Join(cases, " "), strings.Join(whereClauses, " OR "))
|
||||||
|
|
||||||
|
args = append(args, time.Now())
|
||||||
|
|
||||||
|
return r.db.WithContext(ctx).Exec(sql, args...).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConversationLastSeqWithContext 更新会话最后消息序号
|
||||||
|
func (r *MessageRepository) UpdateConversationLastSeqWithContext(ctx context.Context, convID string, lastSeq int64, lastMsgTime time.Time) error {
|
||||||
|
return r.db.WithContext(ctx).
|
||||||
|
Model(&model.Conversation{}).
|
||||||
|
Where("id = ?", convID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"last_seq": lastSeq,
|
||||||
|
"last_msg_time": lastMsgTime,
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchWriteMessagesWithTx 在事务中批量写入消息
|
||||||
|
func (r *MessageRepository) BatchWriteMessagesWithTx(tx *gorm.DB, messages []*model.Message) error {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return tx.CreateInBatches(messages, 100).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchUpdateParticipantsWithTx 在事务中批量更新参与者
|
||||||
|
func (r *MessageRepository) BatchUpdateParticipantsWithTx(tx *gorm.DB, updates []ParticipantUpdate) error {
|
||||||
|
if len(updates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cases []string
|
||||||
|
var whereClauses []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
for _, u := range updates {
|
||||||
|
cases = append(cases, "WHEN (conversation_id = ? AND user_id = ?) THEN ?")
|
||||||
|
whereClauses = append(whereClauses, "(conversation_id = ? AND user_id = ?)")
|
||||||
|
args = append(args, u.ConversationID, u.UserID, u.LastReadSeq, u.ConversationID, u.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
sql := fmt.Sprintf(`
|
||||||
|
UPDATE conversation_participants
|
||||||
|
SET last_read_seq = CASE %s END,
|
||||||
|
updated_at = ?
|
||||||
|
WHERE %s
|
||||||
|
`, strings.Join(cases, " "), strings.Join(whereClauses, " OR "))
|
||||||
|
|
||||||
|
args = append(args, time.Now())
|
||||||
|
|
||||||
|
return tx.Exec(sql, args...).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateConversationLastSeqWithTx 在事务中更新会话最后消息序号
|
||||||
|
func (r *MessageRepository) UpdateConversationLastSeqWithTx(tx *gorm.DB, convID string, lastSeq int64, lastMsgTime time.Time) error {
|
||||||
|
return tx.Model(&model.Conversation{}).
|
||||||
|
Where("id = ?", convID).
|
||||||
|
Updates(map[string]interface{}{
|
||||||
|
"last_seq": lastSeq,
|
||||||
|
"last_msg_time": lastMsgTime,
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|||||||
66
internal/repository/schedule_repo.go
Normal file
66
internal/repository/schedule_repo.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ScheduleRepository interface {
|
||||||
|
ListByUserID(userID string) ([]*model.ScheduleCourse, error)
|
||||||
|
GetByID(id string) (*model.ScheduleCourse, error)
|
||||||
|
Create(course *model.ScheduleCourse) error
|
||||||
|
Update(course *model.ScheduleCourse) error
|
||||||
|
DeleteByID(id string) error
|
||||||
|
ExistsColorByUser(userID, color, excludeID string) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type scheduleRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScheduleRepository(db *gorm.DB) ScheduleRepository {
|
||||||
|
return &scheduleRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) ListByUserID(userID string) ([]*model.ScheduleCourse, error) {
|
||||||
|
var courses []*model.ScheduleCourse
|
||||||
|
err := r.db.
|
||||||
|
Where("user_id = ?", userID).
|
||||||
|
Order("day_of_week ASC, start_section ASC, created_at ASC").
|
||||||
|
Find(&courses).Error
|
||||||
|
return courses, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) Create(course *model.ScheduleCourse) error {
|
||||||
|
return r.db.Create(course).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) GetByID(id string) (*model.ScheduleCourse, error) {
|
||||||
|
var course model.ScheduleCourse
|
||||||
|
if err := r.db.Where("id = ?", id).First(&course).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &course, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) Update(course *model.ScheduleCourse) error {
|
||||||
|
return r.db.Save(course).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) DeleteByID(id string) error {
|
||||||
|
return r.db.Delete(&model.ScheduleCourse{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *scheduleRepository) ExistsColorByUser(userID, color, excludeID string) (bool, error) {
|
||||||
|
var count int64
|
||||||
|
query := r.db.Model(&model.ScheduleCourse{}).
|
||||||
|
Where("user_id = ? AND LOWER(color) = LOWER(?)", userID, color)
|
||||||
|
if excludeID != "" {
|
||||||
|
query = query.Where("id <> ?", excludeID)
|
||||||
|
}
|
||||||
|
if err := query.Count(&count).Error; err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ type Router struct {
|
|||||||
stickerHandler *handler.StickerHandler
|
stickerHandler *handler.StickerHandler
|
||||||
gorseHandler *handler.GorseHandler
|
gorseHandler *handler.GorseHandler
|
||||||
voteHandler *handler.VoteHandler
|
voteHandler *handler.VoteHandler
|
||||||
|
scheduleHandler *handler.ScheduleHandler
|
||||||
jwtService *service.JWTService
|
jwtService *service.JWTService
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,6 +42,7 @@ func New(
|
|||||||
stickerHandler *handler.StickerHandler,
|
stickerHandler *handler.StickerHandler,
|
||||||
gorseHandler *handler.GorseHandler,
|
gorseHandler *handler.GorseHandler,
|
||||||
voteHandler *handler.VoteHandler,
|
voteHandler *handler.VoteHandler,
|
||||||
|
scheduleHandler *handler.ScheduleHandler,
|
||||||
) *Router {
|
) *Router {
|
||||||
// 设置JWT服务
|
// 设置JWT服务
|
||||||
userHandler.SetJWTService(jwtService)
|
userHandler.SetJWTService(jwtService)
|
||||||
@@ -59,6 +61,7 @@ func New(
|
|||||||
stickerHandler: stickerHandler,
|
stickerHandler: stickerHandler,
|
||||||
gorseHandler: gorseHandler,
|
gorseHandler: gorseHandler,
|
||||||
voteHandler: voteHandler,
|
voteHandler: voteHandler,
|
||||||
|
scheduleHandler: scheduleHandler,
|
||||||
jwtService: jwtService,
|
jwtService: jwtService,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,6 +163,18 @@ func (r *Router) setupRoutes() {
|
|||||||
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
|
posts.DELETE("/:id/vote", authMiddleware, r.voteHandler.Unvote) // 取消投票
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 课表路由
|
||||||
|
if r.scheduleHandler != nil {
|
||||||
|
schedule := v1.Group("/schedule")
|
||||||
|
schedule.Use(authMiddleware)
|
||||||
|
{
|
||||||
|
schedule.GET("/courses", r.scheduleHandler.ListCourses)
|
||||||
|
schedule.POST("/courses", r.scheduleHandler.CreateCourse)
|
||||||
|
schedule.PUT("/courses/:id", r.scheduleHandler.UpdateCourse)
|
||||||
|
schedule.DELETE("/courses/:id", r.scheduleHandler.DeleteCourse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 投票选项路由
|
// 投票选项路由
|
||||||
voteOptions := v1.Group("/vote-options")
|
voteOptions := v1.Group("/vote-options")
|
||||||
voteOptions.Use(authMiddleware)
|
voteOptions.Use(authMiddleware)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/cache"
|
||||||
"carrot_bbs/internal/dto"
|
"carrot_bbs/internal/dto"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
"carrot_bbs/internal/pkg/sse"
|
"carrot_bbs/internal/pkg/sse"
|
||||||
@@ -58,6 +60,9 @@ type chatServiceImpl struct {
|
|||||||
userRepo *repository.UserRepository
|
userRepo *repository.UserRepository
|
||||||
sensitive SensitiveService
|
sensitive SensitiveService
|
||||||
sseHub *sse.Hub
|
sseHub *sse.Hub
|
||||||
|
|
||||||
|
// 缓存相关字段
|
||||||
|
conversationCache *cache.ConversationCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatService 创建聊天服务
|
// NewChatService 创建聊天服务
|
||||||
@@ -68,12 +73,25 @@ func NewChatService(
|
|||||||
sensitive SensitiveService,
|
sensitive SensitiveService,
|
||||||
sseHub *sse.Hub,
|
sseHub *sse.Hub,
|
||||||
) ChatService {
|
) ChatService {
|
||||||
|
// 创建适配器
|
||||||
|
convRepoAdapter := cache.NewConversationRepositoryAdapter(repo)
|
||||||
|
msgRepoAdapter := cache.NewMessageRepositoryAdapter(repo)
|
||||||
|
|
||||||
|
// 创建会话缓存
|
||||||
|
conversationCache := cache.NewConversationCache(
|
||||||
|
cache.GetCache(),
|
||||||
|
convRepoAdapter,
|
||||||
|
msgRepoAdapter,
|
||||||
|
cache.DefaultConversationCacheSettings(),
|
||||||
|
)
|
||||||
|
|
||||||
return &chatServiceImpl{
|
return &chatServiceImpl{
|
||||||
db: db,
|
db: db,
|
||||||
repo: repo,
|
repo: repo,
|
||||||
userRepo: userRepo,
|
userRepo: userRepo,
|
||||||
sensitive: sensitive,
|
sensitive: sensitive,
|
||||||
sseHub: sseHub,
|
sseHub: sseHub,
|
||||||
|
conversationCache: conversationCache,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -86,18 +104,33 @@ func (s *chatServiceImpl) publishSSEToUsers(userIDs []string, event string, payl
|
|||||||
|
|
||||||
// GetOrCreateConversation 获取或创建私聊会话
|
// GetOrCreateConversation 获取或创建私聊会话
|
||||||
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
func (s *chatServiceImpl) GetOrCreateConversation(ctx context.Context, user1ID, user2ID string) (*model.Conversation, error) {
|
||||||
return s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
conv, err := s.repo.GetOrCreatePrivateConversation(user1ID, user2ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 失效会话列表缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateConversationList(user1ID)
|
||||||
|
s.conversationCache.InvalidateConversationList(user2ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversationList 获取用户的会话列表
|
// GetConversationList 获取用户的会话列表(带缓存)
|
||||||
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
func (s *chatServiceImpl) GetConversationList(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
// 优先使用缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetConversationList(ctx, userID, page, pageSize)
|
||||||
|
}
|
||||||
return s.repo.GetConversations(userID, page, pageSize)
|
return s.repo.GetConversations(userID, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversationByID 获取会话详情
|
// GetConversationByID 获取会话详情(带缓存)
|
||||||
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationID string, userID string) (*model.Conversation, error) {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
participant, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("conversation not found or no permission")
|
return nil, errors.New("conversation not found or no permission")
|
||||||
@@ -105,21 +138,33 @@ func (s *chatServiceImpl) GetConversationByID(ctx context.Context, conversationI
|
|||||||
return nil, fmt.Errorf("failed to get participant: %w", err)
|
return nil, fmt.Errorf("failed to get participant: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取会话信息
|
// 获取会话信息(优先使用缓存)
|
||||||
conv, err := s.repo.GetConversation(conversationID)
|
var conv *model.Conversation
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
conv, err = s.conversationCache.GetConversation(ctx, conversationID)
|
||||||
|
} else {
|
||||||
|
conv, err = s.repo.GetConversation(conversationID)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 填充用户的已读位置信息
|
|
||||||
_ = participant // 可以用于返回已读位置等信息
|
_ = participant // 可以用于返回已读位置等信息
|
||||||
|
|
||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getParticipant 获取参与者信息(优先使用缓存)
|
||||||
|
func (s *chatServiceImpl) getParticipant(ctx context.Context, conversationID, userID string) (*model.ConversationParticipant, error) {
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetParticipant(ctx, conversationID, userID)
|
||||||
|
}
|
||||||
|
return s.repo.GetParticipant(conversationID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteConversationForSelf 仅自己删除会话
|
// DeleteConversationForSelf 仅自己删除会话
|
||||||
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, conversationID string, userID string) error {
|
||||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
participant, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return errors.New("conversation not found or no permission")
|
return errors.New("conversation not found or no permission")
|
||||||
@@ -133,12 +178,18 @@ func (s *chatServiceImpl) DeleteConversationForSelf(ctx context.Context, convers
|
|||||||
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
if err := s.repo.HideConversationForUser(conversationID, userID); err != nil {
|
||||||
return fmt.Errorf("failed to hide conversation: %w", err)
|
return fmt.Errorf("failed to hide conversation: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 失效会话列表缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateConversationList(userID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetConversationPinned 设置会话置顶(用户维度)
|
// SetConversationPinned 设置会话置顶(用户维度)
|
||||||
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversationID string, userID string, isPinned bool) error {
|
||||||
participant, err := s.repo.GetParticipant(conversationID, userID)
|
participant, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return errors.New("conversation not found or no permission")
|
return errors.New("conversation not found or no permission")
|
||||||
@@ -152,13 +203,20 @@ func (s *chatServiceImpl) SetConversationPinned(ctx context.Context, conversatio
|
|||||||
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
if err := s.repo.UpdatePinned(conversationID, userID, isPinned); err != nil {
|
||||||
return fmt.Errorf("failed to update pinned status: %w", err)
|
return fmt.Errorf("failed to update pinned status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 失效缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateParticipant(conversationID, userID)
|
||||||
|
s.conversationCache.InvalidateConversationList(userID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessage 发送消息(使用 segments)
|
// SendMessage 发送消息(使用 segments)
|
||||||
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||||
// 首先验证会话是否存在
|
// 首先验证会话是否存在
|
||||||
conv, err := s.repo.GetConversation(conversationID)
|
conv, err := s.getConversation(ctx, conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("会话不存在,请重新创建会话")
|
return nil, errors.New("会话不存在,请重新创建会话")
|
||||||
@@ -166,9 +224,9 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
return nil, fmt.Errorf("failed to get conversation: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 拉黑限制:仅拦截“被拉黑方 -> 拉黑人”方向
|
// 拉黑限制:仅拦截"被拉黑方 -> 拉黑人"方向
|
||||||
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
if conv.Type == model.ConversationTypePrivate && s.userRepo != nil {
|
||||||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
participants, pErr := s.getParticipants(ctx, conversationID)
|
||||||
if pErr != nil {
|
if pErr != nil {
|
||||||
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
return nil, fmt.Errorf("failed to get participants: %w", pErr)
|
||||||
}
|
}
|
||||||
@@ -209,7 +267,7 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
participant, err := s.repo.GetParticipant(conversationID, senderID)
|
participant, err := s.getParticipant(ctx, conversationID, senderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("您不是该会话的参与者")
|
return nil, errors.New("您不是该会话的参与者")
|
||||||
@@ -231,11 +289,27 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateMessagePages(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步写入缓存
|
||||||
|
go func() {
|
||||||
|
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
|
||||||
|
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 获取会话中的参与者并发送 SSE
|
// 获取会话中的参与者并发送 SSE
|
||||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
participants, err := s.getParticipants(ctx, conversationID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
targetIDs := make([]string, 0, len(participants))
|
targetIDs := make([]string, 0, len(participants))
|
||||||
for _, p := range participants {
|
for _, p := range participants {
|
||||||
|
// 私聊场景下,发送者已经从 HTTP 响应拿到消息,避免再通过 SSE 回推导致本端重复展示。
|
||||||
|
if conv.Type == model.ConversationTypePrivate && p.UserID == senderID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
targetIDs = append(targetIDs, p.UserID)
|
targetIDs = append(targetIDs, p.UserID)
|
||||||
}
|
}
|
||||||
detailType := "private"
|
detailType := "private"
|
||||||
@@ -250,6 +324,10 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
if p.UserID == senderID {
|
if p.UserID == senderID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 失效未读数缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateUnreadCount(p.UserID, conversationID)
|
||||||
|
}
|
||||||
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
|
if totalUnread, uErr := s.repo.GetAllUnreadCount(p.UserID); uErr == nil {
|
||||||
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
|
s.publishSSEToUsers([]string{p.UserID}, "conversation_unread", map[string]interface{}{
|
||||||
"conversation_id": conversationID,
|
"conversation_id": conversationID,
|
||||||
@@ -259,11 +337,46 @@ func (s *chatServiceImpl) SendMessage(ctx context.Context, senderID string, conv
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 失效会话列表缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
for _, p := range participants {
|
||||||
|
s.conversationCache.InvalidateConversationList(p.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_ = participant // 避免未使用变量警告
|
_ = participant // 避免未使用变量警告
|
||||||
|
|
||||||
return message, nil
|
return message, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConversation 获取会话(优先使用缓存)
|
||||||
|
func (s *chatServiceImpl) getConversation(ctx context.Context, conversationID string) (*model.Conversation, error) {
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetConversation(ctx, conversationID)
|
||||||
|
}
|
||||||
|
return s.repo.GetConversation(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getParticipants 获取会话参与者(优先使用缓存)
|
||||||
|
func (s *chatServiceImpl) getParticipants(ctx context.Context, conversationID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetParticipants(ctx, conversationID)
|
||||||
|
}
|
||||||
|
return s.repo.GetConversationParticipants(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheMessage 缓存消息(内部方法)
|
||||||
|
func (s *chatServiceImpl) cacheMessage(ctx context.Context, convID string, msg *model.Message) error {
|
||||||
|
if s.conversationCache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return s.conversationCache.CacheMessage(asyncCtx, convID, msg)
|
||||||
|
}
|
||||||
|
|
||||||
func containsImageSegment(segments model.MessageSegments) bool {
|
func containsImageSegment(segments model.MessageSegments) bool {
|
||||||
for _, seg := range segments {
|
for _, seg := range segments {
|
||||||
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
if seg.Type == string(model.ContentTypeImage) || seg.Type == "image" {
|
||||||
@@ -273,10 +386,10 @@ func containsImageSegment(segments model.MessageSegments) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessages 获取消息历史(分页)
|
// GetMessages 获取消息历史(分页,带缓存)
|
||||||
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string, userID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
_, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, 0, errors.New("conversation not found or no permission")
|
return nil, 0, errors.New("conversation not found or no permission")
|
||||||
@@ -284,13 +397,18 @@ func (s *chatServiceImpl) GetMessages(ctx context.Context, conversationID string
|
|||||||
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
return nil, 0, fmt.Errorf("failed to get participant: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 优先使用缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
return s.repo.GetMessages(conversationID, page, pageSize)
|
return s.repo.GetMessages(conversationID, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
// GetMessagesAfterSeq 获取指定seq之后的消息(用于增量同步)
|
||||||
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationID string, userID string, afterSeq int64, limit int) ([]*model.Message, error) {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
_, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("conversation not found or no permission")
|
return nil, errors.New("conversation not found or no permission")
|
||||||
@@ -308,7 +426,7 @@ func (s *chatServiceImpl) GetMessagesAfterSeq(ctx context.Context, conversationI
|
|||||||
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
// GetMessagesBeforeSeq 获取指定seq之前的历史消息(用于下拉加载更多)
|
||||||
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversationID string, userID string, beforeSeq int64, limit int) ([]*model.Message, error) {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
_, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("conversation not found or no permission")
|
return nil, errors.New("conversation not found or no permission")
|
||||||
@@ -326,7 +444,7 @@ func (s *chatServiceImpl) GetMessagesBeforeSeq(ctx context.Context, conversation
|
|||||||
// MarkAsRead 标记已读
|
// MarkAsRead 标记已读
|
||||||
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string, userID string, seq int64) error {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
_, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return errors.New("conversation not found or no permission")
|
return errors.New("conversation not found or no permission")
|
||||||
@@ -334,17 +452,27 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
|
|||||||
return fmt.Errorf("failed to get participant: %w", err)
|
return fmt.Errorf("failed to get participant: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新参与者的已读位置
|
// 1. 先写入DB(保证数据一致性,DB是唯一数据源)
|
||||||
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
err = s.repo.UpdateLastReadSeq(conversationID, userID, seq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update last read seq: %w", err)
|
return fmt.Errorf("failed to update last read seq: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
participants, pErr := s.repo.GetConversationParticipants(conversationID)
|
// 2. DB 写入成功后,失效缓存(Cache-Aside 模式)
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
// 失效参与者缓存,下次读取时会从 DB 加载最新数据
|
||||||
|
s.conversationCache.InvalidateParticipant(conversationID, userID)
|
||||||
|
// 失效未读数缓存
|
||||||
|
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
|
||||||
|
// 失效会话列表缓存
|
||||||
|
s.conversationCache.InvalidateConversationList(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
participants, pErr := s.getParticipants(ctx, conversationID)
|
||||||
if pErr == nil {
|
if pErr == nil {
|
||||||
detailType := "private"
|
detailType := "private"
|
||||||
groupID := ""
|
groupID := ""
|
||||||
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
detailType = "group"
|
detailType = "group"
|
||||||
if conv.GroupID != nil {
|
if conv.GroupID != nil {
|
||||||
groupID = *conv.GroupID
|
groupID = *conv.GroupID
|
||||||
@@ -372,10 +500,10 @@ func (s *chatServiceImpl) MarkAsRead(ctx context.Context, conversationID string,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUnreadCount 获取指定会话的未读消息数
|
// GetUnreadCount 获取指定会话的未读消息数(带缓存)
|
||||||
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, userID)
|
_, err := s.getParticipant(ctx, conversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return 0, errors.New("conversation not found or no permission")
|
return 0, errors.New("conversation not found or no permission")
|
||||||
@@ -383,6 +511,11 @@ func (s *chatServiceImpl) GetUnreadCount(ctx context.Context, conversationID str
|
|||||||
return 0, fmt.Errorf("failed to get participant: %w", err)
|
return 0, fmt.Errorf("failed to get participant: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 优先使用缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetUnreadCount(ctx, userID, conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
return s.repo.GetUnreadCount(conversationID, userID)
|
return s.repo.GetUnreadCount(conversationID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -427,10 +560,15 @@ func (s *chatServiceImpl) RecallMessage(ctx context.Context, messageID string, u
|
|||||||
return fmt.Errorf("failed to recall message: %w", err)
|
return fmt.Errorf("failed to recall message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if participants, pErr := s.repo.GetConversationParticipants(message.ConversationID); pErr == nil {
|
// 失效消息缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateConversation(message.ConversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if participants, pErr := s.getParticipants(ctx, message.ConversationID); pErr == nil {
|
||||||
detailType := "private"
|
detailType := "private"
|
||||||
groupID := ""
|
groupID := ""
|
||||||
if conv, convErr := s.repo.GetConversation(message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
if conv, convErr := s.getConversation(ctx, message.ConversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
detailType = "group"
|
detailType = "group"
|
||||||
if conv.GroupID != nil {
|
if conv.GroupID != nil {
|
||||||
groupID = *conv.GroupID
|
groupID = *conv.GroupID
|
||||||
@@ -465,7 +603,7 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err = s.repo.GetParticipant(message.ConversationID, userID)
|
_, err = s.getParticipant(ctx, message.ConversationID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return errors.New("no permission to delete this message")
|
return errors.New("no permission to delete this message")
|
||||||
@@ -485,6 +623,11 @@ func (s *chatServiceImpl) DeleteMessage(ctx context.Context, messageID string, u
|
|||||||
return fmt.Errorf("failed to delete message: %w", err)
|
return fmt.Errorf("failed to delete message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 失效消息缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateConversation(message.ConversationID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,19 +638,19 @@ func (s *chatServiceImpl) SendTyping(ctx context.Context, senderID string, conve
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err := s.repo.GetParticipant(conversationID, senderID)
|
_, err := s.getParticipant(ctx, conversationID, senderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取会话中的其他参与者
|
// 获取会话中的其他参与者
|
||||||
participants, err := s.repo.GetConversationParticipants(conversationID)
|
participants, err := s.getParticipants(ctx, conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
detailType := "private"
|
detailType := "private"
|
||||||
if conv, convErr := s.repo.GetConversation(conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
if conv, convErr := s.getConversation(ctx, conversationID); convErr == nil && conv.Type == model.ConversationTypeGroup {
|
||||||
detailType = "group"
|
detailType = "group"
|
||||||
}
|
}
|
||||||
for _, p := range participants {
|
for _, p := range participants {
|
||||||
@@ -537,7 +680,7 @@ func (s *chatServiceImpl) IsUserOnline(userID string) bool {
|
|||||||
// 适用于群聊等由调用方自行负责推送的场景
|
// 适用于群聊等由调用方自行负责推送的场景
|
||||||
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conversationID string, segments model.MessageSegments, replyToID *string) (*model.Message, error) {
|
||||||
// 验证会话是否存在
|
// 验证会话是否存在
|
||||||
_, err := s.repo.GetConversation(conversationID)
|
_, err := s.getConversation(ctx, conversationID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("会话不存在,请重新创建会话")
|
return nil, errors.New("会话不存在,请重新创建会话")
|
||||||
@@ -546,7 +689,7 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 验证用户是否是会话参与者
|
// 验证用户是否是会话参与者
|
||||||
_, err = s.repo.GetParticipant(conversationID, senderID)
|
_, err = s.getParticipant(ctx, conversationID, senderID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errors.New("您不是该会话的参与者")
|
return nil, errors.New("您不是该会话的参与者")
|
||||||
@@ -566,5 +709,17 @@ func (s *chatServiceImpl) SaveMessage(ctx context.Context, senderID string, conv
|
|||||||
return nil, fmt.Errorf("failed to save message: %w", err)
|
return nil, fmt.Errorf("failed to save message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateMessagePages(conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步写入缓存
|
||||||
|
go func() {
|
||||||
|
if err := s.cacheMessage(context.Background(), conversationID, message); err != nil {
|
||||||
|
log.Printf("[ChatService] async cache message failed, convID=%s, msgID=%s, err=%v", conversationID, message.ID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return message, nil
|
return message, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -145,6 +145,45 @@ func (s *groupService) publishGroupNotice(groupID string, notice groupNoticeMess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// invalidateConversationCachesAfterSystemMessage 系统消息写入后失效相关缓存
|
||||||
|
func (s *groupService) invalidateConversationCachesAfterSystemMessage(conversationID string) {
|
||||||
|
if conversationID == "" || s.messageRepo == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 新系统消息会影响消息分页列表
|
||||||
|
cache.InvalidateMessagePages(s.cache, conversationID)
|
||||||
|
// 参与者列表可能发生变化(加群/退群)后,这里统一清理一次
|
||||||
|
s.cache.Delete(cache.ParticipantListKey(conversationID))
|
||||||
|
|
||||||
|
participants, err := s.messageRepo.GetConversationParticipants(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, p := range participants {
|
||||||
|
if p == nil || p.UserID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 会话最后消息、未读数会变化,清理用户维度缓存
|
||||||
|
cache.InvalidateConversationList(s.cache, p.UserID)
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, p.UserID)
|
||||||
|
cache.InvalidateUnreadDetail(s.cache, p.UserID, conversationID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateConversationCachesAfterMembershipChange 成员变更后失效相关缓存
|
||||||
|
func (s *groupService) invalidateConversationCachesAfterMembershipChange(conversationID, userID string) {
|
||||||
|
if conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.cache.Delete(cache.ParticipantListKey(conversationID))
|
||||||
|
if userID != "" {
|
||||||
|
s.cache.Delete(cache.ParticipantKey(conversationID, userID))
|
||||||
|
cache.InvalidateConversationList(s.cache, userID)
|
||||||
|
cache.InvalidateUnreadConversation(s.cache, userID)
|
||||||
|
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ==================== 群组管理 ====================
|
// ==================== 群组管理 ====================
|
||||||
|
|
||||||
// CreateGroup 创建群组
|
// CreateGroup 创建群组
|
||||||
@@ -444,6 +483,7 @@ func (s *groupService) broadcastMemberJoinNotice(groupID string, targetUserID st
|
|||||||
log.Printf("[broadcastMemberJoinNotice] 保存入群提示消息失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
|
log.Printf("[broadcastMemberJoinNotice] 保存入群提示消息失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
|
||||||
} else {
|
} else {
|
||||||
savedMessage = msg
|
savedMessage = msg
|
||||||
|
s.invalidateConversationCachesAfterSystemMessage(conv.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("[broadcastMemberJoinNotice] 获取群组会话失败: groupID=%s, err=%v", groupID, err)
|
log.Printf("[broadcastMemberJoinNotice] 获取群组会话失败: groupID=%s, err=%v", groupID, err)
|
||||||
@@ -502,6 +542,7 @@ func (s *groupService) addMemberToGroupAndConversation(group *model.Group, userI
|
|||||||
if err := s.messageRepo.AddParticipant(conv.ID, userID); err != nil {
|
if err := s.messageRepo.AddParticipant(conv.ID, userID); err != nil {
|
||||||
log.Printf("[addMemberToGroupAndConversation] 添加会话参与者失败: groupID=%s, userID=%s, err=%v", group.ID, userID, err)
|
log.Printf("[addMemberToGroupAndConversation] 添加会话参与者失败: groupID=%s, userID=%s, err=%v", group.ID, userID, err)
|
||||||
}
|
}
|
||||||
|
s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cache.InvalidateGroupMembers(s.cache, group.ID)
|
cache.InvalidateGroupMembers(s.cache, group.ID)
|
||||||
@@ -1036,6 +1077,7 @@ func (s *groupService) LeaveGroup(userID string, groupID string) error {
|
|||||||
// 如果移除参与者失败,记录日志但不阻塞退出群流程
|
// 如果移除参与者失败,记录日志但不阻塞退出群流程
|
||||||
fmt.Printf("[WARN] LeaveGroup: failed to remove participant %s from conversation %s, error: %v\n", userID, conv.ID, err)
|
fmt.Printf("[WARN] LeaveGroup: failed to remove participant %s from conversation %s, error: %v\n", userID, conv.ID, err)
|
||||||
}
|
}
|
||||||
|
s.invalidateConversationCachesAfterMembershipChange(conv.ID, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 失效群组成员缓存
|
// 失效群组成员缓存
|
||||||
@@ -1092,6 +1134,7 @@ func (s *groupService) RemoveMember(userID string, groupID string, targetUserID
|
|||||||
if err := s.messageRepo.RemoveParticipant(conv.ID, targetUserID); err != nil {
|
if err := s.messageRepo.RemoveParticipant(conv.ID, targetUserID); err != nil {
|
||||||
log.Printf("[RemoveMember] 移除会话参与者失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
|
log.Printf("[RemoveMember] 移除会话参与者失败: groupID=%s, userID=%s, err=%v", groupID, targetUserID, err)
|
||||||
}
|
}
|
||||||
|
s.invalidateConversationCachesAfterMembershipChange(conv.ID, targetUserID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1290,6 +1333,7 @@ func (s *groupService) MuteMember(userID string, groupID string, targetUserID st
|
|||||||
} else {
|
} else {
|
||||||
savedMessage = msg
|
savedMessage = msg
|
||||||
log.Printf("[MuteMember] 禁言消息已保存, ID=%s, Seq=%d", msg.ID, msg.Seq)
|
log.Printf("[MuteMember] 禁言消息已保存, ID=%s, Seq=%d", msg.ID, msg.Seq)
|
||||||
|
s.invalidateConversationCachesAfterSystemMessage(conv.ID)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Printf("[MuteMember] 获取群组会话失败: %v", err)
|
log.Printf("[MuteMember] 获取群组会话失败: %v", err)
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"carrot_bbs/internal/cache"
|
"carrot_bbs/internal/cache"
|
||||||
"carrot_bbs/internal/model"
|
"carrot_bbs/internal/model"
|
||||||
"carrot_bbs/internal/repository"
|
"carrot_bbs/internal/repository"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 缓存TTL常量
|
// 缓存TTL常量
|
||||||
@@ -21,15 +24,37 @@ const (
|
|||||||
|
|
||||||
// MessageService 消息服务
|
// MessageService 消息服务
|
||||||
type MessageService struct {
|
type MessageService struct {
|
||||||
|
db *gorm.DB
|
||||||
|
|
||||||
|
// 基础仓储
|
||||||
messageRepo *repository.MessageRepository
|
messageRepo *repository.MessageRepository
|
||||||
cache cache.Cache
|
|
||||||
|
// 缓存相关字段
|
||||||
|
conversationCache *cache.ConversationCache
|
||||||
|
|
||||||
|
// 基础缓存(用于简单缓存操作)
|
||||||
|
baseCache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMessageService 创建消息服务
|
// NewMessageService 创建消息服务
|
||||||
func NewMessageService(messageRepo *repository.MessageRepository) *MessageService {
|
func NewMessageService(db *gorm.DB, messageRepo *repository.MessageRepository) *MessageService {
|
||||||
|
// 创建适配器
|
||||||
|
convRepoAdapter := cache.NewConversationRepositoryAdapter(messageRepo)
|
||||||
|
msgRepoAdapter := cache.NewMessageRepositoryAdapter(messageRepo)
|
||||||
|
|
||||||
|
// 创建会话缓存
|
||||||
|
conversationCache := cache.NewConversationCache(
|
||||||
|
cache.GetCache(),
|
||||||
|
convRepoAdapter,
|
||||||
|
msgRepoAdapter,
|
||||||
|
cache.DefaultConversationCacheSettings(),
|
||||||
|
)
|
||||||
|
|
||||||
return &MessageService{
|
return &MessageService{
|
||||||
messageRepo: messageRepo,
|
db: db,
|
||||||
cache: cache.GetCache(),
|
messageRepo: messageRepo,
|
||||||
|
conversationCache: conversationCache,
|
||||||
|
baseCache: cache.GetCache(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,20 +86,50 @@ func (s *MessageService) SendMessage(ctx context.Context, senderID, receiverID s
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 新消息会改变分页结果,先失效分页缓存,避免读到旧列表
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
s.conversationCache.InvalidateMessagePages(conv.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 异步写入缓存
|
||||||
|
go func() {
|
||||||
|
if err := s.cacheMessage(context.Background(), conv.ID, msg); err != nil {
|
||||||
|
log.Printf("[MessageService] async cache message failed, convID=%s, msgID=%s, err=%v", conv.ID, msg.ID, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 失效会话列表缓存(发送者和接收者)
|
// 失效会话列表缓存(发送者和接收者)
|
||||||
cache.InvalidateConversationList(s.cache, senderID)
|
s.conversationCache.InvalidateConversationList(senderID)
|
||||||
cache.InvalidateConversationList(s.cache, receiverID)
|
s.conversationCache.InvalidateConversationList(receiverID)
|
||||||
|
|
||||||
// 失效未读数缓存
|
// 失效未读数缓存
|
||||||
cache.InvalidateUnreadConversation(s.cache, receiverID)
|
cache.InvalidateUnreadConversation(s.baseCache, receiverID)
|
||||||
cache.InvalidateUnreadDetail(s.cache, receiverID, conv.ID)
|
s.conversationCache.InvalidateUnreadCount(receiverID, conv.ID)
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cacheMessage 缓存消息(内部方法)
|
||||||
|
func (s *MessageService) cacheMessage(ctx context.Context, convID string, msg *model.Message) error {
|
||||||
|
if s.conversationCache == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
asyncCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return s.conversationCache.CacheMessage(asyncCtx, convID, msg)
|
||||||
|
}
|
||||||
|
|
||||||
// GetConversations 获取会话列表(带缓存)
|
// GetConversations 获取会话列表(带缓存)
|
||||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
func (s *MessageService) GetConversations(ctx context.Context, userID string, page, pageSize int) ([]*model.Conversation, int64, error) {
|
||||||
|
// 优先使用 ConversationCache
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetConversationList(ctx, userID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级到基础缓存
|
||||||
cacheSettings := cache.GetSettings()
|
cacheSettings := cache.GetSettings()
|
||||||
conversationTTL := cacheSettings.ConversationTTL
|
conversationTTL := cacheSettings.ConversationTTL
|
||||||
if conversationTTL <= 0 {
|
if conversationTTL <= 0 {
|
||||||
@@ -92,7 +147,7 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa
|
|||||||
// 生成缓存键
|
// 生成缓存键
|
||||||
cacheKey := cache.ConversationListKey(userID, page, pageSize)
|
cacheKey := cache.ConversationListKey(userID, page, pageSize)
|
||||||
result, err := cache.GetOrLoadTyped[*ConversationListResult](
|
result, err := cache.GetOrLoadTyped[*ConversationListResult](
|
||||||
s.cache,
|
s.baseCache,
|
||||||
cacheKey,
|
cacheKey,
|
||||||
conversationTTL,
|
conversationTTL,
|
||||||
jitter,
|
jitter,
|
||||||
@@ -117,8 +172,14 @@ func (s *MessageService) GetConversations(ctx context.Context, userID string, pa
|
|||||||
return result.Conversations, result.Total, nil
|
return result.Conversations, result.Total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessages 获取消息列表
|
// GetMessages 获取消息列表(带缓存)
|
||||||
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
func (s *MessageService) GetMessages(ctx context.Context, conversationID string, page, pageSize int) ([]*model.Message, int64, error) {
|
||||||
|
// 优先使用 ConversationCache
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetMessages(ctx, conversationID, page, pageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级到直接访问数据库
|
||||||
return s.messageRepo.GetMessages(conversationID, page, pageSize)
|
return s.messageRepo.GetMessages(conversationID, page, pageSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,20 +188,25 @@ func (s *MessageService) GetMessagesAfterSeq(ctx context.Context, conversationID
|
|||||||
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
return s.messageRepo.GetMessagesAfterSeq(conversationID, afterSeq, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkAsRead 标记为已读
|
// MarkAsRead 标记为已读(使用 Cache-Aside 模式)
|
||||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
|
func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string, userID string, lastReadSeq int64) error {
|
||||||
|
// 1. 先写入DB(保证数据一致性,DB是唯一数据源)
|
||||||
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
|
err := s.messageRepo.UpdateLastReadSeq(conversationID, userID, lastReadSeq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 失效未读数缓存
|
// 2. DB 写入成功后,失效缓存(Cache-Aside 模式)
|
||||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
if s.conversationCache != nil {
|
||||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
// 失效参与者缓存,下次读取时会从 DB 加载最新数据
|
||||||
|
s.conversationCache.InvalidateParticipant(conversationID, userID)
|
||||||
// 失效会话列表缓存
|
// 失效未读数缓存
|
||||||
cache.InvalidateConversationList(s.cache, userID)
|
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
|
||||||
|
// 失效会话列表缓存
|
||||||
|
s.conversationCache.InvalidateConversationList(userID)
|
||||||
|
}
|
||||||
|
cache.InvalidateUnreadConversation(s.baseCache, userID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -148,6 +214,12 @@ func (s *MessageService) MarkAsRead(ctx context.Context, conversationID string,
|
|||||||
// GetUnreadCount 获取未读消息数(带缓存)
|
// GetUnreadCount 获取未读消息数(带缓存)
|
||||||
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
// userID 参数为 string 类型(UUID格式),与JWT中user_id保持一致
|
||||||
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID string, userID string) (int64, error) {
|
||||||
|
// 优先使用 ConversationCache
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetUnreadCount(ctx, userID, conversationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 降级到基础缓存
|
||||||
cacheSettings := cache.GetSettings()
|
cacheSettings := cache.GetSettings()
|
||||||
unreadTTL := cacheSettings.UnreadCountTTL
|
unreadTTL := cacheSettings.UnreadCountTTL
|
||||||
if unreadTTL <= 0 {
|
if unreadTTL <= 0 {
|
||||||
@@ -166,7 +238,7 @@ func (s *MessageService) GetUnreadCount(ctx context.Context, conversationID stri
|
|||||||
cacheKey := cache.UnreadDetailKey(userID, conversationID)
|
cacheKey := cache.UnreadDetailKey(userID, conversationID)
|
||||||
|
|
||||||
return cache.GetOrLoadTyped[int64](
|
return cache.GetOrLoadTyped[int64](
|
||||||
s.cache,
|
s.baseCache,
|
||||||
cacheKey,
|
cacheKey,
|
||||||
unreadTTL,
|
unreadTTL,
|
||||||
jitter,
|
jitter,
|
||||||
@@ -186,14 +258,18 @@ func (s *MessageService) GetOrCreateConversation(ctx context.Context, user1ID, u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 失效会话列表缓存
|
// 失效会话列表缓存
|
||||||
cache.InvalidateConversationList(s.cache, user1ID)
|
s.conversationCache.InvalidateConversationList(user1ID)
|
||||||
cache.InvalidateConversationList(s.cache, user2ID)
|
s.conversationCache.InvalidateConversationList(user2ID)
|
||||||
|
|
||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversationParticipants 获取会话参与者列表
|
// GetConversationParticipants 获取会话参与者列表
|
||||||
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
func (s *MessageService) GetConversationParticipants(conversationID string) ([]*model.ConversationParticipant, error) {
|
||||||
|
// 优先使用缓存
|
||||||
|
if s.conversationCache != nil {
|
||||||
|
return s.conversationCache.GetParticipants(context.Background(), conversationID)
|
||||||
|
}
|
||||||
return s.messageRepo.GetConversationParticipants(conversationID)
|
return s.messageRepo.GetConversationParticipants(conversationID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,12 +280,12 @@ func ParseConversationID(idStr string) (string, error) {
|
|||||||
|
|
||||||
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
|
// InvalidateUserConversationCache 失效用户会话相关缓存(供外部调用)
|
||||||
func (s *MessageService) InvalidateUserConversationCache(userID string) {
|
func (s *MessageService) InvalidateUserConversationCache(userID string) {
|
||||||
cache.InvalidateConversationList(s.cache, userID)
|
s.conversationCache.InvalidateConversationList(userID)
|
||||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
cache.InvalidateUnreadConversation(s.baseCache, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
|
// InvalidateUserUnreadCache 失效用户未读数缓存(供外部调用)
|
||||||
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
|
func (s *MessageService) InvalidateUserUnreadCache(userID, conversationID string) {
|
||||||
cache.InvalidateUnreadConversation(s.cache, userID)
|
cache.InvalidateUnreadConversation(s.baseCache, userID)
|
||||||
cache.InvalidateUnreadDetail(s.cache, userID, conversationID)
|
s.conversationCache.InvalidateUnreadCount(userID, conversationID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,9 +73,20 @@ func (s *PostService) Create(ctx context.Context, userID, title, content string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
|
func (s *PostService) reviewPostAsync(postID, userID, title, content string, images []string) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[ERROR] Panic in post moderation async flow, fallback publish post=%s panic=%v", postID, r)
|
||||||
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish post %s after panic recovery: %v", postID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.invalidatePostCaches(postID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// 未启用AI时,直接发布
|
// 未启用AI时,直接发布
|
||||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
log.Printf("[WARN] Failed to publish post without AI moderation: %v", err)
|
||||||
} else {
|
} else {
|
||||||
s.invalidatePostCaches(postID)
|
s.invalidatePostCaches(postID)
|
||||||
@@ -87,7 +98,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var rejectedErr *PostModerationRejectedError
|
var rejectedErr *PostModerationRejectedError
|
||||||
if errors.As(err, &rejectedErr) {
|
if errors.As(err, &rejectedErr) {
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to reject post %s: %v", postID, updateErr)
|
||||||
} else {
|
} else {
|
||||||
s.invalidatePostCaches(postID)
|
s.invalidatePostCaches(postID)
|
||||||
@@ -97,7 +108,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 规则审核不可用时,降级为发布,避免长时间pending
|
// 规则审核不可用时,降级为发布,避免长时间pending
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to publish post %s after moderation error: %v", postID, updateErr)
|
||||||
} else {
|
} else {
|
||||||
s.invalidatePostCaches(postID)
|
s.invalidatePostCaches(postID)
|
||||||
@@ -106,7 +117,7 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||||
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
log.Printf("[WARN] Failed to publish post %s: %v", postID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -127,6 +138,26 @@ func (s *PostService) reviewPostAsync(postID, userID, title, content string, ima
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PostService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||||
|
const maxAttempts = 3
|
||||||
|
const retryDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if attempt < maxAttempts {
|
||||||
|
log.Printf("[WARN] UpdateModerationStatus failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err)
|
||||||
|
time.Sleep(time.Duration(attempt) * retryDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PostService) invalidatePostCaches(postID string) {
|
func (s *PostService) invalidatePostCaches(postID string) {
|
||||||
cache.InvalidatePostDetail(s.cache, postID)
|
cache.InvalidatePostDetail(s.cache, postID)
|
||||||
cache.InvalidatePostList(s.cache)
|
cache.InvalidatePostList(s.cache)
|
||||||
|
|||||||
207
internal/service/schedule_service.go
Normal file
207
internal/service/schedule_service.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"carrot_bbs/internal/dto"
|
||||||
|
"carrot_bbs/internal/model"
|
||||||
|
"carrot_bbs/internal/repository"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidSchedulePayload = &ServiceError{Code: 400, Message: "invalid schedule payload"}
|
||||||
|
ErrScheduleCourseNotFound = &ServiceError{Code: 404, Message: "schedule course not found"}
|
||||||
|
ErrScheduleForbidden = &ServiceError{Code: 403, Message: "forbidden schedule operation"}
|
||||||
|
ErrScheduleColorDuplicated = &ServiceError{Code: 400, Message: "course color already used"}
|
||||||
|
)
|
||||||
|
|
||||||
|
var hexColorRegex = regexp.MustCompile(`^#[0-9A-F]{6}$`)
|
||||||
|
|
||||||
|
type CreateScheduleCourseInput struct {
|
||||||
|
Name string
|
||||||
|
Teacher string
|
||||||
|
Location string
|
||||||
|
DayOfWeek int
|
||||||
|
StartSection int
|
||||||
|
EndSection int
|
||||||
|
Weeks []int
|
||||||
|
Color string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ScheduleService interface {
|
||||||
|
ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error)
|
||||||
|
CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error)
|
||||||
|
UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error)
|
||||||
|
DeleteCourse(userID, courseID string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type scheduleService struct {
|
||||||
|
repo repository.ScheduleRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScheduleService(repo repository.ScheduleRepository) ScheduleService {
|
||||||
|
return &scheduleService{repo: repo}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scheduleService) ListCourses(userID string, week int) ([]*dto.ScheduleCourseResponse, error) {
|
||||||
|
courses, err := s.repo.ListByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]*dto.ScheduleCourseResponse, 0, len(courses))
|
||||||
|
for _, item := range courses {
|
||||||
|
weeks := dto.ParseWeeksJSON(item.Weeks)
|
||||||
|
if week > 0 && !containsWeek(weeks, week) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, dto.ConvertScheduleCourseToResponse(item, weeks))
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scheduleService) CreateCourse(userID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) {
|
||||||
|
entity, weeks, err := buildScheduleEntity(userID, input, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.ensureUniqueColor(userID, entity.Color, ""); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.repo.Create(entity); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dto.ConvertScheduleCourseToResponse(entity, weeks), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scheduleService) UpdateCourse(userID, courseID string, input CreateScheduleCourseInput) (*dto.ScheduleCourseResponse, error) {
|
||||||
|
existing, err := s.repo.GetByID(courseID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, ErrScheduleCourseNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if existing.UserID != userID {
|
||||||
|
return nil, ErrScheduleForbidden
|
||||||
|
}
|
||||||
|
|
||||||
|
entity, weeks, err := buildScheduleEntity(userID, input, existing)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.ensureUniqueColor(userID, entity.Color, entity.ID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := s.repo.Update(entity); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dto.ConvertScheduleCourseToResponse(entity, weeks), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scheduleService) DeleteCourse(userID, courseID string) error {
|
||||||
|
existing, err := s.repo.GetByID(courseID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return ErrScheduleCourseNotFound
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if existing.UserID != userID {
|
||||||
|
return ErrScheduleForbidden
|
||||||
|
}
|
||||||
|
return s.repo.DeleteByID(courseID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildScheduleEntity(userID string, input CreateScheduleCourseInput, target *model.ScheduleCourse) (*model.ScheduleCourse, []int, error) {
|
||||||
|
name := strings.TrimSpace(input.Name)
|
||||||
|
if name == "" || input.DayOfWeek < 0 || input.DayOfWeek > 6 || input.StartSection < 1 || input.EndSection < input.StartSection {
|
||||||
|
return nil, nil, ErrInvalidSchedulePayload
|
||||||
|
}
|
||||||
|
|
||||||
|
weeks := normalizeWeeks(input.Weeks)
|
||||||
|
if len(weeks) == 0 {
|
||||||
|
return nil, nil, ErrInvalidSchedulePayload
|
||||||
|
}
|
||||||
|
weeksJSON, err := json.Marshal(weeks)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
entity := target
|
||||||
|
if entity == nil {
|
||||||
|
entity = &model.ScheduleCourse{
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedColor := normalizeHexColor(input.Color)
|
||||||
|
if normalizedColor == "" || !hexColorRegex.MatchString(normalizedColor) {
|
||||||
|
return nil, nil, ErrInvalidSchedulePayload
|
||||||
|
}
|
||||||
|
|
||||||
|
entity.Name = name
|
||||||
|
entity.Teacher = strings.TrimSpace(input.Teacher)
|
||||||
|
entity.Location = strings.TrimSpace(input.Location)
|
||||||
|
entity.DayOfWeek = input.DayOfWeek
|
||||||
|
entity.StartSection = input.StartSection
|
||||||
|
entity.EndSection = input.EndSection
|
||||||
|
entity.Weeks = string(weeksJSON)
|
||||||
|
entity.Color = normalizedColor
|
||||||
|
|
||||||
|
return entity, weeks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *scheduleService) ensureUniqueColor(userID, color, excludeID string) error {
|
||||||
|
exists, err := s.repo.ExistsColorByUser(userID, color, excludeID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return ErrScheduleColorDuplicated
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeWeeks(source []int) []int {
|
||||||
|
unique := make(map[int]struct{}, len(source))
|
||||||
|
result := make([]int, 0, len(source))
|
||||||
|
for _, w := range source {
|
||||||
|
if w < 1 || w > 30 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := unique[w]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
unique[w] = struct{}{}
|
||||||
|
result = append(result, w)
|
||||||
|
}
|
||||||
|
sort.Ints(result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsWeek(weeks []int, target int) bool {
|
||||||
|
for _, week := range weeks {
|
||||||
|
if week == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeHexColor(color string) string {
|
||||||
|
trimmed := strings.TrimSpace(color)
|
||||||
|
if trimmed == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, "#") {
|
||||||
|
return strings.ToUpper(trimmed)
|
||||||
|
}
|
||||||
|
return "#" + strings.ToUpper(trimmed)
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"carrot_bbs/internal/cache"
|
"carrot_bbs/internal/cache"
|
||||||
"carrot_bbs/internal/dto"
|
"carrot_bbs/internal/dto"
|
||||||
@@ -84,8 +85,17 @@ func (s *VoteService) CreateVotePost(ctx context.Context, userID string, req *dt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) {
|
func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string, images []string) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Printf("[ERROR] Panic in vote post moderation async flow, fallback publish post=%s panic=%v", postID, r)
|
||||||
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
|
log.Printf("[WARN] Failed to publish vote post %s after panic recovery: %v", postID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
if s.postAIService == nil || !s.postAIService.IsEnabled() {
|
||||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); err != nil {
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); err != nil {
|
||||||
log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err)
|
log.Printf("[WARN] Failed to publish vote post without AI moderation: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@@ -95,24 +105,44 @@ func (s *VoteService) reviewVotePostAsync(postID, userID, title, content string,
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
var rejectedErr *PostModerationRejectedError
|
var rejectedErr *PostModerationRejectedError
|
||||||
if errors.As(err, &rejectedErr) {
|
if errors.As(err, &rejectedErr) {
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusRejected, rejectedErr.UserMessage(), "ai"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to reject vote post %s: %v", postID, updateErr)
|
||||||
}
|
}
|
||||||
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
s.notifyModerationRejected(userID, rejectedErr.Reason)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateErr := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
if updateErr := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "system"); updateErr != nil {
|
||||||
log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr)
|
log.Printf("[WARN] Failed to publish vote post %s after moderation error: %v", postID, updateErr)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.postRepo.UpdateModerationStatus(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
if err := s.updateModerationStatusWithRetry(postID, model.PostStatusPublished, "", "ai"); err != nil {
|
||||||
log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err)
|
log.Printf("[WARN] Failed to publish vote post %s: %v", postID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *VoteService) updateModerationStatusWithRetry(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||||
|
const maxAttempts = 3
|
||||||
|
const retryDelay = 200 * time.Millisecond
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
if err := s.postRepo.UpdateModerationStatus(postID, status, rejectReason, reviewedBy); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
if attempt < maxAttempts {
|
||||||
|
log.Printf("[WARN] UpdateModerationStatus for vote post failed post=%s attempt=%d/%d err=%v", postID, attempt, maxAttempts, err)
|
||||||
|
time.Sleep(time.Duration(attempt) * retryDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
func (s *VoteService) notifyModerationRejected(userID, reason string) {
|
func (s *VoteService) notifyModerationRejected(userID, reason string) {
|
||||||
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
if s.systemMessageService == nil || strings.TrimSpace(userID) == "" {
|
||||||
return
|
return
|
||||||
|
|||||||
263
scripts/test_moderation.go
Normal file
263
scripts/test_moderation.go
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const moderationSystemPrompt = `你是中文社区的内容审核助手,负责对"帖子标题、正文、配图"做联合审核。目标是平衡社区安全与正常交流:必须拦截高风险违规内容,但不要误伤正常玩梗、二创、吐槽和轻度调侃。请只输出指定JSON。
|
||||||
|
|
||||||
|
审核流程:
|
||||||
|
1) 先判断是否命中硬性违规;
|
||||||
|
2) 再判断语境(玩笑/自嘲/朋友间互动/作品讨论);
|
||||||
|
3) 做文图交叉判断(文本+图片合并理解);
|
||||||
|
4) 给出 approved 与简短 reason。
|
||||||
|
|
||||||
|
硬性违规(命中任一项必须 approved=false):
|
||||||
|
A. 宣传对立与煽动撕裂:
|
||||||
|
- 明确煽动群体对立、地域对立、性别对立、民族宗教对立,鼓动仇恨、排斥、报复。
|
||||||
|
B. 严重人身攻击与网暴引导:
|
||||||
|
- 持续性侮辱贬损、羞辱人格、号召围攻/骚扰/挂人/线下冲突。
|
||||||
|
C. 开盒/人肉/隐私暴露:
|
||||||
|
- 故意公开、拼接、索取他人可识别隐私信息(姓名+联系方式、身份证号、住址、学校单位、车牌、定位轨迹等);
|
||||||
|
- 图片/截图中出现可识别隐私信息并伴随曝光意图,也按违规处理。
|
||||||
|
D. 其他高危违规:
|
||||||
|
- 违法犯罪、暴力威胁、极端仇恨、色情低俗、诈骗引流、恶意广告等。
|
||||||
|
|
||||||
|
放行规则(以下通常 approved=true):
|
||||||
|
- 正常玩梗、表情包、谐音梗、二次创作、无恶意的吐槽;
|
||||||
|
- 非定向、轻度口语化吐槽(无明确攻击对象、无网暴号召、无隐私暴露);
|
||||||
|
- 对社会事件/作品的理性讨论、观点争论(即使语气尖锐,但未煽动对立或人身攻击)。
|
||||||
|
|
||||||
|
边界判定:
|
||||||
|
- 若只是"梗文化表达"且不指向现实伤害,优先通过;
|
||||||
|
- 若存在明确伤害意图(煽动、围攻、曝光隐私),必须拒绝;
|
||||||
|
- 对模糊内容不因个别粗口直接拒绝,需结合对象、意图、号召性和可执行性综合判断。
|
||||||
|
|
||||||
|
reason 要求:
|
||||||
|
- approved=false 时:中文10-30字,说明核心违规点;
|
||||||
|
- approved=true 时:reason 为空字符串。
|
||||||
|
|
||||||
|
输出格式(严格):
|
||||||
|
仅输出一行JSON对象,不要Markdown,不要额外解释:
|
||||||
|
{"approved": true/false, "reason": "..."}`
|
||||||
|
|
||||||
|
type chatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type contentPart struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionsRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []chatMessage `json:"messages"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
EnableThinking *bool `json:"enable_thinking,omitempty"` // qwen3.5思考模式控制
|
||||||
|
ThinkingBudget *int `json:"thinking_budget,omitempty"` // 思考过程最大token数
|
||||||
|
ResponseFormat *responseFormatConfig `json:"response_format,omitempty"` // 响应格式
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseFormatConfig struct {
|
||||||
|
Type string `json:"type"` // "text" or "json_object"
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionsResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
baseURL := flag.String("url", "https://api.littlelan.cn/", "API base URL")
|
||||||
|
apiKey := flag.String("key", "", "API key")
|
||||||
|
model := flag.String("model", "qwen3.5-plus", "Model name")
|
||||||
|
maxTokens := flag.Int("max-tokens", 220, "Max tokens for completion")
|
||||||
|
enableThinking := flag.Bool("enable-thinking", false, "Enable thinking mode for qwen3.5")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if *apiKey == "" {
|
||||||
|
fmt.Println("Error: API key is required. Use -key flag")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试用例
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "简单正常内容",
|
||||||
|
content: "帖子标题:今天天气真好\n帖子内容:出门散步,心情愉快!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "中等长度内容",
|
||||||
|
content: "帖子标题:分享我的学习经验\n帖子内容:最近在学习Go语言,发现这门语言真的很适合后端开发。并发处理特别方便,goroutine和channel的设计非常优雅。有一起学习的小伙伴吗?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "较长内容",
|
||||||
|
content: "帖子标题:关于校园生活的一些思考\n帖子内容:大学四年转眼就过去了,回想起来有很多感慨。刚入学的时候什么都不懂,现在感觉自己成长了很多。在这里想分享一些自己的经验,希望能对学弟学妹们有所帮助。首先是学习方面,一定要认真听课,做好笔记。其次是社交方面,多参加社团活动,结交志同道合的朋友。最后是规划方面,早点想清楚自己想做什么,为之努力。",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 120 * time.Second}
|
||||||
|
|
||||||
|
fmt.Println("============================================")
|
||||||
|
fmt.Printf("模型: %s\n", *model)
|
||||||
|
fmt.Printf("API URL: %s\n", *baseURL)
|
||||||
|
fmt.Printf("MaxTokens 设置: %d\n", *maxTokens)
|
||||||
|
fmt.Printf("EnableThinking: %v\n", *enableThinking)
|
||||||
|
fmt.Println("============================================")
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
fmt.Printf("\n========== 测试: %s ==========\n", tc.name)
|
||||||
|
fmt.Printf("内容长度: %d 字符\n", len(tc.content))
|
||||||
|
|
||||||
|
userPrompt := fmt.Sprintf("%s\n图片批次:1/1(本次仅提供当前批次图片)", tc.content)
|
||||||
|
|
||||||
|
reqBody := chatCompletionsRequest{
|
||||||
|
Model: *model,
|
||||||
|
Messages: []chatMessage{
|
||||||
|
{Role: "system", Content: moderationSystemPrompt},
|
||||||
|
{Role: "user", Content: []contentPart{{Type: "text", Text: userPrompt}}},
|
||||||
|
},
|
||||||
|
Temperature: 0.1,
|
||||||
|
MaxTokens: *maxTokens,
|
||||||
|
}
|
||||||
|
// 设置思考模式
|
||||||
|
if !*enableThinking {
|
||||||
|
reqBody.EnableThinking = enableThinking
|
||||||
|
// 设置思考预算为0,完全禁用思考
|
||||||
|
zero := 0
|
||||||
|
reqBody.ThinkingBudget = &zero
|
||||||
|
}
|
||||||
|
// 使用JSON输出格式
|
||||||
|
reqBody.ResponseFormat = &responseFormatConfig{Type: "json_object"}
|
||||||
|
|
||||||
|
data, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error marshaling request: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := strings.TrimRight(*baseURL, "/") + "/v1/chat/completions"
|
||||||
|
if strings.HasSuffix(strings.TrimRight(*baseURL, "/"), "/v1") {
|
||||||
|
endpoint = strings.TrimRight(*baseURL, "/") + "/chat/completions"
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error creating request: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+*apiKey)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error sending request: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error reading response: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
fmt.Printf("API Error: status=%d, body=%s\n", resp.StatusCode, string(body))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed chatCompletionsResponse
|
||||||
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||||
|
fmt.Printf("Error parsing response: %v\n", err)
|
||||||
|
fmt.Printf("Raw response: %s\n", string(body))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parsed.Choices) == 0 {
|
||||||
|
fmt.Println("No choices in response")
|
||||||
|
fmt.Printf("Raw response: %s\n", string(body))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("响应时间: %v\n", elapsed)
|
||||||
|
fmt.Printf("Finish Reason: %s\n", parsed.Choices[0].FinishReason)
|
||||||
|
fmt.Printf("Token使用情况:\n")
|
||||||
|
fmt.Printf(" - PromptTokens: %d\n", parsed.Usage.PromptTokens)
|
||||||
|
fmt.Printf(" - CompletionTokens: %d\n", parsed.Usage.CompletionTokens)
|
||||||
|
fmt.Printf(" - TotalTokens: %d\n", parsed.Usage.TotalTokens)
|
||||||
|
|
||||||
|
output := parsed.Choices[0].Message.Content
|
||||||
|
fmt.Printf("输出内容长度: %d 字符\n", len(output))
|
||||||
|
|
||||||
|
// 检查输出是否符合预期
|
||||||
|
if parsed.Usage.CompletionTokens > *maxTokens {
|
||||||
|
fmt.Printf("\n⚠️ 警告: CompletionTokens (%d) 超过了 max_tokens 设置 (%d)!\n",
|
||||||
|
parsed.Usage.CompletionTokens, *maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(output) > 500 {
|
||||||
|
fmt.Printf("\n⚠️ 警告: 输出内容过长! 长度=%d\n", len(output))
|
||||||
|
fmt.Printf("前500字符:\n%s...\n", output[:min(500, len(output))])
|
||||||
|
} else {
|
||||||
|
fmt.Printf("输出内容: %s\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试解析JSON
|
||||||
|
extractJSONObject := func(raw string) string {
|
||||||
|
text := strings.TrimSpace(raw)
|
||||||
|
start := strings.Index(text, "{")
|
||||||
|
end := strings.LastIndex(text, "}")
|
||||||
|
if start >= 0 && end > start {
|
||||||
|
return text[start : end+1]
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr := extractJSONObject(output)
|
||||||
|
var result struct {
|
||||||
|
Approved bool `json:"approved"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||||
|
fmt.Printf("\n⚠️ 警告: 无法解析JSON输出: %v\n", err)
|
||||||
|
fmt.Printf("提取的JSON: %s\n", jsonStr)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("\n✓ 解析成功: approved=%v, reason=\"%s\"\n", result.Approved, result.Reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n========== 测试完成 ==========")
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user