321 lines
8.4 KiB
Go
321 lines
8.4 KiB
Go
|
|
package auth
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"carrotskin/pkg/redis"
|
|||
|
|
|
|||
|
|
"go.uber.org/zap"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TokenMetadata Token元数据(存储在Redis中)
|
|||
|
|
type TokenMetadata struct {
|
|||
|
|
UserID int64 `json:"user_id"`
|
|||
|
|
ProfileID string `json:"profile_id"`
|
|||
|
|
ClientUUID string `json:"client_uuid"`
|
|||
|
|
ClientToken string `json:"client_token"`
|
|||
|
|
Version int `json:"version"`
|
|||
|
|
CreatedAt int64 `json:"created_at"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TokenStoreRedis Redis Token存储实现
|
|||
|
|
type TokenStoreRedis struct {
|
|||
|
|
redis *redis.Client
|
|||
|
|
logger *zap.Logger
|
|||
|
|
keyPrefix string
|
|||
|
|
defaultTTL time.Duration
|
|||
|
|
staleTTL time.Duration
|
|||
|
|
maxTokensPerUser int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTokenStoreRedis 创建Redis Token存储
|
|||
|
|
func NewTokenStoreRedis(
|
|||
|
|
redisClient *redis.Client,
|
|||
|
|
logger *zap.Logger,
|
|||
|
|
opts ...TokenStoreOption,
|
|||
|
|
) *TokenStoreRedis {
|
|||
|
|
options := &tokenStoreOptions{
|
|||
|
|
keyPrefix: "token:",
|
|||
|
|
defaultTTL: 24 * time.Hour,
|
|||
|
|
staleTTL: 30 * 24 * time.Hour,
|
|||
|
|
maxTokensPerUser: 10,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, opt := range opts {
|
|||
|
|
opt(options)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &TokenStoreRedis{
|
|||
|
|
redis: redisClient,
|
|||
|
|
logger: logger,
|
|||
|
|
keyPrefix: options.keyPrefix,
|
|||
|
|
defaultTTL: options.defaultTTL,
|
|||
|
|
staleTTL: options.staleTTL,
|
|||
|
|
maxTokensPerUser: options.maxTokensPerUser,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// tokenStoreOptions Token存储配置选项
|
|||
|
|
type tokenStoreOptions struct {
|
|||
|
|
keyPrefix string
|
|||
|
|
defaultTTL time.Duration
|
|||
|
|
staleTTL time.Duration
|
|||
|
|
maxTokensPerUser int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TokenStoreOption Token存储配置选项函数
|
|||
|
|
type TokenStoreOption func(*tokenStoreOptions)
|
|||
|
|
|
|||
|
|
// WithKeyPrefix 设置Key前缀
|
|||
|
|
func WithKeyPrefix(prefix string) TokenStoreOption {
|
|||
|
|
return func(o *tokenStoreOptions) {
|
|||
|
|
o.keyPrefix = prefix
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WithDefaultTTL 设置默认TTL
|
|||
|
|
func WithDefaultTTL(ttl time.Duration) TokenStoreOption {
|
|||
|
|
return func(o *tokenStoreOptions) {
|
|||
|
|
o.defaultTTL = ttl
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WithStaleTTL 设置过期但可用时间
|
|||
|
|
func WithStaleTTL(ttl time.Duration) TokenStoreOption {
|
|||
|
|
return func(o *tokenStoreOptions) {
|
|||
|
|
o.staleTTL = ttl
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WithMaxTokensPerUser 设置每个用户的最大Token数
|
|||
|
|
func WithMaxTokensPerUser(max int) TokenStoreOption {
|
|||
|
|
return func(o *tokenStoreOptions) {
|
|||
|
|
o.maxTokensPerUser = max
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Store 存储Token
|
|||
|
|
func (s *TokenStoreRedis) Store(ctx context.Context, accessToken string, metadata *TokenMetadata, ttl time.Duration) error {
|
|||
|
|
if ttl <= 0 {
|
|||
|
|
ttl = s.defaultTTL
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 序列化元数据
|
|||
|
|
data, err := json.Marshal(metadata)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("序列化Token元数据失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 存储Token
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
if err := s.redis.Set(ctx, tokenKey, data, ttl); err != nil {
|
|||
|
|
return fmt.Errorf("存储Token失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 添加到用户Token集合
|
|||
|
|
userTokensKey := s.getUserTokensKey(metadata.UserID)
|
|||
|
|
if err := s.redis.SAdd(ctx, userTokensKey, accessToken); err != nil {
|
|||
|
|
return fmt.Errorf("添加到用户Token集合失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 清理过期Token(后台执行)
|
|||
|
|
go s.cleanupUserTokens(context.Background(), metadata.UserID)
|
|||
|
|
|
|||
|
|
s.logger.Debug("Token已存储",
|
|||
|
|
zap.String("token", accessToken[:20]+"..."),
|
|||
|
|
zap.Int64("userId", metadata.UserID),
|
|||
|
|
zap.Duration("ttl", ttl),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Retrieve 获取Token元数据
|
|||
|
|
func (s *TokenStoreRedis) Retrieve(ctx context.Context, accessToken string) (*TokenMetadata, error) {
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
data, err := s.redis.Get(ctx, tokenKey)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取Token失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var metadata TokenMetadata
|
|||
|
|
if err := json.Unmarshal([]byte(data), &metadata); err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析Token元数据失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return &metadata, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Delete 删除Token
|
|||
|
|
func (s *TokenStoreRedis) Delete(ctx context.Context, accessToken string) error {
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
|
|||
|
|
// 先获取Token元数据以获取UserID
|
|||
|
|
metadata, err := s.Retrieve(ctx, accessToken)
|
|||
|
|
if err != nil {
|
|||
|
|
// Token可能已过期,忽略错误
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 删除Token
|
|||
|
|
if err := s.redis.Del(ctx, tokenKey); err != nil {
|
|||
|
|
return fmt.Errorf("删除Token失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 从用户Token集合中移除
|
|||
|
|
userTokensKey := s.getUserTokensKey(metadata.UserID)
|
|||
|
|
if err := s.redis.SRem(ctx, userTokensKey, accessToken); err != nil {
|
|||
|
|
return fmt.Errorf("从用户Token集合移除失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
s.logger.Debug("Token已删除",
|
|||
|
|
zap.String("token", accessToken[:20]+"..."),
|
|||
|
|
zap.Int64("userId", metadata.UserID),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteByUserID 删除用户的所有Token
|
|||
|
|
func (s *TokenStoreRedis) DeleteByUserID(ctx context.Context, userID int64) error {
|
|||
|
|
userTokensKey := s.getUserTokensKey(userID)
|
|||
|
|
|
|||
|
|
// 获取用户所有Token
|
|||
|
|
tokens, err := s.redis.SMembers(ctx, userTokensKey)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("获取用户Token列表失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 删除所有Token
|
|||
|
|
if len(tokens) > 0 {
|
|||
|
|
tokenKeys := make([]string, len(tokens))
|
|||
|
|
for i, token := range tokens {
|
|||
|
|
tokenKeys[i] = s.getTokenKey(token)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := s.redis.Del(ctx, tokenKeys...); err != nil {
|
|||
|
|
return fmt.Errorf("批量删除Token失败: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 删除用户Token集合
|
|||
|
|
if err := s.redis.Del(ctx, userTokensKey); err != nil {
|
|||
|
|
return fmt.Errorf("删除用户Token集合失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
s.logger.Info("用户所有Token已删除",
|
|||
|
|
zap.Int64("userId", userID),
|
|||
|
|
zap.Int("count", len(tokens)),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Exists 检查Token是否存在
|
|||
|
|
func (s *TokenStoreRedis) Exists(ctx context.Context, accessToken string) (bool, error) {
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
count, err := s.redis.Exists(ctx, tokenKey)
|
|||
|
|
if err != nil {
|
|||
|
|
return false, fmt.Errorf("检查Token存在失败: %w", err)
|
|||
|
|
}
|
|||
|
|
return count > 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetTTL 获取Token的剩余TTL
|
|||
|
|
func (s *TokenStoreRedis) GetTTL(ctx context.Context, accessToken string) (time.Duration, error) {
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
return s.redis.TTL(ctx, tokenKey)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RefreshTTL 刷新Token的TTL
|
|||
|
|
func (s *TokenStoreRedis) RefreshTTL(ctx context.Context, accessToken string, ttl time.Duration) error {
|
|||
|
|
if ttl <= 0 {
|
|||
|
|
ttl = s.defaultTTL
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tokenKey := s.getTokenKey(accessToken)
|
|||
|
|
if err := s.redis.Expire(ctx, tokenKey, ttl); err != nil {
|
|||
|
|
return fmt.Errorf("刷新Token TTL失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetCountByUser 获取用户的Token数量
|
|||
|
|
func (s *TokenStoreRedis) GetCountByUser(ctx context.Context, userID int64) (int64, error) {
|
|||
|
|
userTokensKey := s.getUserTokensKey(userID)
|
|||
|
|
count, err := s.redis.SMembers(ctx, userTokensKey)
|
|||
|
|
if err != nil {
|
|||
|
|
return 0, fmt.Errorf("获取用户Token数量失败: %w", err)
|
|||
|
|
}
|
|||
|
|
return int64(len(count)), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// cleanupUserTokens 清理用户的过期Token(保留最新的N个)
|
|||
|
|
func (s *TokenStoreRedis) cleanupUserTokens(ctx context.Context, userID int64) {
|
|||
|
|
userTokensKey := s.getUserTokensKey(userID)
|
|||
|
|
|
|||
|
|
// 获取用户所有Token
|
|||
|
|
tokens, err := s.redis.SMembers(ctx, userTokensKey)
|
|||
|
|
if err != nil {
|
|||
|
|
s.logger.Error("获取用户Token列表失败", zap.Error(err), zap.Int64("userId", userID))
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 清理过期的Token(验证它们是否仍存在)
|
|||
|
|
validTokens := make([]string, 0, len(tokens))
|
|||
|
|
for _, token := range tokens {
|
|||
|
|
tokenKey := s.getTokenKey(token)
|
|||
|
|
exists, err := s.redis.Exists(ctx, tokenKey)
|
|||
|
|
if err != nil {
|
|||
|
|
s.logger.Error("检查Token存在失败", zap.Error(err), zap.String("token", token[:20]+"..."))
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if exists > 0 {
|
|||
|
|
validTokens = append(validTokens, token)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果没有变化,直接返回
|
|||
|
|
if len(validTokens) == len(tokens) {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新用户Token集合
|
|||
|
|
if len(validTokens) == 0 {
|
|||
|
|
s.redis.Del(ctx, userTokensKey)
|
|||
|
|
} else {
|
|||
|
|
// 重新设置集合
|
|||
|
|
s.redis.Del(ctx, userTokensKey)
|
|||
|
|
for _, token := range validTokens {
|
|||
|
|
s.redis.SAdd(ctx, userTokensKey, token)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果超过限制,删除最旧的Token(这里简化处理,可以根据createdAt排序)
|
|||
|
|
if len(validTokens) > s.maxTokensPerUser {
|
|||
|
|
tokensToDelete := validTokens[s.maxTokensPerUser:]
|
|||
|
|
for _, token := range tokensToDelete {
|
|||
|
|
s.Delete(ctx, token)
|
|||
|
|
}
|
|||
|
|
s.logger.Info("清理用户多余Token",
|
|||
|
|
zap.Int64("userId", userID),
|
|||
|
|
zap.Int("deleted", len(tokensToDelete)),
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getTokenKey 生成Token的Redis Key
|
|||
|
|
func (s *TokenStoreRedis) getTokenKey(accessToken string) string {
|
|||
|
|
return s.keyPrefix + accessToken
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getUserTokensKey 生成用户Token集合的Redis Key
|
|||
|
|
func (s *TokenStoreRedis) getUserTokensKey(userID int64) string {
|
|||
|
|
return fmt.Sprintf("user:%d:tokens", userID)
|
|||
|
|
}
|