package auth import ( "context" "encoding/json" "fmt" "time" "carrotskin/pkg/redis" "go.uber.org/zap" ) // TokenMetadata Token元数据(存储在Redis中) type TokenMetadata struct { UserID int64 `json:"user_id"` ProfileID string `json:"profile_id"` ClientUUID string `json:"client_uuid"` ClientToken string `json:"client_token"` Version int `json:"version"` CreatedAt int64 `json:"created_at"` } // TokenStoreRedis Redis Token存储实现 type TokenStoreRedis struct { redis *redis.Client logger *zap.Logger keyPrefix string defaultTTL time.Duration staleTTL time.Duration maxTokensPerUser int } // NewTokenStoreRedis 创建Redis Token存储 func NewTokenStoreRedis( redisClient *redis.Client, logger *zap.Logger, opts ...TokenStoreOption, ) *TokenStoreRedis { options := &tokenStoreOptions{ keyPrefix: "token:", defaultTTL: 24 * time.Hour, staleTTL: 30 * 24 * time.Hour, maxTokensPerUser: 10, } for _, opt := range opts { opt(options) } return &TokenStoreRedis{ redis: redisClient, logger: logger, keyPrefix: options.keyPrefix, defaultTTL: options.defaultTTL, staleTTL: options.staleTTL, maxTokensPerUser: options.maxTokensPerUser, } } // tokenStoreOptions Token存储配置选项 type tokenStoreOptions struct { keyPrefix string defaultTTL time.Duration staleTTL time.Duration maxTokensPerUser int } // TokenStoreOption Token存储配置选项函数 type TokenStoreOption func(*tokenStoreOptions) // WithKeyPrefix 设置Key前缀 func WithKeyPrefix(prefix string) TokenStoreOption { return func(o *tokenStoreOptions) { o.keyPrefix = prefix } } // WithDefaultTTL 设置默认TTL func WithDefaultTTL(ttl time.Duration) TokenStoreOption { return func(o *tokenStoreOptions) { o.defaultTTL = ttl } } // WithStaleTTL 设置过期但可用时间 func WithStaleTTL(ttl time.Duration) TokenStoreOption { return func(o *tokenStoreOptions) { o.staleTTL = ttl } } // WithMaxTokensPerUser 设置每个用户的最大Token数 func WithMaxTokensPerUser(max int) TokenStoreOption { return func(o *tokenStoreOptions) { o.maxTokensPerUser = max } } // Store 存储Token func (s *TokenStoreRedis) Store(ctx context.Context, accessToken string, metadata *TokenMetadata, ttl time.Duration) error { if ttl <= 0 { ttl = s.defaultTTL } // 序列化元数据 data, err := json.Marshal(metadata) if err != nil { return fmt.Errorf("序列化Token元数据失败: %w", err) } // 存储Token tokenKey := s.getTokenKey(accessToken) if err := s.redis.Set(ctx, tokenKey, data, ttl); err != nil { return fmt.Errorf("存储Token失败: %w", err) } // 添加到用户Token集合 userTokensKey := s.getUserTokensKey(metadata.UserID) if err := s.redis.SAdd(ctx, userTokensKey, accessToken); err != nil { return fmt.Errorf("添加到用户Token集合失败: %w", err) } // 清理过期Token(后台执行) go s.cleanupUserTokens(context.Background(), metadata.UserID) s.logger.Debug("Token已存储", zap.String("token", accessToken[:20]+"..."), zap.Int64("userId", metadata.UserID), zap.Duration("ttl", ttl), ) return nil } // Retrieve 获取Token元数据 func (s *TokenStoreRedis) Retrieve(ctx context.Context, accessToken string) (*TokenMetadata, error) { tokenKey := s.getTokenKey(accessToken) data, err := s.redis.Get(ctx, tokenKey) if err != nil { return nil, fmt.Errorf("获取Token失败: %w", err) } var metadata TokenMetadata if err := json.Unmarshal([]byte(data), &metadata); err != nil { return nil, fmt.Errorf("解析Token元数据失败: %w", err) } return &metadata, nil } // Delete 删除Token func (s *TokenStoreRedis) Delete(ctx context.Context, accessToken string) error { tokenKey := s.getTokenKey(accessToken) // 先获取Token元数据以获取UserID metadata, err := s.Retrieve(ctx, accessToken) if err != nil { // Token可能已过期,忽略错误 return nil } // 删除Token if err := s.redis.Del(ctx, tokenKey); err != nil { return fmt.Errorf("删除Token失败: %w", err) } // 从用户Token集合中移除 userTokensKey := s.getUserTokensKey(metadata.UserID) if err := s.redis.SRem(ctx, userTokensKey, accessToken); err != nil { return fmt.Errorf("从用户Token集合移除失败: %w", err) } s.logger.Debug("Token已删除", zap.String("token", accessToken[:20]+"..."), zap.Int64("userId", metadata.UserID), ) return nil } // DeleteByUserID 删除用户的所有Token func (s *TokenStoreRedis) DeleteByUserID(ctx context.Context, userID int64) error { userTokensKey := s.getUserTokensKey(userID) // 获取用户所有Token tokens, err := s.redis.SMembers(ctx, userTokensKey) if err != nil { return fmt.Errorf("获取用户Token列表失败: %w", err) } // 删除所有Token if len(tokens) > 0 { tokenKeys := make([]string, len(tokens)) for i, token := range tokens { tokenKeys[i] = s.getTokenKey(token) } if err := s.redis.Del(ctx, tokenKeys...); err != nil { return fmt.Errorf("批量删除Token失败: %w", err) } } // 删除用户Token集合 if err := s.redis.Del(ctx, userTokensKey); err != nil { return fmt.Errorf("删除用户Token集合失败: %w", err) } s.logger.Info("用户所有Token已删除", zap.Int64("userId", userID), zap.Int("count", len(tokens)), ) return nil } // Exists 检查Token是否存在 func (s *TokenStoreRedis) Exists(ctx context.Context, accessToken string) (bool, error) { tokenKey := s.getTokenKey(accessToken) count, err := s.redis.Exists(ctx, tokenKey) if err != nil { return false, fmt.Errorf("检查Token存在失败: %w", err) } return count > 0, nil } // GetTTL 获取Token的剩余TTL func (s *TokenStoreRedis) GetTTL(ctx context.Context, accessToken string) (time.Duration, error) { tokenKey := s.getTokenKey(accessToken) return s.redis.TTL(ctx, tokenKey) } // RefreshTTL 刷新Token的TTL func (s *TokenStoreRedis) RefreshTTL(ctx context.Context, accessToken string, ttl time.Duration) error { if ttl <= 0 { ttl = s.defaultTTL } tokenKey := s.getTokenKey(accessToken) if err := s.redis.Expire(ctx, tokenKey, ttl); err != nil { return fmt.Errorf("刷新Token TTL失败: %w", err) } return nil } // GetCountByUser 获取用户的Token数量 func (s *TokenStoreRedis) GetCountByUser(ctx context.Context, userID int64) (int64, error) { userTokensKey := s.getUserTokensKey(userID) count, err := s.redis.SMembers(ctx, userTokensKey) if err != nil { return 0, fmt.Errorf("获取用户Token数量失败: %w", err) } return int64(len(count)), nil } // cleanupUserTokens 清理用户的过期Token(保留最新的N个) func (s *TokenStoreRedis) cleanupUserTokens(ctx context.Context, userID int64) { userTokensKey := s.getUserTokensKey(userID) // 获取用户所有Token tokens, err := s.redis.SMembers(ctx, userTokensKey) if err != nil { s.logger.Error("获取用户Token列表失败", zap.Error(err), zap.Int64("userId", userID)) return } // 清理过期的Token(验证它们是否仍存在) validTokens := make([]string, 0, len(tokens)) for _, token := range tokens { tokenKey := s.getTokenKey(token) exists, err := s.redis.Exists(ctx, tokenKey) if err != nil { s.logger.Error("检查Token存在失败", zap.Error(err), zap.String("token", token[:20]+"...")) continue } if exists > 0 { validTokens = append(validTokens, token) } } // 如果没有变化,直接返回 if len(validTokens) == len(tokens) { return } // 更新用户Token集合 if len(validTokens) == 0 { s.redis.Del(ctx, userTokensKey) } else { // 重新设置集合 s.redis.Del(ctx, userTokensKey) for _, token := range validTokens { s.redis.SAdd(ctx, userTokensKey, token) } } // 如果超过限制,删除最旧的Token(这里简化处理,可以根据createdAt排序) if len(validTokens) > s.maxTokensPerUser { tokensToDelete := validTokens[s.maxTokensPerUser:] for _, token := range tokensToDelete { s.Delete(ctx, token) } s.logger.Info("清理用户多余Token", zap.Int64("userId", userID), zap.Int("deleted", len(tokensToDelete)), ) } } // getTokenKey 生成Token的Redis Key func (s *TokenStoreRedis) getTokenKey(accessToken string) string { return s.keyPrefix + accessToken } // getUserTokensKey 生成用户Token集合的Redis Key func (s *TokenStoreRedis) getUserTokensKey(userID int64) string { return fmt.Sprintf("user:%d:tokens", userID) }