Files
backend/pkg/auth/yggdrasil_jwt_test.go
lan 4824a997dd feat: 增强令牌管理与客户端仓库集成
新增 ClientRepository 接口,用于管理客户端相关操作。
更新 Token 模型,加入版本号和过期时间字段,以提升令牌管理能力。
将 ClientRepo 集成到容器中,支持依赖注入。
重构 TokenService,采用 JWT 以增强安全性。
更新 Docker 配置,并清理多个文件中的空白字符。
2025-12-03 14:43:38 +08:00

554 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package auth
import (
"context"
"crypto/rsa"
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// MockRedisClient 模拟Redis客户端
type MockRedisClient struct {
data map[string]string
err error
}
func NewMockRedisClient() *MockRedisClient {
return &MockRedisClient{
data: make(map[string]string),
}
}
func (m *MockRedisClient) Get(ctx context.Context, key string) (string, error) {
if m.err != nil {
return "", m.err
}
if val, ok := m.data[key]; ok {
return val, nil
}
return "", redis.Nil
}
func (m *MockRedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
if m.err != nil {
return m.err
}
m.data[key] = value.(string)
return nil
}
func (m *MockRedisClient) SetError(err error) {
m.err = err
}
func (m *MockRedisClient) ClearError() {
m.err = nil
}
func (m *MockRedisClient) SetData(key, value string) {
m.data[key] = value
}
func (m *MockRedisClient) Clear() {
m.data = make(map[string]string)
m.err = nil
}
// 测试辅助函数:生成测试用的密钥对
func generateTestKeyPair(t *testing.T) *rsa.PrivateKey {
privateKey, err := GenerateKeyPair()
if err != nil {
t.Fatalf("生成密钥对失败: %v", err)
}
return privateKey
}
func TestNewYggdrasilJWTService(t *testing.T) {
privateKey := generateTestKeyPair(t)
tests := []struct {
name string
issuer string
expected string
}{
{
name: "自定义issuer",
issuer: "test-issuer",
expected: "test-issuer",
},
{
name: "空issuer使用默认值",
issuer: "",
expected: "carrotskin",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewYggdrasilJWTService(privateKey, tt.issuer)
if service == nil {
t.Fatal("服务创建失败")
}
if service.issuer != tt.expected {
t.Errorf("期望issuer为 %s实际为 %s", tt.expected, service.issuer)
}
if service.privateKey == nil {
t.Error("私钥不应为nil")
}
if service.publicKey == nil {
t.Error("公钥不应为nil")
}
})
}
}
func TestYggdrasilJWTService_GenerateAccessToken(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
if token == "" {
t.Error("Token不应为空")
}
// 验证Token可以解析
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token失败: %v", err)
}
if claims.UserID != userID {
t.Errorf("期望UserID为 %d实际为 %d", userID, claims.UserID)
}
if claims.Subject != clientUUID {
t.Errorf("期望Subject为 %s实际为 %s", clientUUID, claims.Subject)
}
if claims.Version != version {
t.Errorf("期望Version为 %d实际为 %d", version, claims.Version)
}
if claims.ProfileID != profileID {
t.Errorf("期望ProfileID为 %s实际为 %s", profileID, claims.ProfileID)
}
if claims.Issuer != "test-issuer" {
t.Errorf("期望Issuer为 test-issuer实际为 %s", claims.Issuer)
}
}
func TestYggdrasilJWTService_ParseAccessToken(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
// 生成Token
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
tests := []struct {
name string
token string
policy StaleTokenPolicy
expectError bool
}{
{
name: "有效Token允许过期",
token: token,
policy: StalePolicyAllow,
expectError: false,
},
{
name: "有效Token拒绝过期",
token: token,
policy: StalePolicyDeny,
expectError: false,
},
{
name: "无效Token",
token: "invalid-token",
policy: StalePolicyAllow,
expectError: true,
},
{
name: "空Token",
token: "",
policy: StalePolicyAllow,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := service.ParseAccessToken(tt.token, tt.policy)
if tt.expectError {
if err == nil {
t.Error("期望出现错误,但没有错误")
}
if claims != nil {
t.Error("期望claims为nil")
}
} else {
if err != nil {
t.Errorf("不期望出现错误,但出现: %v", err)
}
if claims == nil {
t.Error("claims不应为nil")
}
}
})
}
}
func TestYggdrasilJWTService_ParseAccessToken_Expired(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成已过期的Token
expiresAt := time.Now().Add(-1 * time.Hour) // 1小时前过期
staleAt := time.Now().Add(30 * 24 * time.Hour)
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid", expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 使用StalePolicyDeny应该拒绝过期TokenJWT库会自动检查过期时间
_, err = service.ParseAccessToken(token, StalePolicyDeny)
if err == nil {
t.Error("期望拒绝过期Token但没有错误")
}
// 注意JWT库在解析时会自动验证过期时间即使使用StalePolicyAllow
// 所以过期Token无法解析这是JWT库的行为
// 如果需要支持过期Token需要在解析时禁用过期验证但这不是标准做法
_, err = service.ParseAccessToken(token, StalePolicyAllow)
if err == nil {
t.Log("注意JWT库会自动拒绝过期Token即使使用StalePolicyAllow")
}
}
func TestYggdrasilJWTService_ParseAccessToken_WrongKey(t *testing.T) {
privateKey1 := generateTestKeyPair(t)
privateKey2 := generateTestKeyPair(t)
service1 := NewYggdrasilJWTService(privateKey1, "test-issuer")
service2 := NewYggdrasilJWTService(privateKey2, "test-issuer")
// 使用service1生成Token
token, err := service1.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 使用service2不同密钥解析Token应该失败
_, err = service2.ParseAccessToken(token, StalePolicyAllow)
if err == nil {
t.Error("期望使用错误密钥解析Token失败但没有错误")
}
}
func TestYggdrasilJWTService_GetPublicKey(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
publicKey := service.GetPublicKey()
if publicKey == nil {
t.Error("公钥不应为nil")
}
// 验证公钥与私钥匹配
if publicKey != nil && privateKey != nil {
if publicKey.N.Cmp(privateKey.PublicKey.N) != 0 {
t.Error("公钥与私钥不匹配")
}
}
}
func TestNewYggdrasilJWTManager(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
if manager == nil {
t.Fatal("管理器创建失败")
}
if manager.redisClient != mockRedis {
t.Error("Redis客户端未正确设置")
}
}
func TestYggdrasilJWTManager_SetPrivateKey(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
manager.SetPrivateKey(privateKey)
// 验证JWT服务已创建
service, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service == nil {
t.Fatal("JWT服务不应为nil")
}
// 验证服务可以正常工作
if service.GetPublicKey() == nil {
t.Error("公钥不应为nil")
}
}
func TestYggdrasilJWTManager_GetJWTService_FromPrivateKey(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
manager.SetPrivateKey(privateKey)
// 第一次获取
service1, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
// 第二次获取应该返回同一个实例
service2, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service1 != service2 {
t.Error("应该返回同一个JWT服务实例")
}
}
func TestYggdrasilJWTManager_GetJWTService_FromRedis(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
if err != nil {
t.Fatalf("编码私钥失败: %v", err)
}
// 设置Redis数据
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
// 获取JWT服务
service, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service == nil {
t.Error("JWT服务不应为nil")
}
// 验证服务可以正常工作
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
if token == "" {
t.Error("Token不应为空")
}
}
func TestYggdrasilJWTManager_GetJWTService_RedisError(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
// 设置Redis错误
mockRedis.SetError(errors.New("redis connection error"))
// 尝试获取JWT服务应该失败
_, err := manager.GetJWTService()
if err == nil {
t.Error("期望出现错误,但没有错误")
}
}
func TestYggdrasilJWTManager_GetJWTService_InvalidPEM(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
// 设置无效的PEM数据
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, "invalid-pem-data")
// 尝试获取JWT服务应该失败
_, err := manager.GetJWTService()
if err == nil {
t.Error("期望出现错误,但没有错误")
}
}
func TestYggdrasilJWTManager_GetJWTService_Concurrent(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
if err != nil {
t.Fatalf("编码私钥失败: %v", err)
}
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
// 并发获取JWT服务
const numGoroutines = 10
results := make(chan *YggdrasilJWTService, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
service, err := manager.GetJWTService()
if err != nil {
errors <- err
return
}
results <- service
}()
}
// 收集结果
services := make(map[*YggdrasilJWTService]bool)
for i := 0; i < numGoroutines; i++ {
select {
case service := <-results:
services[service] = true
case err := <-errors:
t.Fatalf("获取JWT服务失败: %v", err)
}
}
// 所有goroutine应该返回同一个服务实例
if len(services) != 1 {
t.Errorf("期望所有goroutine返回同一个服务实例但得到 %d 个不同的实例", len(services))
}
}
func TestYggdrasilTokenClaims_EmptyProfileID(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成没有ProfileID的Token
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 解析Token
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token失败: %v", err)
}
if claims.ProfileID != "" {
t.Errorf("期望ProfileID为空实际为 %s", claims.ProfileID)
}
}
func TestYggdrasilJWTService_VersionMismatch(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成Version=1的Token
token1, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 生成Version=2的Token
token2, err := service.GenerateAccessToken(123, "client-uuid", 2, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 解析两个Token
claims1, err := service.ParseAccessToken(token1, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token1失败: %v", err)
}
claims2, err := service.ParseAccessToken(token2, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token2失败: %v", err)
}
// 验证Version不同
if claims1.Version == claims2.Version {
t.Error("两个Token的Version应该不同")
}
if claims1.Version != 1 {
t.Errorf("期望Token1的Version为1实际为 %d", claims1.Version)
}
if claims2.Version != 2 {
t.Errorf("期望Token2的Version为2实际为 %d", claims2.Version)
}
}
// 基准测试
func BenchmarkGenerateAccessToken(b *testing.B) {
privateKey := generateTestKeyPair(&testing.T{})
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
b.Fatalf("生成Token失败: %v", err)
}
}
}
func BenchmarkParseAccessToken(b *testing.B) {
privateKey := generateTestKeyPair(&testing.T{})
service := NewYggdrasilJWTService(privateKey, "test-issuer")
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
b.Fatalf("生成Token失败: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
b.Fatalf("解析Token失败: %v", err)
}
}
}