Files
backend/pkg/database/cache.go

496 lines
12 KiB
Go
Raw Normal View History

package database
import (
"context"
"encoding/json"
"fmt"
"time"
"carrotskin/pkg/redis"
)
// CacheConfig 缓存配置
type CacheConfig struct {
Prefix string // 缓存键前缀
Expiration time.Duration // 过期时间
Enabled bool // 是否启用缓存
Policy CachePolicy // 缓存策略(可选,不配置则回落到 Expiration
}
// CachePolicy 缓存策略,用于为不同实体设置默认 TTL
type CachePolicy struct {
UserTTL time.Duration
UserEmailTTL time.Duration
ProfileTTL time.Duration
ProfileListTTL time.Duration
TextureTTL time.Duration
TextureListTTL time.Duration
}
// CacheManager 缓存管理器
type CacheManager struct {
redis *redis.Client
config CacheConfig
Policy CachePolicy
}
// NewCacheManager 创建缓存管理器
func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManager {
if config.Prefix == "" {
config.Prefix = "db:"
}
if config.Expiration == 0 {
config.Expiration = 5 * time.Minute
}
// 填充默认策略(未配置时退回全局过期时间)
applyPolicyDefaults := func(p *CachePolicy) {
if p.UserTTL == 0 {
p.UserTTL = config.Expiration
}
if p.UserEmailTTL == 0 {
p.UserEmailTTL = config.Expiration
}
if p.ProfileTTL == 0 {
p.ProfileTTL = config.Expiration
}
if p.ProfileListTTL == 0 {
p.ProfileListTTL = config.Expiration
}
if p.TextureTTL == 0 {
p.TextureTTL = config.Expiration
}
if p.TextureListTTL == 0 {
p.TextureListTTL = config.Expiration
}
}
applyPolicyDefaults(&config.Policy)
return &CacheManager{
redis: redisClient,
config: config,
Policy: config.Policy,
}
}
// buildKey 构建缓存键
func (cm *CacheManager) buildKey(key string) string {
return cm.config.Prefix + key
}
// Get 获取缓存
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
if !cm.config.Enabled || cm.redis == nil {
return fmt.Errorf("cache not enabled")
}
data, err := cm.redis.GetBytes(ctx, cm.buildKey(key))
if err != nil || data == nil {
return fmt.Errorf("cache miss")
}
return json.Unmarshal(data, dest)
}
// TryGet 获取缓存,命中时返回 true不视为错误
func (cm *CacheManager) TryGet(ctx context.Context, key string, dest interface{}) (bool, error) {
if err := cm.Get(ctx, key, dest); err != nil {
return false, err
}
return true, nil
}
// Set 设置缓存
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
data, err := json.Marshal(value)
if err != nil {
return err
}
exp := cm.config.Expiration
if len(expiration) > 0 && expiration[0] > 0 {
exp = expiration[0]
}
return cm.redis.Set(ctx, cm.buildKey(key), data, exp)
}
// SetAsync 异步设置缓存,避免在主请求链路阻塞
func (cm *CacheManager) SetAsync(ctx context.Context, key string, value interface{}, expiration ...time.Duration) {
go func() {
_ = cm.Set(ctx, key, value, expiration...)
}()
}
// Delete 删除缓存
func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = cm.buildKey(key)
}
return cm.redis.Del(ctx, fullKeys...)
}
// DeletePattern 删除匹配模式的缓存
// 使用 Redis SCAN 命令安全地删除匹配的键,避免阻塞
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
// 构建完整的匹配模式
fullPattern := cm.buildKey(pattern)
// 使用 SCAN 命令迭代查找匹配的键
var cursor uint64
var deletedCount int
for {
// 每次扫描100个键
keys, nextCursor, err := cm.redis.Client.Scan(ctx, cursor, fullPattern, 100).Result()
if err != nil {
return fmt.Errorf("扫描缓存键失败: %w", err)
}
// 批量删除找到的键
if len(keys) > 0 {
if err := cm.redis.Client.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("删除缓存键失败: %w", err)
}
deletedCount += len(keys)
}
// 更新游标
cursor = nextCursor
// cursor == 0 表示扫描完成
if cursor == 0 {
break
}
// 检查 context 是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
return nil
}
// GetOrSet 获取缓存,如果不存在则执行回调并设置缓存
func (cm *CacheManager) GetOrSet(ctx context.Context, key string, dest interface{}, fn func() (interface{}, error), expiration ...time.Duration) error {
// 尝试从缓存获取
err := cm.Get(ctx, key, dest)
if err == nil {
return nil // 缓存命中
}
// 缓存未命中,执行回调获取数据
result, err := fn()
if err != nil {
return err
}
// 设置缓存
if err := cm.Set(ctx, key, result, expiration...); err != nil {
// 缓存设置失败不影响主流程,只记录日志
// logger.Warn("failed to set cache", zap.Error(err))
}
// 将结果转换为目标类型
data, err := json.Marshal(result)
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Cached 缓存装饰器 - 为查询函数添加缓存
func Cached[T any](
ctx context.Context,
cache *CacheManager,
key string,
queryFn func() (*T, error),
expiration ...time.Duration,
) (*T, error) {
// 尝试从缓存获取
var result T
if err := cache.Get(ctx, key, &result); err == nil {
return &result, nil
}
// 缓存未命中,执行查询
data, err := queryFn()
if err != nil {
return nil, err
}
// 设置缓存(异步,不阻塞)
cache.SetAsync(context.Background(), key, data, expiration...)
return data, nil
}
// CachedList 缓存装饰器 - 为列表查询添加缓存
func CachedList[T any](
ctx context.Context,
cache *CacheManager,
key string,
queryFn func() ([]T, error),
expiration ...time.Duration,
) ([]T, error) {
// 尝试从缓存获取
var result []T
if err := cache.Get(ctx, key, &result); err == nil {
return result, nil
}
// 缓存未命中,执行查询
data, err := queryFn()
if err != nil {
return nil, err
}
// 设置缓存(异步,不阻塞)
cache.SetAsync(context.Background(), key, data, expiration...)
return data, nil
}
// InvalidateCache 使缓存失效的辅助函数
type CacheInvalidator struct {
cache *CacheManager
}
// NewCacheInvalidator 创建缓存失效器
func NewCacheInvalidator(cache *CacheManager) *CacheInvalidator {
return &CacheInvalidator{cache: cache}
}
// OnCreate 创建时使缓存失效
func (ci *CacheInvalidator) OnCreate(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// OnUpdate 更新时使缓存失效
func (ci *CacheInvalidator) OnUpdate(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// OnDelete 删除时使缓存失效
func (ci *CacheInvalidator) OnDelete(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// BatchInvalidate 批量使缓存失效(支持模式匹配)
func (ci *CacheInvalidator) BatchInvalidate(ctx context.Context, pattern string) {
_ = ci.cache.DeletePattern(ctx, pattern)
}
// CacheKeyBuilder 缓存键构建器
type CacheKeyBuilder struct {
prefix string
}
// NewCacheKeyBuilder 创建缓存键构建器
func NewCacheKeyBuilder(prefix string) *CacheKeyBuilder {
return &CacheKeyBuilder{prefix: prefix}
}
// User 构建用户相关缓存键
func (b *CacheKeyBuilder) User(userID int64) string {
return fmt.Sprintf("%suser:id:%d", b.prefix, userID)
}
// UserByEmail 构建邮箱查询缓存键
func (b *CacheKeyBuilder) UserByEmail(email string) string {
return fmt.Sprintf("%suser:email:%s", b.prefix, email)
}
// UserByUsername 构建用户名查询缓存键
func (b *CacheKeyBuilder) UserByUsername(username string) string {
return fmt.Sprintf("%suser:username:%s", b.prefix, username)
}
// Profile 构建档案缓存键
func (b *CacheKeyBuilder) Profile(uuid string) string {
return fmt.Sprintf("%sprofile:uuid:%s", b.prefix, uuid)
}
// ProfileList 构建用户档案列表缓存键
func (b *CacheKeyBuilder) ProfileList(userID int64) string {
return fmt.Sprintf("%sprofile:user:%d:list", b.prefix, userID)
}
// Texture 构建材质缓存键
func (b *CacheKeyBuilder) Texture(textureID int64) string {
return fmt.Sprintf("%stexture:id:%d", b.prefix, textureID)
}
2025-12-03 10:58:39 +08:00
// TextureByHash 构建材质hash缓存键
func (b *CacheKeyBuilder) TextureByHash(hash string) string {
return fmt.Sprintf("%stexture:hash:%s", b.prefix, hash)
}
// TextureList 构建材质列表缓存键
func (b *CacheKeyBuilder) TextureList(userID int64, page int) string {
return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page)
}
// TextureListPattern 构建材质列表缓存键模式(用于批量失效)
func (b *CacheKeyBuilder) TextureListPattern(userID int64) string {
return fmt.Sprintf("%stexture:user:%d:*", b.prefix, userID)
}
// Token 构建令牌缓存键
func (b *CacheKeyBuilder) Token(accessToken string) string {
return fmt.Sprintf("%stoken:%s", b.prefix, accessToken)
}
// UserPattern 用户相关的所有缓存键模式
func (b *CacheKeyBuilder) UserPattern(userID int64) string {
return fmt.Sprintf("%suser:*:%d*", b.prefix, userID)
}
// ProfilePattern 档案相关的所有缓存键模式
func (b *CacheKeyBuilder) ProfilePattern(userID int64) string {
return fmt.Sprintf("%sprofile:*:%d*", b.prefix, userID)
}
// Exists 检查缓存键是否存在
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
if !cm.config.Enabled || cm.redis == nil {
return false, nil
}
count, err := cm.redis.Exists(ctx, cm.buildKey(key))
if err != nil {
return false, err
}
return count > 0, nil
}
// TTL 获取缓存键的剩余过期时间
func (cm *CacheManager) TTL(ctx context.Context, key string) (time.Duration, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.TTL(ctx, cm.buildKey(key))
}
// Expire 设置缓存键的过期时间
func (cm *CacheManager) Expire(ctx context.Context, key string, expiration time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
return cm.redis.Expire(ctx, cm.buildKey(key), expiration)
}
// MGet 批量获取多个缓存
func (cm *CacheManager) MGet(ctx context.Context, keys []string) (map[string]interface{}, error) {
if !cm.config.Enabled || cm.redis == nil {
return nil, fmt.Errorf("cache not enabled")
}
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
// 构建完整的键
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = cm.buildKey(key)
}
// 批量获取
values, err := cm.redis.Client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
// 解析结果
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
result[keys[i]] = val
}
}
return result, nil
}
// MSet 批量设置多个缓存
func (cm *CacheManager) MSet(ctx context.Context, values map[string]interface{}, expiration time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
if len(values) == 0 {
return nil
}
// 逐个设置Redis MSet 不支持过期时间)
for key, value := range values {
if err := cm.Set(ctx, key, value, expiration); err != nil {
return err
}
}
return nil
}
// Increment 递增缓存值
func (cm *CacheManager) Increment(ctx context.Context, key string) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.Incr(ctx, cm.buildKey(key))
}
// Decrement 递减缓存值
func (cm *CacheManager) Decrement(ctx context.Context, key string) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.Decr(ctx, cm.buildKey(key))
}
// IncrementWithExpire 递增并设置过期时间
func (cm *CacheManager) IncrementWithExpire(ctx context.Context, key string, expiration time.Duration) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
fullKey := cm.buildKey(key)
// 递增
val, err := cm.redis.Incr(ctx, fullKey)
if err != nil {
return 0, err
}
// 设置过期时间(如果是新键)
if val == 1 {
_ = cm.redis.Expire(ctx, fullKey, expiration)
}
return val, nil
}