package auth import ( "context" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "sync" "time" "github.com/golang-jwt/jwt/v5" ) const ( YggdrasilPrivateKeyRedisKey = "yggdrasil:private_key" ) // RedisClient 定义Redis客户端接口(用于测试) type RedisClient interface { Get(ctx context.Context, key string) (string, error) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error } // YggdrasilJWTService Yggdrasil JWT服务(使用RSA512) type YggdrasilJWTService struct { privateKey *rsa.PrivateKey publicKey *rsa.PublicKey issuer string } // NewYggdrasilJWTService 创建新的Yggdrasil JWT服务 func NewYggdrasilJWTService(privateKey *rsa.PrivateKey, issuer string) *YggdrasilJWTService { if issuer == "" { issuer = "carrotskin" } return &YggdrasilJWTService{ privateKey: privateKey, publicKey: &privateKey.PublicKey, issuer: issuer, } } // YggdrasilTokenClaims Yggdrasil Token声明 type YggdrasilTokenClaims struct { Version int `json:"version"` // 版本号,用于失效旧Token UserID int64 `json:"user_id"` // 用户ID ProfileID string `json:"profile_id,omitempty"` // 选中的Profile UUID jwt.RegisteredClaims } // StaleTokenPolicy Token过期策略 type StaleTokenPolicy int const ( StalePolicyAllow StaleTokenPolicy = iota // 允许过期的Token(但未过StaleAt) StalePolicyDeny // 拒绝过期的Token ) // GenerateAccessToken 生成AccessToken JWT func (j *YggdrasilJWTService) GenerateAccessToken( userID int64, clientUUID string, version int, profileID string, expiresAt time.Time, staleAt time.Time, ) (string, error) { claims := YggdrasilTokenClaims{ Version: version, UserID: userID, ProfileID: profileID, RegisteredClaims: jwt.RegisteredClaims{ Subject: clientUUID, IssuedAt: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(expiresAt), NotBefore: jwt.NewNumericDate(time.Now()), Issuer: j.issuer, }, } token := jwt.NewWithClaims(jwt.SigningMethodRS512, claims) return token.SignedString(j.privateKey) } // ParseAccessToken 解析AccessToken JWT func (j *YggdrasilJWTService) ParseAccessToken(accessToken string, stalePolicy StaleTokenPolicy) (*YggdrasilTokenClaims, error) { token, err := jwt.ParseWithClaims(accessToken, &YggdrasilTokenClaims{}, func(token *jwt.Token) (interface{}, error) { // 验证签名算法 if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, errors.New("不支持的签名算法,需要使用RSA") } return j.publicKey, nil }) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("无效的token") } claims, ok := token.Claims.(*YggdrasilTokenClaims) if !ok { return nil, errors.New("无法解析token声明") } // 检查StaleAt(如果设置了拒绝过期策略) if stalePolicy == StalePolicyDeny && claims.ExpiresAt != nil { if time.Now().After(claims.ExpiresAt.Time) { return nil, errors.New("token已过期") } } return claims, nil } // GetPublicKey 获取公钥 func (j *YggdrasilJWTService) GetPublicKey() *rsa.PublicKey { return j.publicKey } // YggdrasilJWTManager Yggdrasil JWT管理器,用于获取或创建JWT服务 type YggdrasilJWTManager struct { redisClient RedisClient jwtService *YggdrasilJWTService privateKey *rsa.PrivateKey mu sync.RWMutex } // NewYggdrasilJWTManager 创建Yggdrasil JWT管理器 func NewYggdrasilJWTManager(redisClient RedisClient) *YggdrasilJWTManager { return &YggdrasilJWTManager{ redisClient: redisClient, } } // GetJWTService 获取或创建Yggdrasil JWT服务(线程安全) func (m *YggdrasilJWTManager) GetJWTService() (*YggdrasilJWTService, error) { m.mu.RLock() if m.jwtService != nil { service := m.jwtService m.mu.RUnlock() return service, nil } m.mu.RUnlock() m.mu.Lock() defer m.mu.Unlock() // 双重检查 if m.jwtService != nil { return m.jwtService, nil } // 从Redis获取私钥 privateKey, err := m.getPrivateKeyFromRedis() if err != nil { return nil, fmt.Errorf("获取私钥失败: %w", err) } m.privateKey = privateKey m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin") return m.jwtService, nil } // SetPrivateKey 直接设置私钥(用于测试或直接从signatureService获取) func (m *YggdrasilJWTManager) SetPrivateKey(privateKey *rsa.PrivateKey) { m.mu.Lock() defer m.mu.Unlock() m.privateKey = privateKey if privateKey != nil { m.jwtService = NewYggdrasilJWTService(privateKey, "carrotskin") } } // getPrivateKeyFromRedis 从Redis获取私钥 func (m *YggdrasilJWTManager) getPrivateKeyFromRedis() (*rsa.PrivateKey, error) { if m.privateKey != nil { return m.privateKey, nil } ctx := context.Background() privateKeyPEM, err := m.redisClient.Get(ctx, YggdrasilPrivateKeyRedisKey) if err != nil || privateKeyPEM == "" { return nil, fmt.Errorf("从Redis获取私钥失败: %w", err) } // 解析PEM格式的私钥 block, _ := pem.Decode([]byte(privateKeyPEM)) if block == nil { return nil, fmt.Errorf("解析PEM私钥失败") } privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("解析RSA私钥失败: %w", err) } return privateKey, nil } // GenerateKeyPair 生成RSA密钥对(用于测试) func GenerateKeyPair() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2048) } // EncodePrivateKeyToPEM 将私钥编码为PEM格式(用于测试) func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) (string, error) { privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) privateKeyPEM := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes, }) return string(privateKeyPEM), nil }