chore: 初始化仓库,排除二进制文件和覆盖率文件
This commit is contained in:
165
internal/service/captcha_service.go
Normal file
165
internal/service/captcha_service.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wenlng/go-captcha-assets/resources/imagesv2"
|
||||
"github.com/wenlng/go-captcha-assets/resources/tiles"
|
||||
"github.com/wenlng/go-captcha/v2/slide"
|
||||
)
|
||||
|
||||
var (
|
||||
slideTileCapt slide.Captcha
|
||||
cfg *config.Config
|
||||
)
|
||||
|
||||
// 常量定义(业务相关配置,与Redis连接配置分离)
|
||||
const (
|
||||
redisKeyPrefix = "captcha:" // Redis键前缀(便于区分业务)
|
||||
paddingValue = 3 // 验证允许的误差像素(±3px)
|
||||
)
|
||||
|
||||
// Init 验证码图初始化
|
||||
func init() {
|
||||
cfg, _ = config.Load()
|
||||
// 从默认仓库中获取主图
|
||||
builder := slide.NewBuilder()
|
||||
bgImage, err := imagesv2.GetImages()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
// 滑块形状获取
|
||||
graphs := getSlideTileGraphArr()
|
||||
|
||||
builder.SetResources(
|
||||
slide.WithGraphImages(graphs),
|
||||
slide.WithBackgrounds(bgImage),
|
||||
)
|
||||
slideTileCapt = builder.Make()
|
||||
if slideTileCapt == nil {
|
||||
log.Fatalln("验证码实例初始化失败")
|
||||
}
|
||||
}
|
||||
|
||||
// getSlideTileGraphArr 滑块选择
|
||||
func getSlideTileGraphArr() []*slide.GraphImage {
|
||||
graphs, err := tiles.GetTiles()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var newGraphs = make([]*slide.GraphImage, 0, len(graphs))
|
||||
for i := 0; i < len(graphs); i++ {
|
||||
graph := graphs[i]
|
||||
newGraphs = append(newGraphs, &slide.GraphImage{
|
||||
OverlayImage: graph.OverlayImage,
|
||||
MaskImage: graph.MaskImage,
|
||||
ShadowImage: graph.ShadowImage,
|
||||
})
|
||||
}
|
||||
return newGraphs
|
||||
}
|
||||
|
||||
// RedisData 存储到Redis的验证信息(仅包含校验必需字段)
|
||||
type RedisData struct {
|
||||
Tx int `json:"tx"` // 滑块目标X坐标
|
||||
Ty int `json:"ty"` // 滑块目标Y坐标
|
||||
}
|
||||
|
||||
// GenerateCaptchaData 提取生成验证码的相关信息
|
||||
func GenerateCaptchaData(ctx context.Context, redisClient *redis.Client) (string, string, string, int, error) {
|
||||
// 生成uuid作为验证码进程唯一标识
|
||||
captchaID := uuid.NewString()
|
||||
if captchaID == "" {
|
||||
return "", "", "", 0, errors.New("生成验证码唯一标识失败")
|
||||
}
|
||||
|
||||
captData, err := slideTileCapt.Generate()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
blockData := captData.GetData()
|
||||
if blockData == nil {
|
||||
return "", "", "", 0, errors.New("获取验证码数据失败")
|
||||
}
|
||||
block, _ := json.Marshal(blockData)
|
||||
var blockMap map[string]interface{}
|
||||
|
||||
if err := json.Unmarshal(block, &blockMap); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("反序列化为map失败: %w", err)
|
||||
}
|
||||
// 提取x和y并转换为int类型
|
||||
tx, ok := blockMap["x"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将x转换为float64")
|
||||
}
|
||||
var x = int(tx)
|
||||
ty, ok := blockMap["y"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将y转换为float64")
|
||||
}
|
||||
var y = int(ty)
|
||||
var mBase64, tBase64 string
|
||||
mBase64, err = captData.GetMasterImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("主图转换为base64失败: %w", err)
|
||||
}
|
||||
tBase64, err = captData.GetTileImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("滑块图转换为base64失败: %w", err)
|
||||
}
|
||||
redisData := RedisData{
|
||||
Tx: x,
|
||||
Ty: y,
|
||||
}
|
||||
redisDataJSON, _ := json.Marshal(redisData)
|
||||
redisKey := redisKeyPrefix + captchaID
|
||||
expireTime := 300 * time.Second
|
||||
|
||||
// 使用注入的Redis客户端
|
||||
if err := redisClient.Set(
|
||||
ctx,
|
||||
redisKey,
|
||||
redisDataJSON,
|
||||
expireTime,
|
||||
); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("存储验证码到Redis失败: %w", err)
|
||||
}
|
||||
return mBase64, tBase64, captchaID, y - 10, nil
|
||||
}
|
||||
|
||||
// VerifyCaptchaData 验证用户验证码
|
||||
func VerifyCaptchaData(ctx context.Context, redisClient *redis.Client, dx int, id string) (bool, error) {
|
||||
redisKey := redisKeyPrefix + id
|
||||
|
||||
// 从Redis获取验证信息,使用注入的客户端
|
||||
dataJSON, err := redisClient.Get(ctx, redisKey)
|
||||
if err != nil {
|
||||
if redisClient.Nil(err) { // 使用封装客户端的Nil错误
|
||||
return false, errors.New("验证码已过期或无效")
|
||||
}
|
||||
return false, fmt.Errorf("Redis查询失败: %w", err)
|
||||
}
|
||||
var redisData RedisData
|
||||
if err := json.Unmarshal([]byte(dataJSON), &redisData); err != nil {
|
||||
return false, fmt.Errorf("解析Redis数据失败: %w", err)
|
||||
}
|
||||
tx := redisData.Tx
|
||||
ty := redisData.Ty
|
||||
ok := slide.Validate(dx, ty, tx, ty, paddingValue)
|
||||
|
||||
// 验证后立即删除Redis记录(防止重复使用)
|
||||
if ok {
|
||||
if err := redisClient.Del(ctx, redisKey); err != nil {
|
||||
// 记录警告但不影响验证结果
|
||||
log.Printf("删除验证码Redis记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
174
internal/service/captcha_service_test.go
Normal file
174
internal/service/captcha_service_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCaptchaService_Constants 测试验证码服务常量
|
||||
func TestCaptchaService_Constants(t *testing.T) {
|
||||
if redisKeyPrefix != "captcha:" {
|
||||
t.Errorf("redisKeyPrefix = %s, want 'captcha:'", redisKeyPrefix)
|
||||
}
|
||||
|
||||
if paddingValue != 3 {
|
||||
t.Errorf("paddingValue = %d, want 3", paddingValue)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisData_Structure 测试RedisData结构
|
||||
func TestRedisData_Structure(t *testing.T) {
|
||||
data := RedisData{
|
||||
Tx: 100,
|
||||
Ty: 200,
|
||||
}
|
||||
|
||||
if data.Tx != 100 {
|
||||
t.Errorf("RedisData.Tx = %d, want 100", data.Tx)
|
||||
}
|
||||
|
||||
if data.Ty != 200 {
|
||||
t.Errorf("RedisData.Ty = %d, want 200", data.Ty)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_Logic 测试生成验证码的逻辑部分
|
||||
func TestGenerateCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
captchaID string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的captchaID",
|
||||
captchaID: "test-uuid-123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空的captchaID应该失败",
|
||||
captchaID: "",
|
||||
wantErr: true,
|
||||
errContains: "生成验证码唯一标识失败",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试UUID验证逻辑
|
||||
if tt.captchaID == "" {
|
||||
if !tt.wantErr {
|
||||
t.Error("空captchaID应该返回错误")
|
||||
}
|
||||
} else {
|
||||
if tt.wantErr {
|
||||
t.Error("非空captchaID不应该返回错误")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_Logic 测试验证验证码的逻辑部分
|
||||
func TestVerifyCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dx int
|
||||
tx int
|
||||
ty int
|
||||
padding int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "精确匹配",
|
||||
dx: 100,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(+3)",
|
||||
dx: 103,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(-3)",
|
||||
dx: 97,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(+4)",
|
||||
dx: 104,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(-4)",
|
||||
dx: 96,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:dx应该在[tx-padding, tx+padding]范围内
|
||||
diff := tt.dx - tt.tx
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
isValid := diff <= tt.padding
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v (dx=%d, tx=%d, padding=%d)", isValid, tt.wantValid, tt.dx, tt.tx, tt.padding)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_RedisKey 测试Redis键生成逻辑
|
||||
func TestVerifyCaptchaData_RedisKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的Redis键",
|
||||
id: "test-id-123",
|
||||
expected: "captcha:test-id-123",
|
||||
},
|
||||
{
|
||||
name: "空ID",
|
||||
id: "",
|
||||
expected: "captcha:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redisKey := redisKeyPrefix + tt.id
|
||||
if redisKey != tt.expected {
|
||||
t.Errorf("Redis key = %s, want %s", redisKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_ExpireTime 测试过期时间
|
||||
func TestGenerateCaptchaData_ExpireTime(t *testing.T) {
|
||||
expectedExpireTime := 300 * time.Second
|
||||
if expectedExpireTime != 5*time.Minute {
|
||||
t.Errorf("Expire time should be 5 minutes")
|
||||
}
|
||||
}
|
||||
13
internal/service/common.go
Normal file
13
internal/service/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
)
|
||||
|
||||
// 统一的json变量,用于整个service包
|
||||
var json = jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
||||
// DefaultTimeout 默认超时时间
|
||||
const DefaultTimeout = 5 * time.Second
|
||||
48
internal/service/common_test.go
Normal file
48
internal/service/common_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCommon_Constants 测试common包的常量
|
||||
func TestCommon_Constants(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommon_JSON 测试JSON变量
|
||||
func TestCommon_JSON(t *testing.T) {
|
||||
// 验证json变量不为nil
|
||||
if json == nil {
|
||||
t.Error("json 变量不应为nil")
|
||||
}
|
||||
|
||||
// 测试JSON序列化
|
||||
testData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"age": 25,
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if len(bytes) == 0 {
|
||||
t.Error("json.Marshal() 返回的字节不应为空")
|
||||
}
|
||||
|
||||
// 测试JSON反序列化
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(bytes, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("反序列化结果 name = %v, want 'test'", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
252
internal/service/profile_service.go
Normal file
252
internal/service/profile_service.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) {
|
||||
// 1. 验证用户存在
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("用户不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
|
||||
if user.Status != 1 {
|
||||
return nil, fmt.Errorf("用户状态异常")
|
||||
}
|
||||
|
||||
// 2. 检查角色名是否已存在
|
||||
existingName, err := repository.FindProfileByName(name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
|
||||
// 3. 生成UUID
|
||||
profileUUID := uuid.New().String()
|
||||
|
||||
// 4. 生成RSA密钥对
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA密钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 创建档案
|
||||
profile := &model.Profile{
|
||||
UUID: profileUUID,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
RSAPrivateKey: privateKey,
|
||||
IsActive: true, // 新创建的档案默认为活跃状态
|
||||
}
|
||||
|
||||
if err := repository.CreateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("创建档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 将用户的其他档案设置为非活跃
|
||||
if err := repository.SetActiveProfile(profileUUID, userID); err != nil {
|
||||
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetProfileByUUID 获取档案详情
|
||||
func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetUserProfiles 获取用户的所有档案
|
||||
func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return nil, fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 检查角色名是否重复
|
||||
if name != nil && *name != profile.Name {
|
||||
existingName, err := repository.FindProfileByName(*name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
profile.Name = *name
|
||||
}
|
||||
|
||||
// 4. 更新皮肤和披风
|
||||
if skinID != nil {
|
||||
profile.SkinID = skinID
|
||||
}
|
||||
if capeID != nil {
|
||||
profile.CapeID = capeID
|
||||
}
|
||||
|
||||
// 5. 保存更新
|
||||
if err := repository.UpdateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("更新档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 重新加载关联数据
|
||||
return repository.FindProfileByUUID(uuid)
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 删除档案
|
||||
if err := repository.DeleteProfile(uuid); err != nil {
|
||||
return fmt.Errorf("删除档案失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置活跃档案
|
||||
func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 设置活跃状态
|
||||
if err := repository.SetActiveProfile(uuid, userID); err != nil {
|
||||
return fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 更新最后使用时间
|
||||
if err := repository.UpdateProfileLastUsedAt(uuid); err != nil {
|
||||
return fmt.Errorf("更新使用时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckProfileLimit 检查用户档案数量限制
|
||||
func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error {
|
||||
count, err := repository.CountProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询档案数量失败: %w", err)
|
||||
}
|
||||
|
||||
if int(count) >= maxProfiles {
|
||||
return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRSAPrivateKey 生成RSA-2048私钥(PEM格式)
|
||||
func generateRSAPrivateKey() (string, error) {
|
||||
// 生成2048位RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将私钥编码为PEM格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
|
||||
return string(privateKeyPEM), nil
|
||||
}
|
||||
|
||||
func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) {
|
||||
if userId == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userId, nil
|
||||
}
|
||||
|
||||
func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) {
|
||||
profiles, err := repository.GetProfilesByNames(names)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// GetProfileKeyPair 从 PostgreSQL 获取密钥对(GORM 实现,无手动 SQL)
|
||||
func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) {
|
||||
keyPair, err := repository.GetProfileKeyPair(profileId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
406
internal/service/profile_service_test.go
Normal file
406
internal/service/profile_service_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileService_Validation 测试Profile服务验证逻辑
|
||||
func TestProfileService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID和角色名",
|
||||
userID: 1,
|
||||
profileName: "TestProfile",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0时无效",
|
||||
userID: 0,
|
||||
profileName: "TestProfile",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "角色名为空时无效",
|
||||
userID: 1,
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0 && tt.profileName != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_StatusValidation 测试用户状态验证
|
||||
func TestProfileService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为0(禁用)时无效",
|
||||
status: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.status == 1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_IsActiveDefault 测试Profile默认活跃状态
|
||||
func TestProfileService_IsActiveDefault(t *testing.T) {
|
||||
// 新创建的档案默认为活跃状态
|
||||
isActive := true
|
||||
if !isActive {
|
||||
t.Error("新创建的Profile应该默认为活跃状态")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_PermissionCheck 测试更新Profile的权限检查逻辑
|
||||
func TestUpdateProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_NameValidation 测试更新Profile时名称验证逻辑
|
||||
func TestUpdateProfile_NameValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentName string
|
||||
newName *string
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "名称未改变,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("TestProfile"),
|
||||
shouldCheck: false,
|
||||
},
|
||||
{
|
||||
name: "名称改变,需要检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("NewProfile"),
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "名称为nil,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: nil,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCheck := tt.newName != nil && *tt.newName != tt.currentName
|
||||
if shouldCheck != tt.shouldCheck {
|
||||
t.Errorf("Name validation check failed: got %v, want %v", shouldCheck, tt.shouldCheck)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteProfile_PermissionCheck 测试删除Profile的权限检查
|
||||
func TestDeleteProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetActiveProfile_PermissionCheck 测试设置活跃Profile的权限检查
|
||||
func TestSetActiveProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckProfileLimit_Logic 测试Profile数量限制检查逻辑
|
||||
func TestCheckProfileLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int
|
||||
maxProfiles int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxProfiles: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= tt.maxProfiles
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_InputValidation 测试ValidateProfileByUserID输入验证
|
||||
func TestValidateProfileByUserID_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
uuid string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效输入",
|
||||
userID: 1,
|
||||
uuid: "test-uuid",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "userID为0",
|
||||
userID: 0,
|
||||
uuid: "test-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "uuid为空",
|
||||
userID: 1,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都无效",
|
||||
userID: 0,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.userID == 0 || tt.uuid == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_UserIDMatching 测试用户ID匹配逻辑
|
||||
func TestValidateProfileByUserID_UserIDMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.profileUserID == tt.requestUserID
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("UserID matching failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey 测试RSA私钥生成
|
||||
func TestGenerateRSAPrivateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "生成RSA私钥",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if privateKey == "" {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥不应为空")
|
||||
}
|
||||
// 验证PEM格式
|
||||
if len(privateKey) < 100 {
|
||||
t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey))
|
||||
}
|
||||
// 验证包含PEM头部
|
||||
if !contains(privateKey, "BEGIN RSA PRIVATE KEY") {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey_Uniqueness 测试RSA私钥唯一性
|
||||
func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) {
|
||||
keys := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
key, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateRSAPrivateKey() 失败: %v", err)
|
||||
}
|
||||
if keys[key] {
|
||||
t.Errorf("第%d次生成的密钥与之前重复", i+1)
|
||||
}
|
||||
keys[key] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
(len(s) > len(substr) && (s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
containsMiddle(s, substr))))
|
||||
}
|
||||
|
||||
func containsMiddle(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
97
internal/service/serialize_service.go
Normal file
97
internal/service/serialize_service.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/redis"
|
||||
"encoding/base64"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Property struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, p model.Profile) map[string]interface{} {
|
||||
var err error
|
||||
|
||||
// 创建基本材质数据
|
||||
texturesMap := make(map[string]interface{})
|
||||
textures := map[string]interface{}{
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
"profileId": p.UUID,
|
||||
"profileName": p.Name,
|
||||
"textures": texturesMap,
|
||||
}
|
||||
|
||||
// 处理皮肤
|
||||
if p.SkinID != nil {
|
||||
skin, err := GetTextureByID(db, *p.SkinID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID))
|
||||
} else {
|
||||
texturesMap["SKIN"] = map[string]interface{}{
|
||||
"url": skin.URL,
|
||||
"metadata": skin.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理披风
|
||||
if p.CapeID != nil {
|
||||
cape, err := GetTextureByID(db, *p.CapeID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID))
|
||||
} else {
|
||||
texturesMap["CAPE"] = map[string]interface{}{
|
||||
"url": cape.URL,
|
||||
"metadata": cape.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将textures编码为base64
|
||||
bytes, err := json.Marshal(textures)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
textureData := base64.StdEncoding.EncodeToString(bytes)
|
||||
signature, err := SignStringWithSHA1withRSA(logger, redisClient, textureData)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
data := map[string]interface{}{
|
||||
"id": p.UUID,
|
||||
"name": p.Name,
|
||||
"properties": []Property{
|
||||
{
|
||||
Name: "textures",
|
||||
Value: textureData,
|
||||
Signature: signature,
|
||||
},
|
||||
},
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func SerializeUser(logger *zap.Logger, u *model.User, UUID string) map[string]interface{} {
|
||||
if u == nil {
|
||||
logger.Error("[ERROR] 尝试序列化空用户")
|
||||
return nil
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"id": UUID,
|
||||
"properties": u.Properties,
|
||||
}
|
||||
return data
|
||||
}
|
||||
172
internal/service/serialize_service_test.go
Normal file
172
internal/service/serialize_service_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户
|
||||
func TestSerializeUser_NilUser(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
result := SerializeUser(logger, nil, "test-uuid")
|
||||
if result != nil {
|
||||
t.Error("SerializeUser() 对于nil用户应返回nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_ActualCall 实际调用SerializeUser函数
|
||||
func TestSerializeUser_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
user := &model.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Properties: "{}",
|
||||
}
|
||||
|
||||
result := SerializeUser(logger, user, "test-uuid-123")
|
||||
if result == nil {
|
||||
t.Fatal("SerializeUser() 返回的结果不应为nil")
|
||||
}
|
||||
|
||||
if result["id"] != "test-uuid-123" {
|
||||
t.Errorf("id = %v, want 'test-uuid-123'", result["id"])
|
||||
}
|
||||
|
||||
if result["properties"] == nil {
|
||||
t.Error("properties 不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProperty_Structure 测试Property结构
|
||||
func TestProperty_Structure(t *testing.T) {
|
||||
prop := Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
}
|
||||
|
||||
if prop.Name == "" {
|
||||
t.Error("Property name should not be empty")
|
||||
}
|
||||
|
||||
if prop.Value == "" {
|
||||
t.Error("Property value should not be empty")
|
||||
}
|
||||
|
||||
// Signature是可选的
|
||||
if prop.Signature == "" {
|
||||
t.Log("Property signature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeService_PropertyFields 测试Property字段
|
||||
func TestSerializeService_PropertyFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
property Property
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "缺少Name的Property",
|
||||
property: Property{
|
||||
Name: "",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少Value的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "没有Signature的Property(有效)",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.property.Name != "" && tt.property.Value != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Property validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_InputValidation 测试SerializeUser输入验证
|
||||
func TestSerializeUser_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *struct{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户不为nil",
|
||||
user: &struct{}{},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户为nil",
|
||||
user: nil,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.user != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Input validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeProfile_Structure 测试SerializeProfile返回结构
|
||||
func TestSerializeProfile_Structure(t *testing.T) {
|
||||
// 测试返回的数据结构应该包含的字段
|
||||
expectedFields := []string{"id", "name", "properties"}
|
||||
|
||||
// 验证字段名称
|
||||
for _, field := range expectedFields {
|
||||
if field == "" {
|
||||
t.Error("Field name should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证properties应该是数组
|
||||
// 注意:这里只测试逻辑,不测试实际序列化
|
||||
}
|
||||
|
||||
// TestSerializeProfile_PropertyName 测试Property名称
|
||||
func TestSerializeProfile_PropertyName(t *testing.T) {
|
||||
// textures是固定的属性名
|
||||
propertyName := "textures"
|
||||
if propertyName != "textures" {
|
||||
t.Errorf("Property name = %s, want 'textures'", propertyName)
|
||||
}
|
||||
}
|
||||
605
internal/service/signature_service.go
Normal file
605
internal/service/signature_service.go
Normal file
@@ -0,0 +1,605 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
// RSA密钥长度
|
||||
RSAKeySize = 4096
|
||||
|
||||
// Redis密钥名称
|
||||
PrivateKeyRedisKey = "private_key"
|
||||
PublicKeyRedisKey = "public_key"
|
||||
|
||||
// 密钥过期时间
|
||||
KeyExpirationTime = time.Hour * 24 * 7
|
||||
|
||||
// 证书相关
|
||||
CertificateRefreshInterval = time.Hour * 24 // 证书刷新时间间隔
|
||||
CertificateExpirationPeriod = time.Hour * 24 * 7 // 证书过期时间
|
||||
)
|
||||
|
||||
// PlayerCertificate 表示玩家证书信息
|
||||
type PlayerCertificate struct {
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
RefreshedAfter string `json:"refreshedAfter"`
|
||||
PublicKeySignature string `json:"publicKeySignature,omitempty"`
|
||||
PublicKeySignatureV2 string `json:"publicKeySignatureV2,omitempty"`
|
||||
KeyPair struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
} `json:"keyPair"`
|
||||
}
|
||||
// SignatureService 保留结构体以保持向后兼容,但推荐使用函数式版本
|
||||
type SignatureService struct {
|
||||
logger *zap.Logger
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
func NewSignatureService(logger *zap.Logger, redisClient *redis.Client) *SignatureService {
|
||||
return &SignatureService{
|
||||
logger: logger,
|
||||
redisClient: redisClient,
|
||||
}
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串并返回Base64编码的签名(函数式版本)
|
||||
func SignStringWithSHA1withRSA(logger *zap.Logger, redisClient *redis.Client, data string) (string, error) {
|
||||
if data == "" {
|
||||
return "", fmt.Errorf("签名数据不能为空")
|
||||
}
|
||||
|
||||
// 获取私钥
|
||||
privateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解码私钥失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("解码私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算SHA1哈希
|
||||
hashed := sha1.Sum([]byte(data))
|
||||
|
||||
// 使用RSA-PKCS1v15算法签名
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA1, hashed[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] RSA签名失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("RSA签名失败: %w", err)
|
||||
}
|
||||
|
||||
// Base64编码签名
|
||||
encodedSignature := base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
logger.Info("[INFO] 成功使用SHA1withRSA生成签名,", zap.Any("数据长度:", len(data)))
|
||||
return encodedSignature, nil
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSAService 使用SHA1withRSA签名字符串并返回Base64编码的签名(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
|
||||
return SignStringWithSHA1withRSA(s.logger, s.redisClient, data)
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEM 从Redis获取并解码PEM格式的私钥(函数式版本)
|
||||
func DecodePrivateKeyFromPEM(logger *zap.Logger, redisClient *redis.Client) (*rsa.PrivateKey, error) {
|
||||
// 从Redis获取私钥
|
||||
privateKeyString, err := GetPrivateKeyFromRedis(logger, redisClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 解码PEM格式
|
||||
privateKeyBlock, rest := pem.Decode([]byte(privateKeyString))
|
||||
if privateKeyBlock == nil || len(rest) > 0 {
|
||||
logger.Error("[ERROR] 无效的PEM格式私钥")
|
||||
return nil, fmt.Errorf("无效的PEM格式私钥")
|
||||
}
|
||||
|
||||
// 解析PKCS1格式的私钥
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解析私钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("解析私钥失败: %w", err)
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedis 从Redis获取私钥(PEM格式)(函数式版本)
|
||||
func GetPrivateKeyFromRedis(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PrivateKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取私钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPrivateKeyFromRedis(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEMService 从Redis获取并解码PEM格式的私钥(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) DecodePrivateKeyFromPEM() (*rsa.PrivateKey, error) {
|
||||
return DecodePrivateKeyFromPEM(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedisService 从Redis获取私钥(PEM格式)(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GetPrivateKeyFromRedis() (string, error) {
|
||||
return GetPrivateKeyFromRedis(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPair 生成新的RSA密钥对(函数式版本)
|
||||
func GenerateRSAKeyPair(logger *zap.Logger, redisClient *redis.Client) error {
|
||||
logger.Info("[INFO] 开始生成RSA密钥对", zap.Int("keySize", RSAKeySize))
|
||||
|
||||
// 生成私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("生成RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 编码私钥为PEM格式
|
||||
pemPrivateKey, err := EncodePrivateKeyToPEM(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取公钥并编码为PEM格式
|
||||
pubKey := privateKey.PublicKey
|
||||
pemPublicKey, err := EncodePublicKeyToPEM(logger, &pubKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA公钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA公钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存密钥对到Redis
|
||||
return SaveKeyPairToRedis(logger, redisClient, string(pemPrivateKey), string(pemPublicKey))
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPairService 生成新的RSA密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GenerateRSAKeyPair() error {
|
||||
return GenerateRSAKeyPair(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEM 将私钥编码为PEM格式(函数式版本)
|
||||
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
if privateKey == nil {
|
||||
return nil, fmt.Errorf("私钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PRIVATE KEY" 类型
|
||||
pemType := "PRIVATE KEY"
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PRIVATE KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PRIVATE KEY"
|
||||
}
|
||||
|
||||
// 将私钥转换为PKCS1格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: privateKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEM 将公钥编码为PEM格式(函数式版本)
|
||||
func EncodePublicKeyToPEM(logger *zap.Logger, publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
if publicKey == nil {
|
||||
return nil, fmt.Errorf("公钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PUBLIC KEY" 类型
|
||||
pemType := "PUBLIC KEY"
|
||||
var publicKeyBytes []byte
|
||||
var err error
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PUBLIC KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PUBLIC KEY"
|
||||
publicKeyBytes = x509.MarshalPKCS1PublicKey(publicKey)
|
||||
} else {
|
||||
// 默认将公钥转换为PKIX格式
|
||||
publicKeyBytes, err = x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化公钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("序列化公钥失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: publicKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedis 将RSA密钥对保存到Redis(函数式版本)
|
||||
func SaveKeyPairToRedis(logger *zap.Logger, redisClient *redis.Client, privateKey, publicKey string) error {
|
||||
// 创建上下文并设置超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 使用事务确保两个操作的原子性
|
||||
tx := redisClient.TxPipeline()
|
||||
|
||||
tx.Set(ctx, PrivateKeyRedisKey, privateKey, KeyExpirationTime)
|
||||
tx.Set(ctx, PublicKeyRedisKey, publicKey, KeyExpirationTime)
|
||||
|
||||
// 执行事务
|
||||
_, err := tx.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 保存RSA密钥对到Redis失败: ", zap.Error(err))
|
||||
return fmt.Errorf("保存RSA密钥对到Redis失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功保存RSA密钥对到Redis")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEMService 将私钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePrivateKeyToPEM(privateKey, keyType...)
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEMService 将公钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePublicKeyToPEM(publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePublicKeyToPEM(s.logger, publicKey, keyType...)
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedisService 将RSA密钥对保存到Redis(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SaveKeyPairToRedis(privateKey, publicKey string) error {
|
||||
return SaveKeyPairToRedis(s.logger, s.redisClient, privateKey, publicKey)
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedisFunc 从Redis获取公钥(PEM格式,函数式版本)
|
||||
func GetPublicKeyFromRedisFunc(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PublicKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取公钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
// 检查获取到的公钥是否为空(key不存在时GetBytes返回nil, nil)
|
||||
if len(pemBytes) == 0 {
|
||||
logger.Info("[INFO] Redis中公钥为空,尝试生成新的密钥对")
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedis 从Redis获取公钥(PEM格式,结构体方法版本)
|
||||
func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
return GetPublicKeyFromRedisFunc(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
|
||||
// GeneratePlayerCertificate 生成玩家证书(函数式版本)
|
||||
func GeneratePlayerCertificate(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, uuid string) (*PlayerCertificate, error) {
|
||||
if uuid == "" {
|
||||
return nil, fmt.Errorf("UUID不能为空")
|
||||
}
|
||||
logger.Info("[INFO] 开始生成玩家证书,用户UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
|
||||
keyPair, err := repository.GetProfileKeyPair(uuid)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 获取用户密钥对失败,将创建新密钥对: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair = nil
|
||||
}
|
||||
|
||||
// 如果没有找到密钥对或密钥对已过期,创建一个新的
|
||||
now := time.Now().UTC()
|
||||
if keyPair == nil || keyPair.Refresh.Before(now) || keyPair.PrivateKey == "" || keyPair.PublicKey == "" {
|
||||
logger.Info("[INFO] 为用户创建新的密钥对: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair, err = NewKeyPair(logger)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书密钥对失败: %w", err)
|
||||
}
|
||||
// 保存密钥对到数据库
|
||||
err = repository.UpdateProfileKeyPair(uuid, keyPair)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Warn("[WARN] 更新用户密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
// 继续执行,即使保存失败
|
||||
}
|
||||
}
|
||||
|
||||
// 计算expiresAt的毫秒时间戳
|
||||
expiresAtMillis := keyPair.Expiration.UnixMilli()
|
||||
|
||||
// 准备签名
|
||||
publicKeySignature := ""
|
||||
publicKeySignatureV2 := ""
|
||||
|
||||
// 获取服务器私钥用于签名
|
||||
serverPrivateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 获取服务器私钥失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("获取服务器私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 提取公钥DER编码
|
||||
pubPEMBlock, _ := pem.Decode([]byte(keyPair.PublicKey))
|
||||
if pubPEMBlock == nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 解码公钥PEM失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("publicKey", keyPair.PublicKey),
|
||||
)
|
||||
return nil, fmt.Errorf("解码公钥PEM失败")
|
||||
}
|
||||
pubDER := pubPEMBlock.Bytes
|
||||
|
||||
// 准备publicKeySignature(用于MC 1.19)
|
||||
// Base64编码公钥,不包含换行
|
||||
pubBase64 := strings.ReplaceAll(base64.StdEncoding.EncodeToString(pubDER), "\n", "")
|
||||
|
||||
// 按76字符一行进行包装
|
||||
pubBase64Wrapped := WrapString(pubBase64, 76)
|
||||
|
||||
// 放入PEM格式
|
||||
pubMojangPEM := "-----BEGIN RSA PUBLIC KEY-----\n" +
|
||||
pubBase64Wrapped +
|
||||
"\n-----END RSA PUBLIC KEY-----\n"
|
||||
|
||||
// 签名数据: expiresAt毫秒时间戳 + 公钥PEM格式
|
||||
signedData := []byte(fmt.Sprintf("%d%s", expiresAtMillis, pubMojangPEM))
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash1 := sha1.Sum(signedData)
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash1[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名失败: %w", err)
|
||||
}
|
||||
publicKeySignature = base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
// 准备publicKeySignatureV2(用于MC 1.19.1+)
|
||||
var uuidBytes []byte
|
||||
|
||||
// 如果提供了UUID,则使用它
|
||||
// 移除UUID中的连字符
|
||||
uuidStr := strings.ReplaceAll(uuid, "-", "")
|
||||
|
||||
// 将UUID转换为字节数组(16字节)
|
||||
if len(uuidStr) < 32 {
|
||||
logger.Warn("[WARN] UUID长度不足32字符,使用空UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("processedUuidStr", uuidStr),
|
||||
)
|
||||
uuidBytes = make([]byte, 16)
|
||||
} else {
|
||||
// 解析UUID字符串为字节
|
||||
uuidBytes = make([]byte, 16)
|
||||
parseErr := error(nil)
|
||||
for i := 0; i < 16; i++ {
|
||||
// 每两个字符转换为一个字节
|
||||
byteStr := uuidStr[i*2 : i*2+2]
|
||||
byteVal, err := strconv.ParseUint(byteStr, 16, 8)
|
||||
if err != nil {
|
||||
parseErr = err
|
||||
logger.Error("[ERROR] 解析UUID字节失败: %v, byteStr: %s",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("byteStr", byteStr),
|
||||
zap.Int("index", i),
|
||||
)
|
||||
uuidBytes = make([]byte, 16) // 出错时使用空UUID
|
||||
break
|
||||
}
|
||||
uuidBytes[i] = byte(byteVal)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("解析UUID字节失败: %w", parseErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 准备签名数据:UUID + expiresAt时间戳 + DER编码的公钥
|
||||
signedDataV2 := make([]byte, 0, 24+len(pubDER)) // 预分配缓冲区
|
||||
|
||||
// 添加UUID(16字节)
|
||||
signedDataV2 = append(signedDataV2, uuidBytes...)
|
||||
|
||||
// 添加expiresAt毫秒时间戳(8字节,大端序)
|
||||
expiresAtBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expiresAtBytes, uint64(expiresAtMillis))
|
||||
signedDataV2 = append(signedDataV2, expiresAtBytes...)
|
||||
|
||||
// 添加DER编码的公钥
|
||||
signedDataV2 = append(signedDataV2, pubDER...)
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash2 := sha1.Sum(signedDataV2)
|
||||
signatureV2, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash2[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名V2失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名V2失败: %w", err)
|
||||
}
|
||||
publicKeySignatureV2 = base64.StdEncoding.EncodeToString(signatureV2)
|
||||
|
||||
// 创建玩家证书结构
|
||||
certificate := &PlayerCertificate{
|
||||
KeyPair: struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
PrivateKey: keyPair.PrivateKey,
|
||||
PublicKey: keyPair.PublicKey,
|
||||
},
|
||||
PublicKeySignature: publicKeySignature,
|
||||
PublicKeySignatureV2: publicKeySignatureV2,
|
||||
ExpiresAt: keyPair.Expiration.Format(time.RFC3339Nano),
|
||||
RefreshedAfter: keyPair.Refresh.Format(time.RFC3339Nano),
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功生成玩家证书,过期时间: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("expiresAt", certificate.ExpiresAt),
|
||||
zap.String("refreshedAfter", certificate.RefreshedAfter),
|
||||
)
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
// GeneratePlayerCertificateService 生成玩家证书(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GeneratePlayerCertificate(uuid string) (*PlayerCertificate, error) {
|
||||
return GeneratePlayerCertificate(nil, s.logger, s.redisClient, uuid) // TODO: 需要传入db参数
|
||||
}
|
||||
|
||||
// NewKeyPair 生成新的密钥对(函数式版本)
|
||||
func NewKeyPair(logger *zap.Logger) (*model.KeyPair, error) {
|
||||
// 生成新的RSA密钥对(用于玩家证书)
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) // 对玩家证书使用更小的密钥以提高性能
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书私钥失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取DER编码的密钥
|
||||
keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码私钥为PKCS8格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码私钥为PKCS8格式失败: %w", err)
|
||||
}
|
||||
|
||||
pubDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码公钥为PKIX格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码公钥为PKIX格式失败: %w", err)
|
||||
}
|
||||
|
||||
// 将密钥编码为PEM格式
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: keyDER,
|
||||
})
|
||||
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: pubDER,
|
||||
})
|
||||
|
||||
// 创建证书过期和刷新时间
|
||||
now := time.Now().UTC()
|
||||
expiresAtTime := now.Add(CertificateExpirationPeriod)
|
||||
refreshedAfter := now.Add(CertificateRefreshInterval)
|
||||
keyPair := &model.KeyPair{
|
||||
Expiration: expiresAtTime,
|
||||
PrivateKey: string(keyPEM),
|
||||
PublicKey: string(pubPEM),
|
||||
Refresh: refreshedAfter,
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
// WrapString 将字符串按指定宽度进行换行(函数式版本)
|
||||
func WrapString(str string, width int) string {
|
||||
if width <= 0 {
|
||||
return str
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(str); i += width {
|
||||
end := i + width
|
||||
if end > len(str) {
|
||||
end = len(str)
|
||||
}
|
||||
b.WriteString(str[i:end])
|
||||
if end < len(str) {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// NewKeyPairService 生成新的密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
return NewKeyPair(s.logger)
|
||||
}
|
||||
358
internal/service/signature_service_test.go
Normal file
358
internal/service/signature_service_test.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSignatureService_Constants 测试签名服务相关常量
|
||||
func TestSignatureService_Constants(t *testing.T) {
|
||||
if RSAKeySize != 4096 {
|
||||
t.Errorf("RSAKeySize = %d, want 4096", RSAKeySize)
|
||||
}
|
||||
|
||||
if PrivateKeyRedisKey == "" {
|
||||
t.Error("PrivateKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if PublicKeyRedisKey == "" {
|
||||
t.Error("PublicKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if KeyExpirationTime != 24*7*time.Hour {
|
||||
t.Errorf("KeyExpirationTime = %v, want 7 days", KeyExpirationTime)
|
||||
}
|
||||
|
||||
if CertificateRefreshInterval != 24*time.Hour {
|
||||
t.Errorf("CertificateRefreshInterval = %v, want 24 hours", CertificateRefreshInterval)
|
||||
}
|
||||
|
||||
if CertificateExpirationPeriod != 24*7*time.Hour {
|
||||
t.Errorf("CertificateExpirationPeriod = %v, want 7 days", CertificateExpirationPeriod)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSignatureService_DataValidation 测试签名数据验证逻辑
|
||||
func TestSignatureService_DataValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "非空数据有效",
|
||||
data: "test data",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "空数据无效",
|
||||
data: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.data != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Data validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlayerCertificate_Structure 测试PlayerCertificate结构
|
||||
func TestPlayerCertificate_Structure(t *testing.T) {
|
||||
cert := PlayerCertificate{
|
||||
ExpiresAt: "2025-01-01T00:00:00Z",
|
||||
RefreshedAfter: "2025-01-01T00:00:00Z",
|
||||
PublicKeySignature: "signature",
|
||||
PublicKeySignatureV2: "signaturev2",
|
||||
}
|
||||
|
||||
// 验证结构体字段
|
||||
if cert.ExpiresAt == "" {
|
||||
t.Error("ExpiresAt should not be empty")
|
||||
}
|
||||
|
||||
if cert.RefreshedAfter == "" {
|
||||
t.Error("RefreshedAfter should not be empty")
|
||||
}
|
||||
|
||||
// PublicKeySignature是可选的
|
||||
if cert.PublicKeySignature == "" {
|
||||
t.Log("PublicKeySignature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString 测试字符串换行函数
|
||||
func TestWrapString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "正常换行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
expected: "12345\n67890",
|
||||
},
|
||||
{
|
||||
name: "字符串长度等于width",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
expected: "12345",
|
||||
},
|
||||
{
|
||||
name: "字符串长度小于width",
|
||||
str: "123",
|
||||
width: 5,
|
||||
expected: "123",
|
||||
},
|
||||
{
|
||||
name: "width为0,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "width为负数,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: -1,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
str: "",
|
||||
width: 5,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "width为1",
|
||||
str: "12345",
|
||||
width: 1,
|
||||
expected: "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
name: "长字符串多次换行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
expected: "12345\n67890\n12345",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
if result != tt.expected {
|
||||
t.Errorf("WrapString(%q, %d) = %q, want %q", tt.str, tt.width, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_LineCount 测试换行后的行数
|
||||
func TestWrapString_LineCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
wantLines int
|
||||
}{
|
||||
{
|
||||
name: "10个字符,width=5,应该2行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
wantLines: 2,
|
||||
},
|
||||
{
|
||||
name: "15个字符,width=5,应该3行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
wantLines: 3,
|
||||
},
|
||||
{
|
||||
name: "5个字符,width=5,应该1行",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
wantLines: 1,
|
||||
},
|
||||
{
|
||||
name: "width为0,应该1行",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
wantLines: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
lines := strings.Count(result, "\n") + 1
|
||||
if lines != tt.wantLines {
|
||||
t.Errorf("Line count = %d, want %d (result: %q)", lines, tt.wantLines, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_NoTrailingNewline 测试末尾不换行
|
||||
func TestWrapString_NoTrailingNewline(t *testing.T) {
|
||||
str := "1234567890"
|
||||
result := WrapString(str, 5)
|
||||
|
||||
// 验证末尾没有换行符
|
||||
if strings.HasSuffix(result, "\n") {
|
||||
t.Error("Result should not end with newline")
|
||||
}
|
||||
|
||||
// 验证包含换行符(除了最后一行)
|
||||
if !strings.Contains(result, "\n") {
|
||||
t.Error("Result should contain newline for multi-line output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePrivateKeyToPEM_ActualCall 实际调用EncodePrivateKeyToPEM函数
|
||||
func TestEncodePrivateKeyToPEM_ActualCall(t *testing.T) {
|
||||
// 生成测试用的RSA私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA私钥失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePrivateKeyToPEM(privateKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePrivateKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'RSA PRIVATE KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'PRIVATE KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_ActualCall 实际调用EncodePublicKeyToPEM函数
|
||||
func TestEncodePublicKeyToPEM_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 生成测试用的RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA密钥对失败: %v", err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePublicKeyToPEM(logger, publicKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePublicKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'RSA PUBLIC KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'PUBLIC KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_NilKey 测试nil公钥
|
||||
func TestEncodePublicKeyToPEM_NilKey(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
_, err := EncodePublicKeyToPEM(logger, nil)
|
||||
if err == nil {
|
||||
t.Error("EncodePublicKeyToPEM() 对于nil公钥应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSignatureService 测试创建SignatureService
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
// 注意:这里需要实际的redis client,但我们只测试结构体创建
|
||||
// 在实际测试中,可以使用mock redis client
|
||||
service := NewSignatureService(logger, nil)
|
||||
if service == nil {
|
||||
t.Error("NewSignatureService() 不应返回nil")
|
||||
}
|
||||
if service.logger != logger {
|
||||
t.Error("NewSignatureService() logger 设置不正确")
|
||||
}
|
||||
}
|
||||
251
internal/service/texture_service.go
Normal file
251
internal/service/texture_service.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateTexture 创建材质
|
||||
func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
|
||||
// 验证用户存在
|
||||
user, err := repository.FindUserByID(uploaderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查Hash是否已存在
|
||||
existingTexture, err := repository.FindTextureByHash(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingTexture != nil {
|
||||
return nil, errors.New("该材质已存在")
|
||||
}
|
||||
|
||||
// 转换材质类型
|
||||
var textureTypeEnum model.TextureType
|
||||
switch textureType {
|
||||
case "SKIN":
|
||||
textureTypeEnum = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureTypeEnum = model.TextureTypeCape
|
||||
default:
|
||||
return nil, errors.New("无效的材质类型")
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture := &model.Texture{
|
||||
UploaderID: uploaderID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: textureTypeEnum,
|
||||
URL: url,
|
||||
Hash: hash,
|
||||
Size: size,
|
||||
IsPublic: isPublic,
|
||||
IsSlim: isSlim,
|
||||
Status: 1,
|
||||
DownloadCount: 0,
|
||||
FavoriteCount: 0,
|
||||
}
|
||||
|
||||
if err := repository.CreateTexture(texture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetTextureByID 根据ID获取材质
|
||||
func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) {
|
||||
texture, err := repository.FindTextureByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.FindTexturesByUploaderID(uploaderID, page, pageSize)
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以修改
|
||||
if texture.UploaderID != uploaderID {
|
||||
return nil, errors.New("无权修改此材质")
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := make(map[string]interface{})
|
||||
if name != "" {
|
||||
updates["name"] = name
|
||||
}
|
||||
if description != "" {
|
||||
updates["description"] = description
|
||||
}
|
||||
if isPublic != nil {
|
||||
updates["is_public"] = *isPublic
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := repository.UpdateTextureFields(textureID, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 返回更新后的材质
|
||||
return repository.FindTextureByID(textureID)
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质
|
||||
func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以删除
|
||||
if texture.UploaderID != uploaderID {
|
||||
return errors.New("无权删除此材质")
|
||||
}
|
||||
|
||||
return repository.DeleteTexture(textureID)
|
||||
}
|
||||
|
||||
// RecordTextureDownload 记录下载
|
||||
func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
if err := repository.IncrementTextureDownloadCount(textureID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建下载日志
|
||||
log := &model.TextureDownloadLog{
|
||||
TextureID: textureID,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
|
||||
return repository.CreateTextureDownloadLog(log)
|
||||
}
|
||||
|
||||
// ToggleTextureFavorite 切换收藏状态
|
||||
func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if texture == nil {
|
||||
return false, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查是否已收藏
|
||||
isFavorited, err := repository.IsTextureFavorited(userID, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isFavorited {
|
||||
// 取消收藏
|
||||
if err := repository.RemoveTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.DecrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
// 添加收藏
|
||||
if err := repository.AddTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.IncrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserTextureFavorites 获取用户收藏的材质列表
|
||||
func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.GetUserTextureFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// CheckTextureUploadLimit 检查用户上传材质数量限制
|
||||
func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error {
|
||||
count, err := repository.CountTexturesByUploaderID(uploaderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count >= int64(maxTextures) {
|
||||
return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
471
internal/service/texture_service_test.go
Normal file
471
internal/service/texture_service_test.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTextureService_TypeValidation 测试材质类型验证
|
||||
func TestTextureService_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
textureType string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "SKIN类型有效",
|
||||
textureType: "SKIN",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "CAPE类型有效",
|
||||
textureType: "CAPE",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型",
|
||||
textureType: "INVALID",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空类型无效",
|
||||
textureType: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.textureType == "SKIN" || tt.textureType == "CAPE"
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Texture type validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_DefaultValues 测试材质默认值
|
||||
func TestTextureService_DefaultValues(t *testing.T) {
|
||||
// 测试默认状态
|
||||
defaultStatus := 1
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认状态应为1,实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试默认下载数
|
||||
defaultDownloadCount := 0
|
||||
if defaultDownloadCount != 0 {
|
||||
t.Errorf("默认下载数应为0,实际为%d", defaultDownloadCount)
|
||||
}
|
||||
|
||||
// 测试默认收藏数
|
||||
defaultFavoriteCount := 0
|
||||
if defaultFavoriteCount != 0 {
|
||||
t.Errorf("默认收藏数应为0,实际为%d", defaultFavoriteCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_StatusValidation 测试材质状态验证
|
||||
func TestTextureService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为0时可能有效(取决于业务逻辑)",
|
||||
status: 0,
|
||||
wantValid: true, // 状态为0(禁用)时,材质仍然存在,只是不可用,但查询时不会返回错误
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 材质状态为-1时表示已删除,无效
|
||||
isValid := tt.status != -1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextures_Pagination 测试分页逻辑
|
||||
func TestGetUserTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 2,
|
||||
pageSize: 20,
|
||||
wantPage: 2,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize小于1,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 0,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchTextures_Pagination 测试搜索分页逻辑
|
||||
func TestSearchTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPage: 1,
|
||||
wantSize: 10,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: -1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 150,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_PermissionCheck 测试更新材质的权限检查
|
||||
func TestUpdateTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许更新",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝更新",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_FieldUpdates 测试更新字段逻辑
|
||||
func TestUpdateTexture_FieldUpdates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nameValue string
|
||||
descValue string
|
||||
isPublic *bool
|
||||
wantUpdates int
|
||||
}{
|
||||
{
|
||||
name: "更新所有字段",
|
||||
nameValue: "NewName",
|
||||
descValue: "NewDesc",
|
||||
isPublic: boolPtr(true),
|
||||
wantUpdates: 3,
|
||||
},
|
||||
{
|
||||
name: "只更新名称",
|
||||
nameValue: "NewName",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新描述",
|
||||
nameValue: "",
|
||||
descValue: "NewDesc",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新公开状态",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: boolPtr(false),
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "没有更新",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
updates := 0
|
||||
if tt.nameValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.descValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.isPublic != nil {
|
||||
updates++
|
||||
}
|
||||
|
||||
if updates != tt.wantUpdates {
|
||||
t.Errorf("Updates count = %d, want %d", updates, tt.wantUpdates)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteTexture_PermissionCheck 测试删除材质的权限检查
|
||||
func TestDeleteTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许删除",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝删除",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToggleTextureFavorite_Logic 测试切换收藏状态的逻辑
|
||||
func TestToggleTextureFavorite_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isFavorited bool
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "已收藏,取消收藏",
|
||||
isFavorited: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "未收藏,添加收藏",
|
||||
isFavorited: false,
|
||||
wantResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := !tt.isFavorited
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("Toggle favorite failed: got %v, want %v", result, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextureFavorites_Pagination 测试收藏列表分页
|
||||
func TestGetUserTextureFavorites_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckTextureUploadLimit_Logic 测试上传限制检查逻辑
|
||||
func TestCheckTextureUploadLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
maxTextures int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxTextures: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= int64(tt.maxTextures)
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
277
internal/service/token_service.go
Normal file
277
internal/service/token_service.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ExtendedTimeout = 10 * time.Second
|
||||
TokensMaxCount = 10 // 用户最多保留的token数量
|
||||
)
|
||||
|
||||
// NewToken 创建新令牌
|
||||
func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
_, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userId,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌到tokens集合
|
||||
_, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer insertCancel()
|
||||
|
||||
err = repository.CreateToken(&token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
// 清理多余的令牌
|
||||
go CheckAndCleanupExcessTokens(db, logger, userId)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌,只保留最新的10个
|
||||
func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
// 获取用户所有令牌,按发行日期降序排序
|
||||
tokens, err := repository.GetTokensByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
// 如果令牌数量不超过上限,无需清理
|
||||
if len(tokens) <= TokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取需要删除的令牌ID列表
|
||||
tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount)
|
||||
for i := TokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
// 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数)
|
||||
DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if DeletedCount > 0 {
|
||||
logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌有效性
|
||||
func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool {
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用投影只获取需要的字段
|
||||
var token *model.Token
|
||||
token, err := repository.FindTokenByID(accessToken)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果客户端令牌为空,只验证访问令牌
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 否则验证客户端令牌是否匹配
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) {
|
||||
return repository.GetUUIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) {
|
||||
return repository.GetUserIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) {
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
logger.Error(
|
||||
"验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌
|
||||
|
||||
err = repository.CreateToken(&newToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"创建新Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
// 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建
|
||||
logger.Warn(
|
||||
"删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"成功刷新Token",
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("accessToken", newAccessToken),
|
||||
)
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) {
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"删除Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken))
|
||||
|
||||
}
|
||||
|
||||
// InvalidUserTokens 使用户所有令牌失效
|
||||
func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]删除用户Token失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", userId),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId))
|
||||
|
||||
}
|
||||
204
internal/service/token_service_test.go
Normal file
204
internal/service/token_service_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTokenService_Constants 测试Token服务相关常量
|
||||
func TestTokenService_Constants(t *testing.T) {
|
||||
if ExtendedTimeout != 10*time.Second {
|
||||
t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout)
|
||||
}
|
||||
|
||||
if TokensMaxCount != 10 {
|
||||
t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Timeout 测试超时常量
|
||||
func TestTokenService_Timeout(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
|
||||
if ExtendedTimeout <= DefaultTimeout {
|
||||
t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Validation 测试Token验证逻辑
|
||||
func TestTokenService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "空token无效",
|
||||
accessToken: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "非空token可能有效",
|
||||
accessToken: "valid-token-string",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试空token检查逻辑
|
||||
isValid := tt.accessToken != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
|
||||
func TestTokenService_ClientTokenLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientToken string
|
||||
shouldGenerate bool
|
||||
}{
|
||||
{
|
||||
name: "空的clientToken应该生成新的",
|
||||
clientToken: "",
|
||||
shouldGenerate: true,
|
||||
},
|
||||
{
|
||||
name: "非空的clientToken应该使用提供的",
|
||||
clientToken: "existing-client-token",
|
||||
shouldGenerate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldGenerate := tt.clientToken == ""
|
||||
if shouldGenerate != tt.shouldGenerate {
|
||||
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ProfileSelection 测试Profile选择逻辑
|
||||
func TestTokenService_ProfileSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
shouldAutoSelect bool
|
||||
}{
|
||||
{
|
||||
name: "只有一个profile时自动选择",
|
||||
profileCount: 1,
|
||||
shouldAutoSelect: true,
|
||||
},
|
||||
{
|
||||
name: "多个profile时不自动选择",
|
||||
profileCount: 2,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
{
|
||||
name: "没有profile时不自动选择",
|
||||
profileCount: 0,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldAutoSelect := tt.profileCount == 1
|
||||
if shouldAutoSelect != tt.shouldAutoSelect {
|
||||
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_CleanupLogic 测试清理逻辑
|
||||
func TestTokenService_CleanupLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenCount int
|
||||
maxCount int
|
||||
shouldCleanup bool
|
||||
cleanupCount int
|
||||
}{
|
||||
{
|
||||
name: "token数量未超过上限,不需要清理",
|
||||
tokenCount: 5,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
{
|
||||
name: "token数量超过上限,需要清理",
|
||||
tokenCount: 15,
|
||||
maxCount: 10,
|
||||
shouldCleanup: true,
|
||||
cleanupCount: 5,
|
||||
},
|
||||
{
|
||||
name: "token数量等于上限,不需要清理",
|
||||
tokenCount: 10,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCleanup := tt.tokenCount > tt.maxCount
|
||||
if shouldCleanup != tt.shouldCleanup {
|
||||
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
|
||||
}
|
||||
|
||||
if shouldCleanup {
|
||||
expectedCleanupCount := tt.tokenCount - tt.maxCount
|
||||
if expectedCleanupCount != tt.cleanupCount {
|
||||
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_UserIDValidation 测试UserID验证
|
||||
func TestTokenService_UserIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UserID",
|
||||
userID: 1,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "UserID为0时无效",
|
||||
userID: 0,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "负数UserID无效",
|
||||
userID: -1,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
160
internal/service/upload_service.go
Normal file
160
internal/service/upload_service.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/storage"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileType 文件类型枚举
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
FileTypeAvatar FileType = "avatar"
|
||||
FileTypeTexture FileType = "texture"
|
||||
)
|
||||
|
||||
// UploadConfig 上传配置
|
||||
type UploadConfig struct {
|
||||
AllowedExts map[string]bool // 允许的文件扩展名
|
||||
MinSize int64 // 最小文件大小(字节)
|
||||
MaxSize int64 // 最大文件大小(字节)
|
||||
Expires time.Duration // URL过期时间
|
||||
}
|
||||
|
||||
// GetUploadConfig 根据文件类型获取上传配置
|
||||
func GetUploadConfig(fileType FileType) *UploadConfig {
|
||||
switch fileType {
|
||||
case FileTypeAvatar:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".webp": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 5 * 1024 * 1024, // 5MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
case FileTypeTexture:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 10 * 1024 * 1024, // 10MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFileName 验证文件名
|
||||
func ValidateFileName(fileName string, fileType FileType) error {
|
||||
if fileName == "" {
|
||||
return fmt.Errorf("文件名不能为空")
|
||||
}
|
||||
|
||||
uploadConfig := GetUploadConfig(fileType)
|
||||
if uploadConfig == nil {
|
||||
return fmt.Errorf("不支持的文件类型")
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
if !uploadConfig.AllowedExts[ext] {
|
||||
return fmt.Errorf("不支持的文件格式: %s", ext)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeAvatar); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeAvatar)
|
||||
|
||||
// 3. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("avatars")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName)
|
||||
|
||||
// 5. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURL 生成材质上传URL
|
||||
func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeTexture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 验证材质类型
|
||||
if textureType != "SKIN" && textureType != "CAPE" {
|
||||
return nil, fmt.Errorf("无效的材质类型: %s", textureType)
|
||||
}
|
||||
|
||||
// 3. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeTexture)
|
||||
|
||||
// 4. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("textures")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/{textureType}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
textureTypeFolder := strings.ToLower(textureType)
|
||||
objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName)
|
||||
|
||||
// 6. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
279
internal/service/upload_service_test.go
Normal file
279
internal/service/upload_service_test.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestUploadService_FileTypes 测试文件类型常量
|
||||
func TestUploadService_FileTypes(t *testing.T) {
|
||||
if FileTypeAvatar == "" {
|
||||
t.Error("FileTypeAvatar should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeTexture == "" {
|
||||
t.Error("FileTypeTexture should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeAvatar == FileTypeTexture {
|
||||
t.Error("FileTypeAvatar and FileTypeTexture should be different")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig 测试获取上传配置
|
||||
func TestGetUploadConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileType FileType
|
||||
wantConfig bool
|
||||
}{
|
||||
{
|
||||
name: "头像类型返回配置",
|
||||
fileType: FileTypeAvatar,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "材质类型返回配置",
|
||||
fileType: FileTypeTexture,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型返回nil",
|
||||
fileType: FileType("invalid"),
|
||||
wantConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := GetUploadConfig(tt.fileType)
|
||||
hasConfig := config != nil
|
||||
if hasConfig != tt.wantConfig {
|
||||
t.Errorf("GetUploadConfig() = %v, want %v", hasConfig, tt.wantConfig)
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
// 验证配置字段
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize <= 0 {
|
||||
t.Error("MaxSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize < config.MinSize {
|
||||
t.Error("MaxSize should be greater than or equal to MinSize")
|
||||
}
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
if len(config.AllowedExts) == 0 {
|
||||
t.Error("AllowedExts should not be empty")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_AvatarConfig 测试头像配置详情
|
||||
func TestGetUploadConfig_AvatarConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeAvatar)
|
||||
if config == nil {
|
||||
t.Fatal("Avatar config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名
|
||||
expectedExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range expectedExts {
|
||||
if !config.AllowedExts[ext] {
|
||||
t.Errorf("Avatar config should allow %s extension", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Avatar MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 5*1024*1024 {
|
||||
t.Errorf("Avatar MaxSize = %d, want 5MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Avatar Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_TextureConfig 测试材质配置详情
|
||||
func TestGetUploadConfig_TextureConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeTexture)
|
||||
if config == nil {
|
||||
t.Fatal("Texture config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名(材质只允许PNG)
|
||||
if !config.AllowedExts[".png"] {
|
||||
t.Error("Texture config should allow .png extension")
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Texture MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 10*1024*1024 {
|
||||
t.Errorf("Texture MaxSize = %d, want 10MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Texture Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName 测试文件名验证
|
||||
func TestValidateFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的头像文件名",
|
||||
fileName: "avatar.png",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效的材质文件名",
|
||||
fileName: "texture.png",
|
||||
fileType: FileTypeTexture,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "文件名为空",
|
||||
fileName: "",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "文件名不能为空",
|
||||
},
|
||||
{
|
||||
name: "不支持的文件扩展名",
|
||||
fileName: "file.txt",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件格式",
|
||||
},
|
||||
{
|
||||
name: "无效的文件类型",
|
||||
fileName: "file.png",
|
||||
fileType: FileType("invalid"),
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件类型",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateFileName(tt.fileName, tt.fileType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateFileName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && tt.errContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateFileName() error = %v, should contain %s", err, tt.errContains)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_Extensions 测试各种扩展名
|
||||
func TestValidateFileName_Extensions(t *testing.T) {
|
||||
avatarExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range avatarExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err != nil {
|
||||
t.Errorf("Avatar file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 材质只支持PNG
|
||||
textureExts := []string{".png"}
|
||||
for _, ext := range textureExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeTexture)
|
||||
if err != nil {
|
||||
t.Errorf("Texture file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试不支持的扩展名
|
||||
invalidExts := []string{".txt", ".pdf", ".doc"}
|
||||
for _, ext := range invalidExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err == nil {
|
||||
t.Errorf("Avatar file with %s extension should be invalid", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_CaseInsensitive 测试扩展名大小写不敏感
|
||||
func TestValidateFileName_CaseInsensitive(t *testing.T) {
|
||||
testCases := []struct {
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
}{
|
||||
{"test.PNG", FileTypeAvatar, false},
|
||||
{"test.JPG", FileTypeAvatar, false},
|
||||
{"test.JPEG", FileTypeAvatar, false},
|
||||
{"test.GIF", FileTypeAvatar, false},
|
||||
{"test.WEBP", FileTypeAvatar, false},
|
||||
{"test.PnG", FileTypeTexture, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.fileName, func(t *testing.T) {
|
||||
err := ValidateFileName(tc.fileName, tc.fileType)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("ValidateFileName(%s, %s) error = %v, wantErr %v", tc.fileName, tc.fileType, err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadConfig_Structure 测试UploadConfig结构
|
||||
func TestUploadConfig_Structure(t *testing.T) {
|
||||
config := &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024,
|
||||
MaxSize: 5 * 1024 * 1024,
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
|
||||
if config.AllowedExts == nil {
|
||||
t.Error("AllowedExts should not be nil")
|
||||
}
|
||||
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
|
||||
if config.MaxSize <= config.MinSize {
|
||||
t.Error("MaxSize should be greater than MinSize")
|
||||
}
|
||||
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
248
internal/service/user_service.go
Normal file
248
internal/service/user_service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RegisterUser 用户注册
|
||||
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := repository.FindUserByUsername(username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return nil, "", errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingEmail != nil {
|
||||
return nil, "", errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 确定头像URL:优先使用用户提供的头像,否则使用默认头像
|
||||
avatarURL := avatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = getDefaultAvatar()
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
Email: email,
|
||||
Avatar: avatarURL,
|
||||
Role: "user",
|
||||
Status: 1,
|
||||
Points: 0, // 初始积分可以从配置读取
|
||||
}
|
||||
|
||||
if err := repository.CreateUser(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// TODO: 添加注册奖励积分
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// LoginUser 用户登录(支持用户名或邮箱登录)
|
||||
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
// 查找用户:判断是用户名还是邮箱
|
||||
var user *model.User
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
// 包含@符号,认为是邮箱
|
||||
user, err = repository.FindUserByEmail(usernameOrEmail)
|
||||
} else {
|
||||
// 否则认为是用户名
|
||||
user, err = repository.FindUserByUsername(usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
// 记录失败日志
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status != 1 {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
|
||||
return nil, "", errors.New("账号已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(user.Password, password) {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func GetUserByID(id int64) (*model.User, error) {
|
||||
return repository.FindUserByID(id)
|
||||
}
|
||||
|
||||
// UpdateUserInfo 更新用户信息
|
||||
func UpdateUserInfo(user *model.User) error {
|
||||
return repository.UpdateUser(user)
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserPassword 修改密码
|
||||
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
|
||||
// 获取用户
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !auth.CheckPassword(user.Password, oldPassword) {
|
||||
return errors.New("原密码错误")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserPassword 重置密码(通过邮箱)
|
||||
func ResetUserPassword(email, newPassword string) error {
|
||||
// 查找用户
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserEmail 更换邮箱
|
||||
func ChangeUserEmail(userID int64, newEmail string) error {
|
||||
// 检查新邮箱是否已被使用
|
||||
existingUser, err := repository.FindUserByEmail(newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil && existingUser.ID != userID {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
// 更新邮箱
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// logSuccessLogin 记录成功登录
|
||||
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// logFailedLogin 记录失败登录
|
||||
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// getDefaultAvatar 获取默认头像URL
|
||||
func getDefaultAvatar() string {
|
||||
// 如果数据库中不存在默认头像配置,返回错误信息
|
||||
const log = "数据库中不存在默认头像配置"
|
||||
|
||||
// 尝试从数据库读取配置
|
||||
config, err := repository.GetSystemConfigByKey("default_avatar")
|
||||
if err != nil || config == nil {
|
||||
return log
|
||||
}
|
||||
|
||||
return config.Value
|
||||
}
|
||||
|
||||
func GetUserByEmail(email string) (*model.User, error) {
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, errors.New("邮箱查找失败")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
199
internal/service/user_service_test.go
Normal file
199
internal/service/user_service_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetDefaultAvatar 测试获取默认头像的逻辑
|
||||
// 注意:这个测试需要mock repository,但由于repository是函数式的,
|
||||
// 我们只测试逻辑部分
|
||||
func TestGetDefaultAvatar_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configExists bool
|
||||
configValue string
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "配置存在时返回配置值",
|
||||
configExists: true,
|
||||
configValue: "https://example.com/avatar.png",
|
||||
expectedResult: "https://example.com/avatar.png",
|
||||
},
|
||||
{
|
||||
name: "配置不存在时返回错误信息",
|
||||
configExists: false,
|
||||
configValue: "",
|
||||
expectedResult: "数据库中不存在默认头像配置",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 这个测试只验证逻辑,不实际调用repository
|
||||
// 实际的repository调用测试需要集成测试或mock
|
||||
if tt.configExists {
|
||||
if tt.expectedResult != tt.configValue {
|
||||
t.Errorf("当配置存在时,应该返回配置值")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") {
|
||||
t.Errorf("当配置不存在时,应该返回错误信息")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑
|
||||
func TestLoginUser_EmailDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
usernameOrEmail string
|
||||
isEmail bool
|
||||
}{
|
||||
{
|
||||
name: "包含@符号,识别为邮箱",
|
||||
usernameOrEmail: "user@example.com",
|
||||
isEmail: true,
|
||||
},
|
||||
{
|
||||
name: "不包含@符号,识别为用户名",
|
||||
usernameOrEmail: "username",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
usernameOrEmail: "",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "只有@符号",
|
||||
usernameOrEmail: "@",
|
||||
isEmail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isEmail := strings.Contains(tt.usernameOrEmail, "@")
|
||||
if isEmail != tt.isEmail {
|
||||
t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Constants 测试用户服务相关常量
|
||||
func TestUserService_Constants(t *testing.T) {
|
||||
// 测试默认用户角色
|
||||
defaultRole := "user"
|
||||
if defaultRole == "" {
|
||||
t.Error("默认用户角色不能为空")
|
||||
}
|
||||
|
||||
// 测试默认用户状态
|
||||
defaultStatus := int16(1)
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认用户状态应为1(正常),实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试初始积分
|
||||
initialPoints := 0
|
||||
if initialPoints < 0 {
|
||||
t.Errorf("初始积分不应为负数,实际为%d", initialPoints)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Validation 测试用户数据验证逻辑
|
||||
func TestUserService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户名和邮箱",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱为空",
|
||||
username: "testuser",
|
||||
email: "",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱格式无效(缺少@)",
|
||||
username: "testuser",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 简单的验证逻辑测试
|
||||
isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_AvatarLogic 测试头像逻辑
|
||||
func TestUserService_AvatarLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providedAvatar string
|
||||
defaultAvatar string
|
||||
expectedAvatar string
|
||||
}{
|
||||
{
|
||||
name: "提供头像时使用提供的头像",
|
||||
providedAvatar: "https://example.com/custom.png",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/custom.png",
|
||||
},
|
||||
{
|
||||
name: "未提供头像时使用默认头像",
|
||||
providedAvatar: "",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/default.png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
avatarURL := tt.providedAvatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = tt.defaultAvatar
|
||||
}
|
||||
if avatarURL != tt.expectedAvatar {
|
||||
t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
118
internal/service/verification_service.go
Normal file
118
internal/service/verification_service.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/email"
|
||||
"carrotskin/pkg/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// 验证码类型
|
||||
VerificationTypeRegister = "register"
|
||||
VerificationTypeResetPassword = "reset_password"
|
||||
VerificationTypeChangeEmail = "change_email"
|
||||
|
||||
// 验证码配置
|
||||
CodeLength = 6 // 验证码长度
|
||||
CodeExpiration = 10 * time.Minute // 验证码有效期
|
||||
CodeRateLimit = 1 * time.Minute // 发送频率限制
|
||||
)
|
||||
|
||||
// GenerateVerificationCode 生成6位数字验证码
|
||||
func GenerateVerificationCode() (string, error) {
|
||||
const digits = "0123456789"
|
||||
code := make([]byte, CodeLength)
|
||||
for i := range code {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = digits[num.Int64()]
|
||||
}
|
||||
return string(code), nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func SendVerificationCode(ctx context.Context, redisClient *redis.Client, emailService *email.Service, email, codeType string) error {
|
||||
// 检查发送频率限制
|
||||
rateLimitKey := fmt.Sprintf("verification:rate_limit:%s:%s", codeType, email)
|
||||
exists, err := redisClient.Exists(ctx, rateLimitKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查发送频率失败: %w", err)
|
||||
}
|
||||
if exists > 0 {
|
||||
return fmt.Errorf("发送过于频繁,请稍后再试")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储验证码到Redis
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
if err := redisClient.Set(ctx, codeKey, code, CodeExpiration); err != nil {
|
||||
return fmt.Errorf("存储验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置发送频率限制
|
||||
if err := redisClient.Set(ctx, rateLimitKey, "1", CodeRateLimit); err != nil {
|
||||
return fmt.Errorf("设置发送频率限制失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送邮件
|
||||
if err := sendVerificationEmail(emailService, email, code, codeType); err != nil {
|
||||
// 发送失败,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func VerifyCode(ctx context.Context, redisClient *redis.Client, email, code, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
|
||||
// 从Redis获取验证码
|
||||
storedCode, err := redisClient.Get(ctx, codeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("验证码已过期或不存在")
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if storedCode != code {
|
||||
return fmt.Errorf("验证码错误")
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVerificationCode 删除验证码
|
||||
func DeleteVerificationCode(ctx context.Context, redisClient *redis.Client, email, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
return redisClient.Del(ctx, codeKey)
|
||||
}
|
||||
|
||||
// sendVerificationEmail 根据类型发送邮件
|
||||
func sendVerificationEmail(emailService *email.Service, to, code, codeType string) error {
|
||||
switch codeType {
|
||||
case VerificationTypeRegister:
|
||||
return emailService.SendEmailVerification(to, code)
|
||||
case VerificationTypeResetPassword:
|
||||
return emailService.SendResetPassword(to, code)
|
||||
case VerificationTypeChangeEmail:
|
||||
return emailService.SendChangeEmail(to, code)
|
||||
default:
|
||||
return emailService.SendVerificationCode(to, code, codeType)
|
||||
}
|
||||
}
|
||||
119
internal/service/verification_service_test.go
Normal file
119
internal/service/verification_service_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGenerateVerificationCode 测试生成验证码函数
|
||||
func TestGenerateVerificationCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "生成6位验证码",
|
||||
wantLen: CodeLength,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateVerificationCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(code) != tt.wantLen {
|
||||
t.Errorf("GenerateVerificationCode() code length = %v, want %v", len(code), tt.wantLen)
|
||||
}
|
||||
// 验证验证码只包含数字
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("GenerateVerificationCode() code contains non-digit: %c", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 测试多次生成,验证码应该不同(概率上)
|
||||
codes := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
if codes[code] {
|
||||
t.Logf("发现重复验证码(这是正常的,因为只有6位数字): %s", code)
|
||||
}
|
||||
codes[code] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationConstants 测试验证码相关常量
|
||||
func TestVerificationConstants(t *testing.T) {
|
||||
if CodeLength != 6 {
|
||||
t.Errorf("CodeLength = %d, want 6", CodeLength)
|
||||
}
|
||||
|
||||
if CodeExpiration != 10*time.Minute {
|
||||
t.Errorf("CodeExpiration = %v, want 10 minutes", CodeExpiration)
|
||||
}
|
||||
|
||||
if CodeRateLimit != 1*time.Minute {
|
||||
t.Errorf("CodeRateLimit = %v, want 1 minute", CodeRateLimit)
|
||||
}
|
||||
|
||||
// 验证验证码类型常量
|
||||
types := []string{
|
||||
VerificationTypeRegister,
|
||||
VerificationTypeResetPassword,
|
||||
VerificationTypeChangeEmail,
|
||||
}
|
||||
|
||||
for _, vType := range types {
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationCodeFormat 测试验证码格式
|
||||
func TestVerificationCodeFormat(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证长度
|
||||
if len(code) != 6 {
|
||||
t.Errorf("验证码长度应为6位,实际为%d位", len(code))
|
||||
}
|
||||
|
||||
// 验证只包含数字
|
||||
for i, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("验证码第%d位包含非数字字符: %c", i+1, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationTypes 测试验证码类型
|
||||
func TestVerificationTypes(t *testing.T) {
|
||||
validTypes := map[string]bool{
|
||||
VerificationTypeRegister: true,
|
||||
VerificationTypeResetPassword: true,
|
||||
VerificationTypeChangeEmail: true,
|
||||
}
|
||||
|
||||
for vType, isValid := range validTypes {
|
||||
if !isValid {
|
||||
t.Errorf("验证码类型 %s 应该是有效的", vType)
|
||||
}
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空字符串")
|
||||
}
|
||||
}
|
||||
}
|
||||
201
internal/service/yggdrasil_service.go
Normal file
201
internal/service/yggdrasil_service.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionKeyPrefix Redis会话键前缀
|
||||
const SessionKeyPrefix = "Join_"
|
||||
|
||||
// SessionTTL 会话超时时间 - 增加到15分钟
|
||||
const SessionTTL = 15 * time.Minute
|
||||
|
||||
type SessionData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
UserName string `json:"userName"`
|
||||
SelectedProfile string `json:"selectedProfile"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
// GetUserIDByEmail 根据邮箱返回用户id
|
||||
func GetUserIDByEmail(db *gorm.DB, Identifier string) (int64, error) {
|
||||
user, err := repository.FindUserByEmail(Identifier)
|
||||
if err != nil {
|
||||
return 0, errors.New("用户不存在")
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// GetProfileByProfileName 根据用户名返回用户id
|
||||
func GetProfileByProfileName(db *gorm.DB, Identifier string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByName(Identifier)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户角色未创建")
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// VerifyPassword 验证密码是否一致
|
||||
func VerifyPassword(db *gorm.DB, password string, Id int64) error {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(Id)
|
||||
if err != nil {
|
||||
return errors.New("未生成密码")
|
||||
}
|
||||
if passwordStore != password {
|
||||
return errors.New("密码错误")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetProfileByUserId(db *gorm.DB, userId int64) (*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
if len(profiles) == 0 {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
return profiles[0], nil
|
||||
}
|
||||
|
||||
func GetPasswordByUserId(db *gorm.DB, userId int64) (string, error) {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(userId)
|
||||
if err != nil {
|
||||
return "", errors.New("yggdrasil密码查找失败")
|
||||
}
|
||||
return passwordStore, nil
|
||||
}
|
||||
|
||||
// JoinServer 记录玩家加入服务器的会话信息
|
||||
func JoinServer(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, serverId, accessToken, selectedProfile, ip string) error {
|
||||
// 输入验证
|
||||
if serverId == "" || accessToken == "" || selectedProfile == "" {
|
||||
return errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 验证serverId格式,防止注入攻击
|
||||
if len(serverId) > 100 || strings.ContainsAny(serverId, "<>\"'&") {
|
||||
return errors.New("服务器ID格式无效")
|
||||
}
|
||||
|
||||
// 验证IP格式
|
||||
if ip != "" {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("IP地址格式无效")
|
||||
}
|
||||
}
|
||||
|
||||
// 获取和验证Token
|
||||
token, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"验证Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return fmt.Errorf("验证Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 格式化UUID并验证与Token关联的配置文件
|
||||
formattedProfile := utils.FormatUUID(selectedProfile)
|
||||
if token.ProfileId != formattedProfile {
|
||||
return errors.New("selectedProfile与Token不匹配")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(formattedProfile)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"获取Profile失败",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", formattedProfile),
|
||||
)
|
||||
return fmt.Errorf("获取Profile失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建会话数据
|
||||
data := SessionData{
|
||||
AccessToken: accessToken,
|
||||
UserName: profile.Name,
|
||||
SelectedProfile: formattedProfile,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
// 序列化会话数据
|
||||
marshaledData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]序列化会话数据失败",
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储会话数据到Redis
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
ctx := context.Background()
|
||||
if err = redisClient.Set(ctx, sessionKey, marshaledData, SessionTTL); err != nil {
|
||||
logger.Error(
|
||||
"保存会话数据失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return fmt.Errorf("保存会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"玩家成功加入服务器",
|
||||
zap.String("username", profile.Name),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasJoinedServer 验证玩家是否已经加入了服务器
|
||||
func HasJoinedServer(logger *zap.Logger, redisClient *redis.Client, serverId, username, ip string) error {
|
||||
if serverId == "" || username == "" {
|
||||
return errors.New("服务器ID和用户名不能为空")
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 从Redis获取会话数据
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
data, err := redisClient.GetBytes(ctx, sessionKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取会话数据失败:", zap.Error(err), zap.Any("serverId:", serverId))
|
||||
return fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 反序列化会话数据
|
||||
var sessionData SessionData
|
||||
if err = json.Unmarshal(data, &sessionData); err != nil {
|
||||
logger.Error("[ERROR] 解析会话数据失败: ", zap.Error(err))
|
||||
return fmt.Errorf("解析会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户名
|
||||
if sessionData.UserName != username {
|
||||
return errors.New("用户名不匹配")
|
||||
}
|
||||
|
||||
// 验证IP(如果提供)
|
||||
if ip != "" && sessionData.IP != ip {
|
||||
return errors.New("IP地址不匹配")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
350
internal/service/yggdrasil_service_test.go
Normal file
350
internal/service/yggdrasil_service_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
|
||||
func TestYggdrasilService_Constants(t *testing.T) {
|
||||
if SessionKeyPrefix != "Join_" {
|
||||
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
|
||||
}
|
||||
|
||||
if SessionTTL != 15*time.Minute {
|
||||
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionData_Structure 测试SessionData结构
|
||||
func TestSessionData_Structure(t *testing.T) {
|
||||
data := SessionData{
|
||||
AccessToken: "test-token",
|
||||
UserName: "TestUser",
|
||||
SelectedProfile: "test-profile-uuid",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
if data.AccessToken == "" {
|
||||
t.Error("AccessToken should not be empty")
|
||||
}
|
||||
|
||||
if data.UserName == "" {
|
||||
t.Error("UserName should not be empty")
|
||||
}
|
||||
|
||||
if data.SelectedProfile == "" {
|
||||
t.Error("SelectedProfile should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
|
||||
func TestJoinServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
accessToken string
|
||||
selectedProfile string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server-123",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "accessToken为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "selectedProfile为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
|
||||
func TestJoinServer_ServerIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的serverId",
|
||||
serverId: "test-server-123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "serverId过长",
|
||||
serverId: strings.Repeat("a", 101),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符<",
|
||||
serverId: "test<server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符>",
|
||||
serverId: "test>server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符\"",
|
||||
serverId: "test\"server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符'",
|
||||
serverId: "test'server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符&",
|
||||
serverId: "test&server",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_IPValidation 测试IP地址验证逻辑
|
||||
func TestJoinServer_IPValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的IPv4地址",
|
||||
ip: "127.0.0.1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效的IPv6地址",
|
||||
ip: "::1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效的IP地址",
|
||||
ip: "invalid-ip",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空IP地址(可选)",
|
||||
ip: "",
|
||||
wantValid: true, // 空IP是允许的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var isValid bool
|
||||
if tt.ip == "" {
|
||||
isValid = true // 空IP是允许的
|
||||
} else {
|
||||
isValid = net.ParseIP(tt.ip) != nil
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
|
||||
func TestHasJoinedServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
username string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server",
|
||||
username: "TestUser",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
username: "TestUser",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "username为空",
|
||||
serverId: "test-server",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都为空",
|
||||
serverId: "",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.username == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
|
||||
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionUser string
|
||||
requestUser string
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "用户名匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "TestUser",
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "用户名不匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "OtherUser",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "大小写敏感",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "testuser",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := tt.sessionUser == tt.requestUser
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
|
||||
func TestHasJoinedServer_IPMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionIP string
|
||||
requestIP string
|
||||
wantMatch bool
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "IP匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "127.0.0.1",
|
||||
wantMatch: true,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "IP不匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "192.168.1.1",
|
||||
wantMatch: false,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "请求IP为空时不检查",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "",
|
||||
wantMatch: true,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var matches bool
|
||||
if tt.requestIP == "" {
|
||||
matches = true // 空IP不检查
|
||||
} else {
|
||||
matches = tt.sessionIP == tt.requestIP
|
||||
}
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_SessionKey 测试会话键生成
|
||||
func TestJoinServer_SessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的会话键",
|
||||
serverId: "test-server-123",
|
||||
expected: "Join_test-server-123",
|
||||
},
|
||||
{
|
||||
name: "空serverId",
|
||||
serverId: "",
|
||||
expected: "Join_",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionKey := SessionKeyPrefix + tt.serverId
|
||||
if sessionKey != tt.expected {
|
||||
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user