chore: 初始化仓库,排除二进制文件和覆盖率文件
Some checks failed
SonarQube Analysis / sonarqube (push) Has been cancelled
Test / test (push) Has been cancelled
Test / lint (push) Has been cancelled
Test / build (push) Has been cancelled

This commit is contained in:
lan
2025-11-28 23:30:49 +08:00
commit 4b4980820f
107 changed files with 20755 additions and 0 deletions

View File

@@ -0,0 +1,165 @@
package service
import (
"carrotskin/pkg/config"
"carrotskin/pkg/redis"
"context"
"errors"
"fmt"
"log"
"time"
"github.com/google/uuid"
"github.com/wenlng/go-captcha-assets/resources/imagesv2"
"github.com/wenlng/go-captcha-assets/resources/tiles"
"github.com/wenlng/go-captcha/v2/slide"
)
var (
slideTileCapt slide.Captcha
cfg *config.Config
)
// 常量定义业务相关配置与Redis连接配置分离
const (
redisKeyPrefix = "captcha:" // Redis键前缀便于区分业务
paddingValue = 3 // 验证允许的误差像素±3px
)
// Init 验证码图初始化
func init() {
cfg, _ = config.Load()
// 从默认仓库中获取主图
builder := slide.NewBuilder()
bgImage, err := imagesv2.GetImages()
if err != nil {
log.Fatalln(err)
}
// 滑块形状获取
graphs := getSlideTileGraphArr()
builder.SetResources(
slide.WithGraphImages(graphs),
slide.WithBackgrounds(bgImage),
)
slideTileCapt = builder.Make()
if slideTileCapt == nil {
log.Fatalln("验证码实例初始化失败")
}
}
// getSlideTileGraphArr 滑块选择
func getSlideTileGraphArr() []*slide.GraphImage {
graphs, err := tiles.GetTiles()
if err != nil {
log.Fatalln(err)
}
var newGraphs = make([]*slide.GraphImage, 0, len(graphs))
for i := 0; i < len(graphs); i++ {
graph := graphs[i]
newGraphs = append(newGraphs, &slide.GraphImage{
OverlayImage: graph.OverlayImage,
MaskImage: graph.MaskImage,
ShadowImage: graph.ShadowImage,
})
}
return newGraphs
}
// RedisData 存储到Redis的验证信息仅包含校验必需字段
type RedisData struct {
Tx int `json:"tx"` // 滑块目标X坐标
Ty int `json:"ty"` // 滑块目标Y坐标
}
// GenerateCaptchaData 提取生成验证码的相关信息
func GenerateCaptchaData(ctx context.Context, redisClient *redis.Client) (string, string, string, int, error) {
// 生成uuid作为验证码进程唯一标识
captchaID := uuid.NewString()
if captchaID == "" {
return "", "", "", 0, errors.New("生成验证码唯一标识失败")
}
captData, err := slideTileCapt.Generate()
if err != nil {
return "", "", "", 0, fmt.Errorf("生成验证码失败: %w", err)
}
blockData := captData.GetData()
if blockData == nil {
return "", "", "", 0, errors.New("获取验证码数据失败")
}
block, _ := json.Marshal(blockData)
var blockMap map[string]interface{}
if err := json.Unmarshal(block, &blockMap); err != nil {
return "", "", "", 0, fmt.Errorf("反序列化为map失败: %w", err)
}
// 提取x和y并转换为int类型
tx, ok := blockMap["x"].(float64)
if !ok {
return "", "", "", 0, errors.New("无法将x转换为float64")
}
var x = int(tx)
ty, ok := blockMap["y"].(float64)
if !ok {
return "", "", "", 0, errors.New("无法将y转换为float64")
}
var y = int(ty)
var mBase64, tBase64 string
mBase64, err = captData.GetMasterImage().ToBase64()
if err != nil {
return "", "", "", 0, fmt.Errorf("主图转换为base64失败: %w", err)
}
tBase64, err = captData.GetTileImage().ToBase64()
if err != nil {
return "", "", "", 0, fmt.Errorf("滑块图转换为base64失败: %w", err)
}
redisData := RedisData{
Tx: x,
Ty: y,
}
redisDataJSON, _ := json.Marshal(redisData)
redisKey := redisKeyPrefix + captchaID
expireTime := 300 * time.Second
// 使用注入的Redis客户端
if err := redisClient.Set(
ctx,
redisKey,
redisDataJSON,
expireTime,
); err != nil {
return "", "", "", 0, fmt.Errorf("存储验证码到Redis失败: %w", err)
}
return mBase64, tBase64, captchaID, y - 10, nil
}
// VerifyCaptchaData 验证用户验证码
func VerifyCaptchaData(ctx context.Context, redisClient *redis.Client, dx int, id string) (bool, error) {
redisKey := redisKeyPrefix + id
// 从Redis获取验证信息使用注入的客户端
dataJSON, err := redisClient.Get(ctx, redisKey)
if err != nil {
if redisClient.Nil(err) { // 使用封装客户端的Nil错误
return false, errors.New("验证码已过期或无效")
}
return false, fmt.Errorf("Redis查询失败: %w", err)
}
var redisData RedisData
if err := json.Unmarshal([]byte(dataJSON), &redisData); err != nil {
return false, fmt.Errorf("解析Redis数据失败: %w", err)
}
tx := redisData.Tx
ty := redisData.Ty
ok := slide.Validate(dx, ty, tx, ty, paddingValue)
// 验证后立即删除Redis记录防止重复使用
if ok {
if err := redisClient.Del(ctx, redisKey); err != nil {
// 记录警告但不影响验证结果
log.Printf("删除验证码Redis记录失败: %v", err)
}
}
return ok, nil
}

View File

@@ -0,0 +1,174 @@
package service
import (
"testing"
"time"
)
// TestCaptchaService_Constants 测试验证码服务常量
func TestCaptchaService_Constants(t *testing.T) {
if redisKeyPrefix != "captcha:" {
t.Errorf("redisKeyPrefix = %s, want 'captcha:'", redisKeyPrefix)
}
if paddingValue != 3 {
t.Errorf("paddingValue = %d, want 3", paddingValue)
}
}
// TestRedisData_Structure 测试RedisData结构
func TestRedisData_Structure(t *testing.T) {
data := RedisData{
Tx: 100,
Ty: 200,
}
if data.Tx != 100 {
t.Errorf("RedisData.Tx = %d, want 100", data.Tx)
}
if data.Ty != 200 {
t.Errorf("RedisData.Ty = %d, want 200", data.Ty)
}
}
// TestGenerateCaptchaData_Logic 测试生成验证码的逻辑部分
func TestGenerateCaptchaData_Logic(t *testing.T) {
tests := []struct {
name string
captchaID string
wantErr bool
errContains string
}{
{
name: "有效的captchaID",
captchaID: "test-uuid-123",
wantErr: false,
},
{
name: "空的captchaID应该失败",
captchaID: "",
wantErr: true,
errContains: "生成验证码唯一标识失败",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试UUID验证逻辑
if tt.captchaID == "" {
if !tt.wantErr {
t.Error("空captchaID应该返回错误")
}
} else {
if tt.wantErr {
t.Error("非空captchaID不应该返回错误")
}
}
})
}
}
// TestVerifyCaptchaData_Logic 测试验证验证码的逻辑部分
func TestVerifyCaptchaData_Logic(t *testing.T) {
tests := []struct {
name string
dx int
tx int
ty int
padding int
wantValid bool
}{
{
name: "精确匹配",
dx: 100,
tx: 100,
ty: 200,
padding: 3,
wantValid: true,
},
{
name: "在误差范围内(+3",
dx: 103,
tx: 100,
ty: 200,
padding: 3,
wantValid: true,
},
{
name: "在误差范围内(-3",
dx: 97,
tx: 100,
ty: 200,
padding: 3,
wantValid: true,
},
{
name: "超出误差范围(+4",
dx: 104,
tx: 100,
ty: 200,
padding: 3,
wantValid: false,
},
{
name: "超出误差范围(-4",
dx: 96,
tx: 100,
ty: 200,
padding: 3,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 验证逻辑dx应该在[tx-padding, tx+padding]范围内
diff := tt.dx - tt.tx
if diff < 0 {
diff = -diff
}
isValid := diff <= tt.padding
if isValid != tt.wantValid {
t.Errorf("Validation failed: got %v, want %v (dx=%d, tx=%d, padding=%d)", isValid, tt.wantValid, tt.dx, tt.tx, tt.padding)
}
})
}
}
// TestVerifyCaptchaData_RedisKey 测试Redis键生成逻辑
func TestVerifyCaptchaData_RedisKey(t *testing.T) {
tests := []struct {
name string
id string
expected string
}{
{
name: "生成正确的Redis键",
id: "test-id-123",
expected: "captcha:test-id-123",
},
{
name: "空ID",
id: "",
expected: "captcha:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redisKey := redisKeyPrefix + tt.id
if redisKey != tt.expected {
t.Errorf("Redis key = %s, want %s", redisKey, tt.expected)
}
})
}
}
// TestGenerateCaptchaData_ExpireTime 测试过期时间
func TestGenerateCaptchaData_ExpireTime(t *testing.T) {
expectedExpireTime := 300 * time.Second
if expectedExpireTime != 5*time.Minute {
t.Errorf("Expire time should be 5 minutes")
}
}

View File

@@ -0,0 +1,13 @@
package service
import (
"time"
jsoniter "github.com/json-iterator/go"
)
// 统一的json变量用于整个service包
var json = jsoniter.ConfigCompatibleWithStandardLibrary
// DefaultTimeout 默认超时时间
const DefaultTimeout = 5 * time.Second

View File

@@ -0,0 +1,48 @@
package service
import (
"testing"
"time"
)
// TestCommon_Constants 测试common包的常量
func TestCommon_Constants(t *testing.T) {
if DefaultTimeout != 5*time.Second {
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
}
}
// TestCommon_JSON 测试JSON变量
func TestCommon_JSON(t *testing.T) {
// 验证json变量不为nil
if json == nil {
t.Error("json 变量不应为nil")
}
// 测试JSON序列化
testData := map[string]interface{}{
"name": "test",
"age": 25,
}
bytes, err := json.Marshal(testData)
if err != nil {
t.Fatalf("json.Marshal() 失败: %v", err)
}
if len(bytes) == 0 {
t.Error("json.Marshal() 返回的字节不应为空")
}
// 测试JSON反序列化
var result map[string]interface{}
err = json.Unmarshal(bytes, &result)
if err != nil {
t.Fatalf("json.Unmarshal() 失败: %v", err)
}
if result["name"] != "test" {
t.Errorf("反序列化结果 name = %v, want 'test'", result["name"])
}
}

View File

@@ -0,0 +1,252 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"gorm.io/gorm"
)
// CreateProfile 创建档案
func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) {
// 1. 验证用户存在
user, err := repository.FindUserByID(userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("用户不存在")
}
return nil, fmt.Errorf("查询用户失败: %w", err)
}
if user.Status != 1 {
return nil, fmt.Errorf("用户状态异常")
}
// 2. 检查角色名是否已存在
existingName, err := repository.FindProfileByName(name)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("查询角色名失败: %w", err)
}
if existingName != nil {
return nil, fmt.Errorf("角色名已被使用")
}
// 3. 生成UUID
profileUUID := uuid.New().String()
// 4. 生成RSA密钥对
privateKey, err := generateRSAPrivateKey()
if err != nil {
return nil, fmt.Errorf("生成RSA密钥失败: %w", err)
}
// 5. 创建档案
profile := &model.Profile{
UUID: profileUUID,
UserID: userID,
Name: name,
RSAPrivateKey: privateKey,
IsActive: true, // 新创建的档案默认为活跃状态
}
if err := repository.CreateProfile(profile); err != nil {
return nil, fmt.Errorf("创建档案失败: %w", err)
}
// 6. 将用户的其他档案设置为非活跃
if err := repository.SetActiveProfile(profileUUID, userID); err != nil {
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
}
return profile, nil
}
// GetProfileByUUID 获取档案详情
func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) {
profile, err := repository.FindProfileByUUID(uuid)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("档案不存在")
}
return nil, fmt.Errorf("查询档案失败: %w", err)
}
return profile, nil
}
// GetUserProfiles 获取用户的所有档案
func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) {
profiles, err := repository.FindProfilesByUserID(userID)
if err != nil {
return nil, fmt.Errorf("查询档案列表失败: %w", err)
}
return profiles, nil
}
// UpdateProfile 更新档案
func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
// 1. 查询档案
profile, err := repository.FindProfileByUUID(uuid)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("档案不存在")
}
return nil, fmt.Errorf("查询档案失败: %w", err)
}
// 2. 验证权限
if profile.UserID != userID {
return nil, fmt.Errorf("无权操作此档案")
}
// 3. 检查角色名是否重复
if name != nil && *name != profile.Name {
existingName, err := repository.FindProfileByName(*name)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("查询角色名失败: %w", err)
}
if existingName != nil {
return nil, fmt.Errorf("角色名已被使用")
}
profile.Name = *name
}
// 4. 更新皮肤和披风
if skinID != nil {
profile.SkinID = skinID
}
if capeID != nil {
profile.CapeID = capeID
}
// 5. 保存更新
if err := repository.UpdateProfile(profile); err != nil {
return nil, fmt.Errorf("更新档案失败: %w", err)
}
// 6. 重新加载关联数据
return repository.FindProfileByUUID(uuid)
}
// DeleteProfile 删除档案
func DeleteProfile(db *gorm.DB, uuid string, userID int64) error {
// 1. 查询档案
profile, err := repository.FindProfileByUUID(uuid)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("档案不存在")
}
return fmt.Errorf("查询档案失败: %w", err)
}
// 2. 验证权限
if profile.UserID != userID {
return fmt.Errorf("无权操作此档案")
}
// 3. 删除档案
if err := repository.DeleteProfile(uuid); err != nil {
return fmt.Errorf("删除档案失败: %w", err)
}
return nil
}
// SetActiveProfile 设置活跃档案
func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error {
// 1. 查询档案
profile, err := repository.FindProfileByUUID(uuid)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("档案不存在")
}
return fmt.Errorf("查询档案失败: %w", err)
}
// 2. 验证权限
if profile.UserID != userID {
return fmt.Errorf("无权操作此档案")
}
// 3. 设置活跃状态
if err := repository.SetActiveProfile(uuid, userID); err != nil {
return fmt.Errorf("设置活跃状态失败: %w", err)
}
// 4. 更新最后使用时间
if err := repository.UpdateProfileLastUsedAt(uuid); err != nil {
return fmt.Errorf("更新使用时间失败: %w", err)
}
return nil
}
// CheckProfileLimit 检查用户档案数量限制
func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error {
count, err := repository.CountProfilesByUserID(userID)
if err != nil {
return fmt.Errorf("查询档案数量失败: %w", err)
}
if int(count) >= maxProfiles {
return fmt.Errorf("已达到档案数量上限(%d个", maxProfiles)
}
return nil
}
// generateRSAPrivateKey 生成RSA-2048私钥PEM格式
func generateRSAPrivateKey() (string, error) {
// 生成2048位RSA密钥对
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", err
}
// 将私钥编码为PEM格式
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: privateKeyBytes,
})
return string(privateKeyPEM), nil
}
func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) {
if userId == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空")
}
profile, err := repository.FindProfileByUUID(UUID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在")
}
return false, fmt.Errorf("验证配置文件失败: %w", err)
}
return profile.UserID == userId, nil
}
func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) {
profiles, err := repository.GetProfilesByNames(names)
if err != nil {
return nil, fmt.Errorf("查找失败: %w", err)
}
return profiles, nil
}
// GetProfileKeyPair 从 PostgreSQL 获取密钥对GORM 实现,无手动 SQL
func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) {
keyPair, err := repository.GetProfileKeyPair(profileId)
if err != nil {
return nil, fmt.Errorf("查找失败: %w", err)
}
return keyPair, nil
}

