Merge remote-tracking branch 'origin/feature/redis-auth-integration' into dev
# Conflicts: # go.mod # go.sum # internal/container/container.go # internal/repository/interfaces.go # internal/service/mocks_test.go # internal/service/texture_service_test.go # internal/service/token_service_test.go # pkg/redis/manager.go
This commit is contained in:
320
pkg/auth/token_redis.go
Normal file
320
pkg/auth/token_redis.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/redis"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TokenMetadata Token元数据(存储在Redis中)
|
||||
type TokenMetadata struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
ClientUUID string `json:"client_uuid"`
|
||||
ClientToken string `json:"client_token"`
|
||||
Version int `json:"version"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
// TokenStoreRedis Redis Token存储实现
|
||||
type TokenStoreRedis struct {
|
||||
redis *redis.Client
|
||||
logger *zap.Logger
|
||||
keyPrefix string
|
||||
defaultTTL time.Duration
|
||||
staleTTL time.Duration
|
||||
maxTokensPerUser int
|
||||
}
|
||||
|
||||
// NewTokenStoreRedis 创建Redis Token存储
|
||||
func NewTokenStoreRedis(
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
opts ...TokenStoreOption,
|
||||
) *TokenStoreRedis {
|
||||
options := &tokenStoreOptions{
|
||||
keyPrefix: "token:",
|
||||
defaultTTL: 24 * time.Hour,
|
||||
staleTTL: 30 * 24 * time.Hour,
|
||||
maxTokensPerUser: 10,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
return &TokenStoreRedis{
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
keyPrefix: options.keyPrefix,
|
||||
defaultTTL: options.defaultTTL,
|
||||
staleTTL: options.staleTTL,
|
||||
maxTokensPerUser: options.maxTokensPerUser,
|
||||
}
|
||||
}
|
||||
|
||||
// tokenStoreOptions Token存储配置选项
|
||||
type tokenStoreOptions struct {
|
||||
keyPrefix string
|
||||
defaultTTL time.Duration
|
||||
staleTTL time.Duration
|
||||
maxTokensPerUser int
|
||||
}
|
||||
|
||||
// TokenStoreOption Token存储配置选项函数
|
||||
type TokenStoreOption func(*tokenStoreOptions)
|
||||
|
||||
// WithKeyPrefix 设置Key前缀
|
||||
func WithKeyPrefix(prefix string) TokenStoreOption {
|
||||
return func(o *tokenStoreOptions) {
|
||||
o.keyPrefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultTTL 设置默认TTL
|
||||
func WithDefaultTTL(ttl time.Duration) TokenStoreOption {
|
||||
return func(o *tokenStoreOptions) {
|
||||
o.defaultTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithStaleTTL 设置过期但可用时间
|
||||
func WithStaleTTL(ttl time.Duration) TokenStoreOption {
|
||||
return func(o *tokenStoreOptions) {
|
||||
o.staleTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxTokensPerUser 设置每个用户的最大Token数
|
||||
func WithMaxTokensPerUser(max int) TokenStoreOption {
|
||||
return func(o *tokenStoreOptions) {
|
||||
o.maxTokensPerUser = max
|
||||
}
|
||||
}
|
||||
|
||||
// Store 存储Token
|
||||
func (s *TokenStoreRedis) Store(ctx context.Context, accessToken string, metadata *TokenMetadata, ttl time.Duration) error {
|
||||
if ttl <= 0 {
|
||||
ttl = s.defaultTTL
|
||||
}
|
||||
|
||||
// 序列化元数据
|
||||
data, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化Token元数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储Token
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
if err := s.redis.Set(ctx, tokenKey, data, ttl); err != nil {
|
||||
return fmt.Errorf("存储Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 添加到用户Token集合
|
||||
userTokensKey := s.getUserTokensKey(metadata.UserID)
|
||||
if err := s.redis.SAdd(ctx, userTokensKey, accessToken); err != nil {
|
||||
return fmt.Errorf("添加到用户Token集合失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理过期Token(后台执行)
|
||||
go s.cleanupUserTokens(context.Background(), metadata.UserID)
|
||||
|
||||
s.logger.Debug("Token已存储",
|
||||
zap.String("token", accessToken[:20]+"..."),
|
||||
zap.Int64("userId", metadata.UserID),
|
||||
zap.Duration("ttl", ttl),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Retrieve 获取Token元数据
|
||||
func (s *TokenStoreRedis) Retrieve(ctx context.Context, accessToken string) (*TokenMetadata, error) {
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
data, err := s.redis.Get(ctx, tokenKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取Token失败: %w", err)
|
||||
}
|
||||
|
||||
var metadata TokenMetadata
|
||||
if err := json.Unmarshal([]byte(data), &metadata); err != nil {
|
||||
return nil, fmt.Errorf("解析Token元数据失败: %w", err)
|
||||
}
|
||||
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// Delete 删除Token
|
||||
func (s *TokenStoreRedis) Delete(ctx context.Context, accessToken string) error {
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
|
||||
// 先获取Token元数据以获取UserID
|
||||
metadata, err := s.Retrieve(ctx, accessToken)
|
||||
if err != nil {
|
||||
// Token可能已过期,忽略错误
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除Token
|
||||
if err := s.redis.Del(ctx, tokenKey); err != nil {
|
||||
return fmt.Errorf("删除Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 从用户Token集合中移除
|
||||
userTokensKey := s.getUserTokensKey(metadata.UserID)
|
||||
if err := s.redis.SRem(ctx, userTokensKey, accessToken); err != nil {
|
||||
return fmt.Errorf("从用户Token集合移除失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Debug("Token已删除",
|
||||
zap.String("token", accessToken[:20]+"..."),
|
||||
zap.Int64("userId", metadata.UserID),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户的所有Token
|
||||
func (s *TokenStoreRedis) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
userTokensKey := s.getUserTokensKey(userID)
|
||||
|
||||
// 获取用户所有Token
|
||||
tokens, err := s.redis.SMembers(ctx, userTokensKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取用户Token列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除所有Token
|
||||
if len(tokens) > 0 {
|
||||
tokenKeys := make([]string, len(tokens))
|
||||
for i, token := range tokens {
|
||||
tokenKeys[i] = s.getTokenKey(token)
|
||||
}
|
||||
|
||||
if err := s.redis.Del(ctx, tokenKeys...); err != nil {
|
||||
return fmt.Errorf("批量删除Token失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除用户Token集合
|
||||
if err := s.redis.Del(ctx, userTokensKey); err != nil {
|
||||
return fmt.Errorf("删除用户Token集合失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("用户所有Token已删除",
|
||||
zap.Int64("userId", userID),
|
||||
zap.Int("count", len(tokens)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查Token是否存在
|
||||
func (s *TokenStoreRedis) Exists(ctx context.Context, accessToken string) (bool, error) {
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
count, err := s.redis.Exists(ctx, tokenKey)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("检查Token存在失败: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetTTL 获取Token的剩余TTL
|
||||
func (s *TokenStoreRedis) GetTTL(ctx context.Context, accessToken string) (time.Duration, error) {
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
return s.redis.TTL(ctx, tokenKey)
|
||||
}
|
||||
|
||||
// RefreshTTL 刷新Token的TTL
|
||||
func (s *TokenStoreRedis) RefreshTTL(ctx context.Context, accessToken string, ttl time.Duration) error {
|
||||
if ttl <= 0 {
|
||||
ttl = s.defaultTTL
|
||||
}
|
||||
|
||||
tokenKey := s.getTokenKey(accessToken)
|
||||
if err := s.redis.Expire(ctx, tokenKey, ttl); err != nil {
|
||||
return fmt.Errorf("刷新Token TTL失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCountByUser 获取用户的Token数量
|
||||
func (s *TokenStoreRedis) GetCountByUser(ctx context.Context, userID int64) (int64, error) {
|
||||
userTokensKey := s.getUserTokensKey(userID)
|
||||
count, err := s.redis.SMembers(ctx, userTokensKey)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("获取用户Token数量失败: %w", err)
|
||||
}
|
||||
return int64(len(count)), nil
|
||||
}
|
||||
|
||||
// cleanupUserTokens 清理用户的过期Token(保留最新的N个)
|
||||
func (s *TokenStoreRedis) cleanupUserTokens(ctx context.Context, userID int64) {
|
||||
userTokensKey := s.getUserTokensKey(userID)
|
||||
|
||||
// 获取用户所有Token
|
||||
tokens, err := s.redis.SMembers(ctx, userTokensKey)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token列表失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
}
|
||||
|
||||
// 清理过期的Token(验证它们是否仍存在)
|
||||
validTokens := make([]string, 0, len(tokens))
|
||||
for _, token := range tokens {
|
||||
tokenKey := s.getTokenKey(token)
|
||||
exists, err := s.redis.Exists(ctx, tokenKey)
|
||||
if err != nil {
|
||||
s.logger.Error("检查Token存在失败", zap.Error(err), zap.String("token", token[:20]+"..."))
|
||||
continue
|
||||
}
|
||||
|
||||
if exists > 0 {
|
||||
validTokens = append(validTokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有变化,直接返回
|
||||
if len(validTokens) == len(tokens) {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新用户Token集合
|
||||
if len(validTokens) == 0 {
|
||||
s.redis.Del(ctx, userTokensKey)
|
||||
} else {
|
||||
// 重新设置集合
|
||||
s.redis.Del(ctx, userTokensKey)
|
||||
for _, token := range validTokens {
|
||||
s.redis.SAdd(ctx, userTokensKey, token)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果超过限制,删除最旧的Token(这里简化处理,可以根据createdAt排序)
|
||||
if len(validTokens) > s.maxTokensPerUser {
|
||||
tokensToDelete := validTokens[s.maxTokensPerUser:]
|
||||
for _, token := range tokensToDelete {
|
||||
s.Delete(ctx, token)
|
||||
}
|
||||
s.logger.Info("清理用户多余Token",
|
||||
zap.Int64("userId", userID),
|
||||
zap.Int("deleted", len(tokensToDelete)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// getTokenKey 生成Token的Redis Key
|
||||
func (s *TokenStoreRedis) getTokenKey(accessToken string) string {
|
||||
return s.keyPrefix + accessToken
|
||||
}
|
||||
|
||||
// getUserTokensKey 生成用户Token集合的Redis Key
|
||||
func (s *TokenStoreRedis) getUserTokensKey(userID int64) string {
|
||||
return fmt.Sprintf("user:%d:tokens", userID)
|
||||
}
|
||||
47
pkg/config/config_load_test.go
Normal file
47
pkg/config/config_load_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// 重置 viper,避免测试间干扰
|
||||
func resetViper() {
|
||||
viper.Reset()
|
||||
}
|
||||
|
||||
func TestLoad_DefaultsAndBucketsOverride(t *testing.T) {
|
||||
resetViper()
|
||||
// 设置部分环境变量覆盖
|
||||
_ = os.Setenv("RUSTFS_BUCKET_TEXTURES", "tex-bkt")
|
||||
_ = os.Setenv("RUSTFS_BUCKET_AVATARS", "ava-bkt")
|
||||
_ = os.Setenv("DATABASE_MAX_IDLE_CONNS", "20")
|
||||
_ = os.Setenv("DATABASE_MAX_OPEN_CONNS", "50")
|
||||
_ = os.Setenv("DATABASE_CONN_MAX_LIFETIME", "2h")
|
||||
_ = os.Setenv("DATABASE_CONN_MAX_IDLE_TIME", "30m")
|
||||
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load err: %v", err)
|
||||
}
|
||||
|
||||
// 默认值检查
|
||||
if cfg.Server.Port == "" || cfg.Database.Driver == "" || cfg.Redis.Host == "" {
|
||||
t.Fatalf("expected defaults filled: %+v", cfg)
|
||||
}
|
||||
|
||||
// 覆盖检查
|
||||
if cfg.RustFS.Buckets["textures"] != "tex-bkt" || cfg.RustFS.Buckets["avatars"] != "ava-bkt" {
|
||||
t.Fatalf("buckets override failed: %+v", cfg.RustFS.Buckets)
|
||||
}
|
||||
if cfg.Database.MaxIdleConns != 20 || cfg.Database.MaxOpenConns != 50 {
|
||||
t.Fatalf("db pool override failed: %+v", cfg.Database)
|
||||
}
|
||||
if cfg.Database.ConnMaxLifetime.String() != "2h0m0s" || cfg.Database.ConnMaxIdleTime.String() != "30m0s" {
|
||||
t.Fatalf("db duration override failed: %v %v", cfg.Database.ConnMaxLifetime, cfg.Database.ConnMaxIdleTime)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,12 +14,24 @@ type CacheConfig struct {
|
||||
Prefix string // 缓存键前缀
|
||||
Expiration time.Duration // 过期时间
|
||||
Enabled bool // 是否启用缓存
|
||||
Policy CachePolicy // 缓存策略(可选,不配置则回落到 Expiration)
|
||||
}
|
||||
|
||||
// CachePolicy 缓存策略,用于为不同实体设置默认 TTL
|
||||
type CachePolicy struct {
|
||||
UserTTL time.Duration
|
||||
UserEmailTTL time.Duration
|
||||
ProfileTTL time.Duration
|
||||
ProfileListTTL time.Duration
|
||||
TextureTTL time.Duration
|
||||
TextureListTTL time.Duration
|
||||
}
|
||||
|
||||
// CacheManager 缓存管理器
|
||||
type CacheManager struct {
|
||||
redis *redis.Client
|
||||
config CacheConfig
|
||||
Policy CachePolicy
|
||||
}
|
||||
|
||||
// NewCacheManager 创建缓存管理器
|
||||
@@ -31,9 +43,33 @@ func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManage
|
||||
config.Expiration = 5 * time.Minute
|
||||
}
|
||||
|
||||
// 填充默认策略(未配置时退回全局过期时间)
|
||||
applyPolicyDefaults := func(p *CachePolicy) {
|
||||
if p.UserTTL == 0 {
|
||||
p.UserTTL = config.Expiration
|
||||
}
|
||||
if p.UserEmailTTL == 0 {
|
||||
p.UserEmailTTL = config.Expiration
|
||||
}
|
||||
if p.ProfileTTL == 0 {
|
||||
p.ProfileTTL = config.Expiration
|
||||
}
|
||||
if p.ProfileListTTL == 0 {
|
||||
p.ProfileListTTL = config.Expiration
|
||||
}
|
||||
if p.TextureTTL == 0 {
|
||||
p.TextureTTL = config.Expiration
|
||||
}
|
||||
if p.TextureListTTL == 0 {
|
||||
p.TextureListTTL = config.Expiration
|
||||
}
|
||||
}
|
||||
applyPolicyDefaults(&config.Policy)
|
||||
|
||||
return &CacheManager{
|
||||
redis: redisClient,
|
||||
config: config,
|
||||
Policy: config.Policy,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +92,14 @@ func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) e
|
||||
return json.Unmarshal(data, dest)
|
||||
}
|
||||
|
||||
// TryGet 获取缓存,命中时返回 true,不视为错误
|
||||
func (cm *CacheManager) TryGet(ctx context.Context, key string, dest interface{}) (bool, error) {
|
||||
if err := cm.Get(ctx, key, dest); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
|
||||
if !cm.config.Enabled || cm.redis == nil {
|
||||
@@ -75,6 +119,13 @@ func (cm *CacheManager) Set(ctx context.Context, key string, value interface{},
|
||||
return cm.redis.Set(ctx, cm.buildKey(key), data, exp)
|
||||
}
|
||||
|
||||
// SetAsync 异步设置缓存,避免在主请求链路阻塞
|
||||
func (cm *CacheManager) SetAsync(ctx context.Context, key string, value interface{}, expiration ...time.Duration) {
|
||||
go func() {
|
||||
_ = cm.Set(ctx, key, value, expiration...)
|
||||
}()
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error {
|
||||
if !cm.config.Enabled || cm.redis == nil {
|
||||
@@ -187,11 +238,7 @@ func Cached[T any](
|
||||
}
|
||||
|
||||
// 设置缓存(异步,不阻塞)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = cache.Set(cacheCtx, key, data, expiration...)
|
||||
}()
|
||||
cache.SetAsync(context.Background(), key, data, expiration...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -217,11 +264,7 @@ func CachedList[T any](
|
||||
}
|
||||
|
||||
// 设置缓存(异步,不阻塞)
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = cache.Set(cacheCtx, key, data, expiration...)
|
||||
}()
|
||||
cache.SetAsync(context.Background(), key, data, expiration...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
@@ -306,6 +349,11 @@ func (b *CacheKeyBuilder) TextureList(userID int64, page int) string {
|
||||
return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page)
|
||||
}
|
||||
|
||||
// TextureListPattern 构建材质列表缓存键模式(用于批量失效)
|
||||
func (b *CacheKeyBuilder) TextureListPattern(userID int64) string {
|
||||
return fmt.Sprintf("%stexture:user:%d:*", b.prefix, userID)
|
||||
}
|
||||
|
||||
// Token 构建令牌缓存键
|
||||
func (b *CacheKeyBuilder) Token(accessToken string) string {
|
||||
return fmt.Sprintf("%stoken:%s", b.prefix, accessToken)
|
||||
|
||||
184
pkg/database/cache_test.go
Normal file
184
pkg/database/cache_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pkgRedis "carrotskin/pkg/redis"
|
||||
|
||||
miniredis "github.com/alicebob/miniredis/v2"
|
||||
goRedis "github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
func newCacheWithMiniRedis(t *testing.T) (*CacheManager, func()) {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start miniredis: %v", err)
|
||||
}
|
||||
|
||||
rdb := goRedis.NewClient(&goRedis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
client := &pkgRedis.Client{Client: rdb}
|
||||
|
||||
cache := NewCacheManager(client, CacheConfig{
|
||||
Prefix: "t:",
|
||||
Expiration: time.Minute,
|
||||
Enabled: true,
|
||||
Policy: CachePolicy{
|
||||
UserTTL: 2 * time.Minute,
|
||||
UserEmailTTL: 3 * time.Minute,
|
||||
ProfileTTL: 2 * time.Minute,
|
||||
ProfileListTTL: 90 * time.Second,
|
||||
TextureTTL: 2 * time.Minute,
|
||||
TextureListTTL: 45 * time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
cleanup := func() {
|
||||
_ = rdb.Close()
|
||||
mr.Close()
|
||||
}
|
||||
return cache, cleanup
|
||||
}
|
||||
|
||||
func TestCacheManager_GetSet_TryGet(t *testing.T) {
|
||||
cache, cleanup := newCacheWithMiniRedis(t)
|
||||
defer cleanup()
|
||||
|
||||
ctx := context.Background()
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
u := User{ID: 1, Name: "alice"}
|
||||
if err := cache.Set(ctx, "user:1", u, 10*time.Second); err != nil {
|
||||
t.Fatalf("Set err: %v", err)
|
||||
}
|
||||
|
||||
var got User
|
||||
if err := cache.Get(ctx, "user:1", &got); err != nil {
|
||||
t.Fatalf("Get err: %v", err)
|
||||
}
|
||||
if got != u {
|
||||
t.Fatalf("unexpected value: %+v", got)
|
||||
}
|
||||
|
||||
var got2 User
|
||||
ok, err := cache.TryGet(ctx, "user:1", &got2)
|
||||
if err != nil || !ok {
|
||||
t.Fatalf("TryGet failed, ok=%v err=%v", ok, err)
|
||||
}
|
||||
if got2 != u {
|
||||
t.Fatalf("unexpected TryGet: %+v", got2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheManager_DeletePattern(t *testing.T) {
|
||||
cache, cleanup := newCacheWithMiniRedis(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
_ = cache.Set(ctx, "user:1", "a", 0)
|
||||
_ = cache.Set(ctx, "user:2", "b", 0)
|
||||
_ = cache.Set(ctx, "profile:1", "c", 0)
|
||||
|
||||
// 删除 user:* 键
|
||||
if err := cache.DeletePattern(ctx, "user:*"); err != nil {
|
||||
t.Fatalf("DeletePattern err: %v", err)
|
||||
}
|
||||
|
||||
var v string
|
||||
ok, _ := cache.TryGet(ctx, "user:1", &v)
|
||||
if ok {
|
||||
t.Fatalf("expected user:1 deleted")
|
||||
}
|
||||
ok, _ = cache.TryGet(ctx, "user:2", &v)
|
||||
if ok {
|
||||
t.Fatalf("expected user:2 deleted")
|
||||
}
|
||||
ok, _ = cache.TryGet(ctx, "profile:1", &v)
|
||||
if !ok {
|
||||
t.Fatalf("expected profile:1 kept")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedAndCachedList(t *testing.T) {
|
||||
cache, cleanup := newCacheWithMiniRedis(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
callCount := 0
|
||||
result, err := Cached(ctx, cache, "key1", func() (*string, error) {
|
||||
callCount++
|
||||
val := "hello"
|
||||
return &val, nil
|
||||
}, cache.Policy.UserTTL)
|
||||
if err != nil || *result != "hello" || callCount != 1 {
|
||||
t.Fatalf("Cached first call failed")
|
||||
}
|
||||
// 等待缓存写入完成
|
||||
for i := 0; i < 10; i++ {
|
||||
var tmp string
|
||||
if ok, _ := cache.TryGet(ctx, "key1", &tmp); ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// 第二次应命中缓存
|
||||
_, err = Cached(ctx, cache, "key1", func() (*string, error) {
|
||||
callCount++
|
||||
val := "world"
|
||||
return &val, nil
|
||||
}, cache.Policy.UserTTL)
|
||||
if err != nil || callCount != 1 {
|
||||
t.Fatalf("Cached should hit cache, callCount=%d err=%v", callCount, err)
|
||||
}
|
||||
|
||||
listCall := 0
|
||||
_, err = CachedList(ctx, cache, "list", func() ([]string, error) {
|
||||
listCall++
|
||||
return []string{"a", "b"}, nil
|
||||
}, cache.Policy.ProfileListTTL)
|
||||
if err != nil || listCall != 1 {
|
||||
t.Fatalf("CachedList first call failed")
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
var tmp []string
|
||||
if ok, _ := cache.TryGet(ctx, "list", &tmp); ok {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
_, err = CachedList(ctx, cache, "list", func() ([]string, error) {
|
||||
listCall++
|
||||
return []string{"c"}, nil
|
||||
}, cache.Policy.ProfileListTTL)
|
||||
if err != nil || listCall != 1 {
|
||||
t.Fatalf("CachedList should hit cache, calls=%d err=%v", listCall, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementWithExpire(t *testing.T) {
|
||||
cache, cleanup := newCacheWithMiniRedis(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
val, err := cache.IncrementWithExpire(ctx, "counter", time.Second)
|
||||
if err != nil || val != 1 {
|
||||
t.Fatalf("first increment failed, val=%d err=%v", val, err)
|
||||
}
|
||||
val, err = cache.IncrementWithExpire(ctx, "counter", time.Second)
|
||||
if err != nil || val != 2 {
|
||||
t.Fatalf("second increment failed, val=%d err=%v", val, err)
|
||||
}
|
||||
ttl, err := cache.TTL(ctx, "counter")
|
||||
if err != nil || ttl <= 0 {
|
||||
t.Fatalf("TTL not set: ttl=%v err=%v", ttl, err)
|
||||
}
|
||||
}
|
||||
@@ -75,7 +75,6 @@ func AutoMigrate(logger *zap.Logger) error {
|
||||
&model.TextureDownloadLog{},
|
||||
|
||||
// 认证相关表
|
||||
&model.Token{},
|
||||
&model.Client{}, // Client表用于管理Token版本
|
||||
|
||||
// Yggdrasil相关表(在User之后创建,因为它引用User)
|
||||
|
||||
24
pkg/database/manager_sqlite_test.go
Normal file
24
pkg/database/manager_sqlite_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 使用内存 sqlite 验证 AutoMigrate 关键路径,无需真实 Postgres
|
||||
func TestAutoMigrate_WithSQLite(t *testing.T) {
|
||||
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite err: %v", err)
|
||||
}
|
||||
dbInstance = db
|
||||
defer func() { dbInstance = nil }()
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
if err := AutoMigrate(logger); err != nil {
|
||||
t.Fatalf("AutoMigrate sqlite err: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
|
||||
// TestGetDB_NotInitialized 测试未初始化时获取数据库实例
|
||||
func TestGetDB_NotInitialized(t *testing.T) {
|
||||
dbInstance = nil
|
||||
_, err := GetDB()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
|
||||
expectedError := "数据库未初始化,请先调用 database.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
@@ -22,17 +23,19 @@ func TestGetDB_NotInitialized(t *testing.T) {
|
||||
|
||||
// TestMustGetDB_Panic 测试MustGetDB在未初始化时panic
|
||||
func TestMustGetDB_Panic(t *testing.T) {
|
||||
dbInstance = nil
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetDB 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
_ = MustGetDB()
|
||||
}
|
||||
|
||||
// TestInit_Database 测试数据库初始化逻辑
|
||||
func TestInit_Database(t *testing.T) {
|
||||
dbInstance = nil
|
||||
cfg := config.DatabaseConfig{
|
||||
Driver: "postgres",
|
||||
Host: "localhost",
|
||||
@@ -46,21 +49,21 @@ func TestInit_Database(t *testing.T) {
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
// 注意:实际连接可能失败,这是可以接受的
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Logf("Init() 返回错误(可能正常,如果数据库未运行): %v", err)
|
||||
t.Skipf("数据库未运行,跳过连接测试: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoMigrate_ErrorHandling 测试AutoMigrate的错误处理逻辑
|
||||
func TestAutoMigrate_ErrorHandling(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
|
||||
// 测试未初始化时的错误处理
|
||||
err := AutoMigrate(logger)
|
||||
if err == nil {
|
||||
@@ -82,4 +85,3 @@ func TestClose_NotInitialized(t *testing.T) {
|
||||
t.Errorf("Close() 在未初始化时应该返回nil,实际返回: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
56
pkg/email/email_test.go
Normal file
56
pkg/email/email_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func resetEmailOnce() {
|
||||
serviceInstance = nil
|
||||
once = sync.Once{}
|
||||
}
|
||||
|
||||
func TestEmailManager_Disabled(t *testing.T) {
|
||||
resetEmailOnce()
|
||||
cfg := config.EmailConfig{Enabled: false}
|
||||
if err := Init(cfg, zap.NewNop()); err != nil {
|
||||
t.Fatalf("Init disabled err: %v", err)
|
||||
}
|
||||
svc := MustGetService()
|
||||
if err := svc.SendVerificationCode("to@test.com", "123456", "email_verification"); err == nil {
|
||||
t.Fatalf("expected error when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailManager_SendFailsWithInvalidSMTP(t *testing.T) {
|
||||
resetEmailOnce()
|
||||
cfg := config.EmailConfig{
|
||||
Enabled: true,
|
||||
SMTPHost: "127.0.0.1",
|
||||
SMTPPort: 1, // invalid/closed port to trigger error quickly
|
||||
Username: "user",
|
||||
Password: "pwd",
|
||||
FromName: "name",
|
||||
}
|
||||
_ = Init(cfg, zap.NewNop())
|
||||
svc := MustGetService()
|
||||
if err := svc.SendVerificationCode("to@test.com", "123456", "reset_password"); err == nil {
|
||||
t.Fatalf("expected send error with invalid smtp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmailManager_SubjectAndBody(t *testing.T) {
|
||||
svc := &Service{cfg: config.EmailConfig{FromName: "name", Username: "user"}, logger: zap.NewNop()}
|
||||
if subj := svc.getSubject("email_verification"); subj == "" {
|
||||
t.Fatalf("subject empty")
|
||||
}
|
||||
body := svc.getBody("123456", "change_email")
|
||||
if !strings.Contains(body, "123456") || !strings.Contains(body, "更换邮箱") {
|
||||
t.Fatalf("body content mismatch")
|
||||
}
|
||||
}
|
||||
@@ -2,18 +2,25 @@ package email
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
func resetEmail() {
|
||||
serviceInstance = nil
|
||||
once = sync.Once{}
|
||||
}
|
||||
|
||||
// TestGetService_NotInitialized 测试未初始化时获取邮件服务
|
||||
func TestGetService_NotInitialized(t *testing.T) {
|
||||
resetEmail()
|
||||
_, err := GetService()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
|
||||
expectedError := "邮件服务未初始化,请先调用 email.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
@@ -22,33 +29,35 @@ func TestGetService_NotInitialized(t *testing.T) {
|
||||
|
||||
// TestMustGetService_Panic 测试MustGetService在未初始化时panic
|
||||
func TestMustGetService_Panic(t *testing.T) {
|
||||
resetEmail()
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetService 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
_ = MustGetService()
|
||||
}
|
||||
|
||||
// TestInit_Email 测试邮件服务初始化
|
||||
func TestInit_Email(t *testing.T) {
|
||||
resetEmail()
|
||||
cfg := config.EmailConfig{
|
||||
Enabled: false,
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
FromName: "noreply@example.com",
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
FromName: "noreply@example.com",
|
||||
}
|
||||
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Errorf("Init() 错误 = %v, want nil", err)
|
||||
}
|
||||
|
||||
|
||||
// 验证可以获取服务
|
||||
service, err := GetService()
|
||||
if err != nil {
|
||||
@@ -58,4 +67,3 @@ func TestInit_Email(t *testing.T) {
|
||||
t.Error("GetService() 返回的服务不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,11 @@ package redis
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
redis9 "github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -15,19 +18,69 @@ var (
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
// miniredisInstance 用于测试/开发环境
|
||||
miniredisInstance *miniredis.Miniredis
|
||||
)
|
||||
|
||||
// Init 初始化Redis客户端(线程安全,只会执行一次)
|
||||
// 如果Redis连接失败且环境为测试/开发,则回退到miniredis
|
||||
func Init(cfg config.RedisConfig, logger *zap.Logger) error {
|
||||
var err error
|
||||
once.Do(func() {
|
||||
clientInstance, initError = New(cfg, logger)
|
||||
if initError != nil {
|
||||
return
|
||||
// 尝试连接真实Redis
|
||||
clientInstance, err = New(cfg, logger)
|
||||
if err != nil {
|
||||
logger.Warn("Redis连接失败,尝试使用miniredis回退", zap.Error(err))
|
||||
|
||||
// 检查是否允许回退到miniredis(仅开发/测试环境)
|
||||
if allowFallbackToMiniRedis() {
|
||||
clientInstance, err = initMiniRedis(logger)
|
||||
if err != nil {
|
||||
initError = fmt.Errorf("Redis和miniredis都初始化失败: %w", err)
|
||||
logger.Error("miniredis初始化失败", zap.Error(initError))
|
||||
return
|
||||
}
|
||||
logger.Info("已回退到miniredis用于开发/测试环境")
|
||||
} else {
|
||||
initError = fmt.Errorf("Redis连接失败且不允许回退: %w", err)
|
||||
logger.Error("Redis连接失败", zap.Error(initError))
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// allowFallbackToMiniRedis 检查是否允许回退到miniredis
|
||||
func allowFallbackToMiniRedis() bool {
|
||||
// 检查环境变量
|
||||
env := os.Getenv("ENVIRONMENT")
|
||||
return env == "development" || env == "test" || env == "dev" ||
|
||||
os.Getenv("USE_MINIREDIS") == "true"
|
||||
}
|
||||
|
||||
// initMiniRedis 初始化miniredis(用于开发/测试环境)
|
||||
func initMiniRedis(logger *zap.Logger) (*Client, error) {
|
||||
var err error
|
||||
miniredisInstance, err = miniredis.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("启动miniredis失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建Redis客户端连接到miniredis
|
||||
redisClient := redis9.NewClient(&redis9.Options{
|
||||
Addr: miniredisInstance.Addr(),
|
||||
})
|
||||
|
||||
client := &Client{
|
||||
Client: redisClient,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
logger.Info("miniredis已启动", zap.String("addr", miniredisInstance.Addr()))
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetClient 获取Redis客户端实例(线程安全)
|
||||
func GetClient() (*Client, error) {
|
||||
if clientInstance == nil {
|
||||
@@ -45,13 +98,21 @@ func MustGetClient() *Client {
|
||||
return client
|
||||
}
|
||||
|
||||
// Close 关闭Redis连接(包括miniredis如果使用了)
|
||||
func Close() error {
|
||||
var err error
|
||||
if miniredisInstance != nil {
|
||||
miniredisInstance.Close()
|
||||
miniredisInstance = nil
|
||||
}
|
||||
if clientInstance != nil {
|
||||
err = clientInstance.Close()
|
||||
clientInstance = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// IsUsingMiniRedis 检查是否使用了miniredis
|
||||
func IsUsingMiniRedis() bool {
|
||||
return miniredisInstance != nil
|
||||
}
|
||||
|
||||
71
pkg/storage/minio_test.go
Normal file
71
pkg/storage/minio_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
)
|
||||
|
||||
// 使用 nil client 仅测试纯函数和错误分支
|
||||
func TestStorage_GetBucketAndBuildURL(t *testing.T) {
|
||||
s := &StorageClient{
|
||||
client: (*minio.Client)(nil),
|
||||
buckets: map[string]string{"textures": "tex-bkt"},
|
||||
publicURL: "http://localhost:9000",
|
||||
}
|
||||
|
||||
if b, err := s.GetBucket("textures"); err != nil || b != "tex-bkt" {
|
||||
t.Fatalf("GetBucket mismatch: %v %s", err, b)
|
||||
}
|
||||
if _, err := s.GetBucket("missing"); err == nil {
|
||||
t.Fatalf("expected error for missing bucket")
|
||||
}
|
||||
|
||||
if url := s.BuildFileURL("tex-bkt", "obj"); url != "http://localhost:9000/tex-bkt/obj" {
|
||||
t.Fatalf("BuildFileURL mismatch: %s", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStorage_SkipConnectWhenNoCreds(t *testing.T) {
|
||||
// 当 AccessKey/Secret 为空时跳过 ListBuckets 测试,避免真实依赖
|
||||
cfg := config.RustFSConfig{
|
||||
Endpoint: "127.0.0.1:9000",
|
||||
Buckets: map[string]string{"avatars": "ava", "textures": "tex"},
|
||||
UseSSL: false,
|
||||
}
|
||||
if _, err := NewStorage(cfg); err != nil {
|
||||
t.Fatalf("NewStorage should not error when creds empty: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPresignedHelpers_WithNilClient(t *testing.T) {
|
||||
s := &StorageClient{
|
||||
client: (*minio.Client)(nil),
|
||||
buckets: map[string]string{"textures": "tex-bkt"},
|
||||
publicURL: "http://localhost:9000",
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 预期会panic(nil client),用recover捕获
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("GeneratePresignedURL expected panic with nil client")
|
||||
}
|
||||
}()
|
||||
_, _ = s.GeneratePresignedURL(ctx, "tex-bkt", "obj", time.Minute)
|
||||
}()
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatalf("GeneratePresignedPostURL expected panic with nil client")
|
||||
}
|
||||
}()
|
||||
_, _ = s.GeneratePresignedPostURL(ctx, "tex-bkt", "obj", 0, 10, time.Minute)
|
||||
}()
|
||||
}
|
||||
Reference in New Issue
Block a user