package cache import ( "context" "encoding/json" "fmt" "log" "math" "math/rand" "sort" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/redis/go-redis/v9" redisPkg "carrot_bbs/internal/pkg/redis" ) // Cache 缓存接口 type Cache interface { // Set 设置缓存值,支持TTL Set(key string, value interface{}, ttl time.Duration) // Get 获取缓存值 Get(key string) (interface{}, bool) // Delete 删除缓存 Delete(key string) // DeleteByPrefix 根据前缀删除缓存 DeleteByPrefix(prefix string) // Clear 清空所有缓存 Clear() // Exists 检查键是否存在 Exists(key string) bool // Increment 增加计数器的值 Increment(key string) int64 // IncrementBy 增加指定值 IncrementBy(key string, value int64) int64 // ==================== Hash 操作 ==================== // HSet 设置 Hash 字段 HSet(ctx context.Context, key string, field string, value interface{}) error // HMSet 批量设置 Hash 字段 HMSet(ctx context.Context, key string, values map[string]interface{}) error // HGet 获取 Hash 字段值 HGet(ctx context.Context, key string, field string) (string, error) // HMGet 批量获取 Hash 字段值 HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) // HGetAll 获取 Hash 所有字段 HGetAll(ctx context.Context, key string) (map[string]string, error) // HDel 删除 Hash 字段 HDel(ctx context.Context, key string, fields ...string) error // ==================== Sorted Set 操作 ==================== // ZAdd 添加 Sorted Set 成员 ZAdd(ctx context.Context, key string, score float64, member string) error // ZRangeByScore 按分数范围获取成员(升序) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) // ZRevRangeByScore 按分数范围获取成员(降序) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) // ZRem 删除 Sorted Set 成员 ZRem(ctx context.Context, key string, members ...interface{}) error // ZCard 获取 Sorted Set 成员数量 ZCard(ctx context.Context, key string) (int64, error) // ==================== 计数器操作 ==================== // Incr 原子递增(返回新值) Incr(ctx context.Context, key string) (int64, error) // Expire 设置过期时间 Expire(ctx context.Context, key string, ttl time.Duration) error } // cacheItem 缓存项(用于内存缓存降级) type cacheItem struct { value interface{} expiration int64 // 过期时间戳(纳秒) } const nullMarkerValue = "__carrot_cache_null__" type cacheMetrics struct { hit atomic.Int64 miss atomic.Int64 decodeError atomic.Int64 setError atomic.Int64 invalidate atomic.Int64 } var metrics cacheMetrics var loadLocks sync.Map type MetricsSnapshot struct { Hit int64 Miss int64 DecodeError int64 SetError int64 Invalidate int64 } type Settings struct { Enabled bool KeyPrefix string DefaultTTL time.Duration NullTTL time.Duration JitterRatio float64 PostListTTL time.Duration ConversationTTL time.Duration UnreadCountTTL time.Duration GroupMembersTTL time.Duration DisableFlushDB bool } var settings = Settings{ Enabled: true, DefaultTTL: 30 * time.Second, NullTTL: 5 * time.Second, JitterRatio: 0.1, PostListTTL: 30 * time.Second, ConversationTTL: 60 * time.Second, UnreadCountTTL: 30 * time.Second, GroupMembersTTL: 120 * time.Second, DisableFlushDB: true, } func Configure(s Settings) { settings.Enabled = s.Enabled if s.KeyPrefix != "" { settings.KeyPrefix = s.KeyPrefix } if s.DefaultTTL > 0 { settings.DefaultTTL = s.DefaultTTL } if s.NullTTL > 0 { settings.NullTTL = s.NullTTL } if s.JitterRatio > 0 { settings.JitterRatio = s.JitterRatio } if s.PostListTTL > 0 { settings.PostListTTL = s.PostListTTL } if s.ConversationTTL > 0 { settings.ConversationTTL = s.ConversationTTL } if s.UnreadCountTTL > 0 { settings.UnreadCountTTL = s.UnreadCountTTL } if s.GroupMembersTTL > 0 { settings.GroupMembersTTL = s.GroupMembersTTL } settings.DisableFlushDB = s.DisableFlushDB } func GetSettings() Settings { return settings } func normalizeKey(key string) string { if settings.KeyPrefix == "" { return key } return settings.KeyPrefix + ":" + key } func normalizePrefix(prefix string) string { if settings.KeyPrefix == "" { return prefix } return settings.KeyPrefix + ":" + prefix } func GetMetricsSnapshot() MetricsSnapshot { return MetricsSnapshot{ Hit: metrics.hit.Load(), Miss: metrics.miss.Load(), DecodeError: metrics.decodeError.Load(), SetError: metrics.setError.Load(), Invalidate: metrics.invalidate.Load(), } } // isExpired 检查是否过期 func (item *cacheItem) isExpired() bool { if item.expiration == 0 { return false } return time.Now().UnixNano() > item.expiration } // MemoryCache 内存缓存实现(降级使用) type MemoryCache struct { items sync.Map // cleanupInterval 清理过期缓存的间隔 cleanupInterval time.Duration // stopCleanup 停止清理协程的通道 stopCleanup chan struct{} } // NewMemoryCache 创建内存缓存 func NewMemoryCache() *MemoryCache { c := &MemoryCache{ cleanupInterval: 1 * time.Minute, stopCleanup: make(chan struct{}), } // 启动后台清理协程 go c.cleanup() return c } // Set 设置缓存值 func (c *MemoryCache) Set(key string, value interface{}, ttl time.Duration) { key = normalizeKey(key) var expiration int64 if ttl > 0 { expiration = time.Now().Add(ttl).UnixNano() } c.items.Store(key, &cacheItem{ value: value, expiration: expiration, }) } // Get 获取缓存值 func (c *MemoryCache) Get(key string) (interface{}, bool) { key = normalizeKey(key) val, ok := c.items.Load(key) if !ok { return nil, false } item := val.(*cacheItem) if item.isExpired() { c.items.Delete(key) return nil, false } return item.value, true } // Delete 删除缓存 func (c *MemoryCache) Delete(key string) { key = normalizeKey(key) metrics.invalidate.Add(1) c.items.Delete(key) } // DeleteByPrefix 根据前缀删除缓存 func (c *MemoryCache) DeleteByPrefix(prefix string) { prefix = normalizePrefix(prefix) c.items.Range(func(key, value interface{}) bool { if keyStr, ok := key.(string); ok { if strings.HasPrefix(keyStr, prefix) { metrics.invalidate.Add(1) c.items.Delete(key) } } return true }) } // Clear 清空所有缓存 func (c *MemoryCache) Clear() { c.items.Range(func(key, value interface{}) bool { metrics.invalidate.Add(1) c.items.Delete(key) return true }) } // Exists 检查键是否存在 func (c *MemoryCache) Exists(key string) bool { _, ok := c.Get(key) return ok } // Increment 增加计数器的值 func (c *MemoryCache) Increment(key string) int64 { return c.IncrementBy(key, 1) } // IncrementBy 增加指定值 func (c *MemoryCache) IncrementBy(key string, value int64) int64 { key = normalizeKey(key) for { val, ok := c.items.Load(key) if !ok { // 键不存在,创建新值 c.items.Store(key, &cacheItem{ value: value, expiration: 0, }) return value } item := val.(*cacheItem) if item.isExpired() { // 已过期,创建新值 c.items.Store(key, &cacheItem{ value: value, expiration: 0, }) return value } // 尝试更新 currentValue, ok := item.value.(int64) if !ok { // 类型不匹配,覆盖为新值 c.items.Store(key, &cacheItem{ value: value, expiration: item.expiration, }) return value } newValue := currentValue + value // 使用 CAS 操作确保并发安全 if c.items.CompareAndSwap(key, val, &cacheItem{ value: newValue, expiration: item.expiration, }) { return newValue } // CAS 失败,重试 } } // cleanup 定期清理过期缓存 func (c *MemoryCache) cleanup() { ticker := time.NewTicker(c.cleanupInterval) defer ticker.Stop() for { select { case <-ticker.C: c.cleanExpired() case <-c.stopCleanup: return } } } // cleanExpired 清理过期缓存 func (c *MemoryCache) cleanExpired() { count := 0 c.items.Range(func(key, value interface{}) bool { item := value.(*cacheItem) if item.isExpired() { c.items.Delete(key) count++ } return true }) if count > 0 { log.Printf("[Cache] Cleaned %d expired items", count) } } // Stop 停止缓存清理协程 func (c *MemoryCache) Stop() { close(c.stopCleanup) } // ==================== MemoryCache Hash 操作 ==================== // hashItem Hash 存储项 type hashItem struct { fields sync.Map } // HSet 设置 Hash 字段 func (c *MemoryCache) HSet(ctx context.Context, key string, field string, value interface{}) error { key = normalizeKey(key) item, _ := c.items.Load(key) var h *hashItem if item == nil { h = &hashItem{} c.items.Store(key, &cacheItem{value: h, expiration: 0}) } else { ci := item.(*cacheItem) if ci.isExpired() { h = &hashItem{} c.items.Store(key, &cacheItem{value: h, expiration: 0}) } else { h = ci.value.(*hashItem) } } h.fields.Store(field, value) return nil } // HMSet 批量设置 Hash 字段 func (c *MemoryCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error { for field, value := range values { if err := c.HSet(ctx, key, field, value); err != nil { return err } } return nil } // HGet 获取 Hash 字段值 func (c *MemoryCache) HGet(ctx context.Context, key string, field string) (string, error) { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return "", fmt.Errorf("key not found") } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return "", fmt.Errorf("key not found") } h, ok := ci.value.(*hashItem) if !ok { return "", fmt.Errorf("key is not a hash") } val, ok := h.fields.Load(field) if !ok { return "", fmt.Errorf("field not found") } switch v := val.(type) { case string: return v, nil case []byte: return string(v), nil default: data, _ := json.Marshal(v) return string(data), nil } } // HMGet 批量获取 Hash 字段值 func (c *MemoryCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) { result := make([]interface{}, len(fields)) for i, field := range fields { val, err := c.HGet(ctx, key, field) if err != nil { result[i] = nil } else { result[i] = val } } return result, nil } // HGetAll 获取 Hash 所有字段 func (c *MemoryCache) HGetAll(ctx context.Context, key string) (map[string]string, error) { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return nil, fmt.Errorf("key not found") } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return nil, fmt.Errorf("key not found") } h, ok := ci.value.(*hashItem) if !ok { return nil, fmt.Errorf("key is not a hash") } result := make(map[string]string) h.fields.Range(func(k, v interface{}) bool { keyStr := k.(string) switch val := v.(type) { case string: result[keyStr] = val case []byte: result[keyStr] = string(val) default: data, _ := json.Marshal(val) result[keyStr] = string(data) } return true }) return result, nil } // HDel 删除 Hash 字段 func (c *MemoryCache) HDel(ctx context.Context, key string, fields ...string) error { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return nil } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return nil } h, ok := ci.value.(*hashItem) if !ok { return nil } for _, field := range fields { h.fields.Delete(field) } return nil } // ==================== MemoryCache Sorted Set 操作 ==================== // zItem Sorted Set 成员 type zItem struct { score float64 member string } // zsetItem Sorted Set 存储项 type zsetItem struct { members sync.Map // member -> *zItem byScore *sortedSlice // 按分数排序的切片 } // sortedSlice 简单的排序切片实现 type sortedSlice struct { items []*zItem mu sync.RWMutex } // ZAdd 添加 Sorted Set 成员 func (c *MemoryCache) ZAdd(ctx context.Context, key string, score float64, member string) error { key = normalizeKey(key) item, _ := c.items.Load(key) var z *zsetItem if item == nil { z = &zsetItem{byScore: &sortedSlice{}} c.items.Store(key, &cacheItem{value: z, expiration: 0}) } else { ci := item.(*cacheItem) if ci.isExpired() { z = &zsetItem{byScore: &sortedSlice{}} c.items.Store(key, &cacheItem{value: z, expiration: 0}) } else { z = ci.value.(*zsetItem) } } z.members.Store(member, &zItem{score: score, member: member}) z.byScore.mu.Lock() // 简单实现:重新构建排序切片 z.byScore.items = nil z.members.Range(func(k, v interface{}) bool { z.byScore.items = append(z.byScore.items, v.(*zItem)) return true }) // 按分数排序 sort.Slice(z.byScore.items, func(i, j int) bool { return z.byScore.items[i].score < z.byScore.items[j].score }) z.byScore.mu.Unlock() return nil } // ZRangeByScore 按分数范围获取成员(升序) func (c *MemoryCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return nil, nil } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return nil, nil } z, ok := ci.value.(*zsetItem) if !ok { return nil, fmt.Errorf("key is not a sorted set") } minScore, _ := strconv.ParseFloat(min, 64) maxScore, _ := strconv.ParseFloat(max, 64) if min == "-inf" { minScore = math.Inf(-1) } if max == "+inf" { maxScore = math.Inf(1) } z.byScore.mu.RLock() defer z.byScore.mu.RUnlock() var result []string var skipped int64 = 0 for _, item := range z.byScore.items { if item.score < minScore || item.score > maxScore { continue } if skipped < offset { skipped++ continue } if count > 0 && int64(len(result)) >= count { break } result = append(result, item.member) } return result, nil } // ZRevRangeByScore 按分数范围获取成员(降序) func (c *MemoryCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return nil, nil } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return nil, nil } z, ok := ci.value.(*zsetItem) if !ok { return nil, fmt.Errorf("key is not a sorted set") } minScore, _ := strconv.ParseFloat(min, 64) maxScore, _ := strconv.ParseFloat(max, 64) if min == "-inf" { minScore = math.Inf(-1) } if max == "+inf" { maxScore = math.Inf(1) } z.byScore.mu.RLock() defer z.byScore.mu.RUnlock() var result []string var skipped int64 = 0 // 从后往前遍历 for i := len(z.byScore.items) - 1; i >= 0; i-- { item := z.byScore.items[i] if item.score < minScore || item.score > maxScore { continue } if skipped < offset { skipped++ continue } if count > 0 && int64(len(result)) >= count { break } result = append(result, item.member) } return result, nil } // ZRem 删除 Sorted Set 成员 func (c *MemoryCache) ZRem(ctx context.Context, key string, members ...interface{}) error { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return nil } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return nil } z, ok := ci.value.(*zsetItem) if !ok { return nil } for _, m := range members { if member, ok := m.(string); ok { z.members.Delete(member) } } // 重建排序切片 z.byScore.mu.Lock() z.byScore.items = nil z.members.Range(func(k, v interface{}) bool { z.byScore.items = append(z.byScore.items, v.(*zItem)) return true }) sort.Slice(z.byScore.items, func(i, j int) bool { return z.byScore.items[i].score < z.byScore.items[j].score }) z.byScore.mu.Unlock() return nil } // ZCard 获取 Sorted Set 成员数量 func (c *MemoryCache) ZCard(ctx context.Context, key string) (int64, error) { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return 0, nil } ci := item.(*cacheItem) if ci.isExpired() { c.items.Delete(key) return 0, nil } z, ok := ci.value.(*zsetItem) if !ok { return 0, fmt.Errorf("key is not a sorted set") } var count int64 = 0 z.members.Range(func(k, v interface{}) bool { count++ return true }) return count, nil } // ==================== MemoryCache 计数器操作 ==================== // Incr 原子递增(返回新值) func (c *MemoryCache) Incr(ctx context.Context, key string) (int64, error) { return c.IncrementBy(key, 1), nil } // Expire 设置过期时间 func (c *MemoryCache) Expire(ctx context.Context, key string, ttl time.Duration) error { key = normalizeKey(key) item, ok := c.items.Load(key) if !ok { return fmt.Errorf("key not found") } ci := item.(*cacheItem) var expiration int64 if ttl > 0 { expiration = time.Now().Add(ttl).UnixNano() } c.items.Store(key, &cacheItem{ value: ci.value, expiration: expiration, }) return nil } // RedisCache Redis缓存实现 type RedisCache struct { client *redisPkg.Client ctx context.Context } // NewRedisCache 创建Redis缓存 func NewRedisCache(client *redisPkg.Client) *RedisCache { return &RedisCache{ client: client, ctx: context.Background(), } } // Set 设置缓存值 func (c *RedisCache) Set(key string, value interface{}, ttl time.Duration) { key = normalizeKey(key) // 将值序列化为JSON data, err := json.Marshal(value) if err != nil { metrics.setError.Add(1) log.Printf("[RedisCache] Failed to marshal value for key %s: %v", key, err) return } if err := c.client.Set(c.ctx, key, data, ttl); err != nil { metrics.setError.Add(1) log.Printf("[RedisCache] Failed to set key %s: %v", key, err) } } // Get 获取缓存值 func (c *RedisCache) Get(key string) (interface{}, bool) { key = normalizeKey(key) data, err := c.client.Get(c.ctx, key) if err != nil { if err == redis.Nil { return nil, false } log.Printf("[RedisCache] Failed to get key %s: %v", key, err) return nil, false } // 返回原始字符串,由调用侧决定如何解码为目标类型 return data, true } // Delete 删除缓存 func (c *RedisCache) Delete(key string) { key = normalizeKey(key) metrics.invalidate.Add(1) if err := c.client.Del(c.ctx, key); err != nil { log.Printf("[RedisCache] Failed to delete key %s: %v", key, err) } } // DeleteByPrefix 根据前缀删除缓存 func (c *RedisCache) DeleteByPrefix(prefix string) { prefix = normalizePrefix(prefix) // 使用原生客户端执行SCAN命令 rdb := c.client.GetClient() var cursor uint64 for { keys, nextCursor, err := rdb.Scan(c.ctx, cursor, prefix+"*", 100).Result() if err != nil { log.Printf("[RedisCache] Failed to scan keys with prefix %s: %v", prefix, err) return } if len(keys) > 0 { metrics.invalidate.Add(int64(len(keys))) if err := c.client.Del(c.ctx, keys...); err != nil { log.Printf("[RedisCache] Failed to delete keys with prefix %s: %v", prefix, err) } } cursor = nextCursor if cursor == 0 { break } } } // Clear 清空所有缓存 func (c *RedisCache) Clear() { if settings.DisableFlushDB { log.Printf("[RedisCache] Skip FlushDB because cache.disable_flushdb=true") return } metrics.invalidate.Add(1) rdb := c.client.GetClient() if err := rdb.FlushDB(c.ctx).Err(); err != nil { log.Printf("[RedisCache] Failed to clear cache: %v", err) } } // Exists 检查键是否存在 func (c *RedisCache) Exists(key string) bool { key = normalizeKey(key) n, err := c.client.Exists(c.ctx, key) if err != nil { log.Printf("[RedisCache] Failed to check existence of key %s: %v", key, err) return false } return n > 0 } // Increment 增加计数器的值 func (c *RedisCache) Increment(key string) int64 { return c.IncrementBy(key, 1) } // IncrementBy 增加指定值 func (c *RedisCache) IncrementBy(key string, value int64) int64 { key = normalizeKey(key) rdb := c.client.GetClient() result, err := rdb.IncrBy(c.ctx, key, value).Result() if err != nil { log.Printf("[RedisCache] Failed to increment key %s: %v", key, err) return 0 } return result } // ==================== RedisCache Hash 操作 ==================== // HSet 设置 Hash 字段 func (c *RedisCache) HSet(ctx context.Context, key string, field string, value interface{}) error { key = normalizeKey(key) return c.client.HSet(ctx, key, field, value) } // HMSet 批量设置 Hash 字段 func (c *RedisCache) HMSet(ctx context.Context, key string, values map[string]interface{}) error { key = normalizeKey(key) return c.client.HMSet(ctx, key, values) } // HGet 获取 Hash 字段值 func (c *RedisCache) HGet(ctx context.Context, key string, field string) (string, error) { key = normalizeKey(key) return c.client.HGet(ctx, key, field) } // HMGet 批量获取 Hash 字段值 func (c *RedisCache) HMGet(ctx context.Context, key string, fields ...string) ([]interface{}, error) { key = normalizeKey(key) return c.client.HMGet(ctx, key, fields...) } // HGetAll 获取 Hash 所有字段 func (c *RedisCache) HGetAll(ctx context.Context, key string) (map[string]string, error) { key = normalizeKey(key) return c.client.HGetAll(ctx, key) } // HDel 删除 Hash 字段 func (c *RedisCache) HDel(ctx context.Context, key string, fields ...string) error { key = normalizeKey(key) return c.client.HDel(ctx, key, fields...) } // ==================== RedisCache Sorted Set 操作 ==================== // ZAdd 添加 Sorted Set 成员 func (c *RedisCache) ZAdd(ctx context.Context, key string, score float64, member string) error { key = normalizeKey(key) return c.client.ZAdd(ctx, key, score, member) } // ZRangeByScore 按分数范围获取成员(升序) func (c *RedisCache) ZRangeByScore(ctx context.Context, key string, min, max string, offset, count int64) ([]string, error) { key = normalizeKey(key) return c.client.ZRangeByScore(ctx, key, min, max, offset, count) } // ZRevRangeByScore 按分数范围获取成员(降序) func (c *RedisCache) ZRevRangeByScore(ctx context.Context, key string, max, min string, offset, count int64) ([]string, error) { key = normalizeKey(key) return c.client.ZRevRangeByScore(ctx, key, max, min, offset, count) } // ZRem 删除 Sorted Set 成员 func (c *RedisCache) ZRem(ctx context.Context, key string, members ...interface{}) error { key = normalizeKey(key) return c.client.ZRem(ctx, key, members...) } // ZCard 获取 Sorted Set 成员数量 func (c *RedisCache) ZCard(ctx context.Context, key string) (int64, error) { key = normalizeKey(key) return c.client.ZCard(ctx, key) } // ==================== RedisCache 计数器操作 ==================== // Incr 原子递增(返回新值) func (c *RedisCache) Incr(ctx context.Context, key string) (int64, error) { key = normalizeKey(key) return c.client.Incr(ctx, key) } // Expire 设置过期时间 func (c *RedisCache) Expire(ctx context.Context, key string, ttl time.Duration) error { key = normalizeKey(key) _, err := c.client.Expire(ctx, key, ttl) return err } // 全局缓存实例 var globalCache Cache var once sync.Once // InitCache 初始化全局缓存实例(使用Redis) func InitCache(redisClient *redisPkg.Client) { once.Do(func() { if redisClient != nil { globalCache = NewRedisCache(redisClient) log.Println("[Cache] Initialized Redis cache") } else { globalCache = NewMemoryCache() log.Println("[Cache] Initialized Memory cache (Redis not available)") } }) } // GetCache 获取全局缓存实例 func GetCache() Cache { if globalCache == nil { // 如果未初始化,返回内存缓存作为降级 log.Println("[Cache] Warning: Cache not initialized, using Memory cache") return NewMemoryCache() } return globalCache } // GetRedisClient 从缓存中获取Redis客户端(仅在Redis模式下有效) func GetRedisClient() (*redisPkg.Client, error) { if redisCache, ok := globalCache.(*RedisCache); ok { return redisCache.client, nil } return nil, fmt.Errorf("cache is not using Redis backend") } func SetWithJitter(c Cache, key string, value interface{}, ttl time.Duration, jitterRatio float64) { if !settings.Enabled { return } c.Set(key, value, ApplyTTLJitter(ttl, jitterRatio)) } func SetNull(c Cache, key string, ttl time.Duration) { if !settings.Enabled { return } c.Set(key, nullMarkerValue, ttl) } func ApplyTTLJitter(ttl time.Duration, jitterRatio float64) time.Duration { if ttl <= 0 || jitterRatio <= 0 { return ttl } if jitterRatio > 1 { jitterRatio = 1 } maxJitter := int64(float64(ttl) * jitterRatio) if maxJitter <= 0 { return ttl } delta := rand.Int63n(maxJitter + 1) return ttl + time.Duration(delta) } func GetTyped[T any](c Cache, key string) (T, bool) { var zero T if !settings.Enabled { return zero, false } raw, ok := c.Get(key) if !ok { metrics.miss.Add(1) return zero, false } if str, ok := raw.(string); ok && str == nullMarkerValue { metrics.hit.Add(1) return zero, false } if typed, ok := raw.(T); ok { metrics.hit.Add(1) return typed, true } var out T switch v := raw.(type) { case string: if err := json.Unmarshal([]byte(v), &out); err != nil { metrics.decodeError.Add(1) return zero, false } metrics.hit.Add(1) return out, true case []byte: if err := json.Unmarshal(v, &out); err != nil { metrics.decodeError.Add(1) return zero, false } metrics.hit.Add(1) return out, true default: data, err := json.Marshal(v) if err != nil { metrics.decodeError.Add(1) return zero, false } if err := json.Unmarshal(data, &out); err != nil { metrics.decodeError.Add(1) return zero, false } metrics.hit.Add(1) return out, true } } func GetOrLoadTyped[T any]( c Cache, key string, ttl time.Duration, jitterRatio float64, nullTTL time.Duration, loader func() (T, error), ) (T, error) { if cached, ok := GetTyped[T](c, key); ok { return cached, nil } lockValue, _ := loadLocks.LoadOrStore(key, &sync.Mutex{}) lock := lockValue.(*sync.Mutex) lock.Lock() defer lock.Unlock() if cached, ok := GetTyped[T](c, key); ok { return cached, nil } loaded, err := loader() if err != nil { var zero T return zero, err } encoded, marshalErr := json.Marshal(loaded) if marshalErr == nil && string(encoded) == "null" && nullTTL > 0 { SetNull(c, key, nullTTL) return loaded, nil } SetWithJitter(c, key, loaded, ttl, jitterRatio) return loaded, nil }