448 lines
11 KiB
Go
448 lines
11 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"carrotskin/pkg/redis"
|
||
)
|
||
|
||
// CacheConfig 缓存配置
|
||
type CacheConfig struct {
|
||
Prefix string // 缓存键前缀
|
||
Expiration time.Duration // 过期时间
|
||
Enabled bool // 是否启用缓存
|
||
}
|
||
|
||
// CacheManager 缓存管理器
|
||
type CacheManager struct {
|
||
redis *redis.Client
|
||
config CacheConfig
|
||
}
|
||
|
||
// NewCacheManager 创建缓存管理器
|
||
func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManager {
|
||
if config.Prefix == "" {
|
||
config.Prefix = "db:"
|
||
}
|
||
if config.Expiration == 0 {
|
||
config.Expiration = 5 * time.Minute
|
||
}
|
||
|
||
return &CacheManager{
|
||
redis: redisClient,
|
||
config: config,
|
||
}
|
||
}
|
||
|
||
// buildKey 构建缓存键
|
||
func (cm *CacheManager) buildKey(key string) string {
|
||
return cm.config.Prefix + key
|
||
}
|
||
|
||
// Get 获取缓存
|
||
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
data, err := cm.redis.GetBytes(ctx, cm.buildKey(key))
|
||
if err != nil || data == nil {
|
||
return fmt.Errorf("cache miss")
|
||
}
|
||
|
||
return json.Unmarshal(data, dest)
|
||
}
|
||
|
||
// Set 设置缓存
|
||
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
data, err := json.Marshal(value)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
exp := cm.config.Expiration
|
||
if len(expiration) > 0 && expiration[0] > 0 {
|
||
exp = expiration[0]
|
||
}
|
||
|
||
return cm.redis.Set(ctx, cm.buildKey(key), data, exp)
|
||
}
|
||
|
||
// Delete 删除缓存
|
||
func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
fullKeys := make([]string, len(keys))
|
||
for i, key := range keys {
|
||
fullKeys[i] = cm.buildKey(key)
|
||
}
|
||
|
||
return cm.redis.Del(ctx, fullKeys...)
|
||
}
|
||
|
||
// DeletePattern 删除匹配模式的缓存
|
||
// 使用 Redis SCAN 命令安全地删除匹配的键,避免阻塞
|
||
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
// 构建完整的匹配模式
|
||
fullPattern := cm.buildKey(pattern)
|
||
|
||
// 使用 SCAN 命令迭代查找匹配的键
|
||
var cursor uint64
|
||
var deletedCount int
|
||
|
||
for {
|
||
// 每次扫描100个键
|
||
keys, nextCursor, err := cm.redis.Client.Scan(ctx, cursor, fullPattern, 100).Result()
|
||
if err != nil {
|
||
return fmt.Errorf("扫描缓存键失败: %w", err)
|
||
}
|
||
|
||
// 批量删除找到的键
|
||
if len(keys) > 0 {
|
||
if err := cm.redis.Client.Del(ctx, keys...).Err(); err != nil {
|
||
return fmt.Errorf("删除缓存键失败: %w", err)
|
||
}
|
||
deletedCount += len(keys)
|
||
}
|
||
|
||
// 更新游标
|
||
cursor = nextCursor
|
||
|
||
// cursor == 0 表示扫描完成
|
||
if cursor == 0 {
|
||
break
|
||
}
|
||
|
||
// 检查 context 是否已取消
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetOrSet 获取缓存,如果不存在则执行回调并设置缓存
|
||
func (cm *CacheManager) GetOrSet(ctx context.Context, key string, dest interface{}, fn func() (interface{}, error), expiration ...time.Duration) error {
|
||
// 尝试从缓存获取
|
||
err := cm.Get(ctx, key, dest)
|
||
if err == nil {
|
||
return nil // 缓存命中
|
||
}
|
||
|
||
// 缓存未命中,执行回调获取数据
|
||
result, err := fn()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 设置缓存
|
||
if err := cm.Set(ctx, key, result, expiration...); err != nil {
|
||
// 缓存设置失败不影响主流程,只记录日志
|
||
// logger.Warn("failed to set cache", zap.Error(err))
|
||
}
|
||
|
||
// 将结果转换为目标类型
|
||
data, err := json.Marshal(result)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return json.Unmarshal(data, dest)
|
||
}
|
||
|
||
// Cached 缓存装饰器 - 为查询函数添加缓存
|
||
func Cached[T any](
|
||
ctx context.Context,
|
||
cache *CacheManager,
|
||
key string,
|
||
queryFn func() (*T, error),
|
||
expiration ...time.Duration,
|
||
) (*T, error) {
|
||
// 尝试从缓存获取
|
||
var result T
|
||
if err := cache.Get(ctx, key, &result); err == nil {
|
||
return &result, nil
|
||
}
|
||
|
||
// 缓存未命中,执行查询
|
||
data, err := queryFn()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 设置缓存(异步,不阻塞)
|
||
go func() {
|
||
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer cancel()
|
||
_ = cache.Set(cacheCtx, key, data, expiration...)
|
||
}()
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// CachedList 缓存装饰器 - 为列表查询添加缓存
|
||
func CachedList[T any](
|
||
ctx context.Context,
|
||
cache *CacheManager,
|
||
key string,
|
||
queryFn func() ([]T, error),
|
||
expiration ...time.Duration,
|
||
) ([]T, error) {
|
||
// 尝试从缓存获取
|
||
var result []T
|
||
if err := cache.Get(ctx, key, &result); err == nil {
|
||
return result, nil
|
||
}
|
||
|
||
// 缓存未命中,执行查询
|
||
data, err := queryFn()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 设置缓存(异步,不阻塞)
|
||
go func() {
|
||
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||
defer cancel()
|
||
_ = cache.Set(cacheCtx, key, data, expiration...)
|
||
}()
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// InvalidateCache 使缓存失效的辅助函数
|
||
type CacheInvalidator struct {
|
||
cache *CacheManager
|
||
}
|
||
|
||
// NewCacheInvalidator 创建缓存失效器
|
||
func NewCacheInvalidator(cache *CacheManager) *CacheInvalidator {
|
||
return &CacheInvalidator{cache: cache}
|
||
}
|
||
|
||
// OnCreate 创建时使缓存失效
|
||
func (ci *CacheInvalidator) OnCreate(ctx context.Context, keys ...string) {
|
||
_ = ci.cache.Delete(ctx, keys...)
|
||
}
|
||
|
||
// OnUpdate 更新时使缓存失效
|
||
func (ci *CacheInvalidator) OnUpdate(ctx context.Context, keys ...string) {
|
||
_ = ci.cache.Delete(ctx, keys...)
|
||
}
|
||
|
||
// OnDelete 删除时使缓存失效
|
||
func (ci *CacheInvalidator) OnDelete(ctx context.Context, keys ...string) {
|
||
_ = ci.cache.Delete(ctx, keys...)
|
||
}
|
||
|
||
// BatchInvalidate 批量使缓存失效(支持模式匹配)
|
||
func (ci *CacheInvalidator) BatchInvalidate(ctx context.Context, pattern string) {
|
||
_ = ci.cache.DeletePattern(ctx, pattern)
|
||
}
|
||
|
||
// CacheKeyBuilder 缓存键构建器
|
||
type CacheKeyBuilder struct {
|
||
prefix string
|
||
}
|
||
|
||
// NewCacheKeyBuilder 创建缓存键构建器
|
||
func NewCacheKeyBuilder(prefix string) *CacheKeyBuilder {
|
||
return &CacheKeyBuilder{prefix: prefix}
|
||
}
|
||
|
||
// User 构建用户相关缓存键
|
||
func (b *CacheKeyBuilder) User(userID int64) string {
|
||
return fmt.Sprintf("%suser:id:%d", b.prefix, userID)
|
||
}
|
||
|
||
// UserByEmail 构建邮箱查询缓存键
|
||
func (b *CacheKeyBuilder) UserByEmail(email string) string {
|
||
return fmt.Sprintf("%suser:email:%s", b.prefix, email)
|
||
}
|
||
|
||
// UserByUsername 构建用户名查询缓存键
|
||
func (b *CacheKeyBuilder) UserByUsername(username string) string {
|
||
return fmt.Sprintf("%suser:username:%s", b.prefix, username)
|
||
}
|
||
|
||
// Profile 构建档案缓存键
|
||
func (b *CacheKeyBuilder) Profile(uuid string) string {
|
||
return fmt.Sprintf("%sprofile:uuid:%s", b.prefix, uuid)
|
||
}
|
||
|
||
// ProfileList 构建用户档案列表缓存键
|
||
func (b *CacheKeyBuilder) ProfileList(userID int64) string {
|
||
return fmt.Sprintf("%sprofile:user:%d:list", b.prefix, userID)
|
||
}
|
||
|
||
// Texture 构建材质缓存键
|
||
func (b *CacheKeyBuilder) Texture(textureID int64) string {
|
||
return fmt.Sprintf("%stexture:id:%d", b.prefix, textureID)
|
||
}
|
||
|
||
// TextureByHash 构建材质hash缓存键
|
||
func (b *CacheKeyBuilder) TextureByHash(hash string) string {
|
||
return fmt.Sprintf("%stexture:hash:%s", b.prefix, hash)
|
||
}
|
||
|
||
// TextureList 构建材质列表缓存键
|
||
func (b *CacheKeyBuilder) TextureList(userID int64, page int) string {
|
||
return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page)
|
||
}
|
||
|
||
// Token 构建令牌缓存键
|
||
func (b *CacheKeyBuilder) Token(accessToken string) string {
|
||
return fmt.Sprintf("%stoken:%s", b.prefix, accessToken)
|
||
}
|
||
|
||
// UserPattern 用户相关的所有缓存键模式
|
||
func (b *CacheKeyBuilder) UserPattern(userID int64) string {
|
||
return fmt.Sprintf("%suser:*:%d*", b.prefix, userID)
|
||
}
|
||
|
||
// ProfilePattern 档案相关的所有缓存键模式
|
||
func (b *CacheKeyBuilder) ProfilePattern(userID int64) string {
|
||
return fmt.Sprintf("%sprofile:*:%d*", b.prefix, userID)
|
||
}
|
||
|
||
// Exists 检查缓存键是否存在
|
||
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return false, nil
|
||
}
|
||
|
||
count, err := cm.redis.Exists(ctx, cm.buildKey(key))
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return count > 0, nil
|
||
}
|
||
|
||
// TTL 获取缓存键的剩余过期时间
|
||
func (cm *CacheManager) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return 0, fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
return cm.redis.TTL(ctx, cm.buildKey(key))
|
||
}
|
||
|
||
// Expire 设置缓存键的过期时间
|
||
func (cm *CacheManager) Expire(ctx context.Context, key string, expiration time.Duration) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
return cm.redis.Expire(ctx, cm.buildKey(key), expiration)
|
||
}
|
||
|
||
// MGet 批量获取多个缓存
|
||
func (cm *CacheManager) MGet(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil, fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
if len(keys) == 0 {
|
||
return make(map[string]interface{}), nil
|
||
}
|
||
|
||
// 构建完整的键
|
||
fullKeys := make([]string, len(keys))
|
||
for i, key := range keys {
|
||
fullKeys[i] = cm.buildKey(key)
|
||
}
|
||
|
||
// 批量获取
|
||
values, err := cm.redis.Client.MGet(ctx, fullKeys...).Result()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 解析结果
|
||
result := make(map[string]interface{})
|
||
for i, val := range values {
|
||
if val != nil {
|
||
result[keys[i]] = val
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// MSet 批量设置多个缓存
|
||
func (cm *CacheManager) MSet(ctx context.Context, values map[string]interface{}, expiration time.Duration) error {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
if len(values) == 0 {
|
||
return nil
|
||
}
|
||
|
||
// 逐个设置(Redis MSet 不支持过期时间)
|
||
for key, value := range values {
|
||
if err := cm.Set(ctx, key, value, expiration); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Increment 递增缓存值
|
||
func (cm *CacheManager) Increment(ctx context.Context, key string) (int64, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return 0, fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
return cm.redis.Incr(ctx, cm.buildKey(key))
|
||
}
|
||
|
||
// Decrement 递减缓存值
|
||
func (cm *CacheManager) Decrement(ctx context.Context, key string) (int64, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return 0, fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
return cm.redis.Decr(ctx, cm.buildKey(key))
|
||
}
|
||
|
||
// IncrementWithExpire 递增并设置过期时间
|
||
func (cm *CacheManager) IncrementWithExpire(ctx context.Context, key string, expiration time.Duration) (int64, error) {
|
||
if !cm.config.Enabled || cm.redis == nil {
|
||
return 0, fmt.Errorf("cache not enabled")
|
||
}
|
||
|
||
fullKey := cm.buildKey(key)
|
||
|
||
// 递增
|
||
val, err := cm.redis.Incr(ctx, fullKey)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
|
||
// 设置过期时间(如果是新键)
|
||
if val == 1 {
|
||
_ = cm.redis.Expire(ctx, fullKey, expiration)
|
||
}
|
||
|
||
return val, nil
|
||
}
|