Initial backend repository commit.

Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
This commit is contained in:
2026-03-09 21:28:58 +08:00
commit 4d8f2ec997
102 changed files with 25022 additions and 0 deletions

604
internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,604 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
redisPkg "carrot_bbs/internal/pkg/redis"
)
// Cache 缓存接口
type Cache interface {
// Set 设置缓存值支持TTL
Set(key string, value interface{}, ttl time.Duration)
// Get 获取缓存值
Get(key string) (interface{}, bool)
// Delete 删除缓存
Delete(key string)
// DeleteByPrefix 根据前缀删除缓存
DeleteByPrefix(prefix string)
// Clear 清空所有缓存
Clear()
// Exists 检查键是否存在
Exists(key string) bool
// Increment 增加计数器的值
Increment(key string) int64
// IncrementBy 增加指定值
IncrementBy(key string, value int64) int64
}
// cacheItem 缓存项(用于内存缓存降级)
type cacheItem struct {
value interface{}
expiration int64 // 过期时间戳(纳秒)
}
const nullMarkerValue = "__carrot_cache_null__"
type cacheMetrics struct {
hit atomic.Int64
miss atomic.Int64
decodeError atomic.Int64
setError atomic.Int64
invalidate atomic.Int64
}
var metrics cacheMetrics
var loadLocks sync.Map
type MetricsSnapshot struct {
Hit int64
Miss int64
DecodeError int64
SetError int64
Invalidate int64
}
type Settings struct {
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
}
var settings = Settings{
Enabled: true,
DefaultTTL: 30 * time.Second,
NullTTL: 5 * time.Second,
JitterRatio: 0.1,
PostListTTL: 30 * time.Second,
ConversationTTL: 60 * time.Second,
UnreadCountTTL: 30 * time.Second,
GroupMembersTTL: 120 * time.Second,
DisableFlushDB: true,
}
func Configure(s Settings) {
settings.Enabled = s.Enabled
if s.KeyPrefix != "" {
settings.KeyPrefix = s.KeyPrefix
}
if s.DefaultTTL > 0 {
settings.DefaultTTL = s.DefaultTTL
}
if s.NullTTL > 0 {
settings.NullTTL = s.NullTTL
}
if s.JitterRatio > 0 {
settings.JitterRatio = s.JitterRatio
}
if s.PostListTTL > 0 {
settings.PostListTTL = s.PostListTTL
}
if s.ConversationTTL > 0 {
settings.ConversationTTL = s.ConversationTTL
}
if s.UnreadCountTTL > 0 {
settings.UnreadCountTTL = s.UnreadCountTTL
}
if s.GroupMembersTTL > 0 {
settings.GroupMembersTTL = s.GroupMembersTTL
}
settings.DisableFlushDB = s.DisableFlushDB
}
func GetSettings() Settings {
return settings
}
func normalizeKey(key string) string {
if settings.KeyPrefix == "" {
return key
}
return settings.KeyPrefix + ":" + key
}
func normalizePrefix(prefix string) string {
if settings.KeyPrefix == "" {
return prefix
}
return settings.KeyPrefix + ":" + prefix
}
func GetMetricsSnapshot() MetricsSnapshot {
return MetricsSnapshot{
Hit: metrics.hit.Load(),
Miss: metrics.miss.Load(),
DecodeError: metrics.decodeError.Load(),
SetError: metrics.setError.Load(),
Invalidate: metrics.invalidate.Load(),
}
}
// isExpired 检查是否过期
func (item *cacheItem) isExpired() bool {
if item.expiration == 0 {
return false
}
return time.Now().UnixNano() > item.expiration
}
// MemoryCache 内存缓存实现(降级使用)
type MemoryCache struct {
items sync.Map
// cleanupInterval 清理过期缓存的间隔
cleanupInterval time.Duration
// stopCleanup 停止清理协程的通道
stopCleanup chan struct{}
}
// NewMemoryCache 创建内存缓存
func NewMemoryCache() *MemoryCache {
c := &MemoryCache{
cleanupInterval: 1 * time.Minute,
stopCleanup: make(chan struct{}),
}
// 启动后台清理协程
go c.cleanup()
return c
}
// Set 设置缓存值
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items.Store(key, &cacheItem{
value: value,
expiration: expiration,
})
}
// Get 获取缓存值
func (c *MemoryCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
val, ok := c.items.Load(key)
if !ok {
return nil, false
}
item := val.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
return nil, false
}
return item.value, true
}
// Delete 删除缓存
func (c *MemoryCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
c.items.Delete(key)
}
// DeleteByPrefix 根据前缀删除缓存
func (c *MemoryCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
c.items.Range(func(key, value interface{}) bool {
if keyStr, ok := key.(string); ok {
if strings.HasPrefix(keyStr, prefix) {
metrics.invalidate.Add(1)
c.items.Delete(key)
}
}
return true
})
}
// Clear 清空所有缓存
func (c *MemoryCache) Clear() {
c.items.Range(func(key, value interface{}) bool {
metrics.invalidate.Add(1)
c.items.Delete(key)
return true
})
}
// Exists 检查键是否存在
func (c *MemoryCache) Exists(key string) bool {
_, ok := c.Get(key)
return ok
}
// Increment 增加计数器的值
func (c *MemoryCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
for {
val, ok := c.items.Load(key)
if !ok {
// 键不存在,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
item := val.(*cacheItem)
if item.isExpired() {
// 已过期,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
// 尝试更新
currentValue, ok := item.value.(int64)
if !ok {
// 类型不匹配,覆盖为新值
c.items.Store(key, &cacheItem{
value: value,
expiration: item.expiration,
})
return value
}
newValue := currentValue + value
// 使用 CAS 操作确保并发安全
if c.items.CompareAndSwap(key, val, &cacheItem{
value: newValue,
expiration: item.expiration,
}) {
return newValue
}
// CAS 失败,重试
}
}
// cleanup 定期清理过期缓存
func (c *MemoryCache) cleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanExpired()
case <-c.stopCleanup:
return
}
}
}
// cleanExpired 清理过期缓存
func (c *MemoryCache) cleanExpired() {
count := 0
c.items.Range(func(key, value interface{}) bool {
item := value.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
count++
}
return true
})
if count > 0 {
log.Printf("[Cache] Cleaned %d expired items", count)
}
}
// Stop 停止缓存清理协程
func (c *MemoryCache) Stop() {
close(c.stopCleanup)
}
// RedisCache Redis缓存实现
type RedisCache struct {
client *redisPkg.Client
ctx context.Context
}
// NewRedisCache 创建Redis缓存
func NewRedisCache(client *redisPkg.Client) *RedisCache {
return &RedisCache{
client: client,
ctx: context.Background(),
}
}
// Set 设置缓存值
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
// 将值序列化为JSON
data, err := json.Marshal(value)
if err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
return
}
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
}
}
// Get 获取缓存值
func (c *RedisCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
data, err := c.client.Get(c.ctx, key)
if err != nil {
if err == redis.Nil {
return nil, false
}
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
return nil, false
}
// 返回原始字符串,由调用侧决定如何解码为目标类型
return data, true
}
// Delete 删除缓存
func (c *RedisCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
if err := c.client.Del(c.ctx, key); err != nil {
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
}
}
// DeleteByPrefix 根据前缀删除缓存
func (c *RedisCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
// 使用原生客户端执行SCAN命令
rdb := c.client.GetClient()
var cursor uint64
for {
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
if err != nil {
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
return
}
if len(keys) > 0 {
metrics.invalidate.Add(int64(len(keys)))
if err := c.client.Del(c.ctx, keys...); err != nil {
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
}
// Clear 清空所有缓存
func (c *RedisCache) Clear() {
if settings.DisableFlushDB {
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
return
}
metrics.invalidate.Add(1)
rdb := c.client.GetClient()
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
log.Printf("[RedisCache] Failed to clear cache: %v", err)
}
}
// Exists 检查键是否存在
func (c *RedisCache) Exists(key string) bool {
key = normalizeKey(key)
n, err := c.client.Exists(c.ctx, key)
if err != nil {
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
return false
}
return n > 0
}
// Increment 增加计数器的值
func (c *RedisCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
rdb := c.client.GetClient()
result, err := rdb.IncrBy(c.ctx, key, value).Result()
if err != nil {
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
return 0
}
return result
}
// 全局缓存实例
var globalCache Cache
var once sync.Once
// InitCache 初始化全局缓存实例使用Redis
func InitCache(redisClient *redisPkg.Client) {
once.Do(func() {
if redisClient != nil {
globalCache = NewRedisCache(redisClient)
log.Println("[Cache] Initialized Redis cache")
} else {
globalCache = NewMemoryCache()
log.Println("[Cache] Initialized Memory cache (Redis not available)")
}
})
}
// GetCache 获取全局缓存实例
func GetCache() Cache {
if globalCache == nil {
// 如果未初始化,返回内存缓存作为降级
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
return NewMemoryCache()
}
return globalCache
}
// GetRedisClient 从缓存中获取Redis客户端仅在Redis模式下有效
func GetRedisClient() (*redisPkg.Client, error) {
if redisCache, ok := globalCache.(*RedisCache); ok {
return redisCache.client, nil
}
return nil, fmt.Errorf("cache is not using Redis backend")
}
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
if !settings.Enabled {
return
}
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
}
func SetNull(c Cache, key string, ttl time.Duration) {
if !settings.Enabled {
return
}
c.Set(key, nullMarkerValue, ttl)
}
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
if ttl <= 0 || jitterRatio <= 0 {
return ttl
}
if jitterRatio > 1 {
jitterRatio = 1
}
maxJitter := int64(float64(ttl) * jitterRatio)
if maxJitter <= 0 {
return ttl
}
delta := rand.Int63n(maxJitter + 1)
return ttl + time.Duration(delta)
}
func GetTyped[T any](c Cache, key string) (T, bool) {
var zero T
if !settings.Enabled {
return zero, false
}
raw, ok := c.Get(key)
if !ok {
metrics.miss.Add(1)
return zero, false
}
if str, ok := raw.(string); ok && str == nullMarkerValue {
metrics.hit.Add(1)
return zero, false
}
if typed, ok := raw.(T); ok {
metrics.hit.Add(1)
return typed, true
}
var out T
switch v := raw.(type) {
case string:
if err := json.Unmarshal([]byte(v), &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
case []byte:
if err := json.Unmarshal(v, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
default:
data, err := json.Marshal(v)
if err != nil {
metrics.decodeError.Add(1)
return zero, false
}
if err := json.Unmarshal(data, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
}
}
func GetOrLoadTyped[T any](
c Cache,
key string,
ttl time.Duration,
jitterRatio float64,
nullTTL time.Duration,
loader func() (T, error),
) (T, error) {
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
lock := lockValue.(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
loaded, err := loader()
if err != nil {
var zero T
return zero, err
}
encoded, marshalErr := json.Marshal(loaded)
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
SetNull(c, key, nullTTL)
return loaded, nil
}
SetWithJitter(c, key, loaded, ttl, jitterRatio)
return loaded, nil
}

147
internal/cache/keys.go vendored Normal file
View File

@@ -0,0 +1,147 @@
package cache
import (
"fmt"
)
// 缓存键前缀常量
const (
// 帖子相关
PrefixPostList = "posts:list"
PrefixPost = "posts:detail"
// 会话相关
PrefixConversationList = "conversations:list"
PrefixConversationDetail = "conversations:detail"
// 群组相关
PrefixGroupMembers = "groups:members"
PrefixGroupInfo = "groups:info"
// 未读数相关
PrefixUnreadSystem = "unread:system"
PrefixUnreadConversation = "unread:conversation"
PrefixUnreadDetail = "unread:detail"
// 用户相关
PrefixUserInfo = "users:info"
PrefixUserMe = "users:me"
)
// PostListKey 生成帖子列表缓存键
// postType: 帖子类型 (recommend, hot, follow, latest)
// page: 页码
// pageSize: 每页数量
// userID: 用户维度(仅在个性化列表如 follow 场景使用)
func PostListKey(postType string, userID string, page, pageSize int) string {
if userID == "" {
return fmt.Sprintf("%s:%s:%d:%d", PrefixPostList, postType, page, pageSize)
}
return fmt.Sprintf("%s:%s:%s:%d:%d", PrefixPostList, postType, userID, page, pageSize)
}
// PostDetailKey 生成帖子详情缓存键
func PostDetailKey(postID string) string {
return fmt.Sprintf("%s:%s", PrefixPost, postID)
}
// ConversationListKey 生成会话列表缓存键
func ConversationListKey(userID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:%d:%d", PrefixConversationList, userID, page, pageSize)
}
// ConversationDetailKey 生成会话详情缓存键
func ConversationDetailKey(conversationID, userID string) string {
return fmt.Sprintf("%s:%s:%s", PrefixConversationDetail, conversationID, userID)
}
// GroupMembersKey 生成群组成员缓存键
func GroupMembersKey(groupID string, page, pageSize int) string {
return fmt.Sprintf("%s:%s:page:%d:size:%d", PrefixGroupMembers, groupID, page, pageSize)
}
// GroupMembersAllKey 生成群组全量成员ID列表缓存键
func GroupMembersAllKey(groupID string) string {
return fmt.Sprintf("%s:all:%s", PrefixGroupMembers, groupID)
}
// GroupInfoKey 生成群组信息缓存键
func GroupInfoKey(groupID string) string {
return fmt.Sprintf("%s:%s", PrefixGroupInfo, groupID)
}
// UnreadSystemKey 生成系统消息未读数缓存键
func UnreadSystemKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUnreadSystem, userID)
}
// UnreadConversationKey 生成会话未读总数缓存键
func UnreadConversationKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUnreadConversation, userID)
}
// UnreadDetailKey 生成单个会话未读数缓存键
func UnreadDetailKey(userID, conversationID string) string {
return fmt.Sprintf("%s:%s:%s", PrefixUnreadDetail, userID, conversationID)
}
// UserInfoKey 生成用户信息缓存键
func UserInfoKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUserInfo, userID)
}
// UserMeKey 生成当前用户信息缓存键
func UserMeKey(userID string) string {
return fmt.Sprintf("%s:%s", PrefixUserMe, userID)
}
// InvalidatePostList 失效帖子列表缓存
func InvalidatePostList(cache Cache) {
cache.DeleteByPrefix(PrefixPostList)
}
// InvalidatePostDetail 失效帖子详情缓存
func InvalidatePostDetail(cache Cache, postID string) {
cache.Delete(PostDetailKey(postID))
}
// InvalidateConversationList 失效会话列表缓存
func InvalidateConversationList(cache Cache, userID string) {
cache.DeleteByPrefix(PrefixConversationList + ":" + userID + ":")
}
// InvalidateConversationDetail 失效会话详情缓存
func InvalidateConversationDetail(cache Cache, conversationID, userID string) {
cache.Delete(ConversationDetailKey(conversationID, userID))
}
// InvalidateGroupMembers 失效群组成员缓存
func InvalidateGroupMembers(cache Cache, groupID string) {
cache.DeleteByPrefix(PrefixGroupMembers + ":" + groupID)
}
// InvalidateGroupInfo 失效群组信息缓存
func InvalidateGroupInfo(cache Cache, groupID string) {
cache.Delete(GroupInfoKey(groupID))
}
// InvalidateUnreadSystem 失效系统消息未读数缓存
func InvalidateUnreadSystem(cache Cache, userID string) {
cache.Delete(UnreadSystemKey(userID))
}
// InvalidateUnreadConversation 失效会话未读数缓存
func InvalidateUnreadConversation(cache Cache, userID string) {
cache.Delete(UnreadConversationKey(userID))
}
// InvalidateUnreadDetail 失效单个会话未读数缓存
func InvalidateUnreadDetail(cache Cache, userID, conversationID string) {
cache.Delete(UnreadDetailKey(userID, conversationID))
}
// InvalidateUserInfo 失效用户信息缓存
func InvalidateUserInfo(cache Cache, userID string) {
cache.Delete(UserInfoKey(userID))
cache.Delete(UserMeKey(userID))
}