View File

@@ -0,0 +1,406 @@
package service
import (
"testing"
)
// TestProfileService_Validation 测试Profile服务验证逻辑
func TestProfileService_Validation(t *testing.T) {
tests := []struct {
name string
userID int64
profileName string
wantValid bool
}{
{
name: "有效的用户ID和角色名",
userID: 1,
profileName: "TestProfile",
wantValid: true,
},
{
name: "用户ID为0时无效",
userID: 0,
profileName: "TestProfile",
wantValid: false,
},
{
name: "角色名为空时无效",
userID: 1,
profileName: "",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.userID > 0 && tt.profileName != ""
if isValid != tt.wantValid {
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestProfileService_StatusValidation 测试用户状态验证
func TestProfileService_StatusValidation(t *testing.T) {
tests := []struct {
name string
status int16
wantValid bool
}{
{
name: "状态为1正常时有效",
status: 1,
wantValid: true,
},
{
name: "状态为0禁用时无效",
status: 0,
wantValid: false,
},
{
name: "状态为-1删除时无效",
status: -1,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.status == 1
if isValid != tt.wantValid {
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestProfileService_IsActiveDefault 测试Profile默认活跃状态
func TestProfileService_IsActiveDefault(t *testing.T) {
// 新创建的档案默认为活跃状态
isActive := true
if !isActive {
t.Error("新创建的Profile应该默认为活跃状态")
}
}
// TestUpdateProfile_PermissionCheck 测试更新Profile的权限检查逻辑
func TestUpdateProfile_PermissionCheck(t *testing.T) {
tests := []struct {
name string
profileUserID int64
requestUserID int64
wantErr bool
}{
{
name: "用户ID匹配允许操作",
profileUserID: 1,
requestUserID: 1,
wantErr: false,
},
{
name: "用户ID不匹配拒绝操作",
profileUserID: 1,
requestUserID: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.profileUserID != tt.requestUserID
if hasError != tt.wantErr {
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestUpdateProfile_NameValidation 测试更新Profile时名称验证逻辑
func TestUpdateProfile_NameValidation(t *testing.T) {
tests := []struct {
name string
currentName string
newName *string
shouldCheck bool
}{
{
name: "名称未改变,不检查",
currentName: "TestProfile",
newName: stringPtr("TestProfile"),
shouldCheck: false,
},
{
name: "名称改变,需要检查",
currentName: "TestProfile",
newName: stringPtr("NewProfile"),
shouldCheck: true,
},
{
name: "名称为nil不检查",
currentName: "TestProfile",
newName: nil,
shouldCheck: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldCheck := tt.newName != nil && *tt.newName != tt.currentName
if shouldCheck != tt.shouldCheck {
t.Errorf("Name validation check failed: got %v, want %v", shouldCheck, tt.shouldCheck)
}
})
}
}
// TestDeleteProfile_PermissionCheck 测试删除Profile的权限检查
func TestDeleteProfile_PermissionCheck(t *testing.T) {
tests := []struct {
name string
profileUserID int64
requestUserID int64
wantErr bool
}{
{
name: "用户ID匹配允许删除",
profileUserID: 1,
requestUserID: 1,
wantErr: false,
},
{
name: "用户ID不匹配拒绝删除",
profileUserID: 1,
requestUserID: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.profileUserID != tt.requestUserID
if hasError != tt.wantErr {
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestSetActiveProfile_PermissionCheck 测试设置活跃Profile的权限检查
func TestSetActiveProfile_PermissionCheck(t *testing.T) {
tests := []struct {
name string
profileUserID int64
requestUserID int64
wantErr bool
}{
{
name: "用户ID匹配允许设置",
profileUserID: 1,
requestUserID: 1,
wantErr: false,
},
{
name: "用户ID不匹配拒绝设置",
profileUserID: 1,
requestUserID: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.profileUserID != tt.requestUserID
if hasError != tt.wantErr {
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestCheckProfileLimit_Logic 测试Profile数量限制检查逻辑
func TestCheckProfileLimit_Logic(t *testing.T) {
tests := []struct {
name string
count int
maxProfiles int
wantErr bool
}{
{
name: "未达到上限",
count: 5,
maxProfiles: 10,
wantErr: false,
},
{
name: "达到上限",
count: 10,
maxProfiles: 10,
wantErr: true,
},
{
name: "超过上限",
count: 15,
maxProfiles: 10,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.count >= tt.maxProfiles
if hasError != tt.wantErr {
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestValidateProfileByUserID_InputValidation 测试ValidateProfileByUserID输入验证
func TestValidateProfileByUserID_InputValidation(t *testing.T) {
tests := []struct {
name string
userID int64
uuid string
wantErr bool
}{
{
name: "有效输入",
userID: 1,
uuid: "test-uuid",
wantErr: false,
},
{
name: "userID为0",
userID: 0,
uuid: "test-uuid",
wantErr: true,
},
{
name: "uuid为空",
userID: 1,
uuid: "",
wantErr: true,
},
{
name: "两者都无效",
userID: 0,
uuid: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.userID == 0 || tt.uuid == ""
if hasError != tt.wantErr {
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestValidateProfileByUserID_UserIDMatching 测试用户ID匹配逻辑
func TestValidateProfileByUserID_UserIDMatching(t *testing.T) {
tests := []struct {
name string
profileUserID int64
requestUserID int64
wantValid bool
}{
{
name: "用户ID匹配",
profileUserID: 1,
requestUserID: 1,
wantValid: true,
},
{
name: "用户ID不匹配",
profileUserID: 1,
requestUserID: 2,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.profileUserID == tt.requestUserID
if isValid != tt.wantValid {
t.Errorf("UserID matching failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestGenerateRSAPrivateKey 测试RSA私钥生成
func TestGenerateRSAPrivateKey(t *testing.T) {
tests := []struct {
name string
wantError bool
}{
{
name: "生成RSA私钥",
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
privateKey, err := generateRSAPrivateKey()
if (err != nil) != tt.wantError {
t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError)
return
}
if !tt.wantError {
if privateKey == "" {
t.Error("generateRSAPrivateKey() 返回的私钥不应为空")
}
// 验证PEM格式
if len(privateKey) < 100 {
t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey))
}
// 验证包含PEM头部
if !contains(privateKey, "BEGIN RSA PRIVATE KEY") {
t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部")
}
}
})
}
}
// TestGenerateRSAPrivateKey_Uniqueness 测试RSA私钥唯一性
func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) {
keys := make(map[string]bool)
for i := 0; i < 10; i++ {
key, err := generateRSAPrivateKey()
if err != nil {
t.Fatalf("generateRSAPrivateKey() 失败: %v", err)
}
if keys[key] {
t.Errorf("第%d次生成的密钥与之前重复", i+1)
}
keys[key] = true
}
}
// 辅助函数
func stringPtr(s string) *string {
return &s
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr ||
(len(s) > len(substr) && (s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
containsMiddle(s, substr))))
}
func containsMiddle(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,97 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/pkg/redis"
"encoding/base64"
"go.uber.org/zap"
"time"
"gorm.io/gorm"
)
type Property struct {
Name string `json:"name"`
Value string `json:"value"`
Signature string `json:"signature,omitempty"`
}
func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, p model.Profile) map[string]interface{} {
var err error
// 创建基本材质数据
texturesMap := make(map[string]interface{})
textures := map[string]interface{}{
"timestamp": time.Now().UnixMilli(),
"profileId": p.UUID,
"profileName": p.Name,
"textures": texturesMap,
}
// 处理皮肤
if p.SkinID != nil {
skin, err := GetTextureByID(db, *p.SkinID)
if err != nil {
logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID))
} else {
texturesMap["SKIN"] = map[string]interface{}{
"url": skin.URL,
"metadata": skin.Size,
}
}
}
// 处理披风
if p.CapeID != nil {
cape, err := GetTextureByID(db, *p.CapeID)
if err != nil {
logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID))
} else {
texturesMap["CAPE"] = map[string]interface{}{
"url": cape.URL,
"metadata": cape.Size,
}
}
}
// 将textures编码为base64
bytes, err := json.Marshal(textures)
if err != nil {
logger.Error("[ERROR] 序列化textures失败: ", zap.Error(err))
return nil
}
textureData := base64.StdEncoding.EncodeToString(bytes)
signature, err := SignStringWithSHA1withRSA(logger, redisClient, textureData)
if err != nil {
logger.Error("[ERROR] 签名textures失败: ", zap.Error(err))
return nil
}
// 构建结果
data := map[string]interface{}{
"id": p.UUID,
"name": p.Name,
"properties": []Property{
{
Name: "textures",
Value: textureData,
Signature: signature,
},
},
}
return data
}
func SerializeUser(logger *zap.Logger, u *model.User, UUID string) map[string]interface{} {
if u == nil {
logger.Error("[ERROR] 尝试序列化空用户")
return nil
}
data := map[string]interface{}{
"id": UUID,
"properties": u.Properties,
}
return data
}

View File

@@ -0,0 +1,172 @@
package service
import (
"carrotskin/internal/model"
"testing"
"go.uber.org/zap/zaptest"
)
// TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户
func TestSerializeUser_NilUser(t *testing.T) {
logger := zaptest.NewLogger(t)
result := SerializeUser(logger, nil, "test-uuid")
if result != nil {
t.Error("SerializeUser() 对于nil用户应返回nil")
}
}
// TestSerializeUser_ActualCall 实际调用SerializeUser函数
func TestSerializeUser_ActualCall(t *testing.T) {
logger := zaptest.NewLogger(t)
user := &model.User{
ID: 1,
Username: "testuser",
Email: "test@example.com",
Properties: "{}",
}
result := SerializeUser(logger, user, "test-uuid-123")
if result == nil {
t.Fatal("SerializeUser() 返回的结果不应为nil")
}
if result["id"] != "test-uuid-123" {
t.Errorf("id = %v, want 'test-uuid-123'", result["id"])
}
if result["properties"] == nil {
t.Error("properties 不应为nil")
}
}
// TestProperty_Structure 测试Property结构
func TestProperty_Structure(t *testing.T) {
prop := Property{
Name: "textures",
Value: "base64value",
Signature: "signature",
}
if prop.Name == "" {
t.Error("Property name should not be empty")
}
if prop.Value == "" {
t.Error("Property value should not be empty")
}
// Signature是可选的
if prop.Signature == "" {
t.Log("Property signature is optional")
}
}
// TestSerializeService_PropertyFields 测试Property字段
func TestSerializeService_PropertyFields(t *testing.T) {
tests := []struct {
name string
property Property
wantValid bool
}{
{
name: "有效的Property",
property: Property{
Name: "textures",
Value: "base64value",
Signature: "signature",
},
wantValid: true,
},
{
name: "缺少Name的Property",
property: Property{
Name: "",
Value: "base64value",
Signature: "signature",
},
wantValid: false,
},
{
name: "缺少Value的Property",
property: Property{
Name: "textures",
Value: "",
Signature: "signature",
},
wantValid: false,
},
{
name: "没有Signature的Property有效",
property: Property{
Name: "textures",
Value: "base64value",
Signature: "",
},
wantValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.property.Name != "" && tt.property.Value != ""
if isValid != tt.wantValid {
t.Errorf("Property validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestSerializeUser_InputValidation 测试SerializeUser输入验证
func TestSerializeUser_InputValidation(t *testing.T) {
tests := []struct {
name string
user *struct{}
wantValid bool
}{
{
name: "用户不为nil",
user: &struct{}{},
wantValid: true,
},
{
name: "用户为nil",
user: nil,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.user != nil
if isValid != tt.wantValid {
t.Errorf("Input validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestSerializeProfile_Structure 测试SerializeProfile返回结构
func TestSerializeProfile_Structure(t *testing.T) {
// 测试返回的数据结构应该包含的字段
expectedFields := []string{"id", "name", "properties"}
// 验证字段名称
for _, field := range expectedFields {
if field == "" {
t.Error("Field name should not be empty")
}
}
// 验证properties应该是数组
// 注意:这里只测试逻辑,不测试实际序列化
}
// TestSerializeProfile_PropertyName 测试Property名称
func TestSerializeProfile_PropertyName(t *testing.T) {
// textures是固定的属性名
propertyName := "textures"
if propertyName != "textures" {
t.Errorf("Property name = %s, want 'textures'", propertyName)
}
}

View File

@@ -0,0 +1,605 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/redis"
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/pem"
"fmt"
"go.uber.org/zap"
"strconv"
"strings"
"time"
"gorm.io/gorm"
)
// 常量定义
const (
// RSA密钥长度
RSAKeySize = 4096
// Redis密钥名称
PrivateKeyRedisKey = "private_key"
PublicKeyRedisKey = "public_key"
// 密钥过期时间
KeyExpirationTime = time.Hour * 24 * 7
// 证书相关
CertificateRefreshInterval = time.Hour * 24 // 证书刷新时间间隔
CertificateExpirationPeriod = time.Hour * 24 * 7 // 证书过期时间
)
// PlayerCertificate 表示玩家证书信息
type PlayerCertificate struct {
ExpiresAt string `json:"expiresAt"`
RefreshedAfter string `json:"refreshedAfter"`
PublicKeySignature string `json:"publicKeySignature,omitempty"`
PublicKeySignatureV2 string `json:"publicKeySignatureV2,omitempty"`
KeyPair struct {
PrivateKey string `json:"privateKey"`
PublicKey string `json:"publicKey"`
} `json:"keyPair"`
}
// SignatureService 保留结构体以保持向后兼容,但推荐使用函数式版本
type SignatureService struct {
logger *zap.Logger
redisClient *redis.Client
}
func NewSignatureService(logger *zap.Logger, redisClient *redis.Client) *SignatureService {
return &SignatureService{
logger: logger,
redisClient: redisClient,
}
}
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串并返回Base64编码的签名函数式版本
func SignStringWithSHA1withRSA(logger *zap.Logger, redisClient *redis.Client, data string) (string, error) {
if data == "" {
return "", fmt.Errorf("签名数据不能为空")
}
// 获取私钥
privateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
if err != nil {
logger.Error("[ERROR] 解码私钥失败: ", zap.Error(err))
return "", fmt.Errorf("解码私钥失败: %w", err)
}
// 计算SHA1哈希
hashed := sha1.Sum([]byte(data))
// 使用RSA-PKCS1v15算法签名
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA1, hashed[:])
if err != nil {
logger.Error("[ERROR] RSA签名失败: ", zap.Error(err))
return "", fmt.Errorf("RSA签名失败: %w", err)
}
// Base64编码签名
encodedSignature := base64.StdEncoding.EncodeToString(signature)
logger.Info("[INFO] 成功使用SHA1withRSA生成签名,", zap.Any("数据长度:", len(data)))
return encodedSignature, nil
}
// SignStringWithSHA1withRSAService 使用SHA1withRSA签名字符串并返回Base64编码的签名结构体方法版本保持向后兼容
func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
return SignStringWithSHA1withRSA(s.logger, s.redisClient, data)
}
// DecodePrivateKeyFromPEM 从Redis获取并解码PEM格式的私钥函数式版本
func DecodePrivateKeyFromPEM(logger *zap.Logger, redisClient *redis.Client) (*rsa.PrivateKey, error) {
// 从Redis获取私钥
privateKeyString, err := GetPrivateKeyFromRedis(logger, redisClient)
if err != nil {
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
}
// 解码PEM格式
privateKeyBlock, rest := pem.Decode([]byte(privateKeyString))
if privateKeyBlock == nil || len(rest) > 0 {
logger.Error("[ERROR] 无效的PEM格式私钥")
return nil, fmt.Errorf("无效的PEM格式私钥")
}
// 解析PKCS1格式的私钥
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
logger.Error("[ERROR] 解析私钥失败: ", zap.Error(err))
return nil, fmt.Errorf("解析私钥失败: %w", err)
}
return privateKey, nil
}
// GetPrivateKeyFromRedis 从Redis获取私钥PEM格式函数式版本
func GetPrivateKeyFromRedis(logger *zap.Logger, redisClient *redis.Client) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
pemBytes, err := redisClient.GetBytes(ctx, PrivateKeyRedisKey)
if err != nil {
logger.Info("[INFO] 从Redis获取私钥失败尝试生成新的密钥对: ", zap.Error(err))
// 生成新的密钥对
err = GenerateRSAKeyPair(logger, redisClient)
if err != nil {
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
}
// 递归获取生成的密钥
return GetPrivateKeyFromRedis(logger, redisClient)
}
return string(pemBytes), nil
}
// DecodePrivateKeyFromPEMService 从Redis获取并解码PEM格式的私钥结构体方法版本保持向后兼容
func (s *SignatureService) DecodePrivateKeyFromPEM() (*rsa.PrivateKey, error) {
return DecodePrivateKeyFromPEM(s.logger, s.redisClient)
}
// GetPrivateKeyFromRedisService 从Redis获取私钥PEM格式结构体方法版本保持向后兼容
func (s *SignatureService) GetPrivateKeyFromRedis() (string, error) {
return GetPrivateKeyFromRedis(s.logger, s.redisClient)
}
// GenerateRSAKeyPair 生成新的RSA密钥对函数式版本
func GenerateRSAKeyPair(logger *zap.Logger, redisClient *redis.Client) error {
logger.Info("[INFO] 开始生成RSA密钥对", zap.Int("keySize", RSAKeySize))
// 生成私钥
privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
if err != nil {
logger.Error("[ERROR] 生成RSA私钥失败: ", zap.Error(err))
return fmt.Errorf("生成RSA私钥失败: %w", err)
}
// 编码私钥为PEM格式
pemPrivateKey, err := EncodePrivateKeyToPEM(privateKey)
if err != nil {
logger.Error("[ERROR] 编码RSA私钥失败: ", zap.Error(err))
return fmt.Errorf("编码RSA私钥失败: %w", err)
}
// 获取公钥并编码为PEM格式
pubKey := privateKey.PublicKey
pemPublicKey, err := EncodePublicKeyToPEM(logger, &pubKey)
if err != nil {
logger.Error("[ERROR] 编码RSA公钥失败: ", zap.Error(err))
return fmt.Errorf("编码RSA公钥失败: %w", err)
}
// 保存密钥对到Redis
return SaveKeyPairToRedis(logger, redisClient, string(pemPrivateKey), string(pemPublicKey))
}
// GenerateRSAKeyPairService 生成新的RSA密钥对结构体方法版本保持向后兼容
func (s *SignatureService) GenerateRSAKeyPair() error {
return GenerateRSAKeyPair(s.logger, s.redisClient)
}
// EncodePrivateKeyToPEM 将私钥编码为PEM格式函数式版本
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
if privateKey == nil {
return nil, fmt.Errorf("私钥不能为空")
}
// 默认使用 "PRIVATE KEY" 类型
pemType := "PRIVATE KEY"
// 如果指定了类型参数且为 "RSA",则使用 "RSA PRIVATE KEY"
if len(keyType) > 0 && keyType[0] == "RSA" {
pemType = "RSA PRIVATE KEY"
}
// 将私钥转换为PKCS1格式
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
// 编码为PEM格式
pemBlock := &pem.Block{
Type: pemType,
Bytes: privateKeyBytes,
}
return pem.EncodeToMemory(pemBlock), nil
}
// EncodePublicKeyToPEM 将公钥编码为PEM格式函数式版本
func EncodePublicKeyToPEM(logger *zap.Logger, publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
if publicKey == nil {
return nil, fmt.Errorf("公钥不能为空")
}
// 默认使用 "PUBLIC KEY" 类型
pemType := "PUBLIC KEY"
var publicKeyBytes []byte
var err error
// 如果指定了类型参数且为 "RSA",则使用 "RSA PUBLIC KEY"
if len(keyType) > 0 && keyType[0] == "RSA" {
pemType = "RSA PUBLIC KEY"
publicKeyBytes = x509.MarshalPKCS1PublicKey(publicKey)
} else {
// 默认将公钥转换为PKIX格式
publicKeyBytes, err = x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
logger.Error("[ERROR] 序列化公钥失败: ", zap.Error(err))
return nil, fmt.Errorf("序列化公钥失败: %w", err)
}
}
// 编码为PEM格式
pemBlock := &pem.Block{
Type: pemType,
Bytes: publicKeyBytes,
}
return pem.EncodeToMemory(pemBlock), nil
}
// SaveKeyPairToRedis 将RSA密钥对保存到Redis函数式版本
func SaveKeyPairToRedis(logger *zap.Logger, redisClient *redis.Client, privateKey, publicKey string) error {
// 创建上下文并设置超时
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
// 使用事务确保两个操作的原子性
tx := redisClient.TxPipeline()
tx.Set(ctx, PrivateKeyRedisKey, privateKey, KeyExpirationTime)
tx.Set(ctx, PublicKeyRedisKey, publicKey, KeyExpirationTime)
// 执行事务
_, err := tx.Exec(ctx)
if err != nil {
logger.Error("[ERROR] 保存RSA密钥对到Redis失败: ", zap.Error(err))
return fmt.Errorf("保存RSA密钥对到Redis失败: %w", err)
}
logger.Info("[INFO] 成功保存RSA密钥对到Redis")
return nil
}
// EncodePrivateKeyToPEMService 将私钥编码为PEM格式结构体方法版本保持向后兼容
func (s *SignatureService) EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
return EncodePrivateKeyToPEM(privateKey, keyType...)
}
// EncodePublicKeyToPEMService 将公钥编码为PEM格式结构体方法版本保持向后兼容
func (s *SignatureService) EncodePublicKeyToPEM(publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
return EncodePublicKeyToPEM(s.logger, publicKey, keyType...)
}
// SaveKeyPairToRedisService 将RSA密钥对保存到Redis结构体方法版本保持向后兼容
func (s *SignatureService) SaveKeyPairToRedis(privateKey, publicKey string) error {
return SaveKeyPairToRedis(s.logger, s.redisClient, privateKey, publicKey)
}
// GetPublicKeyFromRedisFunc 从Redis获取公钥PEM格式函数式版本
func GetPublicKeyFromRedisFunc(logger *zap.Logger, redisClient *redis.Client) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
pemBytes, err := redisClient.GetBytes(ctx, PublicKeyRedisKey)
if err != nil {
logger.Info("[INFO] 从Redis获取公钥失败尝试生成新的密钥对: ", zap.Error(err))
// 生成新的密钥对
err = GenerateRSAKeyPair(logger, redisClient)
if err != nil {
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
}
// 递归获取生成的密钥
return GetPublicKeyFromRedisFunc(logger, redisClient)
}
// 检查获取到的公钥是否为空key不存在时GetBytes返回nil, nil
if len(pemBytes) == 0 {
logger.Info("[INFO] Redis中公钥为空尝试生成新的密钥对")
// 生成新的密钥对
err = GenerateRSAKeyPair(logger, redisClient)
if err != nil {
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
}
// 递归获取生成的密钥
return GetPublicKeyFromRedisFunc(logger, redisClient)
}
return string(pemBytes), nil
}
// GetPublicKeyFromRedis 从Redis获取公钥PEM格式结构体方法版本
func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
return GetPublicKeyFromRedisFunc(s.logger, s.redisClient)
}
// GeneratePlayerCertificate 生成玩家证书(函数式版本)
func GeneratePlayerCertificate(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, uuid string) (*PlayerCertificate, error) {
if uuid == "" {
return nil, fmt.Errorf("UUID不能为空")
}
logger.Info("[INFO] 开始生成玩家证书用户UUID: %s",
zap.String("uuid", uuid),
)
keyPair, err := repository.GetProfileKeyPair(uuid)
if err != nil {
logger.Info("[INFO] 获取用户密钥对失败,将创建新密钥对: %v",
zap.Error(err),
zap.String("uuid", uuid),
)
keyPair = nil
}
// 如果没有找到密钥对或密钥对已过期,创建一个新的
now := time.Now().UTC()
if keyPair == nil || keyPair.Refresh.Before(now) || keyPair.PrivateKey == "" || keyPair.PublicKey == "" {
logger.Info("[INFO] 为用户创建新的密钥对: %s",
zap.String("uuid", uuid),
)
keyPair, err = NewKeyPair(logger)
if err != nil {
logger.Error("[ERROR] 生成玩家证书密钥对失败: %v",
zap.Error(err),
zap.String("uuid", uuid),
)
return nil, fmt.Errorf("生成玩家证书密钥对失败: %w", err)
}
// 保存密钥对到数据库
err = repository.UpdateProfileKeyPair(uuid, keyPair)
if err != nil {
// 日志修改logger → s.loggerzap结构化字段
logger.Warn("[WARN] 更新用户密钥对失败: %v",
zap.Error(err),
zap.String("uuid", uuid),
)
// 继续执行,即使保存失败
}
}
// 计算expiresAt的毫秒时间戳
expiresAtMillis := keyPair.Expiration.UnixMilli()
// 准备签名
publicKeySignature := ""
publicKeySignatureV2 := ""
// 获取服务器私钥用于签名
serverPrivateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
if err != nil {
// 日志修改logger → s.loggerzap结构化字段
logger.Error("[ERROR] 获取服务器私钥失败: %v",
zap.Error(err),
zap.String("uuid", uuid),
)
return nil, fmt.Errorf("获取服务器私钥失败: %w", err)
}
// 提取公钥DER编码
pubPEMBlock, _ := pem.Decode([]byte(keyPair.PublicKey))
if pubPEMBlock == nil {
// 日志修改logger → s.loggerzap结构化字段
logger.Error("[ERROR] 解码公钥PEM失败",
zap.String("uuid", uuid),
zap.String("publicKey", keyPair.PublicKey),
)
return nil, fmt.Errorf("解码公钥PEM失败")
}
pubDER := pubPEMBlock.Bytes
// 准备publicKeySignature用于MC 1.19
// Base64编码公钥不包含换行
pubBase64 := strings.ReplaceAll(base64.StdEncoding.EncodeToString(pubDER), "\n", "")
// 按76字符一行进行包装
pubBase64Wrapped := WrapString(pubBase64, 76)
// 放入PEM格式
pubMojangPEM := "-----BEGIN RSA PUBLIC KEY-----\n" +
pubBase64Wrapped +
"\n-----END RSA PUBLIC KEY-----\n"
// 签名数据: expiresAt毫秒时间戳 + 公钥PEM格式
signedData := []byte(fmt.Sprintf("%d%s", expiresAtMillis, pubMojangPEM))
// 计算SHA1哈希并签名
hash1 := sha1.Sum(signedData)
signature, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash1[:])
if err != nil {
logger.Error("[ERROR] 签名失败: %v",
zap.Error(err),
zap.String("uuid", uuid),
zap.Int64("expiresAtMillis", expiresAtMillis),
)
return nil, fmt.Errorf("签名失败: %w", err)
}
publicKeySignature = base64.StdEncoding.EncodeToString(signature)
// 准备publicKeySignatureV2用于MC 1.19.1+
var uuidBytes []byte
// 如果提供了UUID则使用它
// 移除UUID中的连字符
uuidStr := strings.ReplaceAll(uuid, "-", "")
// 将UUID转换为字节数组16字节
if len(uuidStr) < 32 {
logger.Warn("[WARN] UUID长度不足32字符使用空UUID: %s",
zap.String("uuid", uuid),
zap.String("processedUuidStr", uuidStr),
)
uuidBytes = make([]byte, 16)
} else {
// 解析UUID字符串为字节
uuidBytes = make([]byte, 16)
parseErr := error(nil)
for i := 0; i < 16; i++ {
// 每两个字符转换为一个字节
byteStr := uuidStr[i*2 : i*2+2]
byteVal, err := strconv.ParseUint(byteStr, 16, 8)
if err != nil {
parseErr = err
logger.Error("[ERROR] 解析UUID字节失败: %v, byteStr: %s",
zap.Error(err),
zap.String("uuid", uuid),
zap.String("byteStr", byteStr),
zap.Int("index", i),
)
uuidBytes = make([]byte, 16) // 出错时使用空UUID
break
}
uuidBytes[i] = byte(byteVal)
}
if parseErr != nil {
return nil, fmt.Errorf("解析UUID字节失败: %w", parseErr)
}
}
// 准备签名数据UUID + expiresAt时间戳 + DER编码的公钥
signedDataV2 := make([]byte, 0, 24+len(pubDER)) // 预分配缓冲区
// 添加UUID16字节
signedDataV2 = append(signedDataV2, uuidBytes...)
// 添加expiresAt毫秒时间戳8字节大端序
expiresAtBytes := make([]byte, 8)
binary.BigEndian.PutUint64(expiresAtBytes, uint64(expiresAtMillis))
signedDataV2 = append(signedDataV2, expiresAtBytes...)
// 添加DER编码的公钥
signedDataV2 = append(signedDataV2, pubDER...)
// 计算SHA1哈希并签名
hash2 := sha1.Sum(signedDataV2)
signatureV2, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash2[:])
if err != nil {
logger.Error("[ERROR] 签名V2失败: %v",
zap.Error(err),
zap.String("uuid", uuid),
zap.Int64("expiresAtMillis", expiresAtMillis),
)
return nil, fmt.Errorf("签名V2失败: %w", err)
}
publicKeySignatureV2 = base64.StdEncoding.EncodeToString(signatureV2)
// 创建玩家证书结构
certificate := &PlayerCertificate{
KeyPair: struct {
PrivateKey string `json:"privateKey"`
PublicKey string `json:"publicKey"`
}{
PrivateKey: keyPair.PrivateKey,
PublicKey: keyPair.PublicKey,
},
PublicKeySignature: publicKeySignature,
PublicKeySignatureV2: publicKeySignatureV2,
ExpiresAt: keyPair.Expiration.Format(time.RFC3339Nano),
RefreshedAfter: keyPair.Refresh.Format(time.RFC3339Nano),
}
logger.Info("[INFO] 成功生成玩家证书,过期时间: %s",
zap.String("uuid", uuid),
zap.String("expiresAt", certificate.ExpiresAt),
zap.String("refreshedAfter", certificate.RefreshedAfter),
)
return certificate, nil
}
// GeneratePlayerCertificateService 生成玩家证书(结构体方法版本,保持向后兼容)
func (s *SignatureService) GeneratePlayerCertificate(uuid string) (*PlayerCertificate, error) {
return GeneratePlayerCertificate(nil, s.logger, s.redisClient, uuid) // TODO: 需要传入db参数
}
// NewKeyPair 生成新的密钥对(函数式版本)
func NewKeyPair(logger *zap.Logger) (*model.KeyPair, error) {
// 生成新的RSA密钥对用于玩家证书
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) // 对玩家证书使用更小的密钥以提高性能
if err != nil {
logger.Error("[ERROR] 生成玩家证书私钥失败: %v",
zap.Error(err),
)
return nil, fmt.Errorf("生成玩家证书私钥失败: %w", err)
}
// 获取DER编码的密钥
keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
logger.Error("[ERROR] 编码私钥为PKCS8格式失败: %v",
zap.Error(err),
)
return nil, fmt.Errorf("编码私钥为PKCS8格式失败: %w", err)
}
pubDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
if err != nil {
logger.Error("[ERROR] 编码公钥为PKIX格式失败: %v",
zap.Error(err),
)
return nil, fmt.Errorf("编码公钥为PKIX格式失败: %w", err)
}
// 将密钥编码为PEM格式
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: keyDER,
})
pubPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: pubDER,
})
// 创建证书过期和刷新时间
now := time.Now().UTC()
expiresAtTime := now.Add(CertificateExpirationPeriod)
refreshedAfter := now.Add(CertificateRefreshInterval)
keyPair := &model.KeyPair{
Expiration: expiresAtTime,
PrivateKey: string(keyPEM),
PublicKey: string(pubPEM),
Refresh: refreshedAfter,
}
return keyPair, nil
}
// WrapString 将字符串按指定宽度进行换行(函数式版本)
func WrapString(str string, width int) string {
if width <= 0 {
return str
}
var b strings.Builder
for i := 0; i < len(str); i += width {
end := i + width
if end > len(str) {
end = len(str)
}
b.WriteString(str[i:end])
if end < len(str) {
b.WriteString("\n")
}
}
return b.String()
}
// NewKeyPairService 生成新的密钥对(结构体方法版本,保持向后兼容)
func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
return NewKeyPair(s.logger)
}

View File

@@ -0,0 +1,358 @@
package service
import (
"crypto/rand"
"crypto/rsa"
"strings"
"testing"
"time"
"go.uber.org/zap/zaptest"
)
// TestSignatureService_Constants 测试签名服务相关常量
func TestSignatureService_Constants(t *testing.T) {
if RSAKeySize != 4096 {
t.Errorf("RSAKeySize = %d, want 4096", RSAKeySize)
}
if PrivateKeyRedisKey == "" {
t.Error("PrivateKeyRedisKey should not be empty")
}
if PublicKeyRedisKey == "" {
t.Error("PublicKeyRedisKey should not be empty")
}
if KeyExpirationTime != 24*7*time.Hour {
t.Errorf("KeyExpirationTime = %v, want 7 days", KeyExpirationTime)
}
if CertificateRefreshInterval != 24*time.Hour {
t.Errorf("CertificateRefreshInterval = %v, want 24 hours", CertificateRefreshInterval)
}
if CertificateExpirationPeriod != 24*7*time.Hour {
t.Errorf("CertificateExpirationPeriod = %v, want 7 days", CertificateExpirationPeriod)
}
}
// TestSignatureService_DataValidation 测试签名数据验证逻辑
func TestSignatureService_DataValidation(t *testing.T) {
tests := []struct {
name string
data string
wantValid bool
}{
{
name: "非空数据有效",
data: "test data",
wantValid: true,
},
{
name: "空数据无效",
data: "",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.data != ""
if isValid != tt.wantValid {
t.Errorf("Data validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestPlayerCertificate_Structure 测试PlayerCertificate结构
func TestPlayerCertificate_Structure(t *testing.T) {
cert := PlayerCertificate{
ExpiresAt: "2025-01-01T00:00:00Z",
RefreshedAfter: "2025-01-01T00:00:00Z",
PublicKeySignature: "signature",
PublicKeySignatureV2: "signaturev2",
}
// 验证结构体字段
if cert.ExpiresAt == "" {
t.Error("ExpiresAt should not be empty")
}
if cert.RefreshedAfter == "" {
t.Error("RefreshedAfter should not be empty")
}
// PublicKeySignature是可选的
if cert.PublicKeySignature == "" {
t.Log("PublicKeySignature is optional")
}
}
// TestWrapString 测试字符串换行函数
func TestWrapString(t *testing.T) {
tests := []struct {
name string
str string
width int
expected string
}{
{
name: "正常换行",
str: "1234567890",
width: 5,
expected: "12345\n67890",
},
{
name: "字符串长度等于width",
str: "12345",
width: 5,
expected: "12345",
},
{
name: "字符串长度小于width",
str: "123",
width: 5,
expected: "123",
},
{
name: "width为0返回原字符串",
str: "1234567890",
width: 0,
expected: "1234567890",
},
{
name: "width为负数返回原字符串",
str: "1234567890",
width: -1,
expected: "1234567890",
},
{
name: "空字符串",
str: "",
width: 5,
expected: "",
},
{
name: "width为1",
str: "12345",
width: 1,
expected: "1\n2\n3\n4\n5",
},
{
name: "长字符串多次换行",
str: "123456789012345",
width: 5,
expected: "12345\n67890\n12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := WrapString(tt.str, tt.width)
if result != tt.expected {
t.Errorf("WrapString(%q, %d) = %q, want %q", tt.str, tt.width, result, tt.expected)
}
})
}
}
// TestWrapString_LineCount 测试换行后的行数
func TestWrapString_LineCount(t *testing.T) {
tests := []struct {
name string
str string
width int
wantLines int
}{
{
name: "10个字符width=5应该2行",
str: "1234567890",
width: 5,
wantLines: 2,
},
{
name: "15个字符width=5应该3行",
str: "123456789012345",
width: 5,
wantLines: 3,
},
{
name: "5个字符width=5应该1行",
str: "12345",
width: 5,
wantLines: 1,
},
{
name: "width为0应该1行",
str: "1234567890",
width: 0,
wantLines: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := WrapString(tt.str, tt.width)
lines := strings.Count(result, "\n") + 1
if lines != tt.wantLines {
t.Errorf("Line count = %d, want %d (result: %q)", lines, tt.wantLines, result)
}
})
}
}
// TestWrapString_NoTrailingNewline 测试末尾不换行
func TestWrapString_NoTrailingNewline(t *testing.T) {
str := "1234567890"
result := WrapString(str, 5)
// 验证末尾没有换行符
if strings.HasSuffix(result, "\n") {
t.Error("Result should not end with newline")
}
// 验证包含换行符(除了最后一行)
if !strings.Contains(result, "\n") {
t.Error("Result should contain newline for multi-line output")
}
}
// TestEncodePrivateKeyToPEM_ActualCall 实际调用EncodePrivateKeyToPEM函数
func TestEncodePrivateKeyToPEM_ActualCall(t *testing.T) {
// 生成测试用的RSA私钥
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("生成RSA私钥失败: %v", err)
}
tests := []struct {
name string
keyType []string
wantError bool
}{
{
name: "默认类型",
keyType: []string{},
wantError: false,
},
{
name: "RSA类型",
keyType: []string{"RSA"},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pemBytes, err := EncodePrivateKeyToPEM(privateKey, tt.keyType...)
if (err != nil) != tt.wantError {
t.Errorf("EncodePrivateKeyToPEM() error = %v, wantError %v", err, tt.wantError)
return
}
if !tt.wantError {
if len(pemBytes) == 0 {
t.Error("EncodePrivateKeyToPEM() 返回的PEM字节不应为空")
}
pemStr := string(pemBytes)
// 验证PEM格式
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
t.Error("EncodePrivateKeyToPEM() 返回的PEM格式不正确")
}
// 验证类型
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
if !strings.Contains(pemStr, "RSA PRIVATE KEY") {
t.Error("EncodePrivateKeyToPEM() 应包含 'RSA PRIVATE KEY'")
}
} else {
if !strings.Contains(pemStr, "PRIVATE KEY") {
t.Error("EncodePrivateKeyToPEM() 应包含 'PRIVATE KEY'")
}
}
}
})
}
}
// TestEncodePublicKeyToPEM_ActualCall 实际调用EncodePublicKeyToPEM函数
func TestEncodePublicKeyToPEM_ActualCall(t *testing.T) {
logger := zaptest.NewLogger(t)
// 生成测试用的RSA密钥对
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("生成RSA密钥对失败: %v", err)
}
publicKey := &privateKey.PublicKey
tests := []struct {
name string
keyType []string
wantError bool
}{
{
name: "默认类型",
keyType: []string{},
wantError: false,
},
{
name: "RSA类型",
keyType: []string{"RSA"},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pemBytes, err := EncodePublicKeyToPEM(logger, publicKey, tt.keyType...)
if (err != nil) != tt.wantError {
t.Errorf("EncodePublicKeyToPEM() error = %v, wantError %v", err, tt.wantError)
return
}
if !tt.wantError {
if len(pemBytes) == 0 {
t.Error("EncodePublicKeyToPEM() 返回的PEM字节不应为空")
}
pemStr := string(pemBytes)
// 验证PEM格式
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
t.Error("EncodePublicKeyToPEM() 返回的PEM格式不正确")
}
// 验证类型
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
if !strings.Contains(pemStr, "RSA PUBLIC KEY") {
t.Error("EncodePublicKeyToPEM() 应包含 'RSA PUBLIC KEY'")
}
} else {
if !strings.Contains(pemStr, "PUBLIC KEY") {
t.Error("EncodePublicKeyToPEM() 应包含 'PUBLIC KEY'")
}
}
}
})
}
}
// TestEncodePublicKeyToPEM_NilKey 测试nil公钥
func TestEncodePublicKeyToPEM_NilKey(t *testing.T) {
logger := zaptest.NewLogger(t)
_, err := EncodePublicKeyToPEM(logger, nil)
if err == nil {
t.Error("EncodePublicKeyToPEM() 对于nil公钥应返回错误")
}
}
// TestNewSignatureService 测试创建SignatureService
func TestNewSignatureService(t *testing.T) {
logger := zaptest.NewLogger(t)
// 注意这里需要实际的redis client但我们只测试结构体创建
// 在实际测试中可以使用mock redis client
service := NewSignatureService(logger, nil)
if service == nil {
t.Error("NewSignatureService() 不应返回nil")
}
if service.logger != logger {
t.Error("NewSignatureService() logger 设置不正确")
}
}

View File

@@ -0,0 +1,251 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"errors"
"fmt"
"gorm.io/gorm"
)
// CreateTexture 创建材质
func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
// 验证用户存在
user, err := repository.FindUserByID(uploaderID)
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("用户不存在")
}
// 检查Hash是否已存在
existingTexture, err := repository.FindTextureByHash(hash)
if err != nil {
return nil, err
}
if existingTexture != nil {
return nil, errors.New("该材质已存在")
}
// 转换材质类型
var textureTypeEnum model.TextureType
switch textureType {
case "SKIN":
textureTypeEnum = model.TextureTypeSkin
case "CAPE":
textureTypeEnum = model.TextureTypeCape
default:
return nil, errors.New("无效的材质类型")
}
// 创建材质
texture := &model.Texture{
UploaderID: uploaderID,
Name: name,
Description: description,
Type: textureTypeEnum,
URL: url,
Hash: hash,
Size: size,
IsPublic: isPublic,
IsSlim: isSlim,
Status: 1,
DownloadCount: 0,
FavoriteCount: 0,
}
if err := repository.CreateTexture(texture); err != nil {
return nil, err
}
return texture, nil
}
// GetTextureByID 根据ID获取材质
func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) {
texture, err := repository.FindTextureByID(id)
if err != nil {
return nil, err
}
if texture == nil {
return nil, errors.New("材质不存在")
}
if texture.Status == -1 {
return nil, errors.New("材质已删除")
}
return texture, nil
}
// GetUserTextures 获取用户上传的材质列表
func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
return repository.FindTexturesByUploaderID(uploaderID, page, pageSize)
}
// SearchTextures 搜索材质
func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize)
}
// UpdateTexture 更新材质
func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
// 获取材质
texture, err := repository.FindTextureByID(textureID)
if err != nil {
return nil, err
}
if texture == nil {
return nil, errors.New("材质不存在")
}
// 检查权限:只有上传者可以修改
if texture.UploaderID != uploaderID {
return nil, errors.New("无权修改此材质")
}
// 更新字段
updates := make(map[string]interface{})
if name != "" {
updates["name"] = name
}
if description != "" {
updates["description"] = description
}
if isPublic != nil {
updates["is_public"] = *isPublic
}
if len(updates) > 0 {
if err := repository.UpdateTextureFields(textureID, updates); err != nil {
return nil, err
}
}
// 返回更新后的材质
return repository.FindTextureByID(textureID)
}
// DeleteTexture 删除材质
func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error {
// 获取材质
texture, err := repository.FindTextureByID(textureID)
if err != nil {
return err
}
if texture == nil {
return errors.New("材质不存在")
}
// 检查权限:只有上传者可以删除
if texture.UploaderID != uploaderID {
return errors.New("无权删除此材质")
}
return repository.DeleteTexture(textureID)
}
// RecordTextureDownload 记录下载
func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error {
// 检查材质是否存在
texture, err := repository.FindTextureByID(textureID)
if err != nil {
return err
}
if texture == nil {
return errors.New("材质不存在")
}
// 增加下载次数
if err := repository.IncrementTextureDownloadCount(textureID); err != nil {
return err
}
// 创建下载日志
log := &model.TextureDownloadLog{
TextureID: textureID,
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
}
return repository.CreateTextureDownloadLog(log)
}
// ToggleTextureFavorite 切换收藏状态
func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) {
// 检查材质是否存在
texture, err := repository.FindTextureByID(textureID)
if err != nil {
return false, err
}
if texture == nil {
return false, errors.New("材质不存在")
}
// 检查是否已收藏
isFavorited, err := repository.IsTextureFavorited(userID, textureID)
if err != nil {
return false, err
}
if isFavorited {
// 取消收藏
if err := repository.RemoveTextureFavorite(userID, textureID); err != nil {
return false, err
}
if err := repository.DecrementTextureFavoriteCount(textureID); err != nil {
return false, err
}
return false, nil
} else {
// 添加收藏
if err := repository.AddTextureFavorite(userID, textureID); err != nil {
return false, err
}
if err := repository.IncrementTextureFavoriteCount(textureID); err != nil {
return false, err
}
return true, nil
}
}
// GetUserTextureFavorites 获取用户收藏的材质列表
func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
return repository.GetUserTextureFavorites(userID, page, pageSize)
}
// CheckTextureUploadLimit 检查用户上传材质数量限制
func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error {
count, err := repository.CountTexturesByUploaderID(uploaderID)
if err != nil {
return err
}
if count >= int64(maxTextures) {
return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures)
}
return nil
}

