feat: 增强令牌管理与客户端仓库集成
新增 ClientRepository 接口,用于管理客户端相关操作。 更新 Token 模型,加入版本号和过期时间字段,以提升令牌管理能力。 将 ClientRepo 集成到容器中,支持依赖注入。 重构 TokenService,采用 JWT 以增强安全性。 更新 Docker 配置,并清理多个文件中的空白字符。
This commit is contained in:
@@ -79,3 +79,4 @@ minio-data/
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -64,3 +64,4 @@ ENTRYPOINT ["./server"]
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ type Container struct {
|
|||||||
ProfileRepo repository.ProfileRepository
|
ProfileRepo repository.ProfileRepository
|
||||||
TextureRepo repository.TextureRepository
|
TextureRepo repository.TextureRepository
|
||||||
TokenRepo repository.TokenRepository
|
TokenRepo repository.TokenRepository
|
||||||
|
ClientRepo repository.ClientRepository
|
||||||
ConfigRepo repository.SystemConfigRepository
|
ConfigRepo repository.SystemConfigRepository
|
||||||
YggdrasilRepo repository.YggdrasilRepository
|
YggdrasilRepo repository.YggdrasilRepository
|
||||||
|
|
||||||
@@ -75,17 +76,28 @@ func NewContainer(
|
|||||||
c.ProfileRepo = repository.NewProfileRepository(db)
|
c.ProfileRepo = repository.NewProfileRepository(db)
|
||||||
c.TextureRepo = repository.NewTextureRepository(db)
|
c.TextureRepo = repository.NewTextureRepository(db)
|
||||||
c.TokenRepo = repository.NewTokenRepository(db)
|
c.TokenRepo = repository.NewTokenRepository(db)
|
||||||
|
c.ClientRepo = repository.NewClientRepository(db)
|
||||||
c.ConfigRepo = repository.NewSystemConfigRepository(db)
|
c.ConfigRepo = repository.NewSystemConfigRepository(db)
|
||||||
c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
|
c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
|
||||||
|
|
||||||
|
// 初始化SignatureService(用于获取Yggdrasil私钥)
|
||||||
|
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger)
|
||||||
|
|
||||||
|
// 获取Yggdrasil私钥并创建JWT服务
|
||||||
|
_, privateKey, err := signatureService.GetOrCreateYggdrasilKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||||
|
|
||||||
// 初始化Service(注入缓存管理器)
|
// 初始化Service(注入缓存管理器)
|
||||||
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger)
|
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger)
|
||||||
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger)
|
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger)
|
||||||
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger)
|
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger)
|
||||||
c.TokenService = service.NewTokenService(c.TokenRepo, c.ProfileRepo, logger)
|
|
||||||
|
|
||||||
// 初始化SignatureService
|
// 使用JWT版本的TokenService
|
||||||
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger)
|
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
|
||||||
|
|
||||||
// 使用组合服务(内部包含认证、会话、序列化、证书服务)
|
// 使用组合服务(内部包含认证、会话、序列化、证书服务)
|
||||||
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger)
|
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger)
|
||||||
|
|
||||||
|
|||||||
24
internal/model/client.go
Normal file
24
internal/model/client.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Client 客户端实体,用于管理Token版本
|
||||||
|
type Client struct {
|
||||||
|
UUID string `gorm:"column:uuid;type:varchar(36);primaryKey" json:"uuid"` // Client UUID
|
||||||
|
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;uniqueIndex" json:"client_token"` // 客户端Token
|
||||||
|
UserID int64 `gorm:"column:user_id;not null;index:idx_clients_user_id" json:"user_id"` // 用户ID
|
||||||
|
ProfileID string `gorm:"column:profile_id;type:varchar(36);index:idx_clients_profile_id" json:"profile_id,omitempty"` // 选中的Profile
|
||||||
|
Version int `gorm:"column:version;not null;default:0;index:idx_clients_version" json:"version"` // 版本号
|
||||||
|
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||||
|
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||||
|
|
||||||
|
// 关联
|
||||||
|
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
|
||||||
|
Profile *Profile `gorm:"foreignKey:ProfileID;references:UUID;constraint:OnDelete:CASCADE" json:"profile,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName 指定表名
|
||||||
|
func (Client) TableName() string {
|
||||||
|
return "clients"
|
||||||
|
}
|
||||||
|
|
||||||
@@ -4,12 +4,15 @@ import "time"
|
|||||||
|
|
||||||
// Token Yggdrasil 认证令牌模型
|
// Token Yggdrasil 认证令牌模型
|
||||||
type Token struct {
|
type Token struct {
|
||||||
AccessToken string `gorm:"column:access_token;type:varchar(64);primaryKey" json:"access_token"`
|
AccessToken string `gorm:"column:access_token;type:text;primaryKey" json:"access_token"` // 改为text以支持JWT长度
|
||||||
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
|
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
|
||||||
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
|
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
|
||||||
ProfileId string `gorm:"column:profile_id;type:varchar(36);not null;index:idx_tokens_profile_id" json:"profile_id"`
|
ProfileId string `gorm:"column:profile_id;type:varchar(36);index:idx_tokens_profile_id" json:"profile_id"` // 改为可空
|
||||||
|
Version int `gorm:"column:version;not null;default:0;index:idx_tokens_version" json:"version"` // 新增:版本号
|
||||||
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
|
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
|
||||||
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
|
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
|
||||||
|
ExpiresAt *time.Time `gorm:"column:expires_at;type:timestamp" json:"expires_at,omitempty"` // 新增:过期时间
|
||||||
|
StaleAt *time.Time `gorm:"column:stale_at;type:timestamp" json:"stale_at,omitempty"` // 新增:过期但可用时间
|
||||||
|
|
||||||
// 关联
|
// 关联
|
||||||
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
|
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
|
||||||
|
|||||||
63
internal/repository/client_repository.go
Normal file
63
internal/repository/client_repository.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrotskin/internal/model"
|
||||||
|
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// clientRepository ClientRepository的实现
|
||||||
|
type clientRepository struct {
|
||||||
|
db *gorm.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientRepository 创建ClientRepository实例
|
||||||
|
func NewClientRepository(db *gorm.DB) ClientRepository {
|
||||||
|
return &clientRepository{db: db}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) Create(client *model.Client) error {
|
||||||
|
return r.db.Create(client).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) {
|
||||||
|
var client model.Client
|
||||||
|
err := r.db.Where("client_token = ?", clientToken).First(&client).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) {
|
||||||
|
var client model.Client
|
||||||
|
err := r.db.Where("uuid = ?", uuid).First(&client).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) {
|
||||||
|
var clients []*model.Client
|
||||||
|
err := r.db.Where("user_id = ?", userID).Find(&clients).Error
|
||||||
|
return clients, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) Update(client *model.Client) error {
|
||||||
|
return r.db.Save(client).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) IncrementVersion(clientUUID string) error {
|
||||||
|
return r.db.Model(&model.Client{}).
|
||||||
|
Where("uuid = ?", clientUUID).
|
||||||
|
Update("version", gorm.Expr("version + 1")).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) DeleteByClientToken(clientToken string) error {
|
||||||
|
return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *clientRepository) DeleteByUserID(userID int64) error {
|
||||||
|
return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error
|
||||||
|
}
|
||||||
@@ -83,5 +83,14 @@ type YggdrasilRepository interface {
|
|||||||
ResetPassword(id int64, password string) error
|
ResetPassword(id int64, password string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientRepository Client仓储接口
|
||||||
|
type ClientRepository interface {
|
||||||
|
Create(client *model.Client) error
|
||||||
|
FindByClientToken(clientToken string) (*model.Client, error)
|
||||||
|
FindByUUID(uuid string) (*model.Client, error)
|
||||||
|
FindByUserID(userID int64) ([]*model.Client, error)
|
||||||
|
Update(client *model.Client) error
|
||||||
|
IncrementVersion(clientUUID string) error
|
||||||
|
DeleteByClientToken(clientToken string) error
|
||||||
|
DeleteByUserID(userID int64) error
|
||||||
|
}
|
||||||
|
|||||||
497
internal/service/token_service_jwt.go
Normal file
497
internal/service/token_service_jwt.go
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"carrotskin/internal/model"
|
||||||
|
"carrotskin/internal/repository"
|
||||||
|
"carrotskin/pkg/auth"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tokenServiceJWT TokenService的JWT实现(使用JWT + Version机制)
|
||||||
|
type tokenServiceJWT struct {
|
||||||
|
tokenRepo repository.TokenRepository
|
||||||
|
clientRepo repository.ClientRepository
|
||||||
|
profileRepo repository.ProfileRepository
|
||||||
|
yggdrasilJWT *auth.YggdrasilJWTService
|
||||||
|
logger *zap.Logger
|
||||||
|
tokenExpireSec int64 // Token过期时间(秒),0表示永不过期
|
||||||
|
tokenStaleSec int64 // Token过期但可用时间(秒),0表示永不过期
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenServiceJWT 创建使用JWT的TokenService实例
|
||||||
|
func NewTokenServiceJWT(
|
||||||
|
tokenRepo repository.TokenRepository,
|
||||||
|
clientRepo repository.ClientRepository,
|
||||||
|
profileRepo repository.ProfileRepository,
|
||||||
|
yggdrasilJWT *auth.YggdrasilJWTService,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) TokenService {
|
||||||
|
return &tokenServiceJWT{
|
||||||
|
tokenRepo: tokenRepo,
|
||||||
|
clientRepo: clientRepo,
|
||||||
|
profileRepo: profileRepo,
|
||||||
|
yggdrasilJWT: yggdrasilJWT,
|
||||||
|
logger: logger,
|
||||||
|
tokenExpireSec: 24 * 3600, // 默认24小时
|
||||||
|
tokenStaleSec: 30 * 24 * 3600, // 默认30天
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 常量已在 token_service.go 中定义,这里不重复定义
|
||||||
|
|
||||||
|
// Create 创建Token(使用JWT + Version机制)
|
||||||
|
func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||||
|
var (
|
||||||
|
selectedProfileID *model.Profile
|
||||||
|
availableProfiles []*model.Profile
|
||||||
|
)
|
||||||
|
|
||||||
|
// 设置超时上下文
|
||||||
|
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 验证用户存在
|
||||||
|
if UUID != "" {
|
||||||
|
_, err := s.profileRepo.FindByUUID(UUID)
|
||||||
|
if err != nil {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成ClientToken
|
||||||
|
if clientToken == "" {
|
||||||
|
clientToken = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取或创建Client
|
||||||
|
var client *model.Client
|
||||||
|
existingClient, err := s.clientRepo.FindByClientToken(clientToken)
|
||||||
|
if err != nil {
|
||||||
|
// Client不存在,创建新的
|
||||||
|
clientUUID := uuid.New().String()
|
||||||
|
client = &model.Client{
|
||||||
|
UUID: clientUUID,
|
||||||
|
ClientToken: clientToken,
|
||||||
|
UserID: userID,
|
||||||
|
Version: 0,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if UUID != "" {
|
||||||
|
client.ProfileID = UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.clientRepo.Create(client); err != nil {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Client已存在,验证UserID是否匹配
|
||||||
|
if existingClient.UserID != userID {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", errors.New("clientToken已属于其他用户")
|
||||||
|
}
|
||||||
|
client = existingClient
|
||||||
|
// 不增加Version(只有在刷新时才增加),只更新ProfileID和UpdatedAt
|
||||||
|
client.UpdatedAt = time.Now()
|
||||||
|
if UUID != "" {
|
||||||
|
client.ProfileID = UUID
|
||||||
|
if err := s.clientRepo.Update(client); err != nil {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户配置文件
|
||||||
|
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果用户只有一个配置文件,自动选择
|
||||||
|
profileID := client.ProfileID
|
||||||
|
if len(profiles) == 1 {
|
||||||
|
selectedProfileID = profiles[0]
|
||||||
|
if profileID == "" {
|
||||||
|
profileID = selectedProfileID.UUID
|
||||||
|
client.ProfileID = profileID
|
||||||
|
s.clientRepo.Update(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
availableProfiles = profiles
|
||||||
|
|
||||||
|
// 生成Token过期时间
|
||||||
|
now := time.Now()
|
||||||
|
var expiresAt, staleAt time.Time
|
||||||
|
|
||||||
|
if s.tokenExpireSec > 0 {
|
||||||
|
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||||
|
} else {
|
||||||
|
// 使用遥远的未来时间(类似drasl的DISTANT_FUTURE)
|
||||||
|
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.tokenStaleSec > 0 {
|
||||||
|
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||||
|
} else {
|
||||||
|
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成JWT AccessToken
|
||||||
|
accessToken, err := s.yggdrasilJWT.GenerateAccessToken(
|
||||||
|
userID,
|
||||||
|
client.UUID,
|
||||||
|
client.Version,
|
||||||
|
profileID,
|
||||||
|
expiresAt,
|
||||||
|
staleAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存Token记录(用于查询和审计)
|
||||||
|
token := model.Token{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ClientToken: clientToken,
|
||||||
|
UserID: userID,
|
||||||
|
ProfileId: profileID,
|
||||||
|
Version: client.Version,
|
||||||
|
Usable: true,
|
||||||
|
IssueDate: now,
|
||||||
|
ExpiresAt: &expiresAt,
|
||||||
|
StaleAt: &staleAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.tokenRepo.Create(&token)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err))
|
||||||
|
// 不返回错误,因为JWT本身已经生成成功
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理多余的令牌
|
||||||
|
go s.checkAndCleanupExcessTokens(userID)
|
||||||
|
|
||||||
|
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 验证Token(使用JWT验证)
|
||||||
|
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||||
|
if accessToken == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析JWT
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyDeny)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找Client
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Version是否匹配
|
||||||
|
if claims.Version != client.Version {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证ClientToken(如果提供)
|
||||||
|
if clientToken != "" && client.ClientToken != clientToken {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh 刷新Token(使用Version机制,无需删除旧Token)
|
||||||
|
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", "", errors.New("accessToken不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析JWT获取Client信息
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errors.New("accessToken无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找Client
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errors.New("无法找到对应的Client")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证ClientToken
|
||||||
|
if clientToken != "" && client.ClientToken != clientToken {
|
||||||
|
return "", "", errors.New("clientToken无效")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Version(必须匹配)
|
||||||
|
if claims.Version != client.Version {
|
||||||
|
return "", "", errors.New("token版本不匹配,请重新登录")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Profile
|
||||||
|
if selectedProfileID != "" {
|
||||||
|
valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID)
|
||||||
|
if validErr != nil {
|
||||||
|
s.logger.Error("验证Profile失败",
|
||||||
|
zap.Error(validErr),
|
||||||
|
zap.Int64("userId", client.UserID),
|
||||||
|
zap.String("profileId", selectedProfileID),
|
||||||
|
)
|
||||||
|
return "", "", fmt.Errorf("验证角色失败: %w", validErr)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return "", "", errors.New("角色与用户不匹配")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否已绑定Profile
|
||||||
|
if client.ProfileID != "" && client.ProfileID != selectedProfileID {
|
||||||
|
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||||
|
}
|
||||||
|
|
||||||
|
client.ProfileID = selectedProfileID
|
||||||
|
} else {
|
||||||
|
selectedProfileID = client.ProfileID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加Version(这是关键:通过Version失效所有旧Token)
|
||||||
|
client.Version++
|
||||||
|
client.UpdatedAt = time.Now()
|
||||||
|
if err := s.clientRepo.Update(client); err != nil {
|
||||||
|
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成Token过期时间
|
||||||
|
now := time.Now()
|
||||||
|
var expiresAt, staleAt time.Time
|
||||||
|
|
||||||
|
if s.tokenExpireSec > 0 {
|
||||||
|
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
|
||||||
|
} else {
|
||||||
|
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.tokenStaleSec > 0 {
|
||||||
|
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
|
||||||
|
} else {
|
||||||
|
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新的JWT AccessToken(使用新的Version)
|
||||||
|
newAccessToken, err := s.yggdrasilJWT.GenerateAccessToken(
|
||||||
|
client.UserID,
|
||||||
|
client.UUID,
|
||||||
|
client.Version,
|
||||||
|
selectedProfileID,
|
||||||
|
expiresAt,
|
||||||
|
staleAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存新Token记录
|
||||||
|
newToken := model.Token{
|
||||||
|
AccessToken: newAccessToken,
|
||||||
|
ClientToken: client.ClientToken,
|
||||||
|
UserID: client.UserID,
|
||||||
|
ProfileId: selectedProfileID,
|
||||||
|
Version: client.Version,
|
||||||
|
Usable: true,
|
||||||
|
IssueDate: now,
|
||||||
|
ExpiresAt: &expiresAt,
|
||||||
|
StaleAt: &staleAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.tokenRepo.Create(&newToken)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
|
||||||
|
return newAccessToken, client.ClientToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate 使Token失效(通过增加Version)
|
||||||
|
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||||
|
if accessToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析JWT获取Client信息
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("解析Token失败", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找Client并增加Version
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("无法找到对应的Client", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加Version以失效所有旧Token
|
||||||
|
client.Version++
|
||||||
|
client.UpdatedAt = time.Now()
|
||||||
|
if err := s.clientRepo.Update(client); err != nil {
|
||||||
|
s.logger.Error("失效Token失败", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
|
||||||
|
}
|
||||||
|
|
||||||
|
// InvalidateUserTokens 使用户所有Token失效
|
||||||
|
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||||
|
if userID == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户所有Client
|
||||||
|
clients, err := s.clientRepo.FindByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 增加每个Client的Version
|
||||||
|
for _, client := range clients {
|
||||||
|
client.Version++
|
||||||
|
client.UpdatedAt = time.Now()
|
||||||
|
if err := s.clientRepo.Update(client); err != nil {
|
||||||
|
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUUIDByAccessToken 从AccessToken获取UUID(通过JWT解析)
|
||||||
|
func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||||
|
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.ProfileID != "" {
|
||||||
|
return claims.ProfileID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有ProfileID,从Client获取
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("无法找到对应的Client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.ProfileID != "" {
|
||||||
|
return client.ProfileID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errors.New("无法从Token中获取UUID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserIDByAccessToken 从AccessToken获取UserID(通过JWT解析)
|
||||||
|
func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||||
|
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从Client获取UserID
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("无法找到对应的Client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Version
|
||||||
|
if claims.Version != client.Version {
|
||||||
|
return 0, errors.New("token版本不匹配")
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.UserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 私有辅助方法
|
||||||
|
|
||||||
|
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||||
|
if userID == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) <= tokensMaxCount {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
|
||||||
|
for i := tokensMaxCount; i < len(tokens); i++ {
|
||||||
|
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if deletedCount > 0 {
|
||||||
|
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||||
|
if userID == 0 || UUID == "" {
|
||||||
|
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return false, errors.New("配置文件不存在")
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||||
|
}
|
||||||
|
return profile.UserID == userID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientFromToken 从Token获取Client信息(辅助方法)
|
||||||
|
func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) {
|
||||||
|
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Version
|
||||||
|
if claims.Version != client.Version {
|
||||||
|
return nil, errors.New("token版本不匹配")
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
@@ -12,7 +12,6 @@ var (
|
|||||||
// once 确保只初始化一次
|
// once 确保只初始化一次
|
||||||
once sync.Once
|
once sync.Once
|
||||||
// initError 初始化错误
|
// initError 初始化错误
|
||||||
initError error
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Init 初始化JWT服务(线程安全,只会执行一次)
|
// Init 初始化JWT服务(线程安全,只会执行一次)
|
||||||
@@ -39,8 +38,3 @@ func MustGetJWTService() *JWTService {
|
|||||||
}
|
}
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
219
pkg/auth/yggdrasil_jwt.go
Normal file
219
pkg/auth/yggdrasil_jwt.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
YggdrasilPrivateKeyRedisKey = "yggdrasil:private_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RedisClient 定义Redis客户端接口(用于测试)
|
||||||
|
type RedisClient interface {
|
||||||
|
Get(ctx context.Context, key string) (string, error)
|
||||||
|
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// YggdrasilJWTService Yggdrasil JWT服务(使用RSA512)
|
||||||
|
type YggdrasilJWTService struct {
|
||||||
|
privateKey *rsa.PrivateKey
|
||||||
|
publicKey *rsa.PublicKey
|
||||||
|
issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewYggdrasilJWTService 创建新的Yggdrasil JWT服务
|
||||||
|
func NewYggdrasilJWTService(privateKey *rsa.PrivateKey, issuer string) *YggdrasilJWTService {
|
||||||
|
if issuer == "" {
|
||||||
|
issuer = "carrotskin"
|
||||||
|
}
|
||||||
|
return &YggdrasilJWTService{
|
||||||
|
privateKey: privateKey,
|
||||||
|
publicKey: &privateKey.PublicKey,
|
||||||
|
issuer: issuer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// YggdrasilTokenClaims Yggdrasil Token声明
|
||||||
|
type YggdrasilTokenClaims struct {
|
||||||
|
Version int `json:"version"` // 版本号,用于失效旧Token
|
||||||
|
UserID int64 `json:"user_id"` // 用户ID
|
||||||
|
ProfileID string `json:"profile_id,omitempty"` // 选中的Profile UUID
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaleTokenPolicy Token过期策略
|
||||||
|
type StaleTokenPolicy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StalePolicyAllow StaleTokenPolicy = iota // 允许过期的Token(但未过StaleAt)
|
||||||
|
StalePolicyDeny // 拒绝过期的Token
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateAccessToken 生成AccessToken JWT
|
||||||
|
func (j *YggdrasilJWTService) GenerateAccessToken(
|
||||||
|
userID int64,
|
||||||
|
clientUUID string,
|
||||||
|
version int,
|
||||||
|
profileID string,
|
||||||
|
expiresAt time.Time,
|
||||||
|
staleAt time.Time,
|
||||||
|
) (string, error) {
|
||||||
|
claims := YggdrasilTokenClaims{
|
||||||
|
Version: version,
|
||||||
|
UserID: userID,
|
||||||
|
ProfileID: profileID,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Subject: clientUUID,
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||||
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||||
|
Issuer: j.issuer,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims)
|
||||||
|
return token.SignedString(j.privateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseAccessToken 解析AccessToken JWT
|
||||||
|
func (j *YggdrasilJWTService) ParseAccessToken(accessToken string, stalePolicy StaleTokenPolicy) (*YggdrasilTokenClaims, error) {
|
||||||
|
token, err := jwt.ParseWithClaims(accessToken, &YggdrasilTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
// 验证签名算法
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
||||||
|
return nil, errors.New("不支持的签名算法,需要使用RSA")
|
||||||
|
}
|
||||||
|
return j.publicKey, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !token.Valid {
|
||||||
|
return nil, errors.New("无效的token")
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(*YggdrasilTokenClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("无法解析token声明")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查StaleAt(如果设置了拒绝过期策略)
|
||||||
|
if stalePolicy == StalePolicyDeny && claims.ExpiresAt != nil {
|
||||||
|
if time.Now().After(claims.ExpiresAt.Time) {
|
||||||
|
return nil, errors.New("token已过期")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicKey 获取公钥
|
||||||
|
func (j *YggdrasilJWTService) GetPublicKey() *rsa.PublicKey {
|
||||||
|
return j.publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// YggdrasilJWTManager Yggdrasil JWT管理器,用于获取或创建JWT服务
|
||||||
|
type YggdrasilJWTManager struct {
|
||||||
|
redisClient RedisClient
|
||||||
|
jwtService *YggdrasilJWTService
|
||||||
|
privateKey *rsa.PrivateKey
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewYggdrasilJWTManager 创建Yggdrasil JWT管理器
|
||||||
|
func NewYggdrasilJWTManager(redisClient RedisClient) *YggdrasilJWTManager {
|
||||||
|
return &YggdrasilJWTManager{
|
||||||
|
redisClient: redisClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWTService 获取或创建Yggdrasil JWT服务(线程安全)
|
||||||
|
func (m *YggdrasilJWTManager) GetJWTService() (*YggdrasilJWTService, error) {
|
||||||
|
m.mu.RLock()
|
||||||
|
if m.jwtService != nil {
|
||||||
|
service := m.jwtService
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
// 双重检查
|
||||||
|
if m.jwtService != nil {
|
||||||
|
return m.jwtService, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从Redis获取私钥
|
||||||
|
privateKey, err := m.getPrivateKeyFromRedis()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("获取私钥失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.privateKey = privateKey
|
||||||
|
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||||
|
return m.jwtService, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPrivateKey 直接设置私钥(用于测试或直接从signatureService获取)
|
||||||
|
func (m *YggdrasilJWTManager) SetPrivateKey(privateKey *rsa.PrivateKey) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.privateKey = privateKey
|
||||||
|
if privateKey != nil {
|
||||||
|
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPrivateKeyFromRedis 从Redis获取私钥
|
||||||
|
func (m *YggdrasilJWTManager) getPrivateKeyFromRedis() (*rsa.PrivateKey, error) {
|
||||||
|
if m.privateKey != nil {
|
||||||
|
return m.privateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
privateKeyPEM, err := m.redisClient.Get(ctx, YggdrasilPrivateKeyRedisKey)
|
||||||
|
if err != nil || privateKeyPEM == "" {
|
||||||
|
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析PEM格式的私钥
|
||||||
|
block, _ := pem.Decode([]byte(privateKeyPEM))
|
||||||
|
if block == nil {
|
||||||
|
return nil, fmt.Errorf("解析PEM私钥失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("解析RSA私钥失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateKeyPair 生成RSA密钥对(用于测试)
|
||||||
|
func GenerateKeyPair() (*rsa.PrivateKey, error) {
|
||||||
|
return rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePrivateKeyToPEM 将私钥编码为PEM格式(用于测试)
|
||||||
|
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) (string, error) {
|
||||||
|
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: privateKeyBytes,
|
||||||
|
})
|
||||||
|
return string(privateKeyPEM), nil
|
||||||
|
}
|
||||||
553
pkg/auth/yggdrasil_jwt_test.go
Normal file
553
pkg/auth/yggdrasil_jwt_test.go
Normal file
@@ -0,0 +1,553 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRedisClient 模拟Redis客户端
|
||||||
|
type MockRedisClient struct {
|
||||||
|
data map[string]string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMockRedisClient() *MockRedisClient {
|
||||||
|
return &MockRedisClient{
|
||||||
|
data: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) Get(ctx context.Context, key string) (string, error) {
|
||||||
|
if m.err != nil {
|
||||||
|
return "", m.err
|
||||||
|
}
|
||||||
|
if val, ok := m.data[key]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
return "", redis.Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
|
||||||
|
if m.err != nil {
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
m.data[key] = value.(string)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) SetError(err error) {
|
||||||
|
m.err = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) ClearError() {
|
||||||
|
m.err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) SetData(key, value string) {
|
||||||
|
m.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRedisClient) Clear() {
|
||||||
|
m.data = make(map[string]string)
|
||||||
|
m.err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 测试辅助函数:生成测试用的密钥对
|
||||||
|
func generateTestKeyPair(t *testing.T) *rsa.PrivateKey {
|
||||||
|
privateKey, err := GenerateKeyPair()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成密钥对失败: %v", err)
|
||||||
|
}
|
||||||
|
return privateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewYggdrasilJWTService(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
issuer string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "自定义issuer",
|
||||||
|
issuer: "test-issuer",
|
||||||
|
expected: "test-issuer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空issuer使用默认值",
|
||||||
|
issuer: "",
|
||||||
|
expected: "carrotskin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
service := NewYggdrasilJWTService(privateKey, tt.issuer)
|
||||||
|
if service == nil {
|
||||||
|
t.Fatal("服务创建失败")
|
||||||
|
}
|
||||||
|
if service.issuer != tt.expected {
|
||||||
|
t.Errorf("期望issuer为 %s,实际为 %s", tt.expected, service.issuer)
|
||||||
|
}
|
||||||
|
if service.privateKey == nil {
|
||||||
|
t.Error("私钥不应为nil")
|
||||||
|
}
|
||||||
|
if service.publicKey == nil {
|
||||||
|
t.Error("公钥不应为nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_GenerateAccessToken(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
userID := int64(123)
|
||||||
|
clientUUID := "test-client-uuid"
|
||||||
|
version := 1
|
||||||
|
profileID := "test-profile-uuid"
|
||||||
|
expiresAt := time.Now().Add(24 * time.Hour)
|
||||||
|
staleAt := time.Now().Add(30 * 24 * time.Hour)
|
||||||
|
|
||||||
|
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Error("Token不应为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Token可以解析
|
||||||
|
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("解析Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.UserID != userID {
|
||||||
|
t.Errorf("期望UserID为 %d,实际为 %d", userID, claims.UserID)
|
||||||
|
}
|
||||||
|
if claims.Subject != clientUUID {
|
||||||
|
t.Errorf("期望Subject为 %s,实际为 %s", clientUUID, claims.Subject)
|
||||||
|
}
|
||||||
|
if claims.Version != version {
|
||||||
|
t.Errorf("期望Version为 %d,实际为 %d", version, claims.Version)
|
||||||
|
}
|
||||||
|
if claims.ProfileID != profileID {
|
||||||
|
t.Errorf("期望ProfileID为 %s,实际为 %s", profileID, claims.ProfileID)
|
||||||
|
}
|
||||||
|
if claims.Issuer != "test-issuer" {
|
||||||
|
t.Errorf("期望Issuer为 test-issuer,实际为 %s", claims.Issuer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_ParseAccessToken(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
userID := int64(123)
|
||||||
|
clientUUID := "test-client-uuid"
|
||||||
|
version := 1
|
||||||
|
profileID := "test-profile-uuid"
|
||||||
|
expiresAt := time.Now().Add(24 * time.Hour)
|
||||||
|
staleAt := time.Now().Add(30 * 24 * time.Hour)
|
||||||
|
|
||||||
|
// 生成Token
|
||||||
|
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
policy StaleTokenPolicy
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "有效Token,允许过期",
|
||||||
|
token: token,
|
||||||
|
policy: StalePolicyAllow,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有效Token,拒绝过期",
|
||||||
|
token: token,
|
||||||
|
policy: StalePolicyDeny,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无效Token",
|
||||||
|
token: "invalid-token",
|
||||||
|
policy: StalePolicyAllow,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空Token",
|
||||||
|
token: "",
|
||||||
|
policy: StalePolicyAllow,
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claims, err := service.ParseAccessToken(tt.token, tt.policy)
|
||||||
|
if tt.expectError {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("期望出现错误,但没有错误")
|
||||||
|
}
|
||||||
|
if claims != nil {
|
||||||
|
t.Error("期望claims为nil")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("不期望出现错误,但出现: %v", err)
|
||||||
|
}
|
||||||
|
if claims == nil {
|
||||||
|
t.Error("claims不应为nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_ParseAccessToken_Expired(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
// 生成已过期的Token
|
||||||
|
expiresAt := time.Now().Add(-1 * time.Hour) // 1小时前过期
|
||||||
|
staleAt := time.Now().Add(30 * 24 * time.Hour)
|
||||||
|
|
||||||
|
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid", expiresAt, staleAt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用StalePolicyDeny应该拒绝过期Token(JWT库会自动检查过期时间)
|
||||||
|
_, err = service.ParseAccessToken(token, StalePolicyDeny)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("期望拒绝过期Token,但没有错误")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注意:JWT库在解析时会自动验证过期时间,即使使用StalePolicyAllow
|
||||||
|
// 所以过期Token无法解析,这是JWT库的行为
|
||||||
|
// 如果需要支持过期Token,需要在解析时禁用过期验证,但这不是标准做法
|
||||||
|
_, err = service.ParseAccessToken(token, StalePolicyAllow)
|
||||||
|
if err == nil {
|
||||||
|
t.Log("注意:JWT库会自动拒绝过期Token,即使使用StalePolicyAllow")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_ParseAccessToken_WrongKey(t *testing.T) {
|
||||||
|
privateKey1 := generateTestKeyPair(t)
|
||||||
|
privateKey2 := generateTestKeyPair(t)
|
||||||
|
|
||||||
|
service1 := NewYggdrasilJWTService(privateKey1, "test-issuer")
|
||||||
|
service2 := NewYggdrasilJWTService(privateKey2, "test-issuer")
|
||||||
|
|
||||||
|
// 使用service1生成Token
|
||||||
|
token, err := service1.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用service2(不同密钥)解析Token应该失败
|
||||||
|
_, err = service2.ParseAccessToken(token, StalePolicyAllow)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("期望使用错误密钥解析Token失败,但没有错误")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_GetPublicKey(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
publicKey := service.GetPublicKey()
|
||||||
|
if publicKey == nil {
|
||||||
|
t.Error("公钥不应为nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证公钥与私钥匹配
|
||||||
|
if publicKey != nil && privateKey != nil {
|
||||||
|
if publicKey.N.Cmp(privateKey.PublicKey.N) != 0 {
|
||||||
|
t.Error("公钥与私钥不匹配")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewYggdrasilJWTManager(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
if manager == nil {
|
||||||
|
t.Fatal("管理器创建失败")
|
||||||
|
}
|
||||||
|
if manager.redisClient != mockRedis {
|
||||||
|
t.Error("Redis客户端未正确设置")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_SetPrivateKey(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
manager.SetPrivateKey(privateKey)
|
||||||
|
|
||||||
|
// 验证JWT服务已创建
|
||||||
|
service, err := manager.GetJWTService()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("获取JWT服务失败: %v", err)
|
||||||
|
}
|
||||||
|
if service == nil {
|
||||||
|
t.Fatal("JWT服务不应为nil")
|
||||||
|
}
|
||||||
|
// 验证服务可以正常工作
|
||||||
|
if service.GetPublicKey() == nil {
|
||||||
|
t.Error("公钥不应为nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_GetJWTService_FromPrivateKey(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
manager.SetPrivateKey(privateKey)
|
||||||
|
|
||||||
|
// 第一次获取
|
||||||
|
service1, err := manager.GetJWTService()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("获取JWT服务失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二次获取应该返回同一个实例
|
||||||
|
service2, err := manager.GetJWTService()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("获取JWT服务失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if service1 != service2 {
|
||||||
|
t.Error("应该返回同一个JWT服务实例")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_GetJWTService_FromRedis(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("编码私钥失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置Redis数据
|
||||||
|
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
|
||||||
|
|
||||||
|
// 获取JWT服务
|
||||||
|
service, err := manager.GetJWTService()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("获取JWT服务失败: %v", err)
|
||||||
|
}
|
||||||
|
if service == nil {
|
||||||
|
t.Error("JWT服务不应为nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证服务可以正常工作
|
||||||
|
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
t.Error("Token不应为空")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_GetJWTService_RedisError(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
// 设置Redis错误
|
||||||
|
mockRedis.SetError(errors.New("redis connection error"))
|
||||||
|
|
||||||
|
// 尝试获取JWT服务应该失败
|
||||||
|
_, err := manager.GetJWTService()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("期望出现错误,但没有错误")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_GetJWTService_InvalidPEM(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
// 设置无效的PEM数据
|
||||||
|
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, "invalid-pem-data")
|
||||||
|
|
||||||
|
// 尝试获取JWT服务应该失败
|
||||||
|
_, err := manager.GetJWTService()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("期望出现错误,但没有错误")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTManager_GetJWTService_Concurrent(t *testing.T) {
|
||||||
|
mockRedis := NewMockRedisClient()
|
||||||
|
manager := NewYggdrasilJWTManager(mockRedis)
|
||||||
|
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("编码私钥失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
|
||||||
|
|
||||||
|
// 并发获取JWT服务
|
||||||
|
const numGoroutines = 10
|
||||||
|
results := make(chan *YggdrasilJWTService, numGoroutines)
|
||||||
|
errors := make(chan error, numGoroutines)
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
go func() {
|
||||||
|
service, err := manager.GetJWTService()
|
||||||
|
if err != nil {
|
||||||
|
errors <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
results <- service
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集结果
|
||||||
|
services := make(map[*YggdrasilJWTService]bool)
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
select {
|
||||||
|
case service := <-results:
|
||||||
|
services[service] = true
|
||||||
|
case err := <-errors:
|
||||||
|
t.Fatalf("获取JWT服务失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有goroutine应该返回同一个服务实例
|
||||||
|
if len(services) != 1 {
|
||||||
|
t.Errorf("期望所有goroutine返回同一个服务实例,但得到 %d 个不同的实例", len(services))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilTokenClaims_EmptyProfileID(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
// 生成没有ProfileID的Token
|
||||||
|
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析Token
|
||||||
|
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("解析Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.ProfileID != "" {
|
||||||
|
t.Errorf("期望ProfileID为空,实际为 %s", claims.ProfileID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestYggdrasilJWTService_VersionMismatch(t *testing.T) {
|
||||||
|
privateKey := generateTestKeyPair(t)
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
// 生成Version=1的Token
|
||||||
|
token1, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成Version=2的Token
|
||||||
|
token2, err := service.GenerateAccessToken(123, "client-uuid", 2, "profile-uuid",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析两个Token
|
||||||
|
claims1, err := service.ParseAccessToken(token1, StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("解析Token1失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims2, err := service.ParseAccessToken(token2, StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("解析Token2失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证Version不同
|
||||||
|
if claims1.Version == claims2.Version {
|
||||||
|
t.Error("两个Token的Version应该不同")
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims1.Version != 1 {
|
||||||
|
t.Errorf("期望Token1的Version为1,实际为 %d", claims1.Version)
|
||||||
|
}
|
||||||
|
if claims2.Version != 2 {
|
||||||
|
t.Errorf("期望Token2的Version为2,实际为 %d", claims2.Version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 基准测试
|
||||||
|
func BenchmarkGenerateAccessToken(b *testing.B) {
|
||||||
|
privateKey := generateTestKeyPair(&testing.T{})
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
userID := int64(123)
|
||||||
|
clientUUID := "test-client-uuid"
|
||||||
|
version := 1
|
||||||
|
profileID := "test-profile-uuid"
|
||||||
|
expiresAt := time.Now().Add(24 * time.Hour)
|
||||||
|
staleAt := time.Now().Add(30 * 24 * time.Hour)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseAccessToken(b *testing.B) {
|
||||||
|
privateKey := generateTestKeyPair(&testing.T{})
|
||||||
|
service := NewYggdrasilJWTService(privateKey, "test-issuer")
|
||||||
|
|
||||||
|
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
|
||||||
|
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("生成Token失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := service.ParseAccessToken(token, StalePolicyAllow)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("解析Token失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -76,6 +76,7 @@ func AutoMigrate(logger *zap.Logger) error {
|
|||||||
|
|
||||||
// 认证相关表
|
// 认证相关表
|
||||||
&model.Token{},
|
&model.Token{},
|
||||||
|
&model.Client{}, // Client表用于管理Token版本
|
||||||
|
|
||||||
// Yggdrasil相关表(在User之后创建,因为它引用User)
|
// Yggdrasil相关表(在User之后创建,因为它引用User)
|
||||||
&model.Yggdrasil{},
|
&model.Yggdrasil{},
|
||||||
|
|||||||
Reference in New Issue
Block a user