feat: 增强令牌管理与客户端仓库集成

新增 ClientRepository 接口,用于管理客户端相关操作。
更新 Token 模型,加入版本号和过期时间字段,以提升令牌管理能力。
将 ClientRepo 集成到容器中,支持依赖注入。
重构 TokenService,采用 JWT 以增强安全性。
更新 Docker 配置,并清理多个文件中的空白字符。
This commit is contained in:
lan
2025-12-03 14:43:38 +08:00
parent e873c58af9
commit 4824a997dd
12 changed files with 1394 additions and 17 deletions

View File

@@ -30,6 +30,7 @@ type Container struct {
ProfileRepo repository.ProfileRepository
TextureRepo repository.TextureRepository
TokenRepo repository.TokenRepository
ClientRepo repository.ClientRepository
ConfigRepo repository.SystemConfigRepository
YggdrasilRepo repository.YggdrasilRepository
@@ -75,17 +76,28 @@ func NewContainer(
c.ProfileRepo = repository.NewProfileRepository(db)
c.TextureRepo = repository.NewTextureRepository(db)
c.TokenRepo = repository.NewTokenRepository(db)
c.ClientRepo = repository.NewClientRepository(db)
c.ConfigRepo = repository.NewSystemConfigRepository(db)
c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
// 初始化SignatureService用于获取Yggdrasil私钥
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger)
// 获取Yggdrasil私钥并创建JWT服务
_, privateKey, err := signatureService.GetOrCreateYggdrasilKeyPair()
if err != nil {
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
}
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
// 初始化Service注入缓存管理器
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger)
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger)
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger)
c.TokenService = service.NewTokenService(c.TokenRepo, c.ProfileRepo, logger)
// 使用JWT版本的TokenService
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
// 初始化SignatureService
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger)
// 使用组合服务(内部包含认证、会话、序列化、证书服务)
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger)

24
internal/model/client.go Normal file
View File

@@ -0,0 +1,24 @@
package model
import "time"
// Client 客户端实体用于管理Token版本
type Client struct {
UUID string `gorm:"column:uuid;type:varchar(36);primaryKey" json:"uuid"` // Client UUID
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;uniqueIndex" json:"client_token"` // 客户端Token
UserID int64 `gorm:"column:user_id;not null;index:idx_clients_user_id" json:"user_id"` // 用户ID
ProfileID string `gorm:"column:profile_id;type:varchar(36);index:idx_clients_profile_id" json:"profile_id,omitempty"` // 选中的Profile
Version int `gorm:"column:version;not null;default:0;index:idx_clients_version" json:"version"` // 版本号
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
// 关联
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`
Profile *Profile `gorm:"foreignKey:ProfileID;references:UUID;constraint:OnDelete:CASCADE" json:"profile,omitempty"`
}
// TableName 指定表名
func (Client) TableName() string {
return "clients"
}

View File

@@ -4,12 +4,15 @@ import "time"
// Token Yggdrasil 认证令牌模型
type Token struct {
AccessToken string `gorm:"column:access_token;type:varchar(64);primaryKey" json:"access_token"`
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
ProfileId string `gorm:"column:profile_id;type:varchar(36);not null;index:idx_tokens_profile_id" json:"profile_id"`
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
AccessToken string `gorm:"column:access_token;type:text;primaryKey" json:"access_token"` // 改为text以支持JWT长度
UserID int64 `gorm:"column:user_id;not null;index:idx_tokens_user_id" json:"user_id"`
ClientToken string `gorm:"column:client_token;type:varchar(64);not null;index:idx_tokens_client_token" json:"client_token"`
ProfileId string `gorm:"column:profile_id;type:varchar(36);index:idx_tokens_profile_id" json:"profile_id"` // 改为可空
Version int `gorm:"column:version;not null;default:0;index:idx_tokens_version" json:"version"` // 新增:版本号
Usable bool `gorm:"column:usable;not null;default:true;index:idx_tokens_usable" json:"usable"`
IssueDate time.Time `gorm:"column:issue_date;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_tokens_issue_date,sort:desc" json:"issue_date"`
ExpiresAt *time.Time `gorm:"column:expires_at;type:timestamp" json:"expires_at,omitempty"` // 新增:过期时间
StaleAt *time.Time `gorm:"column:stale_at;type:timestamp" json:"stale_at,omitempty"` // 新增:过期但可用时间
// 关联
User *User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE" json:"user,omitempty"`

View File

@@ -0,0 +1,63 @@
package repository
import (
"carrotskin/internal/model"
"gorm.io/gorm"
)
// clientRepository ClientRepository的实现
type clientRepository struct {
db *gorm.DB
}
// NewClientRepository 创建ClientRepository实例
func NewClientRepository(db *gorm.DB) ClientRepository {
return &clientRepository{db: db}
}
func (r *clientRepository) Create(client *model.Client) error {
return r.db.Create(client).Error
}
func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) {
var client model.Client
err := r.db.Where("client_token = ?", clientToken).First(&client).Error
if err != nil {
return nil, err
}
return &client, nil
}
func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) {
var client model.Client
err := r.db.Where("uuid = ?", uuid).First(&client).Error
if err != nil {
return nil, err
}
return &client, nil
}
func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) {
var clients []*model.Client
err := r.db.Where("user_id = ?", userID).Find(&clients).Error
return clients, err
}
func (r *clientRepository) Update(client *model.Client) error {
return r.db.Save(client).Error
}
func (r *clientRepository) IncrementVersion(clientUUID string) error {
return r.db.Model(&model.Client{}).
Where("uuid = ?", clientUUID).
Update("version", gorm.Expr("version + 1")).Error
}
func (r *clientRepository) DeleteByClientToken(clientToken string) error {
return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
}
func (r *clientRepository) DeleteByUserID(userID int64) error {
return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error
}