View File

@@ -0,0 +1,471 @@
package service
import (
"testing"
)
// TestTextureService_TypeValidation 测试材质类型验证
func TestTextureService_TypeValidation(t *testing.T) {
tests := []struct {
name string
textureType string
wantValid bool
}{
{
name: "SKIN类型有效",
textureType: "SKIN",
wantValid: true,
},
{
name: "CAPE类型有效",
textureType: "CAPE",
wantValid: true,
},
{
name: "无效类型",
textureType: "INVALID",
wantValid: false,
},
{
name: "空类型无效",
textureType: "",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.textureType == "SKIN" || tt.textureType == "CAPE"
if isValid != tt.wantValid {
t.Errorf("Texture type validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestTextureService_DefaultValues 测试材质默认值
func TestTextureService_DefaultValues(t *testing.T) {
// 测试默认状态
defaultStatus := 1
if defaultStatus != 1 {
t.Errorf("默认状态应为1实际为%d", defaultStatus)
}
// 测试默认下载数
defaultDownloadCount := 0
if defaultDownloadCount != 0 {
t.Errorf("默认下载数应为0实际为%d", defaultDownloadCount)
}
// 测试默认收藏数
defaultFavoriteCount := 0
if defaultFavoriteCount != 0 {
t.Errorf("默认收藏数应为0实际为%d", defaultFavoriteCount)
}
}
// TestTextureService_StatusValidation 测试材质状态验证
func TestTextureService_StatusValidation(t *testing.T) {
tests := []struct {
name string
status int16
wantValid bool
}{
{
name: "状态为1正常时有效",
status: 1,
wantValid: true,
},
{
name: "状态为-1删除时无效",
status: -1,
wantValid: false,
},
{
name: "状态为0时可能有效取决于业务逻辑",
status: 0,
wantValid: true, // 状态为0禁用材质仍然存在只是不可用但查询时不会返回错误
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 材质状态为-1时表示已删除无效
isValid := tt.status != -1
if isValid != tt.wantValid {
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestGetUserTextures_Pagination 测试分页逻辑
func TestGetUserTextures_Pagination(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
wantPage int
wantSize int
}{
{
name: "有效的分页参数",
page: 2,
pageSize: 20,
wantPage: 2,
wantSize: 20,
},
{
name: "page小于1应该设为1",
page: 0,
pageSize: 20,
wantPage: 1,
wantSize: 20,
},
{
name: "pageSize小于1应该设为20",
page: 1,
pageSize: 0,
wantPage: 1,
wantSize: 20,
},
{
name: "pageSize超过100应该设为20",
page: 1,
pageSize: 200,
wantPage: 1,
wantSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := tt.page
pageSize := tt.pageSize
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
if page != tt.wantPage {
t.Errorf("Page = %d, want %d", page, tt.wantPage)
}
if pageSize != tt.wantSize {
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
}
})
}
}
// TestSearchTextures_Pagination 测试搜索分页逻辑
func TestSearchTextures_Pagination(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
wantPage int
wantSize int
}{
{
name: "有效的分页参数",
page: 1,
pageSize: 10,
wantPage: 1,
wantSize: 10,
},
{
name: "page小于1应该设为1",
page: -1,
pageSize: 20,
wantPage: 1,
wantSize: 20,
},
{
name: "pageSize超过100应该设为20",
page: 1,
pageSize: 150,
wantPage: 1,
wantSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := tt.page
pageSize := tt.pageSize
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
if page != tt.wantPage {
t.Errorf("Page = %d, want %d", page, tt.wantPage)
}
if pageSize != tt.wantSize {
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
}
})
}
}
// TestUpdateTexture_PermissionCheck 测试更新材质的权限检查
func TestUpdateTexture_PermissionCheck(t *testing.T) {
tests := []struct {
name string
uploaderID int64
requestID int64
wantErr bool
}{
{
name: "上传者ID匹配允许更新",
uploaderID: 1,
requestID: 1,
wantErr: false,
},
{
name: "上传者ID不匹配拒绝更新",
uploaderID: 1,
requestID: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.uploaderID != tt.requestID
if hasError != tt.wantErr {
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestUpdateTexture_FieldUpdates 测试更新字段逻辑
func TestUpdateTexture_FieldUpdates(t *testing.T) {
tests := []struct {
name string
nameValue string
descValue string
isPublic *bool
wantUpdates int
}{
{
name: "更新所有字段",
nameValue: "NewName",
descValue: "NewDesc",
isPublic: boolPtr(true),
wantUpdates: 3,
},
{
name: "只更新名称",
nameValue: "NewName",
descValue: "",
isPublic: nil,
wantUpdates: 1,
},
{
name: "只更新描述",
nameValue: "",
descValue: "NewDesc",
isPublic: nil,
wantUpdates: 1,
},
{
name: "只更新公开状态",
nameValue: "",
descValue: "",
isPublic: boolPtr(false),
wantUpdates: 1,
},
{
name: "没有更新",
nameValue: "",
descValue: "",
isPublic: nil,
wantUpdates: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
updates := 0
if tt.nameValue != "" {
updates++
}
if tt.descValue != "" {
updates++
}
if tt.isPublic != nil {
updates++
}
if updates != tt.wantUpdates {
t.Errorf("Updates count = %d, want %d", updates, tt.wantUpdates)
}
})
}
}
// TestDeleteTexture_PermissionCheck 测试删除材质的权限检查
func TestDeleteTexture_PermissionCheck(t *testing.T) {
tests := []struct {
name string
uploaderID int64
requestID int64
wantErr bool
}{
{
name: "上传者ID匹配允许删除",
uploaderID: 1,
requestID: 1,
wantErr: false,
},
{
name: "上传者ID不匹配拒绝删除",
uploaderID: 1,
requestID: 2,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.uploaderID != tt.requestID
if hasError != tt.wantErr {
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestToggleTextureFavorite_Logic 测试切换收藏状态的逻辑
func TestToggleTextureFavorite_Logic(t *testing.T) {
tests := []struct {
name string
isFavorited bool
wantResult bool
}{
{
name: "已收藏,取消收藏",
isFavorited: true,
wantResult: false,
},
{
name: "未收藏,添加收藏",
isFavorited: false,
wantResult: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := !tt.isFavorited
if result != tt.wantResult {
t.Errorf("Toggle favorite failed: got %v, want %v", result, tt.wantResult)
}
})
}
}
// TestGetUserTextureFavorites_Pagination 测试收藏列表分页
func TestGetUserTextureFavorites_Pagination(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
wantPage int
wantSize int
}{
{
name: "有效的分页参数",
page: 1,
pageSize: 20,
wantPage: 1,
wantSize: 20,
},
{
name: "page小于1应该设为1",
page: 0,
pageSize: 20,
wantPage: 1,
wantSize: 20,
},
{
name: "pageSize超过100应该设为20",
page: 1,
pageSize: 200,
wantPage: 1,
wantSize: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := tt.page
pageSize := tt.pageSize
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
if page != tt.wantPage {
t.Errorf("Page = %d, want %d", page, tt.wantPage)
}
if pageSize != tt.wantSize {
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
}
})
}
}
// TestCheckTextureUploadLimit_Logic 测试上传限制检查逻辑
func TestCheckTextureUploadLimit_Logic(t *testing.T) {
tests := []struct {
name string
count int64
maxTextures int
wantErr bool
}{
{
name: "未达到上限",
count: 5,
maxTextures: 10,
wantErr: false,
},
{
name: "达到上限",
count: 10,
maxTextures: 10,
wantErr: true,
},
{
name: "超过上限",
count: 15,
maxTextures: 10,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.count >= int64(tt.maxTextures)
if hasError != tt.wantErr {
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// 辅助函数
func boolPtr(b bool) *bool {
return &b
}

View File

@@ -0,0 +1,277 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"context"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"go.uber.org/zap"
"strconv"
"time"
"gorm.io/gorm"
)
// 常量定义
const (
ExtendedTimeout = 10 * time.Second
TokensMaxCount = 10 // 用户最多保留的token数量
)
// NewToken 创建新令牌
func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
var (
selectedProfileID *model.Profile
availableProfiles []*model.Profile
)
// 设置超时上下文
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer cancel()
// 验证用户存在
_, err := repository.FindProfileByUUID(UUID)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
}
// 生成令牌
if clientToken == "" {
clientToken = uuid.New().String()
}
accessToken := uuid.New().String()
token := model.Token{
AccessToken: accessToken,
ClientToken: clientToken,
UserID: userId,
Usable: true,
IssueDate: time.Now(),
}
// 获取用户配置文件
profiles, err := repository.FindProfilesByUserID(userId)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
}
// 如果用户只有一个配置文件,自动选择
if len(profiles) == 1 {
selectedProfileID = profiles[0]
token.ProfileId = selectedProfileID.UUID
}
availableProfiles = profiles
// 插入令牌到tokens集合
_, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout)
defer insertCancel()
err = repository.CreateToken(&token)
if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
}
// 清理多余的令牌
go CheckAndCleanupExcessTokens(db, logger, userId)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
}
// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌只保留最新的10个
func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
if userId == 0 {
return
}
// 获取用户所有令牌,按发行日期降序排序
tokens, err := repository.GetTokensByUserId(userId)
if err != nil {
logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
return
}
// 如果令牌数量不超过上限,无需清理
if len(tokens) <= TokensMaxCount {
return
}
// 获取需要删除的令牌ID列表
tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount)
for i := TokensMaxCount; i < len(tokens); i++ {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
}
// 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数)
DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete)
if err != nil {
logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
return
}
if DeletedCount > 0 {
logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount))
}
}
// ValidToken 验证令牌有效性
func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool {
if accessToken == "" {
return false
}
// 使用投影只获取需要的字段
var token *model.Token
token, err := repository.FindTokenByID(accessToken)
if err != nil {
return false
}
if !token.Usable {
return false
}
// 如果客户端令牌为空,只验证访问令牌
if clientToken == "" {
return true
}
// 否则验证客户端令牌是否匹配
return token.ClientToken == clientToken
}
func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) {
return repository.GetUUIDByAccessToken(accessToken)
}
func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) {
return repository.GetUserIDByAccessToken(accessToken)
}
// RefreshToken 刷新令牌
func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) {
if accessToken == "" {
return "", "", errors.New("accessToken不能为空")
}
// 查找旧令牌
oldToken, err := repository.GetTokenByAccessToken(accessToken)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return "", "", errors.New("accessToken无效")
}
logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken))
return "", "", fmt.Errorf("查询令牌失败: %w", err)
}
// 验证profile
if selectedProfileID != "" {
valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID)
if validErr != nil {
logger.Error(
"验证Profile失败",
zap.Error(err),
zap.Any("userId", oldToken.UserID),
zap.String("profileId", selectedProfileID),
)
return "", "", fmt.Errorf("验证角色失败: %w", err)
}
if !valid {
return "", "", errors.New("角色与用户不匹配")
}
}
// 检查 clientToken 是否有效
if clientToken != "" && clientToken != oldToken.ClientToken {
return "", "", errors.New("clientToken无效")
}
// 检查 selectedProfileID 的逻辑
if selectedProfileID != "" {
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
}
} else {
selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色
}
// 生成新令牌
newAccessToken := uuid.New().String()
newToken := model.Token{
AccessToken: newAccessToken,
ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同
UserID: oldToken.UserID,
Usable: true,
ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色
IssueDate: time.Now(),
}
// 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌
err = repository.CreateToken(&newToken)
if err != nil {
logger.Error(
"创建新Token失败",
zap.Error(err),
zap.String("accessToken", accessToken),
)
return "", "", fmt.Errorf("创建新Token失败: %w", err)
}
err = repository.DeleteTokenByAccessToken(accessToken)
if err != nil {
// 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建
logger.Warn(
"删除旧Token失败但新Token已创建",
zap.Error(err),
zap.String("oldToken", oldToken.AccessToken),
zap.String("newToken", newAccessToken),
)
}
logger.Info(
"成功刷新Token",
zap.Any("userId", oldToken.UserID),
zap.String("accessToken", newAccessToken),
)
return newAccessToken, oldToken.ClientToken, nil
}
// InvalidToken 使令牌失效
func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) {
if accessToken == "" {
return
}
err := repository.DeleteTokenByAccessToken(accessToken)
if err != nil {
logger.Error(
"删除Token失败",
zap.Error(err),
zap.String("accessToken", accessToken),
)
return
}
logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken))
}
// InvalidUserTokens 使用户所有令牌失效
func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
if userId == 0 {
return
}
err := repository.DeleteTokenByUserId(userId)
if err != nil {
logger.Error(
"[ERROR]删除用户Token失败",
zap.Error(err),
zap.Any("userId", userId),
)
return
}
logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId))
}

