Files
backend/pkg/auth/token_redis.go
lan 6ddcf92ce3 refactor: Remove Token management and integrate Redis for authentication
- Deleted the Token model and its repository, transitioning to a Redis-based token management system.
- Updated the service layer to utilize Redis for token storage, enhancing performance and scalability.
- Refactored the container to remove TokenRepository and integrate the new token service.
- Cleaned up the Dockerfile and other files by removing unnecessary whitespace and comments.
- Enhanced error handling and logging for Redis initialization and usage.
2025-12-24 16:03:46 +08:00

321 lines
8.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package auth
import (
"context"
"encoding/json"
"fmt"
"time"
"carrotskin/pkg/redis"
"go.uber.org/zap"
)
// TokenMetadata Token元数据存储在Redis中
type TokenMetadata struct {
UserID int64 `json:"user_id"`
ProfileID string `json:"profile_id"`
ClientUUID string `json:"client_uuid"`
ClientToken string `json:"client_token"`
Version int `json:"version"`
CreatedAt int64 `json:"created_at"`
}
// TokenStoreRedis Redis Token存储实现
type TokenStoreRedis struct {
redis *redis.Client
logger *zap.Logger
keyPrefix string
defaultTTL time.Duration
staleTTL time.Duration
maxTokensPerUser int
}
// NewTokenStoreRedis 创建Redis Token存储
func NewTokenStoreRedis(
redisClient *redis.Client,
logger *zap.Logger,
opts ...TokenStoreOption,
) *TokenStoreRedis {
options := &tokenStoreOptions{
keyPrefix: "token:",
defaultTTL: 24 * time.Hour,
staleTTL: 30 * 24 * time.Hour,
maxTokensPerUser: 10,
}
for _, opt := range opts {
opt(options)
}
return &TokenStoreRedis{
redis: redisClient,
logger: logger,
keyPrefix: options.keyPrefix,
defaultTTL: options.defaultTTL,
staleTTL: options.staleTTL,
maxTokensPerUser: options.maxTokensPerUser,
}
}
// tokenStoreOptions Token存储配置选项
type tokenStoreOptions struct {
keyPrefix string
defaultTTL time.Duration
staleTTL time.Duration
maxTokensPerUser int
}
// TokenStoreOption Token存储配置选项函数
type TokenStoreOption func(*tokenStoreOptions)
// WithKeyPrefix 设置Key前缀
func WithKeyPrefix(prefix string) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.keyPrefix = prefix
}
}
// WithDefaultTTL 设置默认TTL
func WithDefaultTTL(ttl time.Duration) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.defaultTTL = ttl
}
}
// WithStaleTTL 设置过期但可用时间
func WithStaleTTL(ttl time.Duration) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.staleTTL = ttl
}
}
// WithMaxTokensPerUser 设置每个用户的最大Token数
func WithMaxTokensPerUser(max int) TokenStoreOption {
return func(o *tokenStoreOptions) {
o.maxTokensPerUser = max
}
}
// Store 存储Token
func (s *TokenStoreRedis) Store(ctx context.Context, accessToken string, metadata *TokenMetadata, ttl time.Duration) error {
if ttl <= 0 {
ttl = s.defaultTTL
}
// 序列化元数据
data, err := json.Marshal(metadata)
if err != nil {
return fmt.Errorf("序列化Token元数据失败: %w", err)
}
// 存储Token
tokenKey := s.getTokenKey(accessToken)
if err := s.redis.Set(ctx, tokenKey, data, ttl); err != nil {
return fmt.Errorf("存储Token失败: %w", err)
}
// 添加到用户Token集合
userTokensKey := s.getUserTokensKey(metadata.UserID)
if err := s.redis.SAdd(ctx, userTokensKey, accessToken); err != nil {
return fmt.Errorf("添加到用户Token集合失败: %w", err)
}
// 清理过期Token后台执行
go s.cleanupUserTokens(context.Background(), metadata.UserID)
s.logger.Debug("Token已存储",
zap.String("token", accessToken[:20]+"..."),
zap.Int64("userId", metadata.UserID),
zap.Duration("ttl", ttl),
)
return nil
}
// Retrieve 获取Token元数据
func (s *TokenStoreRedis) Retrieve(ctx context.Context, accessToken string) (*TokenMetadata, error) {
tokenKey := s.getTokenKey(accessToken)
data, err := s.redis.Get(ctx, tokenKey)
if err != nil {
return nil, fmt.Errorf("获取Token失败: %w", err)
}
var metadata TokenMetadata
if err := json.Unmarshal([]byte(data), &metadata); err != nil {
return nil, fmt.Errorf("解析Token元数据失败: %w", err)
}
return &metadata, nil
}
// Delete 删除Token
func (s *TokenStoreRedis) Delete(ctx context.Context, accessToken string) error {
tokenKey := s.getTokenKey(accessToken)
// 先获取Token元数据以获取UserID
metadata, err := s.Retrieve(ctx, accessToken)
if err != nil {
// Token可能已过期忽略错误
return nil
}
// 删除Token
if err := s.redis.Del(ctx, tokenKey); err != nil {
return fmt.Errorf("删除Token失败: %w", err)
}
// 从用户Token集合中移除
userTokensKey := s.getUserTokensKey(metadata.UserID)
if err := s.redis.SRem(ctx, userTokensKey, accessToken); err != nil {
return fmt.Errorf("从用户Token集合移除失败: %w", err)
}
s.logger.Debug("Token已删除",
zap.String("token", accessToken[:20]+"..."),
zap.Int64("userId", metadata.UserID),
)
return nil
}
// DeleteByUserID 删除用户的所有Token
func (s *TokenStoreRedis) DeleteByUserID(ctx context.Context, userID int64) error {
userTokensKey := s.getUserTokensKey(userID)
// 获取用户所有Token
tokens, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
return fmt.Errorf("获取用户Token列表失败: %w", err)
}
// 删除所有Token
if len(tokens) > 0 {
tokenKeys := make([]string, len(tokens))
for i, token := range tokens {
tokenKeys[i] = s.getTokenKey(token)
}
if err := s.redis.Del(ctx, tokenKeys...); err != nil {
return fmt.Errorf("批量删除Token失败: %w", err)
}
}
// 删除用户Token集合
if err := s.redis.Del(ctx, userTokensKey); err != nil {
return fmt.Errorf("删除用户Token集合失败: %w", err)
}
s.logger.Info("用户所有Token已删除",
zap.Int64("userId", userID),
zap.Int("count", len(tokens)),
)
return nil
}
// Exists 检查Token是否存在
func (s *TokenStoreRedis) Exists(ctx context.Context, accessToken string) (bool, error) {
tokenKey := s.getTokenKey(accessToken)
count, err := s.redis.Exists(ctx, tokenKey)
if err != nil {
return false, fmt.Errorf("检查Token存在失败: %w", err)
}
return count > 0, nil
}
// GetTTL 获取Token的剩余TTL
func (s *TokenStoreRedis) GetTTL(ctx context.Context, accessToken string) (time.Duration, error) {
tokenKey := s.getTokenKey(accessToken)
return s.redis.TTL(ctx, tokenKey)
}
// RefreshTTL 刷新Token的TTL
func (s *TokenStoreRedis) RefreshTTL(ctx context.Context, accessToken string, ttl time.Duration) error {
if ttl <= 0 {
ttl = s.defaultTTL
}
tokenKey := s.getTokenKey(accessToken)
if err := s.redis.Expire(ctx, tokenKey, ttl); err != nil {
return fmt.Errorf("刷新Token TTL失败: %w", err)
}
return nil
}
// GetCountByUser 获取用户的Token数量
func (s *TokenStoreRedis) GetCountByUser(ctx context.Context, userID int64) (int64, error) {
userTokensKey := s.getUserTokensKey(userID)
count, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
return 0, fmt.Errorf("获取用户Token数量失败: %w", err)
}
return int64(len(count)), nil
}
// cleanupUserTokens 清理用户的过期Token保留最新的N个
func (s *TokenStoreRedis) cleanupUserTokens(ctx context.Context, userID int64) {
userTokensKey := s.getUserTokensKey(userID)
// 获取用户所有Token
tokens, err := s.redis.SMembers(ctx, userTokensKey)
if err != nil {
s.logger.Error("获取用户Token列表失败", zap.Error(err), zap.Int64("userId", userID))
return
}
// 清理过期的Token验证它们是否仍存在
validTokens := make([]string, 0, len(tokens))
for _, token := range tokens {
tokenKey := s.getTokenKey(token)
exists, err := s.redis.Exists(ctx, tokenKey)
if err != nil {
s.logger.Error("检查Token存在失败", zap.Error(err), zap.String("token", token[:20]+"..."))
continue
}
if exists > 0 {
validTokens = append(validTokens, token)
}
}
// 如果没有变化,直接返回
if len(validTokens) == len(tokens) {
return
}
// 更新用户Token集合
if len(validTokens) == 0 {
s.redis.Del(ctx, userTokensKey)
} else {
// 重新设置集合
s.redis.Del(ctx, userTokensKey)
for _, token := range validTokens {
s.redis.SAdd(ctx, userTokensKey, token)
}
}
// 如果超过限制删除最旧的Token这里简化处理可以根据createdAt排序
if len(validTokens) > s.maxTokensPerUser {
tokensToDelete := validTokens[s.maxTokensPerUser:]
for _, token := range tokensToDelete {
s.Delete(ctx, token)
}
s.logger.Info("清理用户多余Token",
zap.Int64("userId", userID),
zap.Int("deleted", len(tokensToDelete)),
)
}
}
// getTokenKey 生成Token的Redis Key
func (s *TokenStoreRedis) getTokenKey(accessToken string) string {
return s.keyPrefix + accessToken
}
// getUserTokensKey 生成用户Token集合的Redis Key
func (s *TokenStoreRedis) getUserTokensKey(userID int64) string {
return fmt.Sprintf("user:%d:tokens", userID)
}