View File

@@ -83,5 +83,14 @@ type YggdrasilRepository interface {
ResetPassword(id int64, password string) error
}
// ClientRepository Client仓储接口
type ClientRepository interface {
Create(client *model.Client) error
FindByClientToken(clientToken string) (*model.Client, error)
FindByUUID(uuid string) (*model.Client, error)
FindByUserID(userID int64) ([]*model.Client, error)
Update(client *model.Client) error
IncrementVersion(clientUUID string) error
DeleteByClientToken(clientToken string) error
DeleteByUserID(userID int64) error
}

View File

@@ -0,0 +1,497 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"context"
"errors"
"fmt"
"strconv"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
)
// tokenServiceJWT TokenService的JWT实现使用JWT + Version机制
type tokenServiceJWT struct {
tokenRepo repository.TokenRepository
clientRepo repository.ClientRepository
profileRepo repository.ProfileRepository
yggdrasilJWT *auth.YggdrasilJWTService
logger *zap.Logger
tokenExpireSec int64 // Token过期时间0表示永不过期
tokenStaleSec int64 // Token过期但可用时间0表示永不过期
}
// NewTokenServiceJWT 创建使用JWT的TokenService实例
func NewTokenServiceJWT(
tokenRepo repository.TokenRepository,
clientRepo repository.ClientRepository,
profileRepo repository.ProfileRepository,
yggdrasilJWT *auth.YggdrasilJWTService,
logger *zap.Logger,
) TokenService {
return &tokenServiceJWT{
tokenRepo: tokenRepo,
clientRepo: clientRepo,
profileRepo: profileRepo,
yggdrasilJWT: yggdrasilJWT,
logger: logger,
tokenExpireSec: 24 * 3600, // 默认24小时
tokenStaleSec: 30 * 24 * 3600, // 默认30天
}
}
// 常量已在 token_service.go 中定义,这里不重复定义
// Create 创建Token使用JWT + Version机制
func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
var (
selectedProfileID *model.Profile
availableProfiles []*model.Profile
)
// 设置超时上下文
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
// 验证用户存在
if UUID != "" {
_, err := s.profileRepo.FindByUUID(UUID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
}
}
// 生成ClientToken
if clientToken == "" {
clientToken = uuid.New().String()
}
// 获取或创建Client
var client *model.Client
existingClient, err := s.clientRepo.FindByClientToken(clientToken)
if err != nil {
// Client不存在创建新的
clientUUID := uuid.New().String()
client = &model.Client{
UUID: clientUUID,
ClientToken: clientToken,
UserID: userID,
Version: 0,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if UUID != "" {
client.ProfileID = UUID
}
if err := s.clientRepo.Create(client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
}
} else {
// Client已存在验证UserID是否匹配
if existingClient.UserID != userID {
return selectedProfileID, availableProfiles, "", "", errors.New("clientToken已属于其他用户")
}
client = existingClient
// 不增加Version只有在刷新时才增加只更新ProfileID和UpdatedAt
client.UpdatedAt = time.Now()
if UUID != "" {
client.ProfileID = UUID
if err := s.clientRepo.Update(client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
}
}
}
// 获取用户配置文件
profiles, err := s.profileRepo.FindByUserID(userID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
}
// 如果用户只有一个配置文件,自动选择
profileID := client.ProfileID
if len(profiles) == 1 {
selectedProfileID = profiles[0]
if profileID == "" {
profileID = selectedProfileID.UUID
client.ProfileID = profileID
s.clientRepo.Update(client)
}
}
availableProfiles = profiles
// 生成Token过期时间
now := time.Now()
var expiresAt, staleAt time.Time
if s.tokenExpireSec > 0 {
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
} else {
// 使用遥远的未来时间类似drasl的DISTANT_FUTURE
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
if s.tokenStaleSec > 0 {
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
} else {
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
// 生成JWT AccessToken
accessToken, err := s.yggdrasilJWT.GenerateAccessToken(
userID,
client.UUID,
client.Version,
profileID,
expiresAt,
staleAt,
)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("生成AccessToken失败: %w", err)
}
// 保存Token记录用于查询和审计
token := model.Token{
AccessToken: accessToken,
ClientToken: clientToken,
UserID: userID,
ProfileId: profileID,
Version: client.Version,
Usable: true,
IssueDate: now,
ExpiresAt: &expiresAt,
StaleAt: &staleAt,
}
err = s.tokenRepo.Create(&token)
if err != nil {
s.logger.Warn("保存Token记录失败但JWT已生成", zap.Error(err))
// 不返回错误因为JWT本身已经生成成功
}
// 清理多余的令牌
go s.checkAndCleanupExcessTokens(userID)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
}
// Validate 验证Token使用JWT验证
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
if accessToken == "" {
return false
}
// 解析JWT
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyDeny)
if err != nil {
return false
}
// 查找Client
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
return false
}
// 验证Version是否匹配
if claims.Version != client.Version {
return false
}
// 验证ClientToken如果提供
if clientToken != "" && client.ClientToken != clientToken {
return false
}
return true
}
// Refresh 刷新Token使用Version机制无需删除旧Token
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
if accessToken == "" {
return "", "", errors.New("accessToken不能为空")
}
// 解析JWT获取Client信息
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
return "", "", errors.New("accessToken无效")
}
// 查找Client
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
return "", "", errors.New("无法找到对应的Client")
}
// 验证ClientToken
if clientToken != "" && client.ClientToken != clientToken {
return "", "", errors.New("clientToken无效")
}
// 验证Version必须匹配
if claims.Version != client.Version {
return "", "", errors.New("token版本不匹配请重新登录")
}
// 验证Profile
if selectedProfileID != "" {
valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID)
if validErr != nil {
s.logger.Error("验证Profile失败",
zap.Error(validErr),
zap.Int64("userId", client.UserID),
zap.String("profileId", selectedProfileID),
)
return "", "", fmt.Errorf("验证角色失败: %w", validErr)
}
if !valid {
return "", "", errors.New("角色与用户不匹配")
}
// 检查是否已绑定Profile
if client.ProfileID != "" && client.ProfileID != selectedProfileID {
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
}
client.ProfileID = selectedProfileID
} else {
selectedProfileID = client.ProfileID
}
// 增加Version这是关键通过Version失效所有旧Token
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil {
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
}
// 生成Token过期时间
now := time.Now()
var expiresAt, staleAt time.Time
if s.tokenExpireSec > 0 {
expiresAt = now.Add(time.Duration(s.tokenExpireSec) * time.Second)
} else {
expiresAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
if s.tokenStaleSec > 0 {
staleAt = now.Add(time.Duration(s.tokenStaleSec) * time.Second)
} else {
staleAt = time.Date(2038, 1, 1, 0, 0, 0, 0, time.UTC)
}
// 生成新的JWT AccessToken使用新的Version
newAccessToken, err := s.yggdrasilJWT.GenerateAccessToken(
client.UserID,
client.UUID,
client.Version,
selectedProfileID,
expiresAt,
staleAt,
)
if err != nil {
return "", "", fmt.Errorf("生成新AccessToken失败: %w", err)
}
// 保存新Token记录
newToken := model.Token{
AccessToken: newAccessToken,
ClientToken: client.ClientToken,
UserID: client.UserID,
ProfileId: selectedProfileID,
Version: client.Version,
Usable: true,
IssueDate: now,
ExpiresAt: &expiresAt,
StaleAt: &staleAt,
}
err = s.tokenRepo.Create(&newToken)
if err != nil {
s.logger.Warn("保存新Token记录失败但JWT已生成", zap.Error(err))
}
s.logger.Info("成功刷新Token", zap.Int64("userId", client.UserID), zap.Int("version", client.Version))
return newAccessToken, client.ClientToken, nil
}
// Invalidate 使Token失效通过增加Version
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
if accessToken == "" {
return
}
// 解析JWT获取Client信息
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
s.logger.Warn("解析Token失败", zap.Error(err))
return
}
// 查找Client并增加Version
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
s.logger.Warn("无法找到对应的Client", zap.Error(err))
return
}
// 增加Version以失效所有旧Token
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil {
s.logger.Error("失效Token失败", zap.Error(err))
return
}
s.logger.Info("成功失效Token", zap.String("clientUUID", client.UUID), zap.Int("version", client.Version))
}
// InvalidateUserTokens 使用户所有Token失效
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
if userID == 0 {
return
}
// 获取用户所有Client
clients, err := s.clientRepo.FindByUserID(userID)
if err != nil {
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
return
}
// 增加每个Client的Version
for _, client := range clients {
client.Version++
client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil {
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
}
}
s.logger.Info("成功失效用户所有Token", zap.Int64("userId", userID), zap.Int("clientCount", len(clients)))
}
// GetUUIDByAccessToken 从AccessToken获取UUID通过JWT解析
func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
}
if claims.ProfileID != "" {
return claims.ProfileID, nil
}
// 如果没有ProfileID从Client获取
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
return "", fmt.Errorf("无法找到对应的Client: %w", err)
}
if client.ProfileID != "" {
return client.ProfileID, nil
}
return "", errors.New("无法从Token中获取UUID")
}
// GetUserIDByAccessToken 从AccessToken获取UserID通过JWT解析
func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
}
// 从Client获取UserID
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
return 0, fmt.Errorf("无法找到对应的Client: %w", err)
}
// 验证Version
if claims.Version != client.Version {
return 0, errors.New("token版本不匹配")
}
return client.UserID, nil
}
// 私有辅助方法
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
if userID == 0 {
return
}
tokens, err := s.tokenRepo.GetByUserID(userID)
if err != nil {
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if len(tokens) <= tokensMaxCount {
return
}
tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount)
for i := tokensMaxCount; i < len(tokens); i++ {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
}
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
if err != nil {
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return
}
if deletedCount > 0 {
s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount))
}
}
func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空")
}
profile, err := s.profileRepo.FindByUUID(UUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在")
}
return false, fmt.Errorf("验证配置文件失败: %w", err)
}
return profile.UserID == userID, nil
}
// GetClientFromToken 从Token获取Client信息辅助方法
func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken string, stalePolicy auth.StaleTokenPolicy) (*model.Client, error) {
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, stalePolicy)
if err != nil {
return nil, err
}
client, err := s.clientRepo.FindByUUID(claims.Subject)
if err != nil {
return nil, err
}
// 验证Version
if claims.Version != client.Version {
return nil, errors.New("token版本不匹配")
}
return client, nil
}