View File

@@ -0,0 +1,204 @@
package service
import (
"testing"
"time"
)
// TestTokenService_Constants 测试Token服务相关常量
func TestTokenService_Constants(t *testing.T) {
if ExtendedTimeout != 10*time.Second {
t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout)
}
if TokensMaxCount != 10 {
t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount)
}
}
// TestTokenService_Timeout 测试超时常量
func TestTokenService_Timeout(t *testing.T) {
if DefaultTimeout != 5*time.Second {
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
}
if ExtendedTimeout <= DefaultTimeout {
t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout)
}
}
// TestTokenService_Validation 测试Token验证逻辑
func TestTokenService_Validation(t *testing.T) {
tests := []struct {
name string
accessToken string
wantValid bool
}{
{
name: "空token无效",
accessToken: "",
wantValid: false,
},
{
name: "非空token可能有效",
accessToken: "valid-token-string",
wantValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 测试空token检查逻辑
isValid := tt.accessToken != ""
if isValid != tt.wantValid {
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
func TestTokenService_ClientTokenLogic(t *testing.T) {
tests := []struct {
name string
clientToken string
shouldGenerate bool
}{
{
name: "空的clientToken应该生成新的",
clientToken: "",
shouldGenerate: true,
},
{
name: "非空的clientToken应该使用提供的",
clientToken: "existing-client-token",
shouldGenerate: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldGenerate := tt.clientToken == ""
if shouldGenerate != tt.shouldGenerate {
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
}
})
}
}
// TestTokenService_ProfileSelection 测试Profile选择逻辑
func TestTokenService_ProfileSelection(t *testing.T) {
tests := []struct {
name string
profileCount int
shouldAutoSelect bool
}{
{
name: "只有一个profile时自动选择",
profileCount: 1,
shouldAutoSelect: true,
},
{
name: "多个profile时不自动选择",
profileCount: 2,
shouldAutoSelect: false,
},
{
name: "没有profile时不自动选择",
profileCount: 0,
shouldAutoSelect: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldAutoSelect := tt.profileCount == 1
if shouldAutoSelect != tt.shouldAutoSelect {
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
}
})
}
}
// TestTokenService_CleanupLogic 测试清理逻辑
func TestTokenService_CleanupLogic(t *testing.T) {
tests := []struct {
name string
tokenCount int
maxCount int
shouldCleanup bool
cleanupCount int
}{
{
name: "token数量未超过上限不需要清理",
tokenCount: 5,
maxCount: 10,
shouldCleanup: false,
cleanupCount: 0,
},
{
name: "token数量超过上限需要清理",
tokenCount: 15,
maxCount: 10,
shouldCleanup: true,
cleanupCount: 5,
},
{
name: "token数量等于上限不需要清理",
tokenCount: 10,
maxCount: 10,
shouldCleanup: false,
cleanupCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldCleanup := tt.tokenCount > tt.maxCount
if shouldCleanup != tt.shouldCleanup {
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
}
if shouldCleanup {
expectedCleanupCount := tt.tokenCount - tt.maxCount
if expectedCleanupCount != tt.cleanupCount {
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
}
}
})
}
}
// TestTokenService_UserIDValidation 测试UserID验证
func TestTokenService_UserIDValidation(t *testing.T) {
tests := []struct {
name string
userID int64
isValid bool
}{
{
name: "有效的UserID",
userID: 1,
isValid: true,
},
{
name: "UserID为0时无效",
userID: 0,
isValid: false,
},
{
name: "负数UserID无效",
userID: -1,
isValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.userID > 0
if isValid != tt.isValid {
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
}
})
}
}

