chore: 初始化仓库,排除二进制文件和覆盖率文件
This commit is contained in:
70
pkg/auth/jwt.go
Normal file
70
pkg/auth/jwt.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// JWTService JWT服务
|
||||
type JWTService struct {
|
||||
secretKey string
|
||||
expireHours int
|
||||
}
|
||||
|
||||
// NewJWTService 创建新的JWT服务
|
||||
func NewJWTService(secretKey string, expireHours int) *JWTService {
|
||||
return &JWTService{
|
||||
secretKey: secretKey,
|
||||
expireHours: expireHours,
|
||||
}
|
||||
}
|
||||
|
||||
// Claims JWT声明
|
||||
type Claims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT Token (使用UserID和基本信息)
|
||||
func (j *JWTService) GenerateToken(userID int64, username, role string) (string, error) {
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Role: role,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(j.expireHours) * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
NotBefore: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "carrotskin",
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString([]byte(j.secretKey))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tokenString, nil
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT Token
|
||||
func (j *JWTService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(j.secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("无效的token")
|
||||
}
|
||||
235
pkg/auth/jwt_test.go
Normal file
235
pkg/auth/jwt_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewJWTService 测试创建JWT服务
|
||||
func TestNewJWTService(t *testing.T) {
|
||||
secretKey := "test-secret-key"
|
||||
expireHours := 24
|
||||
|
||||
service := NewJWTService(secretKey, expireHours)
|
||||
if service == nil {
|
||||
t.Fatal("NewJWTService() 返回nil")
|
||||
}
|
||||
|
||||
if service.secretKey != secretKey {
|
||||
t.Errorf("secretKey = %q, want %q", service.secretKey, secretKey)
|
||||
}
|
||||
|
||||
if service.expireHours != expireHours {
|
||||
t.Errorf("expireHours = %d, want %d", service.expireHours, expireHours)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_GenerateToken 测试生成Token
|
||||
func TestJWTService_GenerateToken(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
username string
|
||||
role string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "正常生成Token",
|
||||
userID: 1,
|
||||
username: "testuser",
|
||||
role: "user",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空用户名",
|
||||
userID: 1,
|
||||
username: "",
|
||||
role: "user",
|
||||
wantError: false, // JWT允许空字符串
|
||||
},
|
||||
{
|
||||
name: "空角色",
|
||||
userID: 1,
|
||||
username: "testuser",
|
||||
role: "",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := service.GenerateToken(tt.userID, tt.username, tt.role)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("GenerateToken() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if token == "" {
|
||||
t.Error("GenerateToken() 返回的token不应为空")
|
||||
}
|
||||
// 验证token长度合理(JWT token通常很长)
|
||||
if len(token) < 50 {
|
||||
t.Errorf("GenerateToken() 返回的token长度异常: %d", len(token))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_ValidateToken 测试验证Token
|
||||
func TestJWTService_ValidateToken(t *testing.T) {
|
||||
secretKey := "test-secret-key"
|
||||
service := NewJWTService(secretKey, 24)
|
||||
|
||||
// 生成一个有效的token
|
||||
userID := int64(1)
|
||||
username := "testuser"
|
||||
role := "user"
|
||||
token, err := service.GenerateToken(userID, username, role)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantError bool
|
||||
wantUserID int64
|
||||
wantUsername string
|
||||
wantRole string
|
||||
}{
|
||||
{
|
||||
name: "有效token",
|
||||
token: token,
|
||||
wantError: false,
|
||||
wantUserID: userID,
|
||||
wantUsername: username,
|
||||
wantRole: role,
|
||||
},
|
||||
{
|
||||
name: "无效token",
|
||||
token: "invalid.token.here",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "空token",
|
||||
token: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "使用不同密钥签名的token",
|
||||
token: func() string {
|
||||
otherService := NewJWTService("different-secret", 24)
|
||||
token, _ := otherService.GenerateToken(1, "user", "role")
|
||||
return token
|
||||
}(),
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := service.ValidateToken(tt.token)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("ValidateToken() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if claims == nil {
|
||||
t.Fatal("ValidateToken() 返回的claims不应为nil")
|
||||
}
|
||||
if claims.UserID != tt.wantUserID {
|
||||
t.Errorf("UserID = %d, want %d", claims.UserID, tt.wantUserID)
|
||||
}
|
||||
if claims.Username != tt.wantUsername {
|
||||
t.Errorf("Username = %q, want %q", claims.Username, tt.wantUsername)
|
||||
}
|
||||
if claims.Role != tt.wantRole {
|
||||
t.Errorf("Role = %q, want %q", claims.Role, tt.wantRole)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenRoundTrip 测试Token的完整流程
|
||||
func TestJWTService_TokenRoundTrip(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
userID := int64(123)
|
||||
username := "testuser"
|
||||
role := "admin"
|
||||
|
||||
// 生成token
|
||||
token, err := service.GenerateToken(userID, username, role)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证claims内容
|
||||
if claims.UserID != userID {
|
||||
t.Errorf("UserID = %d, want %d", claims.UserID, userID)
|
||||
}
|
||||
if claims.Username != username {
|
||||
t.Errorf("Username = %q, want %q", claims.Username, username)
|
||||
}
|
||||
if claims.Role != role {
|
||||
t.Errorf("Role = %q, want %q", claims.Role, role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenExpiration 测试Token过期时间
|
||||
func TestJWTService_TokenExpiration(t *testing.T) {
|
||||
expireHours := 24
|
||||
service := NewJWTService("test-secret-key", expireHours)
|
||||
|
||||
token, err := service.GenerateToken(1, "user", "role")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if claims.ExpiresAt == nil {
|
||||
t.Error("ExpiresAt 不应为nil")
|
||||
} else {
|
||||
expectedExpiry := time.Now().Add(time.Duration(expireHours) * time.Hour)
|
||||
// 允许1分钟的误差
|
||||
diff := claims.ExpiresAt.Time.Sub(expectedExpiry)
|
||||
if diff < -time.Minute || diff > time.Minute {
|
||||
t.Errorf("ExpiresAt 时间异常: %v, 期望约 %v", claims.ExpiresAt.Time, expectedExpiry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTService_TokenIssuer 测试Token发行者
|
||||
func TestJWTService_TokenIssuer(t *testing.T) {
|
||||
service := NewJWTService("test-secret-key", 24)
|
||||
|
||||
token, err := service.GenerateToken(1, "user", "role")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken() 失败: %v", err)
|
||||
}
|
||||
|
||||
expectedIssuer := "carrotskin"
|
||||
if claims.Issuer != expectedIssuer {
|
||||
t.Errorf("Issuer = %q, want %q", claims.Issuer, expectedIssuer)
|
||||
}
|
||||
}
|
||||
45
pkg/auth/manager.go
Normal file
45
pkg/auth/manager.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
// jwtServiceInstance 全局JWT服务实例
|
||||
jwtServiceInstance *JWTService
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化JWT服务(线程安全,只会执行一次)
|
||||
func Init(cfg config.JWTConfig) error {
|
||||
once.Do(func() {
|
||||
jwtServiceInstance = NewJWTService(cfg.Secret, cfg.ExpireHours)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJWTService 获取JWT服务实例(线程安全)
|
||||
func GetJWTService() (*JWTService, error) {
|
||||
if jwtServiceInstance == nil {
|
||||
return nil, fmt.Errorf("JWT服务未初始化,请先调用 auth.Init()")
|
||||
}
|
||||
return jwtServiceInstance, nil
|
||||
}
|
||||
|
||||
// MustGetJWTService 获取JWT服务实例,如果未初始化则panic
|
||||
func MustGetJWTService() *JWTService {
|
||||
service, err := GetJWTService()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
86
pkg/auth/manager_test.go
Normal file
86
pkg/auth/manager_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetJWTService_NotInitialized 测试未初始化时获取JWT服务
|
||||
func TestGetJWTService_NotInitialized(t *testing.T) {
|
||||
_, err := GetJWTService()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "JWT服务未初始化,请先调用 auth.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetJWTService_Panic 测试MustGetJWTService在未初始化时panic
|
||||
func TestMustGetJWTService_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetJWTService 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetJWTService()
|
||||
}
|
||||
|
||||
// TestInit_JWTService 测试JWT服务初始化
|
||||
func TestInit_JWTService(t *testing.T) {
|
||||
cfg := config.JWTConfig{
|
||||
Secret: "test-secret-key",
|
||||
ExpireHours: 24,
|
||||
}
|
||||
|
||||
err := Init(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("Init() 错误 = %v, want nil", err)
|
||||
}
|
||||
|
||||
// 验证可以获取服务
|
||||
service, err := GetJWTService()
|
||||
if err != nil {
|
||||
t.Errorf("GetJWTService() 错误 = %v, want nil", err)
|
||||
}
|
||||
if service == nil {
|
||||
t.Error("GetJWTService() 返回的服务不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInit_JWTService_Once 测试Init只执行一次
|
||||
func TestInit_JWTService_Once(t *testing.T) {
|
||||
cfg := config.JWTConfig{
|
||||
Secret: "test-secret-key-1",
|
||||
ExpireHours: 24,
|
||||
}
|
||||
|
||||
// 第一次初始化
|
||||
err1 := Init(cfg)
|
||||
if err1 != nil {
|
||||
t.Fatalf("第一次Init() 错误 = %v", err1)
|
||||
}
|
||||
|
||||
service1, _ := GetJWTService()
|
||||
|
||||
// 第二次初始化(应该不会改变服务)
|
||||
cfg2 := config.JWTConfig{
|
||||
Secret: "test-secret-key-2",
|
||||
ExpireHours: 48,
|
||||
}
|
||||
err2 := Init(cfg2)
|
||||
if err2 != nil {
|
||||
t.Fatalf("第二次Init() 错误 = %v", err2)
|
||||
}
|
||||
|
||||
service2, _ := GetJWTService()
|
||||
|
||||
// 验证是同一个实例(sync.Once保证)
|
||||
if service1 != service2 {
|
||||
t.Error("Init应该只执行一次,返回同一个实例")
|
||||
}
|
||||
}
|
||||
|
||||
20
pkg/auth/password.go
Normal file
20
pkg/auth/password.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// HashPassword 使用bcrypt加密密码
|
||||
func HashPassword(password string) (string, error) {
|
||||
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hashedBytes), nil
|
||||
}
|
||||
|
||||
// CheckPassword 验证密码是否匹配
|
||||
func CheckPassword(hashedPassword, password string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
145
pkg/auth/password_test.go
Normal file
145
pkg/auth/password_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestHashPassword 测试密码加密
|
||||
func TestHashPassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "正常密码",
|
||||
password: "testpassword123",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空密码",
|
||||
password: "",
|
||||
wantError: false, // bcrypt允许空密码
|
||||
},
|
||||
{
|
||||
name: "长密码",
|
||||
password: "thisisaverylongpasswordthatexceedsnormallength",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "包含特殊字符的密码",
|
||||
password: "P@ssw0rd!#$%",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hashed, err := HashPassword(tt.password)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("HashPassword() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
// 验证哈希值不为空
|
||||
if hashed == "" {
|
||||
t.Error("HashPassword() 返回的哈希值不应为空")
|
||||
}
|
||||
// 验证哈希值与原密码不同
|
||||
if hashed == tt.password {
|
||||
t.Error("HashPassword() 返回的哈希值不应与原密码相同")
|
||||
}
|
||||
// 验证哈希值长度合理(bcrypt哈希通常是60个字符)
|
||||
if len(hashed) < 50 {
|
||||
t.Errorf("HashPassword() 返回的哈希值长度异常: %d", len(hashed))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckPassword 测试密码验证
|
||||
func TestCheckPassword(t *testing.T) {
|
||||
// 先加密一个密码
|
||||
password := "testpassword123"
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hashedPassword string
|
||||
password string
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "密码匹配",
|
||||
hashedPassword: hashed,
|
||||
password: password,
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "密码不匹配",
|
||||
hashedPassword: hashed,
|
||||
password: "wrongpassword",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "空密码与空哈希",
|
||||
hashedPassword: "",
|
||||
password: "",
|
||||
wantMatch: false, // 空哈希无法验证
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CheckPassword(tt.hashedPassword, tt.password)
|
||||
if result != tt.wantMatch {
|
||||
t.Errorf("CheckPassword() = %v, want %v", result, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHashPassword_Uniqueness 测试每次加密结果不同
|
||||
func TestHashPassword_Uniqueness(t *testing.T) {
|
||||
password := "testpassword123"
|
||||
|
||||
// 多次加密同一密码
|
||||
hashes := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证每次加密的结果都不同(由于salt)
|
||||
if hashes[hashed] {
|
||||
t.Errorf("第%d次加密的结果与之前重复", i+1)
|
||||
}
|
||||
hashes[hashed] = true
|
||||
|
||||
// 但都能验证通过
|
||||
if !CheckPassword(hashed, password) {
|
||||
t.Errorf("第%d次加密的哈希无法验证原密码", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckPassword_Consistency 测试密码验证的一致性
|
||||
func TestCheckPassword_Consistency(t *testing.T) {
|
||||
password := "testpassword123"
|
||||
hashed, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword() 失败: %v", err)
|
||||
}
|
||||
|
||||
// 多次验证应该结果一致
|
||||
for i := 0; i < 10; i++ {
|
||||
if !CheckPassword(hashed, password) {
|
||||
t.Errorf("第%d次验证失败", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user