Files
backend/pkg/auth/yggdrasil_jwt_test.go

554 lines
14 KiB
Go
Raw Normal View History

package auth
import (
"context"
"crypto/rsa"
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
)
// MockRedisClient 模拟Redis客户端
type MockRedisClient struct {
data map[string]string
err error
}
func NewMockRedisClient() *MockRedisClient {
return &MockRedisClient{
data: make(map[string]string),
}
}
func (m *MockRedisClient) Get(ctx context.Context, key string) (string, error) {
if m.err != nil {
return "", m.err
}
if val, ok := m.data[key]; ok {
return val, nil
}
return "", redis.Nil
}
func (m *MockRedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error {
if m.err != nil {
return m.err
}
m.data[key] = value.(string)
return nil
}
func (m *MockRedisClient) SetError(err error) {
m.err = err
}
func (m *MockRedisClient) ClearError() {
m.err = nil
}
func (m *MockRedisClient) SetData(key, value string) {
m.data[key] = value
}
func (m *MockRedisClient) Clear() {
m.data = make(map[string]string)
m.err = nil
}
// 测试辅助函数:生成测试用的密钥对
func generateTestKeyPair(t *testing.T) *rsa.PrivateKey {
privateKey, err := GenerateKeyPair()
if err != nil {
t.Fatalf("生成密钥对失败: %v", err)
}
return privateKey
}
func TestNewYggdrasilJWTService(t *testing.T) {
privateKey := generateTestKeyPair(t)
tests := []struct {
name string
issuer string
expected string
}{
{
name: "自定义issuer",
issuer: "test-issuer",
expected: "test-issuer",
},
{
name: "空issuer使用默认值",
issuer: "",
expected: "carrotskin",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewYggdrasilJWTService(privateKey, tt.issuer)
if service == nil {
t.Fatal("服务创建失败")
}
if service.issuer != tt.expected {
t.Errorf("期望issuer为 %s实际为 %s", tt.expected, service.issuer)
}
if service.privateKey == nil {
t.Error("私钥不应为nil")
}
if service.publicKey == nil {
t.Error("公钥不应为nil")
}
})
}
}
func TestYggdrasilJWTService_GenerateAccessToken(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
if token == "" {
t.Error("Token不应为空")
}
// 验证Token可以解析
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token失败: %v", err)
}
if claims.UserID != userID {
t.Errorf("期望UserID为 %d实际为 %d", userID, claims.UserID)
}
if claims.Subject != clientUUID {
t.Errorf("期望Subject为 %s实际为 %s", clientUUID, claims.Subject)
}
if claims.Version != version {
t.Errorf("期望Version为 %d实际为 %d", version, claims.Version)
}
if claims.ProfileID != profileID {
t.Errorf("期望ProfileID为 %s实际为 %s", profileID, claims.ProfileID)
}
if claims.Issuer != "test-issuer" {
t.Errorf("期望Issuer为 test-issuer实际为 %s", claims.Issuer)
}
}
func TestYggdrasilJWTService_ParseAccessToken(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
// 生成Token
token, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
tests := []struct {
name string
token string
policy StaleTokenPolicy
expectError bool
}{
{
name: "有效Token允许过期",
token: token,
policy: StalePolicyAllow,
expectError: false,
},
{
name: "有效Token拒绝过期",
token: token,
policy: StalePolicyDeny,
expectError: false,
},
{
name: "无效Token",
token: "invalid-token",
policy: StalePolicyAllow,
expectError: true,
},
{
name: "空Token",
token: "",
policy: StalePolicyAllow,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := service.ParseAccessToken(tt.token, tt.policy)
if tt.expectError {
if err == nil {
t.Error("期望出现错误,但没有错误")
}
if claims != nil {
t.Error("期望claims为nil")
}
} else {
if err != nil {
t.Errorf("不期望出现错误,但出现: %v", err)
}
if claims == nil {
t.Error("claims不应为nil")
}
}
})
}
}
func TestYggdrasilJWTService_ParseAccessToken_Expired(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成已过期的Token
expiresAt := time.Now().Add(-1 * time.Hour) // 1小时前过期
staleAt := time.Now().Add(30 * 24 * time.Hour)
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid", expiresAt, staleAt)
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 使用StalePolicyDeny应该拒绝过期TokenJWT库会自动检查过期时间
_, err = service.ParseAccessToken(token, StalePolicyDeny)
if err == nil {
t.Error("期望拒绝过期Token但没有错误")
}
// 注意JWT库在解析时会自动验证过期时间即使使用StalePolicyAllow
// 所以过期Token无法解析这是JWT库的行为
// 如果需要支持过期Token需要在解析时禁用过期验证但这不是标准做法
_, err = service.ParseAccessToken(token, StalePolicyAllow)
if err == nil {
t.Log("注意JWT库会自动拒绝过期Token即使使用StalePolicyAllow")
}
}
func TestYggdrasilJWTService_ParseAccessToken_WrongKey(t *testing.T) {
privateKey1 := generateTestKeyPair(t)
privateKey2 := generateTestKeyPair(t)
service1 := NewYggdrasilJWTService(privateKey1, "test-issuer")
service2 := NewYggdrasilJWTService(privateKey2, "test-issuer")
// 使用service1生成Token
token, err := service1.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 使用service2不同密钥解析Token应该失败
_, err = service2.ParseAccessToken(token, StalePolicyAllow)
if err == nil {
t.Error("期望使用错误密钥解析Token失败但没有错误")
}
}
func TestYggdrasilJWTService_GetPublicKey(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
publicKey := service.GetPublicKey()
if publicKey == nil {
t.Error("公钥不应为nil")
}
// 验证公钥与私钥匹配
if publicKey != nil && privateKey != nil {
if publicKey.N.Cmp(privateKey.PublicKey.N) != 0 {
t.Error("公钥与私钥不匹配")
}
}
}
func TestNewYggdrasilJWTManager(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
if manager == nil {
t.Fatal("管理器创建失败")
}
if manager.redisClient != mockRedis {
t.Error("Redis客户端未正确设置")
}
}
func TestYggdrasilJWTManager_SetPrivateKey(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
manager.SetPrivateKey(privateKey)
// 验证JWT服务已创建
service, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service == nil {
t.Fatal("JWT服务不应为nil")
}
// 验证服务可以正常工作
if service.GetPublicKey() == nil {
t.Error("公钥不应为nil")
}
}
func TestYggdrasilJWTManager_GetJWTService_FromPrivateKey(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
manager.SetPrivateKey(privateKey)
// 第一次获取
service1, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
// 第二次获取应该返回同一个实例
service2, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service1 != service2 {
t.Error("应该返回同一个JWT服务实例")
}
}
func TestYggdrasilJWTManager_GetJWTService_FromRedis(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
if err != nil {
t.Fatalf("编码私钥失败: %v", err)
}
// 设置Redis数据
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
// 获取JWT服务
service, err := manager.GetJWTService()
if err != nil {
t.Fatalf("获取JWT服务失败: %v", err)
}
if service == nil {
t.Error("JWT服务不应为nil")
}
// 验证服务可以正常工作
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
if token == "" {
t.Error("Token不应为空")
}
}
func TestYggdrasilJWTManager_GetJWTService_RedisError(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
// 设置Redis错误
mockRedis.SetError(errors.New("redis connection error"))
// 尝试获取JWT服务应该失败
_, err := manager.GetJWTService()
if err == nil {
t.Error("期望出现错误,但没有错误")
}
}
func TestYggdrasilJWTManager_GetJWTService_InvalidPEM(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
// 设置无效的PEM数据
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, "invalid-pem-data")
// 尝试获取JWT服务应该失败
_, err := manager.GetJWTService()
if err == nil {
t.Error("期望出现错误,但没有错误")
}
}
func TestYggdrasilJWTManager_GetJWTService_Concurrent(t *testing.T) {
mockRedis := NewMockRedisClient()
manager := NewYggdrasilJWTManager(mockRedis)
privateKey := generateTestKeyPair(t)
privateKeyPEM, err := EncodePrivateKeyToPEM(privateKey)
if err != nil {
t.Fatalf("编码私钥失败: %v", err)
}
mockRedis.SetData(YggdrasilPrivateKeyRedisKey, privateKeyPEM)
// 并发获取JWT服务
const numGoroutines = 10
results := make(chan *YggdrasilJWTService, numGoroutines)
errors := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
service, err := manager.GetJWTService()
if err != nil {
errors <- err
return
}
results <- service
}()
}
// 收集结果
services := make(map[*YggdrasilJWTService]bool)
for i := 0; i < numGoroutines; i++ {
select {
case service := <-results:
services[service] = true
case err := <-errors:
t.Fatalf("获取JWT服务失败: %v", err)
}
}
// 所有goroutine应该返回同一个服务实例
if len(services) != 1 {
t.Errorf("期望所有goroutine返回同一个服务实例但得到 %d 个不同的实例", len(services))
}
}
func TestYggdrasilTokenClaims_EmptyProfileID(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成没有ProfileID的Token
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 解析Token
claims, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token失败: %v", err)
}
if claims.ProfileID != "" {
t.Errorf("期望ProfileID为空实际为 %s", claims.ProfileID)
}
}
func TestYggdrasilJWTService_VersionMismatch(t *testing.T) {
privateKey := generateTestKeyPair(t)
service := NewYggdrasilJWTService(privateKey, "test-issuer")
// 生成Version=1的Token
token1, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 生成Version=2的Token
token2, err := service.GenerateAccessToken(123, "client-uuid", 2, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
// 解析两个Token
claims1, err := service.ParseAccessToken(token1, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token1失败: %v", err)
}
claims2, err := service.ParseAccessToken(token2, StalePolicyAllow)
if err != nil {
t.Fatalf("解析Token2失败: %v", err)
}
// 验证Version不同
if claims1.Version == claims2.Version {
t.Error("两个Token的Version应该不同")
}
if claims1.Version != 1 {
t.Errorf("期望Token1的Version为1实际为 %d", claims1.Version)
}
if claims2.Version != 2 {
t.Errorf("期望Token2的Version为2实际为 %d", claims2.Version)
}
}
// 基准测试
func BenchmarkGenerateAccessToken(b *testing.B) {
privateKey := generateTestKeyPair(&testing.T{})
service := NewYggdrasilJWTService(privateKey, "test-issuer")
userID := int64(123)
clientUUID := "test-client-uuid"
version := 1
profileID := "test-profile-uuid"
expiresAt := time.Now().Add(24 * time.Hour)
staleAt := time.Now().Add(30 * 24 * time.Hour)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := service.GenerateAccessToken(userID, clientUUID, version, profileID, expiresAt, staleAt)
if err != nil {
b.Fatalf("生成Token失败: %v", err)
}
}
}
func BenchmarkParseAccessToken(b *testing.B) {
privateKey := generateTestKeyPair(&testing.T{})
service := NewYggdrasilJWTService(privateKey, "test-issuer")
token, err := service.GenerateAccessToken(123, "client-uuid", 1, "profile-uuid",
time.Now().Add(24*time.Hour), time.Now().Add(30*24*time.Hour))
if err != nil {
b.Fatalf("生成Token失败: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := service.ParseAccessToken(token, StalePolicyAllow)
if err != nil {
b.Fatalf("解析Token失败: %v", err)
}
}
}