View File

@@ -0,0 +1,160 @@
package service
import (
"carrotskin/pkg/config"
"carrotskin/pkg/storage"
"context"
"fmt"
"path/filepath"
"strings"
"time"
)
// FileType 文件类型枚举
type FileType string
const (
FileTypeAvatar FileType = "avatar"
FileTypeTexture FileType = "texture"
)
// UploadConfig 上传配置
type UploadConfig struct {
AllowedExts map[string]bool // 允许的文件扩展名
MinSize int64 // 最小文件大小(字节)
MaxSize int64 // 最大文件大小(字节)
Expires time.Duration // URL过期时间
}
// GetUploadConfig 根据文件类型获取上传配置
func GetUploadConfig(fileType FileType) *UploadConfig {
switch fileType {
case FileTypeAvatar:
return &UploadConfig{
AllowedExts: map[string]bool{
".jpg": true,
".jpeg": true,
".png": true,
".gif": true,
".webp": true,
},
MinSize: 1024, // 1KB
MaxSize: 5 * 1024 * 1024, // 5MB
Expires: 15 * time.Minute,
}
case FileTypeTexture:
return &UploadConfig{
AllowedExts: map[string]bool{
".png": true,
},
MinSize: 1024, // 1KB
MaxSize: 10 * 1024 * 1024, // 10MB
Expires: 15 * time.Minute,
}
default:
return nil
}
}
// ValidateFileName 验证文件名
func ValidateFileName(fileName string, fileType FileType) error {
if fileName == "" {
return fmt.Errorf("文件名不能为空")
}
uploadConfig := GetUploadConfig(fileType)
if uploadConfig == nil {
return fmt.Errorf("不支持的文件类型")
}
ext := strings.ToLower(filepath.Ext(fileName))
if !uploadConfig.AllowedExts[ext] {
return fmt.Errorf("不支持的文件格式: %s", ext)
}
return nil
}
// GenerateAvatarUploadURL 生成头像上传URL
func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) {
// 1. 验证文件名
if err := ValidateFileName(fileName, FileTypeAvatar); err != nil {
return nil, err
}
// 2. 获取上传配置
uploadConfig := GetUploadConfig(FileTypeAvatar)
// 3. 获取存储桶名称
bucketName, err := storageClient.GetBucket("avatars")
if err != nil {
return nil, fmt.Errorf("获取存储桶失败: %w", err)
}
// 4. 生成对象名称(路径)
// 格式: user_{userId}/timestamp_{originalFileName}
timestamp := time.Now().Format("20060102150405")
objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName)
// 5. 生成预签名POST URL
result, err := storageClient.GeneratePresignedPostURL(
ctx,
bucketName,
objectName,
uploadConfig.MinSize,
uploadConfig.MaxSize,
uploadConfig.Expires,
cfg.UseSSL,
cfg.Endpoint,
)
if err != nil {
return nil, fmt.Errorf("生成上传URL失败: %w", err)
}
return result, nil
}
// GenerateTextureUploadURL 生成材质上传URL
func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) {
// 1. 验证文件名
if err := ValidateFileName(fileName, FileTypeTexture); err != nil {
return nil, err
}
// 2. 验证材质类型
if textureType != "SKIN" && textureType != "CAPE" {
return nil, fmt.Errorf("无效的材质类型: %s", textureType)
}
// 3. 获取上传配置
uploadConfig := GetUploadConfig(FileTypeTexture)
// 4. 获取存储桶名称
bucketName, err := storageClient.GetBucket("textures")
if err != nil {
return nil, fmt.Errorf("获取存储桶失败: %w", err)
}
// 5. 生成对象名称(路径)
// 格式: user_{userId}/{textureType}/timestamp_{originalFileName}
timestamp := time.Now().Format("20060102150405")
textureTypeFolder := strings.ToLower(textureType)
objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName)
// 6. 生成预签名POST URL
result, err := storageClient.GeneratePresignedPostURL(
ctx,
bucketName,
objectName,
uploadConfig.MinSize,
uploadConfig.MaxSize,
uploadConfig.Expires,
cfg.UseSSL,
cfg.Endpoint,
)
if err != nil {
return nil, fmt.Errorf("生成上传URL失败: %w", err)
}
return result, nil
}

