236 lines
5.6 KiB
Go
236 lines
5.6 KiB
Go
package auth
|
||
|
||
import (
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// TestNewJWTService 测试创建JWT服务
|
||
func TestNewJWTService(t *testing.T) {
|
||
secretKey := "test-secret-key"
|
||
expireHours := 24
|
||
|
||
service := NewJWTService(secretKey, expireHours)
|
||
if service == nil {
|
||
t.Fatal("NewJWTService() 返回nil")
|
||
}
|
||
|
||
if service.secretKey != secretKey {
|
||
t.Errorf("secretKey = %q, want %q", service.secretKey, secretKey)
|
||
}
|
||
|
||
if service.expireHours != expireHours {
|
||
t.Errorf("expireHours = %d, want %d", service.expireHours, expireHours)
|
||
}
|
||
}
|
||
|
||
// TestJWTService_GenerateToken 测试生成Token
|
||
func TestJWTService_GenerateToken(t *testing.T) {
|
||
service := NewJWTService("test-secret-key", 24)
|
||
|
||
tests := []struct {
|
||
name string
|
||
userID int64
|
||
username string
|
||
role string
|
||
wantError bool
|
||
}{
|
||
{
|
||
name: "正常生成Token",
|
||
userID: 1,
|
||
username: "testuser",
|
||
role: "user",
|
||
wantError: false,
|
||
},
|
||
{
|
||
name: "空用户名",
|
||
userID: 1,
|
||
username: "",
|
||
role: "user",
|
||
wantError: false, // JWT允许空字符串
|
||
},
|
||
{
|
||
name: "空角色",
|
||
userID: 1,
|
||
username: "testuser",
|
||
role: "",
|
||
wantError: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
token, err := service.GenerateToken(tt.userID, tt.username, tt.role)
|
||
if (err != nil) != tt.wantError {
|
||
t.Errorf("GenerateToken() error = %v, wantError %v", err, tt.wantError)
|
||
return
|
||
}
|
||
if !tt.wantError {
|
||
if token == "" {
|
||
t.Error("GenerateToken() 返回的token不应为空")
|
||
}
|
||
// 验证token长度合理(JWT token通常很长)
|
||
if len(token) < 50 {
|
||
t.Errorf("GenerateToken() 返回的token长度异常: %d", len(token))
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestJWTService_ValidateToken 测试验证Token
|
||
func TestJWTService_ValidateToken(t *testing.T) {
|
||
secretKey := "test-secret-key"
|
||
service := NewJWTService(secretKey, 24)
|
||
|
||
// 生成一个有效的token
|
||
userID := int64(1)
|
||
username := "testuser"
|
||
role := "user"
|
||
token, err := service.GenerateToken(userID, username, role)
|
||
if err != nil {
|
||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
token string
|
||
wantError bool
|
||
wantUserID int64
|
||
wantUsername string
|
||
wantRole string
|
||
}{
|
||
{
|
||
name: "有效token",
|
||
token: token,
|
||
wantError: false,
|
||
wantUserID: userID,
|
||
wantUsername: username,
|
||
wantRole: role,
|
||
},
|
||
{
|
||
name: "无效token",
|
||
token: "invalid.token.here",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "空token",
|
||
token: "",
|
||
wantError: true,
|
||
},
|
||
{
|
||
name: "使用不同密钥签名的token",
|
||
token: func() string {
|
||
otherService := NewJWTService("different-secret", 24)
|
||
token, _ := otherService.GenerateToken(1, "user", "role")
|
||
return token
|
||
}(),
|
||
wantError: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
claims, err := service.ValidateToken(tt.token)
|
||
if (err != nil) != tt.wantError {
|
||
t.Errorf("ValidateToken() error = %v, wantError %v", err, tt.wantError)
|
||
return
|
||
}
|
||
if !tt.wantError {
|
||
if claims == nil {
|
||
t.Fatal("ValidateToken() 返回的claims不应为nil")
|
||
}
|
||
if claims.UserID != tt.wantUserID {
|
||
t.Errorf("UserID = %d, want %d", claims.UserID, tt.wantUserID)
|
||
}
|
||
if claims.Username != tt.wantUsername {
|
||
t.Errorf("Username = %q, want %q", claims.Username, tt.wantUsername)
|
||
}
|
||
if claims.Role != tt.wantRole {
|
||
t.Errorf("Role = %q, want %q", claims.Role, tt.wantRole)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestJWTService_TokenRoundTrip 测试Token的完整流程
|
||
func TestJWTService_TokenRoundTrip(t *testing.T) {
|
||
service := NewJWTService("test-secret-key", 24)
|
||
|
||
userID := int64(123)
|
||
username := "testuser"
|
||
role := "admin"
|
||
|
||
// 生成token
|
||
token, err := service.GenerateToken(userID, username, role)
|
||
if err != nil {
|
||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||
}
|
||
|
||
// 验证token
|
||
claims, err := service.ValidateToken(token)
|
||
if err != nil {
|
||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||
}
|
||
|
||
// 验证claims内容
|
||
if claims.UserID != userID {
|
||
t.Errorf("UserID = %d, want %d", claims.UserID, userID)
|
||
}
|
||
if claims.Username != username {
|
||
t.Errorf("Username = %q, want %q", claims.Username, username)
|
||
}
|
||
if claims.Role != role {
|
||
t.Errorf("Role = %q, want %q", claims.Role, role)
|
||
}
|
||
}
|
||
|
||
// TestJWTService_TokenExpiration 测试Token过期时间
|
||
func TestJWTService_TokenExpiration(t *testing.T) {
|
||
expireHours := 24
|
||
service := NewJWTService("test-secret-key", expireHours)
|
||
|
||
token, err := service.GenerateToken(1, "user", "role")
|
||
if err != nil {
|
||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||
}
|
||
|
||
claims, err := service.ValidateToken(token)
|
||
if err != nil {
|
||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||
}
|
||
|
||
// 验证过期时间
|
||
if claims.ExpiresAt == nil {
|
||
t.Error("ExpiresAt 不应为nil")
|
||
} else {
|
||
expectedExpiry := time.Now().Add(time.Duration(expireHours) * time.Hour)
|
||
// 允许1分钟的误差
|
||
diff := claims.ExpiresAt.Time.Sub(expectedExpiry)
|
||
if diff < -time.Minute || diff > time.Minute {
|
||
t.Errorf("ExpiresAt 时间异常: %v, 期望约 %v", claims.ExpiresAt.Time, expectedExpiry)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestJWTService_TokenIssuer 测试Token发行者
|
||
func TestJWTService_TokenIssuer(t *testing.T) {
|
||
service := NewJWTService("test-secret-key", 24)
|
||
|
||
token, err := service.GenerateToken(1, "user", "role")
|
||
if err != nil {
|
||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||
}
|
||
|
||
claims, err := service.ValidateToken(token)
|
||
if err != nil {
|
||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||
}
|
||
|
||
expectedIssuer := "carrotskin"
|
||
if claims.Issuer != expectedIssuer {
|
||
t.Errorf("Issuer = %q, want %q", claims.Issuer, expectedIssuer)
|
||
}
|
||
}
|