package service import ( "carrotskin/internal/model" "carrotskin/internal/repository" "carrotskin/pkg/auth" "context" "errors" "fmt" "strconv" "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" "go.uber.org/zap" ) // tokenServiceJWT TokenService的JWT实现(使用JWT + Version机制) type tokenServiceJWT struct { tokenRepo repository.TokenRepository clientRepo repository.ClientRepository profileRepo repository.ProfileRepository yggdrasilJWT *auth.YggdrasilJWTService logger *zap.Logger tokenExpireSec int64 // Token过期时间(秒),0表示永不过期 tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期 } // NewTokenServiceJWT 创建使用JWT的TokenService实例 func NewTokenServiceJWT( tokenRepo repository.TokenRepository, clientRepo repository.ClientRepository, profileRepo repository.ProfileRepository, yggdrasilJWT *auth.YggdrasilJWTService, logger *zap.Logger, ) TokenService { return &tokenServiceJWT{ tokenRepo: tokenRepo, clientRepo: clientRepo, profileRepo: profileRepo, yggdrasilJWT: yggdrasilJWT, logger: logger, tokenExpireSec: 24 * 3600, // 默认24小时 tokenStaleSec: 30 * 24 * 3600, // 默认30天 } } // 常量已在 token_service.go 中定义,这里不重复定义 // Create 创建Token(使用JWT + Version机制) func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { var ( selectedProfileID *model.Profile availableProfiles []*model.Profile ) // 设置超时上下文 _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() // 验证用户存在 if UUID != "" { _, err := s.profileRepo.FindByUUID(UUID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) } } // 生成ClientToken if clientToken == "" { clientToken = uuid.New().String() } // 获取或创建Client var client *model.Client existingClient, err := s.clientRepo.FindByClientToken(clientToken) if err != nil { // Client不存在,创建新的 clientUUID := uuid.New().String() client = &model.Client{ UUID: clientUUID, ClientToken: clientToken, UserID: userID, Version: 0, CreatedAt: time.Now(), UpdatedAt: time.Now(), } if UUID != "" { client.ProfileID = UUID } if err := s.clientRepo.Create(client); err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err) } } else { // Client已存在,验证UserID是否匹配 if existingClient.UserID != userID { return selectedProfileID, availableProfiles, "", "", errors.New("clientToken已属于其他用户") } client = existingClient // 不增加Version(只有在刷新时才增加),只更新ProfileID和UpdatedAt client.UpdatedAt = time.Now() if UUID != "" { client.ProfileID = UUID if err := s.clientRepo.Update(client); err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err) } } } // 获取用户配置文件 profiles, err := s.profileRepo.FindByUserID(userID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) } // 如果用户只有一个配置文件,自动选择 profileID := client.ProfileID if len(profiles) == 1 { selectedProfileID = profiles[0] if profileID == "" { profileID = selectedProfileID.UUID client.ProfileID = profileID s.clientRepo.Update(client) } } availableProfiles = profiles // 生成Token过期时间 now := time.Now() var expiresAt, staleAt time.Time if s.tokenExpireSec > 0 { expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second) } else { // 使用遥远的未来时间(类似drasl的DISTANT_FUTURE) expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC) } if s.tokenStaleSec > 0 { staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second) } else { staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC) } // 生成JWT AccessToken accessToken, err := s.yggdrasilJWT.GenerateAccessToken( userID, client.UUID, client.Version, profileID, expiresAt, staleAt, ) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err) } // 保存Token记录(用于查询和审计) token := model.Token{ AccessToken: accessToken, ClientToken: clientToken, UserID: userID, ProfileId: profileID, Version: client.Version, Usable: true, IssueDate: now, ExpiresAt: &expiresAt, StaleAt: &staleAt, } err = s.tokenRepo.Create(&token) if err != nil { s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err)) // 不返回错误,因为JWT本身已经生成成功 } // 清理多余的令牌 go s.checkAndCleanupExcessTokens(userID) return selectedProfileID, availableProfiles, accessToken, clientToken, nil } // Validate 验证Token(使用JWT验证) func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool { if accessToken == "" { return false } // 解析JWT claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyDeny) if err != nil { return false } // 查找Client client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { return false } // 验证Version是否匹配 if claims.Version != client.Version { return false } // 验证ClientToken(如果提供) if clientToken != "" && client.ClientToken != clientToken { return false } return true } // Refresh 刷新Token(使用Version机制,无需删除旧Token) func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { if accessToken == "" { return "", "", errors.New("accessToken不能为空") } // 解析JWT获取Client信息 claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { return "", "", errors.New("accessToken无效") } // 查找Client client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { return "", "", errors.New("无法找到对应的Client") } // 验证ClientToken if clientToken != "" && client.ClientToken != clientToken { return "", "", errors.New("clientToken无效") } // 验证Version(必须匹配) if claims.Version != client.Version { return "", "", errors.New("token版本不匹配,请重新登录") } // 验证Profile if selectedProfileID != "" { valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID) if validErr != nil { s.logger.Error("验证Profile失败", zap.Error(validErr), zap.Int64("userId", client.UserID), zap.String("profileId", selectedProfileID), ) return "", "", fmt.Errorf("验证角色失败: %w", validErr) } if !valid { return "", "", errors.New("角色与用户不匹配") } // 检查是否已绑定Profile if client.ProfileID != "" && client.ProfileID != selectedProfileID { return "", "", errors.New("原令牌已绑定角色,无法选择新角色") } client.ProfileID = selectedProfileID } else { selectedProfileID = client.ProfileID } // 增加Version(这是关键:通过Version失效所有旧Token) client.Version++ client.UpdatedAt = time.Now() if err := s.clientRepo.Update(client); err != nil { return "", "", fmt.Errorf("更新Client版本失败: %w", err) } // 生成Token过期时间 now := time.Now() var expiresAt, staleAt time.Time if s.tokenExpireSec > 0 { expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second) } else { expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC) } if s.tokenStaleSec > 0 { staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second) } else { staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC) } // 生成新的JWT AccessToken(使用新的Version) newAccessToken, err := s.yggdrasilJWT.GenerateAccessToken( client.UserID, client.UUID, client.Version, selectedProfileID, expiresAt, staleAt, ) if err != nil { return "", "", fmt.Errorf("生成新AccessToken失败: %w", err) } // 保存新Token记录 newToken := model.Token{ AccessToken: newAccessToken, ClientToken: client.ClientToken, UserID: client.UserID, ProfileId: selectedProfileID, Version: client.Version, Usable: true, IssueDate: now, ExpiresAt: &expiresAt, StaleAt: &staleAt, } err = s.tokenRepo.Create(&newToken) if err != nil { s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err)) } s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version)) return newAccessToken, client.ClientToken, nil } // Invalidate 使Token失效(通过增加Version) func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { if accessToken == "" { return } // 解析JWT获取Client信息 claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { s.logger.Warn("解析Token失败", zap.Error(err)) return } // 查找Client并增加Version client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { s.logger.Warn("无法找到对应的Client", zap.Error(err)) return } // 增加Version以失效所有旧Token client.Version++ client.UpdatedAt = time.Now() if err := s.clientRepo.Update(client); err != nil { s.logger.Error("失效Token失败", zap.Error(err)) return } s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version)) } // InvalidateUserTokens 使用户所有Token失效 func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) { if userID == 0 { return } // 获取用户所有Client clients, err := s.clientRepo.FindByUserID(userID) if err != nil { s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID)) return } // 增加每个Client的Version for _, client := range clients { client.Version++ client.UpdatedAt = time.Now() if err := s.clientRepo.Update(client); err != nil { s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID)) } } s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients))) } // GetUUIDByAccessToken 从AccessToken获取UUID(通过JWT解析) func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { // 如果JWT解析失败,尝试从数据库查询(向后兼容) return s.tokenRepo.GetUUIDByAccessToken(accessToken) } if claims.ProfileID != "" { return claims.ProfileID, nil } // 如果没有ProfileID,从Client获取 client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { return "", fmt.Errorf("无法找到对应的Client: %w", err) } if client.ProfileID != "" { return client.ProfileID, nil } return "", errors.New("无法从Token中获取UUID") } // GetUserIDByAccessToken 从AccessToken获取UserID(通过JWT解析) func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { // 如果JWT解析失败,尝试从数据库查询(向后兼容) return s.tokenRepo.GetUserIDByAccessToken(accessToken) } // 从Client获取UserID client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { return 0, fmt.Errorf("无法找到对应的Client: %w", err) } // 验证Version if claims.Version != client.Version { return 0, errors.New("token版本不匹配") } return client.UserID, nil } // 私有辅助方法 func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) { if userID == 0 { return } tokens, err := s.tokenRepo.GetByUserID(userID) if err != nil { s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return } if len(tokens) <= tokensMaxCount { return } tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) for i := tokensMaxCount; i < len(tokens); i++ { tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) } deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) if err != nil { s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return } if deletedCount > 0 { s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) } } func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) { if userID == 0 || UUID == "" { return false, errors.New("用户ID或配置文件ID不能为空") } profile, err := s.profileRepo.FindByUUID(UUID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return false, errors.New("配置文件不存在") } return false, fmt.Errorf("验证配置文件失败: %w", err) } return profile.UserID == userID, nil } // GetClientFromToken 从Token获取Client信息(辅助方法) func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) { claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy) if err != nil { return nil, err } client, err := s.clientRepo.FindByUUID(claims.Subject) if err != nil { return nil, err } // 验证Version if claims.Version != client.Version { return nil, errors.New("token版本不匹配") } return client, nil }