View File

@@ -0,0 +1,279 @@
package service
import (
"strings"
"testing"
"time"
)
// TestUploadService_FileTypes 测试文件类型常量
func TestUploadService_FileTypes(t *testing.T) {
if FileTypeAvatar == "" {
t.Error("FileTypeAvatar should not be empty")
}
if FileTypeTexture == "" {
t.Error("FileTypeTexture should not be empty")
}
if FileTypeAvatar == FileTypeTexture {
t.Error("FileTypeAvatar and FileTypeTexture should be different")
}
}
// TestGetUploadConfig 测试获取上传配置
func TestGetUploadConfig(t *testing.T) {
tests := []struct {
name string
fileType FileType
wantConfig bool
}{
{
name: "头像类型返回配置",
fileType: FileTypeAvatar,
wantConfig: true,
},
{
name: "材质类型返回配置",
fileType: FileTypeTexture,
wantConfig: true,
},
{
name: "无效类型返回nil",
fileType: FileType("invalid"),
wantConfig: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := GetUploadConfig(tt.fileType)
hasConfig := config != nil
if hasConfig != tt.wantConfig {
t.Errorf("GetUploadConfig() = %v, want %v", hasConfig, tt.wantConfig)
}
if config != nil {
// 验证配置字段
if config.MinSize <= 0 {
t.Error("MinSize should be greater than 0")
}
if config.MaxSize <= 0 {
t.Error("MaxSize should be greater than 0")
}
if config.MaxSize < config.MinSize {
t.Error("MaxSize should be greater than or equal to MinSize")
}
if config.Expires <= 0 {
t.Error("Expires should be greater than 0")
}
if len(config.AllowedExts) == 0 {
t.Error("AllowedExts should not be empty")
}
}
})
}
}
// TestGetUploadConfig_AvatarConfig 测试头像配置详情
func TestGetUploadConfig_AvatarConfig(t *testing.T) {
config := GetUploadConfig(FileTypeAvatar)
if config == nil {
t.Fatal("Avatar config should not be nil")
}
// 验证允许的扩展名
expectedExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
for _, ext := range expectedExts {
if !config.AllowedExts[ext] {
t.Errorf("Avatar config should allow %s extension", ext)
}
}
// 验证文件大小限制
if config.MinSize != 1024 {
t.Errorf("Avatar MinSize = %d, want 1024", config.MinSize)
}
if config.MaxSize != 5*1024*1024 {
t.Errorf("Avatar MaxSize = %d, want 5MB", config.MaxSize)
}
// 验证过期时间
if config.Expires != 15*time.Minute {
t.Errorf("Avatar Expires = %v, want 15 minutes", config.Expires)
}
}
// TestGetUploadConfig_TextureConfig 测试材质配置详情
func TestGetUploadConfig_TextureConfig(t *testing.T) {
config := GetUploadConfig(FileTypeTexture)
if config == nil {
t.Fatal("Texture config should not be nil")
}
// 验证允许的扩展名材质只允许PNG
if !config.AllowedExts[".png"] {
t.Error("Texture config should allow .png extension")
}
// 验证文件大小限制
if config.MinSize != 1024 {
t.Errorf("Texture MinSize = %d, want 1024", config.MinSize)
}
if config.MaxSize != 10*1024*1024 {
t.Errorf("Texture MaxSize = %d, want 10MB", config.MaxSize)
}
// 验证过期时间
if config.Expires != 15*time.Minute {
t.Errorf("Texture Expires = %v, want 15 minutes", config.Expires)
}
}
// TestValidateFileName 测试文件名验证
func TestValidateFileName(t *testing.T) {
tests := []struct {
name string
fileName string
fileType FileType
wantErr bool
errContains string
}{
{
name: "有效的头像文件名",
fileName: "avatar.png",
fileType: FileTypeAvatar,
wantErr: false,
},
{
name: "有效的材质文件名",
fileName: "texture.png",
fileType: FileTypeTexture,
wantErr: false,
},
{
name: "文件名为空",
fileName: "",
fileType: FileTypeAvatar,
wantErr: true,
errContains: "文件名不能为空",
},
{
name: "不支持的文件扩展名",
fileName: "file.txt",
fileType: FileTypeAvatar,
wantErr: true,
errContains: "不支持的文件格式",
},
{
name: "无效的文件类型",
fileName: "file.png",
fileType: FileType("invalid"),
wantErr: true,
errContains: "不支持的文件类型",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateFileName(tt.fileName, tt.fileType)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateFileName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.errContains) {
t.Errorf("ValidateFileName() error = %v, should contain %s", err, tt.errContains)
}
}
})
}
}
// TestValidateFileName_Extensions 测试各种扩展名
func TestValidateFileName_Extensions(t *testing.T) {
avatarExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
for _, ext := range avatarExts {
fileName := "test" + ext
err := ValidateFileName(fileName, FileTypeAvatar)
if err != nil {
t.Errorf("Avatar file with %s extension should be valid, got error: %v", ext, err)
}
}
// 材质只支持PNG
textureExts := []string{".png"}
for _, ext := range textureExts {
fileName := "test" + ext
err := ValidateFileName(fileName, FileTypeTexture)
if err != nil {
t.Errorf("Texture file with %s extension should be valid, got error: %v", ext, err)
}
}
// 测试不支持的扩展名
invalidExts := []string{".txt", ".pdf", ".doc"}
for _, ext := range invalidExts {
fileName := "test" + ext
err := ValidateFileName(fileName, FileTypeAvatar)
if err == nil {
t.Errorf("Avatar file with %s extension should be invalid", ext)
}
}
}
// TestValidateFileName_CaseInsensitive 测试扩展名大小写不敏感
func TestValidateFileName_CaseInsensitive(t *testing.T) {
testCases := []struct {
fileName string
fileType FileType
wantErr bool
}{
{"test.PNG", FileTypeAvatar, false},
{"test.JPG", FileTypeAvatar, false},
{"test.JPEG", FileTypeAvatar, false},
{"test.GIF", FileTypeAvatar, false},
{"test.WEBP", FileTypeAvatar, false},
{"test.PnG", FileTypeTexture, false},
}
for _, tc := range testCases {
t.Run(tc.fileName, func(t *testing.T) {
err := ValidateFileName(tc.fileName, tc.fileType)
if (err != nil) != tc.wantErr {
t.Errorf("ValidateFileName(%s, %s) error = %v, wantErr %v", tc.fileName, tc.fileType, err, tc.wantErr)
}
})
}
}
// TestUploadConfig_Structure 测试UploadConfig结构
func TestUploadConfig_Structure(t *testing.T) {
config := &UploadConfig{
AllowedExts: map[string]bool{
".png": true,
},
MinSize: 1024,
MaxSize: 5 * 1024 * 1024,
Expires: 15 * time.Minute,
}
if config.AllowedExts == nil {
t.Error("AllowedExts should not be nil")
}
if config.MinSize <= 0 {
t.Error("MinSize should be greater than 0")
}
if config.MaxSize <= config.MinSize {
t.Error("MaxSize should be greater than MinSize")
}
if config.Expires <= 0 {
t.Error("Expires should be greater than 0")
}
}

