package database import ( "context" "encoding/json" "fmt" "time" "carrotskin/pkg/redis" ) // CacheConfig 缓存配置 type CacheConfig struct { Prefix string // 缓存键前缀 Expiration time.Duration // 过期时间 Enabled bool // 是否启用缓存 Policy CachePolicy // 缓存策略(可选,不配置则回落到 Expiration) } // CachePolicy 缓存策略,用于为不同实体设置默认 TTL type CachePolicy struct { UserTTL time.Duration UserEmailTTL time.Duration ProfileTTL time.Duration ProfileListTTL time.Duration TextureTTL time.Duration TextureListTTL time.Duration } // CacheManager 缓存管理器 type CacheManager struct { redis *redis.Client config CacheConfig Policy CachePolicy } // NewCacheManager 创建缓存管理器 func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManager { if config.Prefix == "" { config.Prefix = "db:" } if config.Expiration == 0 { config.Expiration = 5 * time.Minute } // 填充默认策略(未配置时退回全局过期时间) applyPolicyDefaults := func(p *CachePolicy) { if p.UserTTL == 0 { p.UserTTL = config.Expiration } if p.UserEmailTTL == 0 { p.UserEmailTTL = config.Expiration } if p.ProfileTTL == 0 { p.ProfileTTL = config.Expiration } if p.ProfileListTTL == 0 { p.ProfileListTTL = config.Expiration } if p.TextureTTL == 0 { p.TextureTTL = config.Expiration } if p.TextureListTTL == 0 { p.TextureListTTL = config.Expiration } } applyPolicyDefaults(&config.Policy) return &CacheManager{ redis: redisClient, config: config, Policy: config.Policy, } } // buildKey 构建缓存键 func (cm *CacheManager) buildKey(key string) string { return cm.config.Prefix + key } // Get 获取缓存 func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error { if !cm.config.Enabled || cm.redis == nil { return fmt.Errorf("cache not enabled") } data, err := cm.redis.GetBytes(ctx, cm.buildKey(key)) if err != nil || data == nil { return fmt.Errorf("cache miss") } return json.Unmarshal(data, dest) } // TryGet 获取缓存,命中时返回 true,不视为错误 func (cm *CacheManager) TryGet(ctx context.Context, key string, dest interface{}) (bool, error) { if err := cm.Get(ctx, key, dest); err != nil { return false, err } return true, nil } // Set 设置缓存 func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error { if !cm.config.Enabled || cm.redis == nil { return nil } data, err := json.Marshal(value) if err != nil { return err } exp := cm.config.Expiration if len(expiration) > 0 && expiration[0] > 0 { exp = expiration[0] } return cm.redis.Set(ctx, cm.buildKey(key), data, exp) } // SetAsync 异步设置缓存,避免在主请求链路阻塞 func (cm *CacheManager) SetAsync(ctx context.Context, key string, value interface{}, expiration ...time.Duration) { go func() { _ = cm.Set(ctx, key, value, expiration...) }() } // Delete 删除缓存 func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error { if !cm.config.Enabled || cm.redis == nil { return nil } fullKeys := make([]string, len(keys)) for i, key := range keys { fullKeys[i] = cm.buildKey(key) } return cm.redis.Del(ctx, fullKeys...) } // DeletePattern 删除匹配模式的缓存 // 使用 Redis SCAN 命令安全地删除匹配的键,避免阻塞 func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error { if !cm.config.Enabled || cm.redis == nil { return nil } // 构建完整的匹配模式 fullPattern := cm.buildKey(pattern) // 使用 SCAN 命令迭代查找匹配的键 var cursor uint64 var deletedCount int for { // 每次扫描100个键 keys, nextCursor, err := cm.redis.Client.Scan(ctx, cursor, fullPattern, 100).Result() if err != nil { return fmt.Errorf("扫描缓存键失败: %w", err) } // 批量删除找到的键 if len(keys) > 0 { if err := cm.redis.Client.Del(ctx, keys...).Err(); err != nil { return fmt.Errorf("删除缓存键失败: %w", err) } deletedCount += len(keys) } // 更新游标 cursor = nextCursor // cursor == 0 表示扫描完成 if cursor == 0 { break } // 检查 context 是否已取消 select { case <-ctx.Done(): return ctx.Err() default: } } return nil } // GetOrSet 获取缓存,如果不存在则执行回调并设置缓存 func (cm *CacheManager) GetOrSet(ctx context.Context, key string, dest interface{}, fn func() (interface{}, error), expiration ...time.Duration) error { // 尝试从缓存获取 err := cm.Get(ctx, key, dest) if err == nil { return nil // 缓存命中 } // 缓存未命中,执行回调获取数据 result, err := fn() if err != nil { return err } // 设置缓存 if err := cm.Set(ctx, key, result, expiration...); err != nil { // 缓存设置失败不影响主流程,只记录日志 // logger.Warn("failed to set cache", zap.Error(err)) } // 将结果转换为目标类型 data, err := json.Marshal(result) if err != nil { return err } return json.Unmarshal(data, dest) } // Cached 缓存装饰器 - 为查询函数添加缓存 func Cached[T any]( ctx context.Context, cache *CacheManager, key string, queryFn func() (*T, error), expiration ...time.Duration, ) (*T, error) { // 尝试从缓存获取 var result T if err := cache.Get(ctx, key, &result); err == nil { return &result, nil } // 缓存未命中,执行查询 data, err := queryFn() if err != nil { return nil, err } // 设置缓存(异步,不阻塞) cache.SetAsync(context.Background(), key, data, expiration...) return data, nil } // CachedList 缓存装饰器 - 为列表查询添加缓存 func CachedList[T any]( ctx context.Context, cache *CacheManager, key string, queryFn func() ([]T, error), expiration ...time.Duration, ) ([]T, error) { // 尝试从缓存获取 var result []T if err := cache.Get(ctx, key, &result); err == nil { return result, nil } // 缓存未命中,执行查询 data, err := queryFn() if err != nil { return nil, err } // 设置缓存(异步,不阻塞) cache.SetAsync(context.Background(), key, data, expiration...) return data, nil } // InvalidateCache 使缓存失效的辅助函数 type CacheInvalidator struct { cache *CacheManager } // NewCacheInvalidator 创建缓存失效器 func NewCacheInvalidator(cache *CacheManager) *CacheInvalidator { return &CacheInvalidator{cache: cache} } // OnCreate 创建时使缓存失效 func (ci *CacheInvalidator) OnCreate(ctx context.Context, keys ...string) { _ = ci.cache.Delete(ctx, keys...) } // OnUpdate 更新时使缓存失效 func (ci *CacheInvalidator) OnUpdate(ctx context.Context, keys ...string) { _ = ci.cache.Delete(ctx, keys...) } // OnDelete 删除时使缓存失效 func (ci *CacheInvalidator) OnDelete(ctx context.Context, keys ...string) { _ = ci.cache.Delete(ctx, keys...) } // BatchInvalidate 批量使缓存失效(支持模式匹配) func (ci *CacheInvalidator) BatchInvalidate(ctx context.Context, pattern string) { _ = ci.cache.DeletePattern(ctx, pattern) } // CacheKeyBuilder 缓存键构建器 type CacheKeyBuilder struct { prefix string } // NewCacheKeyBuilder 创建缓存键构建器 func NewCacheKeyBuilder(prefix string) *CacheKeyBuilder { return &CacheKeyBuilder{prefix: prefix} } // User 构建用户相关缓存键 func (b *CacheKeyBuilder) User(userID int64) string { return fmt.Sprintf("%suser:id:%d", b.prefix, userID) } // UserByEmail 构建邮箱查询缓存键 func (b *CacheKeyBuilder) UserByEmail(email string) string { return fmt.Sprintf("%suser:email:%s", b.prefix, email) } // UserByUsername 构建用户名查询缓存键 func (b *CacheKeyBuilder) UserByUsername(username string) string { return fmt.Sprintf("%suser:username:%s", b.prefix, username) } // Profile 构建档案缓存键 func (b *CacheKeyBuilder) Profile(uuid string) string { return fmt.Sprintf("%sprofile:uuid:%s", b.prefix, uuid) } // ProfileList 构建用户档案列表缓存键 func (b *CacheKeyBuilder) ProfileList(userID int64) string { return fmt.Sprintf("%sprofile:user:%d:list", b.prefix, userID) } // Texture 构建材质缓存键 func (b *CacheKeyBuilder) Texture(textureID int64) string { return fmt.Sprintf("%stexture:id:%d", b.prefix, textureID) } // TextureByHash 构建材质hash缓存键 func (b *CacheKeyBuilder) TextureByHash(hash string) string { return fmt.Sprintf("%stexture:hash:%s", b.prefix, hash) } // TextureList 构建材质列表缓存键 func (b *CacheKeyBuilder) TextureList(userID int64, page int) string { return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page) } // TextureListPattern 构建材质列表缓存键模式(用于批量失效) func (b *CacheKeyBuilder) TextureListPattern(userID int64) string { return fmt.Sprintf("%stexture:user:%d:*", b.prefix, userID) } // Token 构建令牌缓存键 func (b *CacheKeyBuilder) Token(accessToken string) string { return fmt.Sprintf("%stoken:%s", b.prefix, accessToken) } // UserPattern 用户相关的所有缓存键模式 func (b *CacheKeyBuilder) UserPattern(userID int64) string { return fmt.Sprintf("%suser:*:%d*", b.prefix, userID) } // ProfilePattern 档案相关的所有缓存键模式 func (b *CacheKeyBuilder) ProfilePattern(userID int64) string { return fmt.Sprintf("%sprofile:*:%d*", b.prefix, userID) } // Exists 检查缓存键是否存在 func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) { if !cm.config.Enabled || cm.redis == nil { return false, nil } count, err := cm.redis.Exists(ctx, cm.buildKey(key)) if err != nil { return false, err } return count > 0, nil } // TTL 获取缓存键的剩余过期时间 func (cm *CacheManager) TTL(ctx context.Context, key string) (time.Duration, error) { if !cm.config.Enabled || cm.redis == nil { return 0, fmt.Errorf("cache not enabled") } return cm.redis.TTL(ctx, cm.buildKey(key)) } // Expire 设置缓存键的过期时间 func (cm *CacheManager) Expire(ctx context.Context, key string, expiration time.Duration) error { if !cm.config.Enabled || cm.redis == nil { return nil } return cm.redis.Expire(ctx, cm.buildKey(key), expiration) } // MGet 批量获取多个缓存 func (cm *CacheManager) MGet(ctx context.Context, keys []string) (map[string]interface{}, error) { if !cm.config.Enabled || cm.redis == nil { return nil, fmt.Errorf("cache not enabled") } if len(keys) == 0 { return make(map[string]interface{}), nil } // 构建完整的键 fullKeys := make([]string, len(keys)) for i, key := range keys { fullKeys[i] = cm.buildKey(key) } // 批量获取 values, err := cm.redis.Client.MGet(ctx, fullKeys...).Result() if err != nil { return nil, err } // 解析结果 result := make(map[string]interface{}) for i, val := range values { if val != nil { result[keys[i]] = val } } return result, nil } // MSet 批量设置多个缓存 func (cm *CacheManager) MSet(ctx context.Context, values map[string]interface{}, expiration time.Duration) error { if !cm.config.Enabled || cm.redis == nil { return nil } if len(values) == 0 { return nil } // 逐个设置(Redis MSet 不支持过期时间) for key, value := range values { if err := cm.Set(ctx, key, value, expiration); err != nil { return err } } return nil } // Increment 递增缓存值 func (cm *CacheManager) Increment(ctx context.Context, key string) (int64, error) { if !cm.config.Enabled || cm.redis == nil { return 0, fmt.Errorf("cache not enabled") } return cm.redis.Incr(ctx, cm.buildKey(key)) } // Decrement 递减缓存值 func (cm *CacheManager) Decrement(ctx context.Context, key string) (int64, error) { if !cm.config.Enabled || cm.redis == nil { return 0, fmt.Errorf("cache not enabled") } return cm.redis.Decr(ctx, cm.buildKey(key)) } // IncrementWithExpire 递增并设置过期时间 func (cm *CacheManager) IncrementWithExpire(ctx context.Context, key string, expiration time.Duration) (int64, error) { if !cm.config.Enabled || cm.redis == nil { return 0, fmt.Errorf("cache not enabled") } fullKey := cm.buildKey(key) // 递增 val, err := cm.redis.Incr(ctx, fullKey) if err != nil { return 0, err } // 设置过期时间(如果是新键) if val == 1 { _ = cm.redis.Expire(ctx, fullKey, expiration) } return val, nil }