Files
backend/pkg/auth/yggdrasil_jwt.go

220 lines
5.7 KiB
Go
Raw Normal View History

package auth
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
const (
YggdrasilPrivateKeyRedisKey = "yggdrasil:private_key"
)
// RedisClient 定义Redis客户端接口用于测试
type RedisClient interface {
Get(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
}
// YggdrasilJWTService Yggdrasil JWT服务使用RSA512
type YggdrasilJWTService struct {
privateKey *rsa.PrivateKey
publicKey *rsa.PublicKey
issuer string
}
// NewYggdrasilJWTService 创建新的Yggdrasil JWT服务
func NewYggdrasilJWTService(privateKey *rsa.PrivateKey, issuer string) *YggdrasilJWTService {
if issuer == "" {
issuer = "carrotskin"
}
return &YggdrasilJWTService{
privateKey: privateKey,
publicKey: &privateKey.PublicKey,
issuer: issuer,
}
}
// YggdrasilTokenClaims Yggdrasil Token声明
type YggdrasilTokenClaims struct {
Version int `json:"version"` // 版本号用于失效旧Token
UserID int64 `json:"user_id"` // 用户ID
ProfileID string `json:"profile_id,omitempty"` // 选中的Profile UUID
jwt.RegisteredClaims
}
// StaleTokenPolicy Token过期策略
type StaleTokenPolicy int
const (
StalePolicyAllow StaleTokenPolicy = iota // 允许过期的Token但未过StaleAt
StalePolicyDeny // 拒绝过期的Token
)
// GenerateAccessToken 生成AccessToken JWT
func (j *YggdrasilJWTService) GenerateAccessToken(
userID int64,
clientUUID string,
version int,
profileID string,
expiresAt time.Time,
staleAt time.Time,
) (string, error) {
claims := YggdrasilTokenClaims{
Version: version,
UserID: userID,
ProfileID: profileID,
RegisteredClaims: jwt.RegisteredClaims{
Subject: clientUUID,
IssuedAt: jwt.NewNumericDate(time.Now()),
ExpiresAt: jwt.NewNumericDate(expiresAt),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: j.issuer,
},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims)
return token.SignedString(j.privateKey)
}
// ParseAccessToken 解析AccessToken JWT
func (j *YggdrasilJWTService) ParseAccessToken(accessToken string, stalePolicy StaleTokenPolicy) (*YggdrasilTokenClaims, error) {
token, err := jwt.ParseWithClaims(accessToken, &YggdrasilTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名算法
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, errors.New("不支持的签名算法需要使用RSA")
}
return j.publicKey, nil
})
if err != nil {
return nil, err
}
if !token.Valid {
return nil, errors.New("无效的token")
}
claims, ok := token.Claims.(*YggdrasilTokenClaims)
if !ok {
return nil, errors.New("无法解析token声明")
}
// 检查StaleAt如果设置了拒绝过期策略
if stalePolicy == StalePolicyDeny && claims.ExpiresAt != nil {
if time.Now().After(claims.ExpiresAt.Time) {
return nil, errors.New("token已过期")
}
}
return claims, nil
}
// GetPublicKey 获取公钥
func (j *YggdrasilJWTService) GetPublicKey() *rsa.PublicKey {
return j.publicKey
}
// YggdrasilJWTManager Yggdrasil JWT管理器用于获取或创建JWT服务
type YggdrasilJWTManager struct {
redisClient RedisClient
jwtService *YggdrasilJWTService
privateKey *rsa.PrivateKey
mu sync.RWMutex
}
// NewYggdrasilJWTManager 创建Yggdrasil JWT管理器
func NewYggdrasilJWTManager(redisClient RedisClient) *YggdrasilJWTManager {
return &YggdrasilJWTManager{
redisClient: redisClient,
}
}
// GetJWTService 获取或创建Yggdrasil JWT服务线程安全
func (m *YggdrasilJWTManager) GetJWTService() (*YggdrasilJWTService, error) {
m.mu.RLock()
if m.jwtService != nil {
service := m.jwtService
m.mu.RUnlock()
return service, nil
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if m.jwtService != nil {
return m.jwtService, nil
}
// 从Redis获取私钥
privateKey, err := m.getPrivateKeyFromRedis()
if err != nil {
return nil, fmt.Errorf("获取私钥失败: %w", err)
}
m.privateKey = privateKey
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
return m.jwtService, nil
}
// SetPrivateKey 直接设置私钥用于测试或直接从signatureService获取
func (m *YggdrasilJWTManager) SetPrivateKey(privateKey *rsa.PrivateKey) {
m.mu.Lock()
defer m.mu.Unlock()
m.privateKey = privateKey
if privateKey != nil {
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
}
}
// getPrivateKeyFromRedis 从Redis获取私钥
func (m *YggdrasilJWTManager) getPrivateKeyFromRedis() (*rsa.PrivateKey, error) {
if m.privateKey != nil {
return m.privateKey, nil
}
ctx := context.Background()
privateKeyPEM, err := m.redisClient.Get(ctx, YggdrasilPrivateKeyRedisKey)
if err != nil || privateKeyPEM == "" {
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
}
// 解析PEM格式的私钥
block, _ := pem.Decode([]byte(privateKeyPEM))
if block == nil {
return nil, fmt.Errorf("解析PEM私钥失败")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("解析RSA私钥失败: %w", err)
}
return privateKey, nil
}
// GenerateKeyPair 生成RSA密钥对用于测试
func GenerateKeyPair() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 2048)
}
// EncodePrivateKeyToPEM 将私钥编码为PEM格式用于测试
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) (string, error) {
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: privateKeyBytes,
})
return string(privateKeyPEM), nil
}