View File

@@ -0,0 +1,248 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"errors"
"strings"
"time"
)
// RegisterUser 用户注册
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
// 检查用户名是否已存在
existingUser, err := repository.FindUserByUsername(username)
if err != nil {
return nil, "", err
}
if existingUser != nil {
return nil, "", errors.New("用户名已存在")
}
// 检查邮箱是否已存在
existingEmail, err := repository.FindUserByEmail(email)
if err != nil {
return nil, "", err
}
if existingEmail != nil {
return nil, "", errors.New("邮箱已被注册")
}
// 加密密码
hashedPassword, err := auth.HashPassword(password)
if err != nil {
return nil, "", errors.New("密码加密失败")
}
// 确定头像URL优先使用用户提供的头像否则使用默认头像
avatarURL := avatar
if avatarURL == "" {
avatarURL = getDefaultAvatar()
}
// 创建用户
user := &model.User{
Username: username,
Password: hashedPassword,
Email: email,
Avatar: avatarURL,
Role: "user",
Status: 1,
Points: 0, // 初始积分可以从配置读取
}
if err := repository.CreateUser(user); err != nil {
return nil, "", err
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
// TODO: 添加注册奖励积分
return user, token, nil
}
// LoginUser 用户登录(支持用户名或邮箱登录)
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
// 查找用户:判断是用户名还是邮箱
var user *model.User
var err error
if strings.Contains(usernameOrEmail, "@") {
// 包含@符号,认为是邮箱
user, err = repository.FindUserByEmail(usernameOrEmail)
} else {
// 否则认为是用户名
user, err = repository.FindUserByUsername(usernameOrEmail)
}
if err != nil {
return nil, "", err
}
if user == nil {
// 记录失败日志
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 检查用户状态
if user.Status != 1 {
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
return nil, "", errors.New("账号已被禁用")
}
// 验证密码
if !auth.CheckPassword(user.Password, password) {
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
return nil, "", errors.New("用户名/邮箱或密码错误")
}
// 生成JWT Token
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
if err != nil {
return nil, "", errors.New("生成Token失败")
}
// 更新最后登录时间
now := time.Now()
user.LastLoginAt = &now
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
"last_login_at": now,
})
// 记录成功登录日志
logSuccessLogin(user.ID, ipAddress, userAgent)
return user, token, nil
}
// GetUserByID 根据ID获取用户
func GetUserByID(id int64) (*model.User, error) {
return repository.FindUserByID(id)
}
// UpdateUserInfo 更新用户信息
func UpdateUserInfo(user *model.User) error {
return repository.UpdateUser(user)
}
// UpdateUserAvatar 更新用户头像
func UpdateUserAvatar(userID int64, avatarURL string) error {
return repository.UpdateUserFields(userID, map[string]interface{}{
"avatar": avatarURL,
})
}
// ChangeUserPassword 修改密码
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
// 获取用户
user, err := repository.FindUserByID(userID)
if err != nil {
return errors.New("用户不存在")
}
// 验证旧密码
if !auth.CheckPassword(user.Password, oldPassword) {
return errors.New("原密码错误")
}
// 加密新密码
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
// 更新密码
return repository.UpdateUserFields(userID, map[string]interface{}{
"password": hashedPassword,
})
}
// ResetUserPassword 重置密码(通过邮箱)
func ResetUserPassword(email, newPassword string) error {
// 查找用户
user, err := repository.FindUserByEmail(email)
if err != nil {
return errors.New("用户不存在")
}
// 加密新密码
hashedPassword, err := auth.HashPassword(newPassword)
if err != nil {
return errors.New("密码加密失败")
}
// 更新密码
return repository.UpdateUserFields(user.ID, map[string]interface{}{
"password": hashedPassword,
})
}
// ChangeUserEmail 更换邮箱
func ChangeUserEmail(userID int64, newEmail string) error {
// 检查新邮箱是否已被使用
existingUser, err := repository.FindUserByEmail(newEmail)
if err != nil {
return err
}
if existingUser != nil && existingUser.ID != userID {
return errors.New("邮箱已被其他用户使用")
}
// 更新邮箱
return repository.UpdateUserFields(userID, map[string]interface{}{
"email": newEmail,
})
}
// logSuccessLogin 记录成功登录
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: true,
}
_ = repository.CreateLoginLog(log)
}
// logFailedLogin 记录失败登录
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
log := &model.UserLoginLog{
UserID: userID,
IPAddress: ipAddress,
UserAgent: userAgent,
LoginMethod: "PASSWORD",
IsSuccess: false,
FailureReason: reason,
}
_ = repository.CreateLoginLog(log)
}
// getDefaultAvatar 获取默认头像URL
func getDefaultAvatar() string {
// 如果数据库中不存在默认头像配置,返回错误信息
const log = "数据库中不存在默认头像配置"
// 尝试从数据库读取配置
config, err := repository.GetSystemConfigByKey("default_avatar")
if err != nil || config == nil {
return log
}
return config.Value
}
func GetUserByEmail(email string) (*model.User, error) {
user, err := repository.FindUserByEmail(email)
if err != nil {
return nil, errors.New("邮箱查找失败")
}
return user, nil
}

View File

