Files
backend/internal/cache/cache.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

605 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package cache
import (
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/redis/go-redis/v9"
redisPkg "carrot_bbs/internal/pkg/redis"
)
// Cache 缓存接口
type Cache interface {
// Set 设置缓存值支持TTL
Set(key string, value interface{}, ttl time.Duration)
// Get 获取缓存值
Get(key string) (interface{}, bool)
// Delete 删除缓存
Delete(key string)
// DeleteByPrefix 根据前缀删除缓存
DeleteByPrefix(prefix string)
// Clear 清空所有缓存
Clear()
// Exists 检查键是否存在
Exists(key string) bool
// Increment 增加计数器的值
Increment(key string) int64
// IncrementBy 增加指定值
IncrementBy(key string, value int64) int64
}
// cacheItem 缓存项(用于内存缓存降级)
type cacheItem struct {
value interface{}
expiration int64 // 过期时间戳(纳秒)
}
const nullMarkerValue = "__carrot_cache_null__"
type cacheMetrics struct {
hit atomic.Int64
miss atomic.Int64
decodeError atomic.Int64
setError atomic.Int64
invalidate atomic.Int64
}
var metrics cacheMetrics
var loadLocks sync.Map
type MetricsSnapshot struct {
Hit int64
Miss int64
DecodeError int64
SetError int64
Invalidate int64
}
type Settings struct {
Enabled bool
KeyPrefix string
DefaultTTL time.Duration
NullTTL time.Duration
JitterRatio float64
PostListTTL time.Duration
ConversationTTL time.Duration
UnreadCountTTL time.Duration
GroupMembersTTL time.Duration
DisableFlushDB bool
}
var settings = Settings{
Enabled: true,
DefaultTTL: 30 * time.Second,
NullTTL: 5 * time.Second,
JitterRatio: 0.1,
PostListTTL: 30 * time.Second,
ConversationTTL: 60 * time.Second,
UnreadCountTTL: 30 * time.Second,
GroupMembersTTL: 120 * time.Second,
DisableFlushDB: true,
}
func Configure(s Settings) {
settings.Enabled = s.Enabled
if s.KeyPrefix != "" {
settings.KeyPrefix = s.KeyPrefix
}
if s.DefaultTTL > 0 {
settings.DefaultTTL = s.DefaultTTL
}
if s.NullTTL > 0 {
settings.NullTTL = s.NullTTL
}
if s.JitterRatio > 0 {
settings.JitterRatio = s.JitterRatio
}
if s.PostListTTL > 0 {
settings.PostListTTL = s.PostListTTL
}
if s.ConversationTTL > 0 {
settings.ConversationTTL = s.ConversationTTL
}
if s.UnreadCountTTL > 0 {
settings.UnreadCountTTL = s.UnreadCountTTL
}
if s.GroupMembersTTL > 0 {
settings.GroupMembersTTL = s.GroupMembersTTL
}
settings.DisableFlushDB = s.DisableFlushDB
}
func GetSettings() Settings {
return settings
}
func normalizeKey(key string) string {
if settings.KeyPrefix == "" {
return key
}
return settings.KeyPrefix + ":" + key
}
func normalizePrefix(prefix string) string {
if settings.KeyPrefix == "" {
return prefix
}
return settings.KeyPrefix + ":" + prefix
}
func GetMetricsSnapshot() MetricsSnapshot {
return MetricsSnapshot{
Hit: metrics.hit.Load(),
Miss: metrics.miss.Load(),
DecodeError: metrics.decodeError.Load(),
SetError: metrics.setError.Load(),
Invalidate: metrics.invalidate.Load(),
}
}
// isExpired 检查是否过期
func (item *cacheItem) isExpired() bool {
if item.expiration == 0 {
return false
}
return time.Now().UnixNano() > item.expiration
}
// MemoryCache 内存缓存实现(降级使用)
type MemoryCache struct {
items sync.Map
// cleanupInterval 清理过期缓存的间隔
cleanupInterval time.Duration
// stopCleanup 停止清理协程的通道
stopCleanup chan struct{}
}
// NewMemoryCache 创建内存缓存
func NewMemoryCache() *MemoryCache {
c := &MemoryCache{
cleanupInterval: 1 * time.Minute,
stopCleanup: make(chan struct{}),
}
// 启动后台清理协程
go c.cleanup()
return c
}
// Set 设置缓存值
func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
c.items.Store(key, &cacheItem{
value: value,
expiration: expiration,
})
}
// Get 获取缓存值
func (c *MemoryCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
val, ok := c.items.Load(key)
if !ok {
return nil, false
}
item := val.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
return nil, false
}
return item.value, true
}
// Delete 删除缓存
func (c *MemoryCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
c.items.Delete(key)
}
// DeleteByPrefix 根据前缀删除缓存
func (c *MemoryCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
c.items.Range(func(key, value interface{}) bool {
if keyStr, ok := key.(string); ok {
if strings.HasPrefix(keyStr, prefix) {
metrics.invalidate.Add(1)
c.items.Delete(key)
}
}
return true
})
}
// Clear 清空所有缓存
func (c *MemoryCache) Clear() {
c.items.Range(func(key, value interface{}) bool {
metrics.invalidate.Add(1)
c.items.Delete(key)
return true
})
}
// Exists 检查键是否存在
func (c *MemoryCache) Exists(key string) bool {
_, ok := c.Get(key)
return ok
}
// Increment 增加计数器的值
func (c *MemoryCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *MemoryCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
for {
val, ok := c.items.Load(key)
if !ok {
// 键不存在,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
item := val.(*cacheItem)
if item.isExpired() {
// 已过期,创建新值
c.items.Store(key, &cacheItem{
value: value,
expiration: 0,
})
return value
}
// 尝试更新
currentValue, ok := item.value.(int64)
if !ok {
// 类型不匹配,覆盖为新值
c.items.Store(key, &cacheItem{
value: value,
expiration: item.expiration,
})
return value
}
newValue := currentValue + value
// 使用 CAS 操作确保并发安全
if c.items.CompareAndSwap(key, val, &cacheItem{
value: newValue,
expiration: item.expiration,
}) {
return newValue
}
// CAS 失败,重试
}
}
// cleanup 定期清理过期缓存
func (c *MemoryCache) cleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.cleanExpired()
case <-c.stopCleanup:
return
}
}
}
// cleanExpired 清理过期缓存
func (c *MemoryCache) cleanExpired() {
count := 0
c.items.Range(func(key, value interface{}) bool {
item := value.(*cacheItem)
if item.isExpired() {
c.items.Delete(key)
count++
}
return true
})
if count > 0 {
log.Printf("[Cache] Cleaned %d expired items", count)
}
}
// Stop 停止缓存清理协程
func (c *MemoryCache) Stop() {
close(c.stopCleanup)
}
// RedisCache Redis缓存实现
type RedisCache struct {
client *redisPkg.Client
ctx context.Context
}
// NewRedisCache 创建Redis缓存
func NewRedisCache(client *redisPkg.Client) *RedisCache {
return &RedisCache{
client: client,
ctx: context.Background(),
}
}
// Set 设置缓存值
func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) {
key = normalizeKey(key)
// 将值序列化为JSON
data, err := json.Marshal(value)
if err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err)
return
}
if err := c.client.Set(c.ctx, key, data, ttl); err != nil {
metrics.setError.Add(1)
log.Printf("[RedisCache] Failed to set key %s: %v", key, err)
}
}
// Get 获取缓存值
func (c *RedisCache) Get(key string) (interface{}, bool) {
key = normalizeKey(key)
data, err := c.client.Get(c.ctx, key)
if err != nil {
if err == redis.Nil {
return nil, false
}
log.Printf("[RedisCache] Failed to get key %s: %v", key, err)
return nil, false
}
// 返回原始字符串,由调用侧决定如何解码为目标类型
return data, true
}
// Delete 删除缓存
func (c *RedisCache) Delete(key string) {
key = normalizeKey(key)
metrics.invalidate.Add(1)
if err := c.client.Del(c.ctx, key); err != nil {
log.Printf("[RedisCache] Failed to delete key %s: %v", key, err)
}
}
// DeleteByPrefix 根据前缀删除缓存
func (c *RedisCache) DeleteByPrefix(prefix string) {
prefix = normalizePrefix(prefix)
// 使用原生客户端执行SCAN命令
rdb := c.client.GetClient()
var cursor uint64
for {
keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result()
if err != nil {
log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err)
return
}
if len(keys) > 0 {
metrics.invalidate.Add(int64(len(keys)))
if err := c.client.Del(c.ctx, keys...); err != nil {
log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
}
// Clear 清空所有缓存
func (c *RedisCache) Clear() {
if settings.DisableFlushDB {
log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true")
return
}
metrics.invalidate.Add(1)
rdb := c.client.GetClient()
if err := rdb.FlushDB(c.ctx).Err(); err != nil {
log.Printf("[RedisCache] Failed to clear cache: %v", err)
}
}
// Exists 检查键是否存在
func (c *RedisCache) Exists(key string) bool {
key = normalizeKey(key)
n, err := c.client.Exists(c.ctx, key)
if err != nil {
log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err)
return false
}
return n > 0
}
// Increment 增加计数器的值
func (c *RedisCache) Increment(key string) int64 {
return c.IncrementBy(key, 1)
}
// IncrementBy 增加指定值
func (c *RedisCache) IncrementBy(key string, value int64) int64 {
key = normalizeKey(key)
rdb := c.client.GetClient()
result, err := rdb.IncrBy(c.ctx, key, value).Result()
if err != nil {
log.Printf("[RedisCache] Failed to increment key %s: %v", key, err)
return 0
}
return result
}
// 全局缓存实例
var globalCache Cache
var once sync.Once
// InitCache 初始化全局缓存实例使用Redis
func InitCache(redisClient *redisPkg.Client) {
once.Do(func() {
if redisClient != nil {
globalCache = NewRedisCache(redisClient)
log.Println("[Cache] Initialized Redis cache")
} else {
globalCache = NewMemoryCache()
log.Println("[Cache] Initialized Memory cache (Redis not available)")
}
})
}
// GetCache 获取全局缓存实例
func GetCache() Cache {
if globalCache == nil {
// 如果未初始化,返回内存缓存作为降级
log.Println("[Cache] Warning: Cache not initialized, using Memory cache")
return NewMemoryCache()
}
return globalCache
}
// GetRedisClient 从缓存中获取Redis客户端仅在Redis模式下有效
func GetRedisClient() (*redisPkg.Client, error) {
if redisCache, ok := globalCache.(*RedisCache); ok {
return redisCache.client, nil
}
return nil, fmt.Errorf("cache is not using Redis backend")
}
func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) {
if !settings.Enabled {
return
}
c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio))
}
func SetNull(c Cache, key string, ttl time.Duration) {
if !settings.Enabled {
return
}
c.Set(key, nullMarkerValue, ttl)
}
func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration {
if ttl <= 0 || jitterRatio <= 0 {
return ttl
}
if jitterRatio > 1 {
jitterRatio = 1
}
maxJitter := int64(float64(ttl) * jitterRatio)
if maxJitter <= 0 {
return ttl
}
delta := rand.Int63n(maxJitter + 1)
return ttl + time.Duration(delta)
}
func GetTyped[T any](c Cache, key string) (T, bool) {
var zero T
if !settings.Enabled {
return zero, false
}
raw, ok := c.Get(key)
if !ok {
metrics.miss.Add(1)
return zero, false
}
if str, ok := raw.(string); ok && str == nullMarkerValue {
metrics.hit.Add(1)
return zero, false
}
if typed, ok := raw.(T); ok {
metrics.hit.Add(1)
return typed, true
}
var out T
switch v := raw.(type) {
case string:
if err := json.Unmarshal([]byte(v), &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
case []byte:
if err := json.Unmarshal(v, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
default:
data, err := json.Marshal(v)
if err != nil {
metrics.decodeError.Add(1)
return zero, false
}
if err := json.Unmarshal(data, &out); err != nil {
metrics.decodeError.Add(1)
return zero, false
}
metrics.hit.Add(1)
return out, true
}
}
func GetOrLoadTyped[T any](
c Cache,
key string,
ttl time.Duration,
jitterRatio float64,
nullTTL time.Duration,
loader func() (T, error),
) (T, error) {
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{})
lock := lockValue.(*sync.Mutex)
lock.Lock()
defer lock.Unlock()
if cached, ok := GetTyped[T](c, key); ok {
return cached, nil
}
loaded, err := loader()
if err != nil {
var zero T
return zero, err
}
encoded, marshalErr := json.Marshal(loaded)
if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 {
SetNull(c, key, nullTTL)
return loaded, nil
}
SetWithJitter(c, key, loaded, ttl, jitterRatio)
return loaded, nil
}