220 lines
5.7 KiB
Go
220 lines
5.7 KiB
Go
|
|
package auth
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"crypto/rand"
|
|||
|
|
"crypto/rsa"
|
|||
|
|
"crypto/x509"
|
|||
|
|
"encoding/pem"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"github.com/golang-jwt/jwt/v5"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
YggdrasilPrivateKeyRedisKey = "yggdrasil:private_key"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// RedisClient 定义Redis客户端接口(用于测试)
|
|||
|
|
type RedisClient interface {
|
|||
|
|
Get(ctx context.Context, key string) (string, error)
|
|||
|
|
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// YggdrasilJWTService Yggdrasil JWT服务(使用RSA512)
|
|||
|
|
type YggdrasilJWTService struct {
|
|||
|
|
privateKey *rsa.PrivateKey
|
|||
|
|
publicKey *rsa.PublicKey
|
|||
|
|
issuer string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewYggdrasilJWTService 创建新的Yggdrasil JWT服务
|
|||
|
|
func NewYggdrasilJWTService(privateKey *rsa.PrivateKey, issuer string) *YggdrasilJWTService {
|
|||
|
|
if issuer == "" {
|
|||
|
|
issuer = "carrotskin"
|
|||
|
|
}
|
|||
|
|
return &YggdrasilJWTService{
|
|||
|
|
privateKey: privateKey,
|
|||
|
|
publicKey: &privateKey.PublicKey,
|
|||
|
|
issuer: issuer,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// YggdrasilTokenClaims Yggdrasil Token声明
|
|||
|
|
type YggdrasilTokenClaims struct {
|
|||
|
|
Version int `json:"version"` // 版本号,用于失效旧Token
|
|||
|
|
UserID int64 `json:"user_id"` // 用户ID
|
|||
|
|
ProfileID string `json:"profile_id,omitempty"` // 选中的Profile UUID
|
|||
|
|
jwt.RegisteredClaims
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// StaleTokenPolicy Token过期策略
|
|||
|
|
type StaleTokenPolicy int
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
StalePolicyAllow StaleTokenPolicy = iota // 允许过期的Token(但未过StaleAt)
|
|||
|
|
StalePolicyDeny // 拒绝过期的Token
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// GenerateAccessToken 生成AccessToken JWT
|
|||
|
|
func (j *YggdrasilJWTService) GenerateAccessToken(
|
|||
|
|
userID int64,
|
|||
|
|
clientUUID string,
|
|||
|
|
version int,
|
|||
|
|
profileID string,
|
|||
|
|
expiresAt time.Time,
|
|||
|
|
staleAt time.Time,
|
|||
|
|
) (string, error) {
|
|||
|
|
claims := YggdrasilTokenClaims{
|
|||
|
|
Version: version,
|
|||
|
|
UserID: userID,
|
|||
|
|
ProfileID: profileID,
|
|||
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|||
|
|
Subject: clientUUID,
|
|||
|
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
|||
|
|
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|||
|
|
NotBefore: jwt.NewNumericDate(time.Now()),
|
|||
|
|
Issuer: j.issuer,
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims)
|
|||
|
|
return token.SignedString(j.privateKey)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ParseAccessToken 解析AccessToken JWT
|
|||
|
|
func (j *YggdrasilJWTService) ParseAccessToken(accessToken string, stalePolicy StaleTokenPolicy) (*YggdrasilTokenClaims, error) {
|
|||
|
|
token, err := jwt.ParseWithClaims(accessToken, &YggdrasilTokenClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|||
|
|
// 验证签名算法
|
|||
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|||
|
|
return nil, errors.New("不支持的签名算法,需要使用RSA")
|
|||
|
|
}
|
|||
|
|
return j.publicKey, nil
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if !token.Valid {
|
|||
|
|
return nil, errors.New("无效的token")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
claims, ok := token.Claims.(*YggdrasilTokenClaims)
|
|||
|
|
if !ok {
|
|||
|
|
return nil, errors.New("无法解析token声明")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查StaleAt(如果设置了拒绝过期策略)
|
|||
|
|
if stalePolicy == StalePolicyDeny && claims.ExpiresAt != nil {
|
|||
|
|
if time.Now().After(claims.ExpiresAt.Time) {
|
|||
|
|
return nil, errors.New("token已过期")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return claims, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetPublicKey 获取公钥
|
|||
|
|
func (j *YggdrasilJWTService) GetPublicKey() *rsa.PublicKey {
|
|||
|
|
return j.publicKey
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// YggdrasilJWTManager Yggdrasil JWT管理器,用于获取或创建JWT服务
|
|||
|
|
type YggdrasilJWTManager struct {
|
|||
|
|
redisClient RedisClient
|
|||
|
|
jwtService *YggdrasilJWTService
|
|||
|
|
privateKey *rsa.PrivateKey
|
|||
|
|
mu sync.RWMutex
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewYggdrasilJWTManager 创建Yggdrasil JWT管理器
|
|||
|
|
func NewYggdrasilJWTManager(redisClient RedisClient) *YggdrasilJWTManager {
|
|||
|
|
return &YggdrasilJWTManager{
|
|||
|
|
redisClient: redisClient,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetJWTService 获取或创建Yggdrasil JWT服务(线程安全)
|
|||
|
|
func (m *YggdrasilJWTManager) GetJWTService() (*YggdrasilJWTService, error) {
|
|||
|
|
m.mu.RLock()
|
|||
|
|
if m.jwtService != nil {
|
|||
|
|
service := m.jwtService
|
|||
|
|
m.mu.RUnlock()
|
|||
|
|
return service, nil
|
|||
|
|
}
|
|||
|
|
m.mu.RUnlock()
|
|||
|
|
|
|||
|
|
m.mu.Lock()
|
|||
|
|
defer m.mu.Unlock()
|
|||
|
|
|
|||
|
|
// 双重检查
|
|||
|
|
if m.jwtService != nil {
|
|||
|
|
return m.jwtService, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 从Redis获取私钥
|
|||
|
|
privateKey, err := m.getPrivateKeyFromRedis()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("获取私钥失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
m.privateKey = privateKey
|
|||
|
|
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
|
|||
|
|
return m.jwtService, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SetPrivateKey 直接设置私钥(用于测试或直接从signatureService获取)
|
|||
|
|
func (m *YggdrasilJWTManager) SetPrivateKey(privateKey *rsa.PrivateKey) {
|
|||
|
|
m.mu.Lock()
|
|||
|
|
defer m.mu.Unlock()
|
|||
|
|
m.privateKey = privateKey
|
|||
|
|
if privateKey != nil {
|
|||
|
|
m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getPrivateKeyFromRedis 从Redis获取私钥
|
|||
|
|
func (m *YggdrasilJWTManager) getPrivateKeyFromRedis() (*rsa.PrivateKey, error) {
|
|||
|
|
if m.privateKey != nil {
|
|||
|
|
return m.privateKey, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
ctx := context.Background()
|
|||
|
|
privateKeyPEM, err := m.redisClient.Get(ctx, YggdrasilPrivateKeyRedisKey)
|
|||
|
|
if err != nil || privateKeyPEM == "" {
|
|||
|
|
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 解析PEM格式的私钥
|
|||
|
|
block, _ := pem.Decode([]byte(privateKeyPEM))
|
|||
|
|
if block == nil {
|
|||
|
|
return nil, fmt.Errorf("解析PEM私钥失败")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("解析RSA私钥失败: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return privateKey, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GenerateKeyPair 生成RSA密钥对(用于测试)
|
|||
|
|
func GenerateKeyPair() (*rsa.PrivateKey, error) {
|
|||
|
|
return rsa.GenerateKey(rand.Reader, 2048)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// EncodePrivateKeyToPEM 将私钥编码为PEM格式(用于测试)
|
|||
|
|
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) (string, error) {
|
|||
|
|
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
|||
|
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
|||
|
|
Type: "RSA PRIVATE KEY",
|
|||
|
|
Bytes: privateKeyBytes,
|
|||
|
|
})
|
|||
|
|
return string(privateKeyPEM), nil
|
|||
|
|
}
|