@@ -0,0 +1,199 @@
package service
import (
"strings"
"testing"
)
// TestGetDefaultAvatar 测试获取默认头像的逻辑
// 注意这个测试需要mock repository但由于repository是函数式的
// 我们只测试逻辑部分
func TestGetDefaultAvatar_Logic(t *testing.T) {
tests := []struct {
name string
configExists bool
configValue string
expectedResult string
}{
{
name: "配置存在时返回配置值",
configExists: true,
configValue: "https://example.com/avatar.png",
expectedResult: "https://example.com/avatar.png",
},
{
name: "配置不存在时返回错误信息",
configExists: false,
configValue: "",
expectedResult: "数据库中不存在默认头像配置",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 这个测试只验证逻辑不实际调用repository
// 实际的repository调用测试需要集成测试或mock
if tt.configExists {
if tt.expectedResult != tt.configValue {
t.Errorf("当配置存在时,应该返回配置值")
}
} else {
if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") {
t.Errorf("当配置不存在时,应该返回错误信息")
}
}
})
}
}
// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑
func TestLoginUser_EmailDetection(t *testing.T) {
tests := []struct {
name string
usernameOrEmail string
isEmail bool
}{
{
name: "包含@符号,识别为邮箱",
usernameOrEmail: "user@example.com",
isEmail: true,
},
{
name: "不包含@符号,识别为用户名",
usernameOrEmail: "username",
isEmail: false,
},
{
name: "空字符串",
usernameOrEmail: "",
isEmail: false,
},
{
name: "只有@符号",
usernameOrEmail: "@",
isEmail: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isEmail := strings.Contains(tt.usernameOrEmail, "@")
if isEmail != tt.isEmail {
t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail)
}
})
}
}
// TestUserService_Constants 测试用户服务相关常量
func TestUserService_Constants(t *testing.T) {
// 测试默认用户角色
defaultRole := "user"
if defaultRole == "" {
t.Error("默认用户角色不能为空")
}
// 测试默认用户状态
defaultStatus := int16(1)
if defaultStatus != 1 {
t.Errorf("默认用户状态应为1正常实际为%d", defaultStatus)
}
// 测试初始积分
initialPoints := 0
if initialPoints < 0 {
t.Errorf("初始积分不应为负数,实际为%d", initialPoints)
}
}
// TestUserService_Validation 测试用户数据验证逻辑
func TestUserService_Validation(t *testing.T) {
tests := []struct {
name string
username string
email string
password string
wantValid bool
}{
{
name: "有效的用户名和邮箱",
username: "testuser",
email: "test@example.com",
password: "password123",
wantValid: true,
},
{
name: "用户名为空",
username: "",
email: "test@example.com",
password: "password123",
wantValid: false,
},
{
name: "邮箱为空",
username: "testuser",
email: "",
password: "password123",
wantValid: false,
},
{
name: "密码为空",
username: "testuser",
email: "test@example.com",
password: "",
wantValid: false,
},
{
name: "邮箱格式无效(缺少@",
username: "testuser",
email: "invalid-email",
password: "password123",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 简单的验证逻辑测试
isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@")
if isValid != tt.wantValid {
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestUserService_AvatarLogic 测试头像逻辑
func TestUserService_AvatarLogic(t *testing.T) {
tests := []struct {
name string
providedAvatar string
defaultAvatar string
expectedAvatar string
}{
{
name: "提供头像时使用提供的头像",
providedAvatar: "https://example.com/custom.png",
defaultAvatar: "https://example.com/default.png",
expectedAvatar: "https://example.com/custom.png",
},
{
name: "未提供头像时使用默认头像",
providedAvatar: "",
defaultAvatar: "https://example.com/default.png",
expectedAvatar: "https://example.com/default.png",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
avatarURL := tt.providedAvatar
if avatarURL == "" {
avatarURL = tt.defaultAvatar
}
if avatarURL != tt.expectedAvatar {
t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar)
}
})
}
}

View File

@@ -0,0 +1,118 @@
package service
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"
"carrotskin/pkg/email"
"carrotskin/pkg/redis"
)
const (
// 验证码类型
VerificationTypeRegister = "register"
VerificationTypeResetPassword = "reset_password"
VerificationTypeChangeEmail = "change_email"
// 验证码配置
CodeLength = 6 // 验证码长度
CodeExpiration = 10 * time.Minute // 验证码有效期
CodeRateLimit = 1 * time.Minute // 发送频率限制
)
// GenerateVerificationCode 生成6位数字验证码
func GenerateVerificationCode() (string, error) {
const digits = "0123456789"
code := make([]byte, CodeLength)
for i := range code {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
if err != nil {
return "", err
}
code[i] = digits[num.Int64()]
}
return string(code), nil
}
// SendVerificationCode 发送验证码
func SendVerificationCode(ctx context.Context, redisClient *redis.Client, emailService *email.Service, email, codeType string) error {
// 检查发送频率限制
rateLimitKey := fmt.Sprintf("verification:rate_limit:%s:%s", codeType, email)
exists, err := redisClient.Exists(ctx, rateLimitKey)
if err != nil {
return fmt.Errorf("检查发送频率失败: %w", err)
}
if exists > 0 {
return fmt.Errorf("发送过于频繁,请稍后再试")
}
// 生成验证码
code, err := GenerateVerificationCode()
if err != nil {
return fmt.Errorf("生成验证码失败: %w", err)
}
// 存储验证码到Redis
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
if err := redisClient.Set(ctx, codeKey, code, CodeExpiration); err != nil {
return fmt.Errorf("存储验证码失败: %w", err)
}
// 设置发送频率限制
if err := redisClient.Set(ctx, rateLimitKey, "1", CodeRateLimit); err != nil {
return fmt.Errorf("设置发送频率限制失败: %w", err)
}
// 发送邮件
if err := sendVerificationEmail(emailService, email, code, codeType); err != nil {
// 发送失败,删除验证码
_ = redisClient.Del(ctx, codeKey)
return fmt.Errorf("发送邮件失败: %w", err)
}
return nil
}
// VerifyCode 验证验证码
func VerifyCode(ctx context.Context, redisClient *redis.Client, email, code, codeType string) error {
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
// 从Redis获取验证码
storedCode, err := redisClient.Get(ctx, codeKey)
if err != nil {
return fmt.Errorf("验证码已过期或不存在")
}
// 验证验证码
if storedCode != code {
return fmt.Errorf("验证码错误")
}
// 验证成功,删除验证码
_ = redisClient.Del(ctx, codeKey)
return nil
}
// DeleteVerificationCode 删除验证码
func DeleteVerificationCode(ctx context.Context, redisClient *redis.Client, email, codeType string) error {
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
return redisClient.Del(ctx, codeKey)
}
// sendVerificationEmail 根据类型发送邮件
func sendVerificationEmail(emailService *email.Service, to, code, codeType string) error {
switch codeType {
case VerificationTypeRegister:
return emailService.SendEmailVerification(to, code)
case VerificationTypeResetPassword:
return emailService.SendResetPassword(to, code)
case VerificationTypeChangeEmail:
return emailService.SendChangeEmail(to, code)
default:
return emailService.SendVerificationCode(to, code, codeType)
}
}

View File

@@ -0,0 +1,119 @@
package service
import (
"testing"
"time"
)
// TestGenerateVerificationCode 测试生成验证码函数
func TestGenerateVerificationCode(t *testing.T) {
tests := []struct {
name string
wantLen int
wantErr bool
}{
{
name: "生成6位验证码",
wantLen: CodeLength,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
code, err := GenerateVerificationCode()
if (err != nil) != tt.wantErr {
t.Errorf("GenerateVerificationCode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && len(code) != tt.wantLen {
t.Errorf("GenerateVerificationCode() code length = %v, want %v", len(code), tt.wantLen)
}
// 验证验证码只包含数字
for _, c := range code {
if c < '0' || c > '9' {
t.Errorf("GenerateVerificationCode() code contains non-digit: %c", c)
}
}
})
}
// 测试多次生成,验证码应该不同(概率上)
codes := make(map[string]bool)
for i := 0; i < 100; i++ {
code, err := GenerateVerificationCode()
if err != nil {
t.Fatalf("GenerateVerificationCode() failed: %v", err)
}
if codes[code] {
t.Logf("发现重复验证码这是正常的因为只有6位数字: %s", code)
}
codes[code] = true
}
}
// TestVerificationConstants 测试验证码相关常量
func TestVerificationConstants(t *testing.T) {
if CodeLength != 6 {
t.Errorf("CodeLength = %d, want 6", CodeLength)
}
if CodeExpiration != 10*time.Minute {
t.Errorf("CodeExpiration = %v, want 10 minutes", CodeExpiration)
}
if CodeRateLimit != 1*time.Minute {
t.Errorf("CodeRateLimit = %v, want 1 minute", CodeRateLimit)
}
// 验证验证码类型常量
types := []string{
VerificationTypeRegister,
VerificationTypeResetPassword,
VerificationTypeChangeEmail,
}
for _, vType := range types {
if vType == "" {
t.Error("验证码类型不能为空")
}
}
}
// TestVerificationCodeFormat 测试验证码格式
func TestVerificationCodeFormat(t *testing.T) {
code, err := GenerateVerificationCode()
if err != nil {
t.Fatalf("GenerateVerificationCode() failed: %v", err)
}
// 验证长度
if len(code) != 6 {
t.Errorf("验证码长度应为6位实际为%d位", len(code))
}
// 验证只包含数字
for i, c := range code {
if c < '0' || c > '9' {
t.Errorf("验证码第%d位包含非数字字符: %c", i+1, c)
}
}
}
// TestVerificationTypes 测试验证码类型
func TestVerificationTypes(t *testing.T) {
validTypes := map[string]bool{
VerificationTypeRegister: true,
VerificationTypeResetPassword: true,
VerificationTypeChangeEmail: true,
}
for vType, isValid := range validTypes {
if !isValid {
t.Errorf("验证码类型 %s 应该是有效的", vType)
}
if vType == "" {
t.Error("验证码类型不能为空字符串")
}
}
}

View File

@@ -0,0 +1,201 @@
package service
import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/redis"
"carrotskin/pkg/utils"
"context"
"errors"
"fmt"
"go.uber.org/zap"
"net"
"strings"
"time"
"gorm.io/gorm"
)
// SessionKeyPrefix Redis会话键前缀
const SessionKeyPrefix = "Join_"
// SessionTTL 会话超时时间 - 增加到15分钟
const SessionTTL = 15 * time.Minute
type SessionData struct {
AccessToken string `json:"accessToken"`
UserName string `json:"userName"`
SelectedProfile string `json:"selectedProfile"`
IP string `json:"ip"`
}
// GetUserIDByEmail 根据邮箱返回用户id
func GetUserIDByEmail(db *gorm.DB, Identifier string) (int64, error) {
user, err := repository.FindUserByEmail(Identifier)
if err != nil {
return 0, errors.New("用户不存在")
}
return user.ID, nil
}
// GetProfileByProfileName 根据用户名返回用户id
func GetProfileByProfileName(db *gorm.DB, Identifier string) (*model.Profile, error) {
profile, err := repository.FindProfileByName(Identifier)
if err != nil {
return nil, errors.New("用户角色未创建")
}
return profile, nil
}
// VerifyPassword 验证密码是否一致
func VerifyPassword(db *gorm.DB, password string, Id int64) error {
passwordStore, err := repository.GetYggdrasilPasswordById(Id)
if err != nil {
return errors.New("未生成密码")
}
if passwordStore != password {
return errors.New("密码错误")
}
return nil
}
func GetProfileByUserId(db *gorm.DB, userId int64) (*model.Profile, error) {
profiles, err := repository.FindProfilesByUserID(userId)
if err != nil {
return nil, errors.New("角色查找失败")
}
if len(profiles) == 0 {
return nil, errors.New("角色查找失败")
}
return profiles[0], nil
}
func GetPasswordByUserId(db *gorm.DB, userId int64) (string, error) {
passwordStore, err := repository.GetYggdrasilPasswordById(userId)
if err != nil {
return "", errors.New("yggdrasil密码查找失败")
}
return passwordStore, nil
}
// JoinServer 记录玩家加入服务器的会话信息
func JoinServer(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, serverId, accessToken, selectedProfile, ip string) error {
// 输入验证
if serverId == "" || accessToken == "" || selectedProfile == "" {
return errors.New("参数不能为空")
}
// 验证serverId格式防止注入攻击
if len(serverId) > 100 || strings.ContainsAny(serverId, "<>\"'&") {
return errors.New("服务器ID格式无效")
}
// 验证IP格式
if ip != "" {
if net.ParseIP(ip) == nil {
return errors.New("IP地址格式无效")
}
}
// 获取和验证Token
token, err := repository.GetTokenByAccessToken(accessToken)
if err != nil {
logger.Error(
"验证Token失败",
zap.Error(err),
zap.String("accessToken", accessToken),
)
return fmt.Errorf("验证Token失败: %w", err)
}
// 格式化UUID并验证与Token关联的配置文件
formattedProfile := utils.FormatUUID(selectedProfile)
if token.ProfileId != formattedProfile {
return errors.New("selectedProfile与Token不匹配")
}
profile, err := repository.FindProfileByUUID(formattedProfile)
if err != nil {
logger.Error(
"获取Profile失败",
zap.Error(err),
zap.String("uuid", formattedProfile),
)
return fmt.Errorf("获取Profile失败: %w", err)
}
// 创建会话数据
data := SessionData{
AccessToken: accessToken,
UserName: profile.Name,
SelectedProfile: formattedProfile,
IP: ip,
}
// 序列化会话数据
marshaledData, err := json.Marshal(data)
if err != nil {
logger.Error(
"[ERROR]序列化会话数据失败",
zap.Error(err),
)
return fmt.Errorf("序列化会话数据失败: %w", err)
}
// 存储会话数据到Redis
sessionKey := SessionKeyPrefix + serverId
ctx := context.Background()
if err = redisClient.Set(ctx, sessionKey, marshaledData, SessionTTL); err != nil {
logger.Error(
"保存会话数据失败",
zap.Error(err),
zap.String("serverId", serverId),
)
return fmt.Errorf("保存会话数据失败: %w", err)
}
logger.Info(
"玩家成功加入服务器",
zap.String("username", profile.Name),
zap.String("serverId", serverId),
)
return nil
}
// HasJoinedServer 验证玩家是否已经加入了服务器
func HasJoinedServer(logger *zap.Logger, redisClient *redis.Client, serverId, username, ip string) error {
if serverId == "" || username == "" {
return errors.New("服务器ID和用户名不能为空")
}
// 设置超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// 从Redis获取会话数据
sessionKey := SessionKeyPrefix + serverId
data, err := redisClient.GetBytes(ctx, sessionKey)
if err != nil {
logger.Error("[ERROR] 获取会话数据失败:", zap.Error(err), zap.Any("serverId:", serverId))
return fmt.Errorf("获取会话数据失败: %w", err)
}
// 反序列化会话数据
var sessionData SessionData
if err = json.Unmarshal(data, &sessionData); err != nil {
logger.Error("[ERROR] 解析会话数据失败: ", zap.Error(err))
return fmt.Errorf("解析会话数据失败: %w", err)
}
// 验证用户名
if sessionData.UserName != username {
return errors.New("用户名不匹配")
}
// 验证IP(如果提供)
if ip != "" && sessionData.IP != ip {
return errors.New("IP地址不匹配")
}
return nil
}

View File

@@ -0,0 +1,350 @@
package service
import (
"net"
"strings"
"testing"
"time"
)
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
func TestYggdrasilService_Constants(t *testing.T) {
if SessionKeyPrefix != "Join_" {
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
}
if SessionTTL != 15*time.Minute {
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
}
}
// TestSessionData_Structure 测试SessionData结构
func TestSessionData_Structure(t *testing.T) {
data := SessionData{
AccessToken: "test-token",
UserName: "TestUser",
SelectedProfile: "test-profile-uuid",
IP: "127.0.0.1",
}
if data.AccessToken == "" {
t.Error("AccessToken should not be empty")
}
if data.UserName == "" {
t.Error("UserName should not be empty")
}
if data.SelectedProfile == "" {
t.Error("SelectedProfile should not be empty")
}
}
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
func TestJoinServer_InputValidation(t *testing.T) {
tests := []struct {
name string
serverId string
accessToken string
selectedProfile string
wantErr bool
errContains string
}{
{
name: "所有参数有效",
serverId: "test-server-123",
accessToken: "test-token",
selectedProfile: "test-profile",
wantErr: false,
},
{
name: "serverId为空",
serverId: "",
accessToken: "test-token",
selectedProfile: "test-profile",
wantErr: true,
errContains: "参数不能为空",
},
{
name: "accessToken为空",
serverId: "test-server",
accessToken: "",
selectedProfile: "test-profile",
wantErr: true,
errContains: "参数不能为空",
},
{
name: "selectedProfile为空",
serverId: "test-server",
accessToken: "test-token",
selectedProfile: "",
wantErr: true,
errContains: "参数不能为空",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
if hasError != tt.wantErr {
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
func TestJoinServer_ServerIDValidation(t *testing.T) {
tests := []struct {
name string
serverId string
wantValid bool
}{
{
name: "有效的serverId",
serverId: "test-server-123",
wantValid: true,
},
{
name: "serverId过长",
serverId: strings.Repeat("a", 101),
wantValid: false,
},
{
name: "serverId包含危险字符<",
serverId: "test<server",
wantValid: false,
},
{
name: "serverId包含危险字符>",
serverId: "test>server",
wantValid: false,
},
{
name: "serverId包含危险字符\"",
serverId: "test\"server",
wantValid: false,
},
{
name: "serverId包含危险字符'",
serverId: "test'server",
wantValid: false,
},
{
name: "serverId包含危险字符&",
serverId: "test&server",
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
if isValid != tt.wantValid {
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
}
})
}
}
// TestJoinServer_IPValidation 测试IP地址验证逻辑
func TestJoinServer_IPValidation(t *testing.T) {
tests := []struct {
name string
ip string
wantValid bool
}{
{
name: "有效的IPv4地址",
ip: "127.0.0.1",
wantValid: true,
},
{
name: "有效的IPv6地址",
ip: "::1",
wantValid: true,
},
{
name: "无效的IP地址",
ip: "invalid-ip",
wantValid: false,
},
{
name: "空IP地址可选",
ip: "",
wantValid: true, // 空IP是允许的
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var isValid bool
if tt.ip == "" {
isValid = true // 空IP是允许的
} else {
isValid = net.ParseIP(tt.ip) != nil
}
if isValid != tt.wantValid {
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
}
})
}
}
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
func TestHasJoinedServer_InputValidation(t *testing.T) {
tests := []struct {
name string
serverId string
username string
wantErr bool
}{
{
name: "所有参数有效",
serverId: "test-server",
username: "TestUser",
wantErr: false,
},
{
name: "serverId为空",
serverId: "",
username: "TestUser",
wantErr: true,
},
{
name: "username为空",
serverId: "test-server",
username: "",
wantErr: true,
},
{
name: "两者都为空",
serverId: "",
username: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hasError := tt.serverId == "" || tt.username == ""
if hasError != tt.wantErr {
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
}
})
}
}
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
tests := []struct {
name string
sessionUser string
requestUser string
wantMatch bool
}{
{
name: "用户名匹配",
sessionUser: "TestUser",
requestUser: "TestUser",
wantMatch: true,
},
{
name: "用户名不匹配",
sessionUser: "TestUser",
requestUser: "OtherUser",
wantMatch: false,
},
{
name: "大小写敏感",
sessionUser: "TestUser",
requestUser: "testuser",
wantMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := tt.sessionUser == tt.requestUser
if matches != tt.wantMatch {
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
}
})
}
}
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
func TestHasJoinedServer_IPMatching(t *testing.T) {
tests := []struct {
name string
sessionIP string
requestIP string
wantMatch bool
shouldCheck bool
}{
{
name: "IP匹配",
sessionIP: "127.0.0.1",
requestIP: "127.0.0.1",
wantMatch: true,
shouldCheck: true,
},
{
name: "IP不匹配",
sessionIP: "127.0.0.1",
requestIP: "192.168.1.1",
wantMatch: false,
shouldCheck: true,
},
{
name: "请求IP为空时不检查",
sessionIP: "127.0.0.1",
requestIP: "",
wantMatch: true,
shouldCheck: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var matches bool
if tt.requestIP == "" {
matches = true // 空IP不检查
} else {
matches = tt.sessionIP == tt.requestIP
}
if matches != tt.wantMatch {
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
}
})
}
}
// TestJoinServer_SessionKey 测试会话键生成
func TestJoinServer_SessionKey(t *testing.T) {
tests := []struct {
name string
serverId string
expected string
}{
{
name: "生成正确的会话键",
serverId: "test-server-123",
expected: "Join_test-server-123",
},
{
name: "空serverId",
serverId: "",
expected: "Join_",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sessionKey := SessionKeyPrefix + tt.serverId
if sessionKey != tt.expected {
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
}
})
}
}