Files
backend/internal/service/token_service_redis.go
lan 6ddcf92ce3 refactor: Remove Token management and integrate Redis for authentication
- Deleted the Token model and its repository, transitioning to a Redis-based token management system.
- Updated the service layer to utilize Redis for token storage, enhancing performance and scalability.
- Refactored the container to remove TokenRepository and integrate the new token service.
- Cleaned up the Dockerfile and other files by removing unnecessary whitespace and comments.
- Enhanced error handling and logging for Redis initialization and usage.
2025-12-24 16:03:46 +08:00

471 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"context"
"errors"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
)
// tokenServiceRedis TokenService的Redis实现
type tokenServiceRedis struct {
tokenStore *auth.TokenStoreRedis
clientRepo repository.ClientRepository
profileRepo repository.ProfileRepository
yggdrasilJWT *auth.YggdrasilJWTService
logger *zap.Logger
tokenExpireSec int64 // Token过期时间0表示永不过期
tokenStaleSec int64 // Token过期但可用时间0表示永不过期
}
// NewTokenServiceRedis 创建使用Redis的TokenService实例
func NewTokenServiceRedis(
tokenStore *auth.TokenStoreRedis,
clientRepo repository.ClientRepository,
profileRepo repository.ProfileRepository,
yggdrasilJWT *auth.YggdrasilJWTService,
logger *zap.Logger,
) TokenService {
return &tokenServiceRedis{
tokenStore: tokenStore,
clientRepo: clientRepo,
profileRepo: profileRepo,
yggdrasilJWT: yggdrasilJWT,
logger: logger,
tokenExpireSec: 24 * 3600, // 默认24小时
tokenStaleSec: 30 * 24 * 3600, // 默认30天
}
}
// Create 创建Token使用JWT + Redis存储
func (s *tokenServiceRedis) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
var (
selectedProfileID *model.Profile
availableProfiles []*model.Profile
)
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
// 验证用户存在
if UUID != "" {
_, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
}
}
// 生成ClientToken
if clientToken == "" {
clientToken = uuid.New().String()
}
// 获取或创建Client
var client *model.Client
existingClient, err := s.clientRepo.FindByClientToken(ctx, clientToken)
if err != nil {
// Client不存在创建新的
clientUUID := uuid.New().String()
client = &model.Client{
UUID: clientUUID,
ClientToken: clientToken,
UserID: userID,
Version: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if UUID != "" {
client.ProfileID = UUID
}
if err := s.clientRepo.Create(ctx, client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
}
} else {
// Client已存在验证UserID是否匹配
if existingClient.UserID != userID {
return selectedProfileID, availableProfiles, "", "", errors.New("clientToken已属于其他用户")
}
client = existingClient
// 不增加Version只有在刷新时才增加只更新ProfileID和UpdatedAt
client.UpdatedAt = time.Now()
if UUID != "" {
client.ProfileID = UUID
if err := s.clientRepo.Update(ctx, client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
}
}
}
// 获取用户配置文件
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
}
// 如果用户只有一个配置文件,自动选择
profileID := client.ProfileID
if len(profiles) == 1 {
selectedProfileID = profiles[0]
if profileID == "" {
profileID = selectedProfileID.UUID
client.ProfileID = profileID
_ = s.clientRepo.Update(ctx, client)
}
}
availableProfiles = profiles
// 生成Token过期时间
now := time.Now()
var expiresAt, staleAt time.Time
if s.tokenExpireSec > 0 {
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
} else {
// 使用遥远的未来时间
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
if s.tokenStaleSec > 0 {
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
} else {
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
// 生成JWT AccessToken
accessToken, err := s.yggdrasilJWT.GenerateAccessToken(
userID,
client.UUID,
client.Version,
profileID,
expiresAt,
staleAt,
)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
}
// 存储Token到Redis
ttl := expiresAt.Sub(now)
metadata := &auth.TokenMetadata{
UserID: userID,
ProfileID: profileID,
ClientUUID: client.UUID,
ClientToken: client.ClientToken,
Version: client.Version,
CreatedAt: now.Unix(),
}
if err := s.tokenStore.Store(ctx, accessToken, metadata, ttl); err != nil {
s.logger.Warn("存储Token到Redis失败", zap.Error(err))
// 不返回错误因为JWT本身已经生成成功
}
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
}
// Validate 验证Token使用JWT验证 + Redis存储验证
func (s *tokenServiceRedis) Validate(ctx context.Context, accessToken, clientToken string) bool {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return false
}
// 解析JWT
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyDeny)
if err != nil {
return false
}
// 从Redis获取Token元数据
metadata, err := s.tokenStore.Retrieve(ctx, accessToken)
if err != nil {
// Token可能已过期或不存在
return false
}
// 查找Client
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
return false
}
// 验证Version是否匹配
if claims.Version != client.Version {
return false
}
// 验证ClientToken如果提供
if clientToken != "" && metadata.ClientToken != clientToken {
return false
}
return true
}
// Refresh 刷新Token使用Version机制Redis存储
func (s *tokenServiceRedis) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return "", "", errors.New("accessToken不能为空")
}
// 解析JWT获取Client信息
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
return "", "", errors.New("accessToken无效")
}
// 查找Client
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
return "", "", errors.New("无法找到对应的Client")
}
// 验证ClientToken
if clientToken != "" && client.ClientToken != clientToken {
return "", "", errors.New("clientToken无效")
}
// 验证Version必须匹配
if claims.Version != client.Version {
return "", "", errors.New("token版本不匹配请重新登录")
}
// 验证Profile
if selectedProfileID != "" {
valid, validErr := s.validateProfileByUserID(ctx, client.UserID, selectedProfileID)
if validErr != nil {
s.logger.Error("验证Profile失败",
zap.Error(validErr),
zap.Int64("userId", client.UserID),
zap.String("profileId", selectedProfileID),
)
return "", "", fmt.Errorf("验证角色失败: %w", validErr)
}
if !valid {
return "", "", errors.New("角色与用户不匹配")
}
// 检查是否已绑定Profile
if client.ProfileID != "" && client.ProfileID != selectedProfileID {
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
}
client.ProfileID = selectedProfileID
} else {
selectedProfileID = client.ProfileID
}
// 增加Version这是关键通过Version失效所有旧Token
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(ctx, client); err != nil {
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
}
// 删除旧Token从Redis
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
s.logger.Warn("删除旧Token失败", zap.Error(err))
}
// 生成Token过期时间
now := time.Now()
var expiresAt, staleAt time.Time
if s.tokenExpireSec > 0 {
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
} else {
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
if s.tokenStaleSec > 0 {
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
} else {
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
// 生成新的JWT AccessToken使用新的Version
newAccessToken, err := s.yggdrasilJWT.GenerateAccessToken(
client.UserID,
client.UUID,
client.Version,
selectedProfileID,
expiresAt,
staleAt,
)
if err != nil {
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
}
// 存储新Token到Redis
ttl := expiresAt.Sub(now)
metadata := &auth.TokenMetadata{
UserID: client.UserID,
ProfileID: selectedProfileID,
ClientUUID: client.UUID,
ClientToken: client.ClientToken,
Version: client.Version,
CreatedAt: now.Unix(),
}
if err := s.tokenStore.Store(ctx, newAccessToken, metadata, ttl); err != nil {
s.logger.Warn("存储新Token到Redis失败", zap.Error(err))
}
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
return newAccessToken, client.ClientToken, nil
}
// Invalidate 使Token失效从Redis删除
func (s *tokenServiceRedis) Invalidate(ctx context.Context, accessToken string) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" {
return
}
// 解析JWT获取Client信息
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
s.logger.Warn("解析Token失败", zap.Error(err))
return
}
// 查找Client并增加Version失效所有旧Token
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
s.logger.Warn("无法找到对应的Client", zap.Error(err))
return
}
// 增加Version以失效所有旧Token
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(ctx, client); err != nil {
s.logger.Error("失效Token失败", zap.Error(err))
return
}
// 从Redis删除Token
if err := s.tokenStore.Delete(ctx, accessToken); err != nil {
s.logger.Warn("从Redis删除Token失败", zap.Error(err))
return
}
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
}
// InvalidateUserTokens 使用户所有Token失效从Redis删除
func (s *tokenServiceRedis) InvalidateUserTokens(ctx context.Context, userID int64) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if userID == 0 {
return
}
// 获取用户所有Client
clients, err := s.clientRepo.FindByUserID(ctx, userID)
if err != nil {
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
return
}
// 增加每个Client的Version
for _, client := range clients {
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(ctx, client); err != nil {
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
}
}
// 从Redis删除用户所有Token
if err := s.tokenStore.DeleteByUserID(ctx, userID); err != nil {
s.logger.Error("从Redis删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
return
}
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
}
// GetUUIDByAccessToken 从AccessToken获取UUID通过JWT解析
func (s *tokenServiceRedis) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
return "", errors.New("accessToken无效")
}
if claims.ProfileID != "" {
return claims.ProfileID, nil
}
// 如果没有ProfileID从Client获取
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
return "", fmt.Errorf("无法找到对应的Client: %w", err)
}
if client.ProfileID != "" {
return client.ProfileID, nil
}
return "", errors.New("无法从Token中获取UUID")
}
// GetUserIDByAccessToken 从AccessToken获取UserID通过JWT解析
func (s *tokenServiceRedis) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
return 0, errors.New("accessToken无效")
}
// 从Client获取UserID
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil {
return 0, fmt.Errorf("无法找到对应的Client: %w", err)
}
// 验证Version
if claims.Version != client.Version {
return 0, errors.New("token版本不匹配")
}
return client.UserID, nil
}
// validateProfileByUserID 验证Profile是否属于用户
func (s *tokenServiceRedis) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空")
}
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在")
}
return false, fmt.Errorf("验证配置文件失败: %w", err)
}
return profile.UserID == userID, nil
}