chore: 初始化仓库,排除二进制文件和覆盖率文件
This commit is contained in:
249
internal/handler/auth_handler.go
Normal file
249
internal/handler/auth_handler.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/email"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 注册新用户账号
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.RegisterRequest true "注册信息"
|
||||
// @Success 200 {object} model.Response "注册成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/register [post]
|
||||
func Register(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
jwtService := auth.MustGetJWTService()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var req types.RegisterRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邮箱验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用service层注册用户(传递可选的头像URL)
|
||||
user, token, err := service.RegisterUser(jwtService, req.Username, req.Password, req.Email, req.Avatar)
|
||||
if err != nil {
|
||||
loggerInstance.Error("用户注册失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
// @Summary 用户登录
|
||||
// @Description 用户登录获取JWT Token,支持用户名或邮箱登录
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)"
|
||||
// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "登录失败"
|
||||
// @Router /api/v1/auth/login [post]
|
||||
func Login(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
var req types.LoginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取IP和UserAgent
|
||||
ipAddress := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
// 调用service层登录
|
||||
user, token, err := service.LoginUser(jwtService, req.Username, req.Password, ipAddress, userAgent)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("用户登录失败",
|
||||
zap.String("username_or_email", req.Username),
|
||||
zap.String("ip", ipAddress),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.LoginResponse{
|
||||
Token: token,
|
||||
UserInfo: &types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
// @Summary 发送验证码
|
||||
// @Description 发送邮箱验证码(注册/重置密码/更换邮箱)
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.SendVerificationCodeRequest true "发送验证码请求"
|
||||
// @Success 200 {object} model.Response "发送成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/send-code [post]
|
||||
func SendVerificationCode(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
emailService := email.MustGetService()
|
||||
|
||||
var req types.SendVerificationCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 发送验证码
|
||||
if err := service.SendVerificationCode(c.Request.Context(), redisClient, emailService, req.Email, req.Type); err != nil {
|
||||
loggerInstance.Error("发送验证码失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.String("type", req.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "验证码已发送,请查收邮件",
|
||||
}))
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// @Summary 重置密码
|
||||
// @Description 通过邮箱验证码重置密码
|
||||
// @Tags auth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body types.ResetPasswordRequest true "重置密码请求"
|
||||
// @Success 200 {object} model.Response "重置成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/auth/reset-password [post]
|
||||
func ResetPassword(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var req types.ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 重置密码
|
||||
if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil {
|
||||
loggerInstance.Error("重置密码失败",
|
||||
zap.String("email", req.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "密码重置成功",
|
||||
}))
|
||||
}
|
||||
155
internal/handler/auth_handler_test.go
Normal file
155
internal/handler/auth_handler_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_RequestValidation 测试认证请求验证逻辑
|
||||
func TestAuthHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
code string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的注册请求",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "123456",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效的登录请求",
|
||||
username: "testuser",
|
||||
email: "",
|
||||
password: "password123",
|
||||
code: "",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
code: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "注册时验证码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
code: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求参数逻辑
|
||||
isValid := tt.username != "" && tt.password != ""
|
||||
// 如果是注册请求,还需要验证码
|
||||
if tt.email != "" && tt.code == "" {
|
||||
isValid = false
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestAuthHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "binding",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "验证码错误",
|
||||
errType: "verification",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "登录失败",
|
||||
errType: "login",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "注册失败",
|
||||
errType: "register",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_ResponseFormat 测试响应格式逻辑
|
||||
func TestAuthHandler_ResponseFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
success bool
|
||||
wantCode int
|
||||
hasToken bool
|
||||
}{
|
||||
{
|
||||
name: "注册成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: true,
|
||||
},
|
||||
{
|
||||
name: "登录成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: true,
|
||||
},
|
||||
{
|
||||
name: "发送验证码成功",
|
||||
success: true,
|
||||
wantCode: 200,
|
||||
hasToken: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证响应格式逻辑
|
||||
if tt.success && tt.wantCode != 200 {
|
||||
t.Errorf("Success response should have code 200, got %d", tt.wantCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
76
internal/handler/captcha_handler.go
Normal file
76
internal/handler/captcha_handler.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/redis"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Generate 生成验证码
|
||||
func Generate(c *gin.Context) {
|
||||
// 调用验证码服务生成验证码数据
|
||||
redisClient := redis.MustGetClient()
|
||||
masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), redisClient)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "生成验证码失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回验证码数据给前端
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"masterImage": masterImg, // 主图(base64格式)
|
||||
"tileImage": tileImg, // 滑块图(base64格式)
|
||||
"captchaId": captchaID, // 验证码唯一标识(用于后续验证)
|
||||
"y": y, // 滑块Y坐标(前端可用于定位滑块初始位置)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Verify 验证验证码
|
||||
func Verify(c *gin.Context) {
|
||||
// 定义请求参数结构体
|
||||
var req struct {
|
||||
CaptchaID string `json:"captchaId" binding:"required"` // 验证码唯一标识
|
||||
Dx int `json:"dx" binding:"required"` // 用户滑动的X轴偏移量
|
||||
}
|
||||
|
||||
// 解析并校验请求参数
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"msg": "参数错误: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用验证码服务验证偏移量
|
||||
redisClient := redis.MustGetClient()
|
||||
valid, err := service.VerifyCaptchaData(c.Request.Context(), redisClient, req.Dx, req.CaptchaID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"msg": "验证失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 根据验证结果返回响应
|
||||
if valid {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"msg": "验证成功",
|
||||
})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 400,
|
||||
"msg": "验证失败,请重试",
|
||||
})
|
||||
}
|
||||
}
|
||||
133
internal/handler/captcha_handler_test.go
Normal file
133
internal/handler/captcha_handler_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCaptchaHandler_RequestValidation 测试验证码请求验证逻辑
|
||||
func TestCaptchaHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
captchaID string
|
||||
dx int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的请求参数",
|
||||
captchaID: "captcha-123",
|
||||
dx: 100,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "captchaID为空",
|
||||
captchaID: "",
|
||||
dx: 100,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "dx为0(可能有效)",
|
||||
captchaID: "captcha-123",
|
||||
dx: 0,
|
||||
wantValid: true, // dx为0也可能是有效的(用户没有滑动)
|
||||
},
|
||||
{
|
||||
name: "dx为负数(可能无效)",
|
||||
captchaID: "captcha-123",
|
||||
dx: -10,
|
||||
wantValid: true, // 负数也可能是有效的,取决于业务逻辑
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.captchaID != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptchaHandler_ResponseFormat 测试响应格式逻辑
|
||||
func TestCaptchaHandler_ResponseFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
valid bool
|
||||
wantCode int
|
||||
wantStatus string
|
||||
}{
|
||||
{
|
||||
name: "验证成功",
|
||||
valid: true,
|
||||
wantCode: 200,
|
||||
wantStatus: "验证成功",
|
||||
},
|
||||
{
|
||||
name: "验证失败",
|
||||
valid: false,
|
||||
wantCode: 400,
|
||||
wantStatus: "验证失败,请重试",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证响应格式逻辑
|
||||
var code int
|
||||
var status string
|
||||
if tt.valid {
|
||||
code = 200
|
||||
status = "验证成功"
|
||||
} else {
|
||||
code = 400
|
||||
status = "验证失败,请重试"
|
||||
}
|
||||
|
||||
if code != tt.wantCode {
|
||||
t.Errorf("Response code = %d, want %d", code, tt.wantCode)
|
||||
}
|
||||
if status != tt.wantStatus {
|
||||
t.Errorf("Response status = %q, want %q", status, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCaptchaHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestCaptchaHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hasError bool
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "生成验证码失败",
|
||||
hasError: true,
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "验证验证码失败",
|
||||
hasError: true,
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
hasError: true,
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if tt.hasError && !tt.wantError {
|
||||
t.Error("Error handling logic failed")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
398
internal/handler/profile_handler.go
Normal file
398
internal/handler/profile_handler.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
// @Summary 创建Minecraft档案
|
||||
// @Description 创建新的Minecraft角色档案,UUID由后端自动生成
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)"
|
||||
// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功,返回完整档案信息(含自动生成的UUID)"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误或已达档案数量上限"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile [post]
|
||||
func CreateProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req types.CreateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误: "+err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从配置或数据库读取限制
|
||||
maxProfiles := 5
|
||||
db := database.MustGetDB()
|
||||
// 检查档案数量限制
|
||||
if err := service.CheckProfileLimit(db, userID.(int64), maxProfiles); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建档案
|
||||
profile, err := service.CreateProfile(db, userID.(int64), req.Name)
|
||||
if err != nil {
|
||||
loggerInstance.Error("创建档案失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetProfiles 获取档案列表
|
||||
// @Summary 获取档案列表
|
||||
// @Description 获取当前用户的所有档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile [get]
|
||||
func GetProfiles(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 查询档案列表
|
||||
profiles, err := service.GetUserProfiles(database.MustGetDB(), userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("获取档案列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为响应格式
|
||||
result := make([]*types.ProfileInfo, 0, len(profiles))
|
||||
for _, profile := range profiles {
|
||||
result = append(result, &types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(result))
|
||||
}
|
||||
|
||||
// GetProfile 获取档案详情
|
||||
// @Summary 获取档案详情
|
||||
// @Description 根据UUID获取档案详细信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [get]
|
||||
func GetProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 查询档案
|
||||
profile, err := service.GetProfileByUUID(database.MustGetDB(), uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("获取档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
// @Summary 更新档案
|
||||
// @Description 更新档案信息
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Param request body types.UpdateProfileRequest true "更新信息"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [put]
|
||||
func UpdateProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req types.UpdateProfileRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误: "+err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新档案
|
||||
var namePtr *string
|
||||
if req.Name != "" {
|
||||
namePtr = &req.Name
|
||||
}
|
||||
|
||||
profile, err := service.UpdateProfile(database.MustGetDB(), uuid, userID.(int64), namePtr, req.SkinID, req.CapeID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("更新档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.ProfileInfo{
|
||||
UUID: profile.UUID,
|
||||
UserID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
SkinID: profile.SkinID,
|
||||
CapeID: profile.CapeID,
|
||||
IsActive: profile.IsActive,
|
||||
LastUsedAt: profile.LastUsedAt,
|
||||
CreatedAt: profile.CreatedAt,
|
||||
UpdatedAt: profile.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
// @Summary 删除档案
|
||||
// @Description 删除指定的Minecraft档案
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "删除成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid} [delete]
|
||||
func DeleteProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 删除档案
|
||||
err := service.DeleteProfile(database.MustGetDB(), uuid, userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("删除档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "删除成功",
|
||||
}))
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置活跃档案
|
||||
// @Summary 设置活跃档案
|
||||
// @Description 将指定档案设置为活跃状态
|
||||
// @Tags profile
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param uuid path string true "档案UUID"
|
||||
// @Success 200 {object} model.Response "设置成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Failure 404 {object} model.ErrorResponse "档案不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/profile/{uuid}/activate [post]
|
||||
func SetActiveProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
uuid := c.Param("uuid")
|
||||
|
||||
// 获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"未授权",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置活跃状态
|
||||
err := service.SetActiveProfile(database.MustGetDB(), uuid, userID.(int64))
|
||||
if err != nil {
|
||||
loggerInstance.Error("设置活跃档案失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
statusCode := http.StatusInternalServerError
|
||||
if err.Error() == "档案不存在" {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if err.Error() == "无权操作此档案" {
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
c.JSON(statusCode, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(gin.H{
|
||||
"message": "设置成功",
|
||||
}))
|
||||
}
|
||||
151
internal/handler/profile_handler_test.go
Normal file
151
internal/handler/profile_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileHandler_PermissionCheck 测试权限检查逻辑
|
||||
func TestProfileHandler_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID interface{}
|
||||
exists bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID",
|
||||
userID: int64(1),
|
||||
exists: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不存在",
|
||||
userID: nil,
|
||||
exists: false,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证权限检查逻辑
|
||||
isValid := tt.exists
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Permission check failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestProfileHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的档案名",
|
||||
profileName: "PlayerName",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "档案名为空",
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "档案名长度超过16",
|
||||
profileName: "ThisIsAVeryLongPlayerName",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求逻辑:档案名长度应该在1-16之间
|
||||
isValid := tt.profileName != "" && len(tt.profileName) >= 1 && len(tt.profileName) <= 16
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_LimitCheck 测试限制检查逻辑
|
||||
func TestProfileHandler_LimitCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentCount int
|
||||
maxCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未达到限制",
|
||||
currentCount: 3,
|
||||
maxCount: 5,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "达到限制",
|
||||
currentCount: 5,
|
||||
maxCount: 5,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "超过限制",
|
||||
currentCount: 6,
|
||||
maxCount: 5,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证限制检查逻辑
|
||||
hasError := tt.currentCount >= tt.maxCount
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("Limit check failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestProfileHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "unauthorized",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
139
internal/handler/routes.go
Normal file
139
internal/handler/routes.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/middleware"
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterRoutes 注册所有路由
|
||||
func RegisterRoutes(router *gin.Engine) {
|
||||
// 设置Swagger文档
|
||||
SetupSwagger(router)
|
||||
|
||||
// API路由组
|
||||
v1 := router.Group("/api/v1")
|
||||
{
|
||||
// 认证路由(无需JWT)
|
||||
authGroup := v1.Group("/auth")
|
||||
{
|
||||
authGroup.POST("/register", Register)
|
||||
authGroup.POST("/login", Login)
|
||||
authGroup.POST("/send-code", SendVerificationCode)
|
||||
authGroup.POST("/reset-password", ResetPassword)
|
||||
}
|
||||
|
||||
// 用户路由(需要JWT认证)
|
||||
userGroup := v1.Group("/user")
|
||||
userGroup.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
userGroup.GET("/profile", GetUserProfile)
|
||||
userGroup.PUT("/profile", UpdateUserProfile)
|
||||
|
||||
// 头像相关
|
||||
userGroup.POST("/avatar/upload-url", GenerateAvatarUploadURL)
|
||||
userGroup.PUT("/avatar", UpdateAvatar)
|
||||
|
||||
// 更换邮箱
|
||||
userGroup.POST("/change-email", ChangeEmail)
|
||||
}
|
||||
|
||||
// 材质路由
|
||||
textureGroup := v1.Group("/texture")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
textureGroup.GET("", SearchTextures) // 搜索材质
|
||||
textureGroup.GET("/:id", GetTexture) // 获取材质详情
|
||||
|
||||
// 需要认证的路由
|
||||
textureAuth := textureGroup.Group("")
|
||||
textureAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
textureAuth.POST("/upload-url", GenerateTextureUploadURL) // 生成上传URL
|
||||
textureAuth.POST("", CreateTexture) // 创建材质记录
|
||||
textureAuth.PUT("/:id", UpdateTexture) // 更新材质
|
||||
textureAuth.DELETE("/:id", DeleteTexture) // 删除材质
|
||||
textureAuth.POST("/:id/favorite", ToggleFavorite) // 切换收藏
|
||||
textureAuth.GET("/my", GetUserTextures) // 我的材质
|
||||
textureAuth.GET("/favorites", GetUserFavorites) // 我的收藏
|
||||
}
|
||||
}
|
||||
|
||||
// 档案路由
|
||||
profileGroup := v1.Group("/profile")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
profileGroup.GET("/:uuid", GetProfile) // 获取档案详情
|
||||
|
||||
// 需要认证的路由
|
||||
profileAuth := profileGroup.Group("")
|
||||
profileAuth.Use(middleware.AuthMiddleware())
|
||||
{
|
||||
profileAuth.POST("/", CreateProfile) // 创建档案
|
||||
profileAuth.GET("/", GetProfiles) // 获取我的档案列表
|
||||
profileAuth.PUT("/:uuid", UpdateProfile) // 更新档案
|
||||
profileAuth.DELETE("/:uuid", DeleteProfile) // 删除档案
|
||||
profileAuth.POST("/:uuid/activate", SetActiveProfile) // 设置活跃档案
|
||||
}
|
||||
}
|
||||
// 验证码路由
|
||||
captchaGroup := v1.Group("/captcha")
|
||||
{
|
||||
captchaGroup.GET("/generate", Generate) //生成验证码
|
||||
captchaGroup.POST("/verify", Verify) //验证验证码
|
||||
}
|
||||
|
||||
// Yggdrasil API路由组
|
||||
ygg := v1.Group("/yggdrasil")
|
||||
{
|
||||
ygg.GET("", GetMetaData)
|
||||
ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates)
|
||||
authserver := ygg.Group("/authserver")
|
||||
{
|
||||
authserver.POST("/authenticate", Authenticate)
|
||||
authserver.POST("/validate", ValidToken)
|
||||
authserver.POST("/refresh", RefreshToken)
|
||||
authserver.POST("/invalidate", InvalidToken)
|
||||
authserver.POST("/signout", SignOut)
|
||||
}
|
||||
sessionServer := ygg.Group("/sessionserver")
|
||||
{
|
||||
sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID)
|
||||
sessionServer.POST("/session/minecraft/join", JoinServer)
|
||||
sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer)
|
||||
}
|
||||
api := ygg.Group("/api")
|
||||
profiles := api.Group("/profiles")
|
||||
{
|
||||
profiles.POST("/minecraft", GetProfilesByName)
|
||||
}
|
||||
}
|
||||
// 系统路由
|
||||
system := v1.Group("/system")
|
||||
{
|
||||
system.GET("/config", GetSystemConfig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 以下是系统配置相关的占位符函数,待后续实现
|
||||
|
||||
// GetSystemConfig 获取系统配置
|
||||
// @Summary 获取系统配置
|
||||
// @Description 获取公开的系统配置信息
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Router /api/v1/system/config [get]
|
||||
func GetSystemConfig(c *gin.Context) {
|
||||
// TODO: 实现从数据库读取系统配置
|
||||
c.JSON(200, model.NewSuccessResponse(gin.H{
|
||||
"site_name": "CarrotSkin",
|
||||
"site_description": "A Minecraft Skin Station",
|
||||
"registration_enabled": true,
|
||||
"max_textures_per_user": 100,
|
||||
"max_profiles_per_user": 5,
|
||||
}))
|
||||
}
|
||||
62
internal/handler/swagger.go
Normal file
62
internal/handler/swagger.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
)
|
||||
|
||||
// @title CarrotSkin API
|
||||
// @version 1.0
|
||||
// @description CarrotSkin 是一个优秀的 Minecraft 皮肤站 API 服务
|
||||
// @description
|
||||
// @description ## 功能特性
|
||||
// @description - 用户注册/登录/管理
|
||||
// @description - 材质上传/下载/管理
|
||||
// @description - Minecraft 档案管理
|
||||
// @description - 权限控制系统
|
||||
// @description - 积分系统
|
||||
// @description
|
||||
// @description ## 认证方式
|
||||
// @description 使用 JWT Token 进行身份认证,需要在请求头中包含:
|
||||
// @description ```
|
||||
// @description Authorization: Bearer <your-jwt-token>
|
||||
// @description ```
|
||||
|
||||
// @contact.name CarrotSkin Team
|
||||
// @contact.email support@carrotskin.com
|
||||
// @license.name MIT
|
||||
// @license.url https://opensource.org/licenses/MIT
|
||||
|
||||
// @host localhost:8080
|
||||
// @BasePath /api/v1
|
||||
|
||||
// @securityDefinitions.apikey BearerAuth
|
||||
// @in header
|
||||
// @name Authorization
|
||||
// @description Type "Bearer" followed by a space and JWT token.
|
||||
|
||||
func SetupSwagger(router *gin.Engine) {
|
||||
// Swagger文档路由
|
||||
router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
// 健康检查接口
|
||||
router.GET("/health", HealthCheck)
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
// @Summary 健康检查
|
||||
// @Description 检查服务是否正常运行
|
||||
// @Tags system
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "成功"
|
||||
// @Router /health [get]
|
||||
func HealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"message": "CarrotSkin API is running",
|
||||
})
|
||||
}
|
||||
599
internal/handler/texture_handler.go
Normal file
599
internal/handler/texture_handler.go
Normal file
@@ -0,0 +1,599 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/storage"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GenerateTextureUploadURL 生成材质上传URL
|
||||
// @Summary 生成材质上传URL
|
||||
// @Description 生成预签名URL用于上传材质文件
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.GenerateTextureUploadURLRequest true "上传URL请求"
|
||||
// @Success 200 {object} model.Response "生成成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/texture/upload-url [post]
|
||||
func GenerateTextureUploadURL(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateTextureUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用UploadService生成预签名URL
|
||||
storageClient := storage.MustGetClient()
|
||||
cfg := *config.MustGetRustFSConfig()
|
||||
result, err := service.GenerateTextureUploadURL(
|
||||
c.Request.Context(),
|
||||
storageClient,
|
||||
cfg,
|
||||
userID.(int64),
|
||||
req.FileName,
|
||||
string(req.TextureType),
|
||||
)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("生成材质上传URL失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.String("texture_type", string(req.TextureType)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.GenerateTextureUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
TextureURL: result.FileURL,
|
||||
ExpiresIn: 900, // 15分钟 = 900秒
|
||||
}))
|
||||
}
|
||||
|
||||
// CreateTexture 创建材质记录
|
||||
// @Summary 创建材质记录
|
||||
// @Description 文件上传完成后,创建材质记录到数据库
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.CreateTextureRequest true "创建材质请求"
|
||||
// @Success 200 {object} model.Response "创建成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/texture [post]
|
||||
func CreateTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.CreateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 从配置或数据库读取限制
|
||||
maxTextures := 100
|
||||
if err := service.CheckTextureUploadLimit(database.MustGetDB(), userID.(int64), maxTextures); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture, err := service.CreateTexture(database.MustGetDB(),
|
||||
userID.(int64),
|
||||
req.Name,
|
||||
req.Description,
|
||||
string(req.Type),
|
||||
req.URL,
|
||||
req.Hash,
|
||||
req.Size,
|
||||
req.IsPublic,
|
||||
req.IsSlim,
|
||||
)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("创建材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("name", req.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetTexture 获取材质详情
|
||||
// @Summary 获取材质详情
|
||||
// @Description 根据ID获取材质详细信息
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 404 {object} model.ErrorResponse "材质不存在"
|
||||
// @Router /api/v1/texture/{id} [get]
|
||||
func GetTexture(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.GetTextureByID(database.MustGetDB(), id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
// @Summary 搜索材质
|
||||
// @Description 根据关键词和类型搜索材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param keyword query string false "关键词"
|
||||
// @Param type query string false "材质类型(SKIN/CAPE)"
|
||||
// @Param public_only query bool false "只看公开材质"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "搜索成功"
|
||||
// @Router /api/v1/texture [get]
|
||||
func SearchTextures(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
textureTypeStr := c.Query("type")
|
||||
publicOnly := c.Query("public_only") == "true"
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
var textureType model.TextureType
|
||||
switch textureTypeStr {
|
||||
case "SKIN":
|
||||
textureType = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureType = model.TextureTypeCape
|
||||
}
|
||||
|
||||
textures, total, err := service.SearchTextures(database.MustGetDB(), keyword, textureType, publicOnly, page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("搜索材质失败",
|
||||
zap.String("keyword", keyword),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"搜索材质失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
// @Summary 更新材质
|
||||
// @Description 更新材质信息(仅上传者可操作)
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Param request body types.UpdateTextureRequest true "更新材质请求"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/texture/{id} [put]
|
||||
func UpdateTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateTextureRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
texture, err := service.UpdateTexture(database.MustGetDB(), textureID, userID.(int64), req.Name, req.Description, req.IsPublic)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("更新材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusForbidden, model.NewErrorResponse(
|
||||
model.CodeForbidden,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质
|
||||
// @Summary 删除材质
|
||||
// @Description 删除材质(软删除,仅上传者可操作)
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "删除成功"
|
||||
// @Failure 403 {object} model.ErrorResponse "无权操作"
|
||||
// @Router /api/v1/texture/{id} [delete]
|
||||
func DeleteTexture(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.DeleteTexture(database.MustGetDB(), textureID, userID.(int64)); err != nil {
|
||||
logger.MustGetLogger().Error("删除材质失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusForbidden, model.NewErrorResponse(
|
||||
model.CodeForbidden,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(nil))
|
||||
}
|
||||
|
||||
// ToggleFavorite 切换收藏状态
|
||||
// @Summary 切换收藏状态
|
||||
// @Description 收藏或取消收藏材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param id path int true "材质ID"
|
||||
// @Success 200 {object} model.Response "切换成功"
|
||||
// @Router /api/v1/texture/{id}/favorite [post]
|
||||
func ToggleFavorite(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
idStr := c.Param("id")
|
||||
textureID, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"无效的材质ID",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
isFavorited, err := service.ToggleTextureFavorite(database.MustGetDB(), userID.(int64), textureID)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("切换收藏状态失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Int64("texture_id", textureID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(map[string]bool{
|
||||
"is_favorited": isFavorited,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
// @Summary 获取用户上传的材质列表
|
||||
// @Description 获取当前用户上传的所有材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "获取成功"
|
||||
// @Router /api/v1/texture/my [get]
|
||||
func GetUserTextures(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
textures, total, err := service.GetUserTextures(database.MustGetDB(), userID.(int64), page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("获取用户材质列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"获取材质列表失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
|
||||
// GetUserFavorites 获取用户收藏的材质列表
|
||||
// @Summary 获取用户收藏的材质列表
|
||||
// @Description 获取当前用户收藏的所有材质
|
||||
// @Tags texture
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} model.PaginationResponse "获取成功"
|
||||
// @Router /api/v1/texture/favorites [get]
|
||||
func GetUserFavorites(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID.(int64), page, pageSize)
|
||||
if err != nil {
|
||||
logger.MustGetLogger().Error("获取用户收藏列表失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"获取收藏列表失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换为TextureInfo
|
||||
textureInfos := make([]*types.TextureInfo, len(textures))
|
||||
for i, texture := range textures {
|
||||
textureInfos[i] = &types.TextureInfo{
|
||||
ID: texture.ID,
|
||||
UploaderID: texture.UploaderID,
|
||||
Name: texture.Name,
|
||||
Description: texture.Description,
|
||||
Type: types.TextureType(texture.Type),
|
||||
URL: texture.URL,
|
||||
Hash: texture.Hash,
|
||||
Size: texture.Size,
|
||||
IsPublic: texture.IsPublic,
|
||||
DownloadCount: texture.DownloadCount,
|
||||
FavoriteCount: texture.FavoriteCount,
|
||||
IsSlim: texture.IsSlim,
|
||||
Status: texture.Status,
|
||||
CreatedAt: texture.CreatedAt,
|
||||
UpdatedAt: texture.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewPaginationResponse(textureInfos, total, page, pageSize))
|
||||
}
|
||||
415
internal/handler/user_handler.go
Normal file
415
internal/handler/user_handler.go
Normal file
@@ -0,0 +1,415 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/internal/types"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/storage"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// GetUserProfile 获取用户信息
|
||||
// @Summary 获取用户信息
|
||||
// @Description 获取当前登录用户的详细信息
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Success 200 {object} model.Response "获取成功"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Router /api/v1/user/profile [get]
|
||||
func GetUserProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
// 从上下文获取用户ID (由JWT中间件设置)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
loggerInstance.Error("获取用户信息失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateUserProfile 更新用户信息
|
||||
// @Summary 更新用户信息
|
||||
// @Description 更新当前登录用户的头像和密码(修改邮箱请使用 /change-email 接口)
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.UpdateUserRequest true "更新信息(修改密码时需同时提供old_password和new_password)"
|
||||
// @Success 200 {object} model.Response{data=types.UserInfo} "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Failure 404 {object} model.ErrorResponse "用户不存在"
|
||||
// @Failure 500 {object} model.ErrorResponse "服务器错误"
|
||||
// @Router /api/v1/user/profile [put]
|
||||
func UpdateUserProfile(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 处理密码修改
|
||||
if req.NewPassword != "" {
|
||||
// 如果提供了新密码,必须同时提供旧密码
|
||||
if req.OldPassword == "" {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"修改密码需要提供原密码",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用修改密码服务
|
||||
if err := service.ChangeUserPassword(userID.(int64), req.OldPassword, req.NewPassword); err != nil {
|
||||
loggerInstance.Error("修改密码失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("用户修改密码成功",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
)
|
||||
}
|
||||
|
||||
// 更新头像
|
||||
if req.Avatar != "" {
|
||||
user.Avatar = req.Avatar
|
||||
}
|
||||
|
||||
// 保存更新(仅当有头像修改时)
|
||||
if req.Avatar != "" {
|
||||
if err := service.UpdateUserInfo(user); err != nil {
|
||||
loggerInstance.Error("更新用户信息失败",
|
||||
zap.Int64("user_id", user.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"更新失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 重新获取更新后的用户信息
|
||||
updatedUser, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || updatedUser == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: updatedUser.ID,
|
||||
Username: updatedUser.Username,
|
||||
Email: updatedUser.Email,
|
||||
Avatar: updatedUser.Avatar,
|
||||
Points: updatedUser.Points,
|
||||
Role: updatedUser.Role,
|
||||
Status: updatedUser.Status,
|
||||
LastLoginAt: updatedUser.LastLoginAt,
|
||||
CreatedAt: updatedUser.CreatedAt,
|
||||
UpdatedAt: updatedUser.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
// @Summary 生成头像上传URL
|
||||
// @Description 生成预签名URL用于上传用户头像
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.GenerateAvatarUploadURLRequest true "文件名"
|
||||
// @Success 200 {object} model.Response "生成成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/user/avatar/upload-url [post]
|
||||
func GenerateAvatarUploadURL(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.GenerateAvatarUploadURLRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用UploadService生成预签名URL
|
||||
storageClient := storage.MustGetClient()
|
||||
cfg := *config.MustGetRustFSConfig()
|
||||
result, err := service.GenerateAvatarUploadURL(c.Request.Context(), storageClient, cfg, userID.(int64), req.FileName)
|
||||
if err != nil {
|
||||
loggerInstance.Error("生成头像上传URL失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("file_name", req.FileName),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.GenerateAvatarUploadURLResponse{
|
||||
PostURL: result.PostURL,
|
||||
FormData: result.FormData,
|
||||
AvatarURL: result.FileURL,
|
||||
ExpiresIn: 900, // 15分钟 = 900秒
|
||||
}))
|
||||
}
|
||||
|
||||
// UpdateAvatar 更新头像URL
|
||||
// @Summary 更新头像URL
|
||||
// @Description 上传完成后更新用户的头像URL到数据库
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param avatar_url query string true "头像URL"
|
||||
// @Success 200 {object} model.Response "更新成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Router /api/v1/user/avatar [put]
|
||||
func UpdateAvatar(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
avatarURL := c.Query("avatar_url")
|
||||
if avatarURL == "" {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"头像URL不能为空",
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新头像
|
||||
if err := service.UpdateUserAvatar(userID.(int64), avatarURL); err != nil {
|
||||
loggerInstance.Error("更新头像失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("avatar_url", avatarURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusInternalServerError, model.NewErrorResponse(
|
||||
model.CodeServerError,
|
||||
"更新头像失败",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 返回更新后的用户信息
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
// ChangeEmail 更换邮箱
|
||||
// @Summary 更换邮箱
|
||||
// @Description 通过验证码更换用户邮箱
|
||||
// @Tags user
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body types.ChangeEmailRequest true "更换邮箱请求"
|
||||
// @Success 200 {object} model.Response{data=types.UserInfo} "更换成功"
|
||||
// @Failure 400 {object} model.ErrorResponse "请求参数错误"
|
||||
// @Failure 401 {object} model.ErrorResponse "未授权"
|
||||
// @Router /api/v1/user/change-email [post]
|
||||
func ChangeEmail(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
model.MsgUnauthorized,
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
var req types.ChangeEmailRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
"请求参数错误",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
redisClient := redis.MustGetClient()
|
||||
if err := service.VerifyCode(c.Request.Context(), redisClient, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil {
|
||||
loggerInstance.Warn("验证码验证失败",
|
||||
zap.String("new_email", req.NewEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 更换邮箱
|
||||
if err := service.ChangeUserEmail(userID.(int64), req.NewEmail); err != nil {
|
||||
loggerInstance.Error("更换邮箱失败",
|
||||
zap.Int64("user_id", userID.(int64)),
|
||||
zap.String("new_email", req.NewEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, model.NewErrorResponse(
|
||||
model.CodeBadRequest,
|
||||
err.Error(),
|
||||
nil,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 获取更新后的用户信息
|
||||
user, err := service.GetUserByID(userID.(int64))
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, model.NewErrorResponse(
|
||||
model.CodeNotFound,
|
||||
"用户不存在",
|
||||
err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(&types.UserInfo{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
Avatar: user.Avatar,
|
||||
Points: user.Points,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}))
|
||||
}
|
||||
151
internal/handler/user_handler_test.go
Normal file
151
internal/handler/user_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUserHandler_PermissionCheck 测试权限检查逻辑
|
||||
func TestUserHandler_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID interface{}
|
||||
exists bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID",
|
||||
userID: int64(1),
|
||||
exists: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不存在",
|
||||
userID: nil,
|
||||
exists: false,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID类型错误",
|
||||
userID: "invalid",
|
||||
exists: true,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证权限检查逻辑
|
||||
isValid := tt.exists
|
||||
if tt.exists {
|
||||
// 验证类型转换
|
||||
if _, ok := tt.userID.(int64); !ok {
|
||||
isValid = false
|
||||
}
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Permission check failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestUserHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
avatar string
|
||||
oldPass string
|
||||
newPass string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "只更新头像",
|
||||
avatar: "https://example.com/avatar.png",
|
||||
oldPass: "",
|
||||
newPass: "",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "更新密码(提供旧密码和新密码)",
|
||||
avatar: "",
|
||||
oldPass: "oldpass123",
|
||||
newPass: "newpass123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "只提供新密码(无效)",
|
||||
avatar: "",
|
||||
oldPass: "",
|
||||
newPass: "newpass123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只提供旧密码(无效)",
|
||||
avatar: "",
|
||||
oldPass: "oldpass123",
|
||||
newPass: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证请求逻辑:更新密码时需要同时提供旧密码和新密码
|
||||
isValid := true
|
||||
if tt.newPass != "" && tt.oldPass == "" {
|
||||
isValid = false
|
||||
}
|
||||
if tt.oldPass != "" && tt.newPass == "" {
|
||||
isValid = false
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestUserHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "unauthorized",
|
||||
wantCode: 401,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
errType: "not_found",
|
||||
wantCode: 404,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
666
internal/handler/yggdrasil_handler.go
Normal file
666
internal/handler/yggdrasil_handler.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/service"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/logger"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ErrInternalServer = "服务器内部错误"
|
||||
// 错误类型
|
||||
ErrInvalidEmailFormat = "邮箱格式不正确"
|
||||
ErrInvalidPassword = "密码必须至少包含8个字符,只能包含字母、数字和特殊字符"
|
||||
ErrWrongPassword = "密码错误"
|
||||
ErrUserNotMatch = "用户不匹配"
|
||||
|
||||
// 错误消息
|
||||
ErrInvalidRequest = "请求格式无效"
|
||||
ErrJoinServerFailed = "加入服务器失败"
|
||||
ErrServerIDRequired = "服务器ID不能为空"
|
||||
ErrUsernameRequired = "用户名不能为空"
|
||||
ErrSessionVerifyFailed = "会话验证失败"
|
||||
ErrProfileNotFound = "未找到用户配置文件"
|
||||
ErrInvalidParams = "无效的请求参数"
|
||||
ErrEmptyUserID = "用户ID为空"
|
||||
ErrUnauthorized = "无权操作此配置文件"
|
||||
ErrGetProfileService = "获取配置文件服务失败"
|
||||
|
||||
// 成功信息
|
||||
SuccessProfileCreated = "创建成功"
|
||||
MsgRegisterSuccess = "注册成功"
|
||||
|
||||
// 错误消息
|
||||
ErrGetProfile = "获取配置文件失败"
|
||||
ErrGetTextureService = "获取材质服务失败"
|
||||
ErrInvalidContentType = "无效的请求内容类型"
|
||||
ErrParseMultipartForm = "解析多部分表单失败"
|
||||
ErrGetFileFromForm = "从表单获取文件失败"
|
||||
ErrInvalidFileType = "无效的文件类型,仅支持PNG图片"
|
||||
ErrSaveTexture = "保存材质失败"
|
||||
ErrSetTexture = "设置材质失败"
|
||||
ErrGetTexture = "获取材质失败"
|
||||
|
||||
// 内存限制
|
||||
MaxMultipartMemory = 32 << 20 // 32 MB
|
||||
|
||||
// 材质类型
|
||||
TextureTypeSkin = "SKIN"
|
||||
TextureTypeCape = "CAPE"
|
||||
|
||||
// 内容类型
|
||||
ContentTypePNG = "image/png"
|
||||
ContentTypeMultipart = "multipart/form-data"
|
||||
|
||||
// 表单参数
|
||||
FormKeyModel = "model"
|
||||
FormKeyFile = "file"
|
||||
|
||||
// 元数据键
|
||||
MetaKeyModel = "model"
|
||||
)
|
||||
|
||||
// 正则表达式
|
||||
var (
|
||||
// 邮箱正则表达式
|
||||
emailRegex = regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
// 密码强度正则表达式(最少8位,只允许字母、数字和特定特殊字符)
|
||||
passwordRegex = regexp.MustCompile(`^[a-zA-Z0-9!@#$%^&*]{8,}$`)
|
||||
)
|
||||
|
||||
// 请求结构体
|
||||
type (
|
||||
// AuthenticateRequest 认证请求
|
||||
AuthenticateRequest struct {
|
||||
Agent map[string]interface{} `json:"agent"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
Identifier string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
RequestUser bool `json:"requestUser"`
|
||||
}
|
||||
|
||||
// ValidTokenRequest 验证令牌请求
|
||||
ValidTokenRequest struct {
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
}
|
||||
|
||||
// RefreshRequest 刷新令牌请求
|
||||
RefreshRequest struct {
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
RequestUser bool `json:"requestUser"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile"`
|
||||
}
|
||||
|
||||
// SignOutRequest 登出请求
|
||||
SignOutRequest struct {
|
||||
Email string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
JoinServerRequest struct {
|
||||
ServerID string `json:"serverId" binding:"required"`
|
||||
AccessToken string `json:"accessToken" binding:"required"`
|
||||
SelectedProfile string `json:"selectedProfile" binding:"required"`
|
||||
}
|
||||
)
|
||||
|
||||
// 响应结构体
|
||||
type (
|
||||
// AuthenticateResponse 认证响应
|
||||
AuthenticateResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile,omitempty"`
|
||||
AvailableProfiles []map[string]interface{} `json:"availableProfiles"`
|
||||
User map[string]interface{} `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshResponse 刷新令牌响应
|
||||
RefreshResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
ClientToken string `json:"clientToken"`
|
||||
SelectedProfile map[string]interface{} `json:"selectedProfile,omitempty"`
|
||||
User map[string]interface{} `json:"user,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
type APIResponse struct {
|
||||
Status int `json:"status"`
|
||||
Data interface{} `json:"data"`
|
||||
Error interface{} `json:"error"`
|
||||
}
|
||||
|
||||
// standardResponse 生成标准响应
|
||||
func standardResponse(c *gin.Context, status int, data interface{}, err interface{}) {
|
||||
c.JSON(status, APIResponse{
|
||||
Status: status,
|
||||
Data: data,
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
|
||||
// Authenticate 用户认证
|
||||
func Authenticate(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
// 读取并保存原始请求体,以便多次读取
|
||||
rawData, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 读取请求体失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"})
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData))
|
||||
|
||||
// 绑定JSON数据到请求结构体
|
||||
var request AuthenticateRequest
|
||||
if err = c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析认证请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 根据标识符类型(邮箱或用户名)获取用户
|
||||
var userId int64
|
||||
var profile *model.Profile
|
||||
var UUID string
|
||||
if emailRegex.MatchString(request.Identifier) {
|
||||
userId, err = service.GetUserIDByEmail(db, request.Identifier)
|
||||
} else {
|
||||
profile, err = service.GetProfileByProfileName(db, request.Identifier)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 用户名不存在: ", zap.String("标识符", request.Identifier), zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userId = profile.UserID
|
||||
UUID = profile.UUID
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 认证失败: 用户不存在",
|
||||
zap.String("标识符:", request.Identifier),
|
||||
zap.Error(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
err = service.VerifyPassword(db, request.Password, userId)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 认证失败:", zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
// 生成新令牌
|
||||
selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(db, loggerInstance, userId, UUID, request.ClientToken)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 生成令牌失败:", zap.Error(err), zap.Any("用户ID:", userId))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := service.GetUserByID(userId)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] id查找错误:", zap.Error(err), zap.Any("ID:", userId))
|
||||
}
|
||||
// 处理可用的配置文件
|
||||
redisClient := redis.MustGetClient()
|
||||
availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles))
|
||||
for _, profile := range availableProfiles {
|
||||
availableProfilesData = append(availableProfilesData, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
response := AuthenticateResponse{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
AvailableProfiles: availableProfilesData,
|
||||
}
|
||||
if selectedProfile != nil {
|
||||
response.SelectedProfile = service.SerializeProfile(db, loggerInstance, redisClient, *selectedProfile)
|
||||
}
|
||||
if request.RequestUser {
|
||||
response.User = map[string]interface{}{
|
||||
"id": userId,
|
||||
"properties": user.Properties,
|
||||
}
|
||||
}
|
||||
|
||||
// 返回认证响应
|
||||
loggerInstance.Info("[INFO] 用户认证成功", zap.Any("用户ID:", userId))
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌
|
||||
func ValidToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析验证令牌请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 验证令牌
|
||||
if service.ValidToken(db, request.AccessToken, request.ClientToken) {
|
||||
loggerInstance.Info("[INFO] 令牌验证成功", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
} else {
|
||||
loggerInstance.Warn("[WARN] 令牌验证失败", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusForbidden, gin.H{"valid": false})
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func RefreshToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request RefreshRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析刷新令牌请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户ID和用户信息
|
||||
UUID, err := service.GetUUIDByAccessToken(db, request.AccessToken)
|
||||
if err != nil {
|
||||
loggerInstance.Warn("[WARN] 刷新令牌失败: 无效的访问令牌", zap.Any("令牌:", request.AccessToken), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userID, _ := service.GetUserIDByAccessToken(db, request.AccessToken)
|
||||
// 格式化UUID 这里是因为HMCL的传入参数是HEX格式,为了兼容HMCL,在此做处理
|
||||
UUID = utils.FormatUUID(UUID)
|
||||
|
||||
profile, err := service.GetProfileByUUID(db, UUID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 无法获取用户信息 错误: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 准备响应数据
|
||||
var profileData map[string]interface{}
|
||||
var userData map[string]interface{}
|
||||
var profileID string
|
||||
|
||||
// 处理选定的配置文件
|
||||
if request.SelectedProfile != nil {
|
||||
// 验证profileID是否存在
|
||||
profileIDValue, ok := request.SelectedProfile["id"]
|
||||
if !ok {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 缺少配置文件ID", zap.Any("ID:", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// 类型断言
|
||||
profileID, ok = profileIDValue.(string)
|
||||
if !ok {
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: 配置文件ID类型错误 ", zap.Any("用户ID:", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"})
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化profileID
|
||||
profileID = utils.FormatUUID(profileID)
|
||||
|
||||
// 验证配置文件所属用户
|
||||
if profile.UserID != userID {
|
||||
loggerInstance.Warn("[WARN] 刷新令牌失败: 用户不匹配 ", zap.Any("用户ID:", userID), zap.Any("配置文件用户ID:", profile.UserID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch})
|
||||
return
|
||||
}
|
||||
|
||||
profileData = service.SerializeProfile(db, loggerInstance, redis.MustGetClient(), *profile)
|
||||
}
|
||||
user, _ := service.GetUserByID(userID)
|
||||
// 添加用户信息(如果请求了)
|
||||
if request.RequestUser {
|
||||
userData = service.SerializeUser(loggerInstance, user, UUID)
|
||||
}
|
||||
|
||||
// 刷新令牌
|
||||
newAccessToken, newClientToken, err := service.RefreshToken(db, loggerInstance,
|
||||
request.AccessToken,
|
||||
request.ClientToken,
|
||||
profileID,
|
||||
)
|
||||
if err != nil {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
loggerInstance.Error("[ERROR] 刷新令牌失败: ", zap.Error(err), zap.Any("用户ID: ", userID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
loggerInstance.Info("[INFO] 刷新令牌成功", zap.Any("用户ID:", userID))
|
||||
c.JSON(http.StatusOK, RefreshResponse{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: newClientToken,
|
||||
SelectedProfile: profileData,
|
||||
User: userData,
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func InvalidToken(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request ValidTokenRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析使令牌失效请求失败: ", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 使令牌失效
|
||||
service.InvalidToken(db, loggerInstance, request.AccessToken)
|
||||
loggerInstance.Info("[INFO] 令牌已使失效", zap.Any("访问令牌:", request.AccessToken))
|
||||
c.JSON(http.StatusNoContent, gin.H{})
|
||||
}
|
||||
|
||||
// SignOut 用户登出
|
||||
func SignOut(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var request SignOutRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析登出请求失败: %v", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证邮箱格式
|
||||
if !emailRegex.MatchString(request.Email) {
|
||||
loggerInstance.Warn("[WARN] 登出失败: 邮箱格式不正确 ", zap.Any(" ", request.Email))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat})
|
||||
return
|
||||
}
|
||||
|
||||
// 通过邮箱获取用户
|
||||
user, err := service.GetUserByEmail(request.Email)
|
||||
if err != nil {
|
||||
loggerInstance.Warn(
|
||||
"登出失败: 用户不存在",
|
||||
zap.String("邮箱", request.Email),
|
||||
zap.Error(err),
|
||||
)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
password, err := service.GetPasswordByUserId(db, user.ID)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 邮箱查找失败", zap.Any("UserId:", user.ID), zap.Error(err))
|
||||
}
|
||||
// 验证密码
|
||||
if password != request.Password {
|
||||
loggerInstance.Warn("[WARN] 登出失败: 密码错误", zap.Any("用户ID:", user.ID))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword})
|
||||
return
|
||||
}
|
||||
|
||||
// 使该用户的所有令牌失效
|
||||
service.InvalidUserTokens(db, loggerInstance, user.ID)
|
||||
loggerInstance.Info("[INFO] 用户登出成功", zap.Any("用户ID:", user.ID))
|
||||
c.JSON(http.StatusNoContent, gin.H{"valid": true})
|
||||
}
|
||||
|
||||
func GetProfileByUUID(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
// 获取并格式化UUID
|
||||
uuid := utils.FormatUUID(c.Param("uuid"))
|
||||
loggerInstance.Info("[INFO] 接收到获取配置文件请求", zap.Any("UUID:", uuid))
|
||||
|
||||
// 获取配置文件
|
||||
profile, err := service.GetProfileByUUID(db, uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取配置文件失败:", zap.Error(err), zap.String("UUID:", uuid))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回配置文件信息
|
||||
loggerInstance.Info("[INFO] 成功获取配置文件", zap.String("UUID:", uuid), zap.String("名称:", profile.Name))
|
||||
c.JSON(http.StatusOK, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
|
||||
func JoinServer(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var request JoinServerRequest
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 解析请求参数
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
loggerInstance.Error(
|
||||
"解析加入服务器请求失败",
|
||||
zap.Error(err),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info(
|
||||
"收到加入服务器请求",
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
|
||||
// 处理加入服务器请求
|
||||
if err := service.JoinServer(db, loggerInstance, redisClient, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil {
|
||||
loggerInstance.Error(
|
||||
"加入服务器失败",
|
||||
zap.Error(err),
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed)
|
||||
return
|
||||
}
|
||||
|
||||
// 加入成功,返回204状态码
|
||||
loggerInstance.Info(
|
||||
"加入服务器成功",
|
||||
zap.String("服务器ID", request.ServerID),
|
||||
zap.String("用户UUID", request.SelectedProfile),
|
||||
zap.String("IP", clientIP),
|
||||
)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func HasJoinedServer(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
clientIP, _ := c.GetQuery("ip")
|
||||
|
||||
// 获取并验证服务器ID参数
|
||||
serverID, exists := c.GetQuery("serverId")
|
||||
if !exists || serverID == "" {
|
||||
loggerInstance.Warn("[WARN] 缺少服务器ID参数", zap.Any("IP:", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取并验证用户名参数
|
||||
username, exists := c.GetQuery("username")
|
||||
if !exists || username == "" {
|
||||
loggerInstance.Warn("[WARN] 缺少用户名参数", zap.Any("服务器ID:", serverID), zap.Any("IP:", clientIP))
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 收到会话验证请求", zap.Any("服务器ID:", serverID), zap.Any("用户名: ", username), zap.Any("IP: ", clientIP))
|
||||
|
||||
// 验证玩家是否已加入服务器
|
||||
if err := service.HasJoinedServer(loggerInstance, redisClient, serverID, username, clientIP); err != nil {
|
||||
loggerInstance.Warn("[WARN] 会话验证失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverID", serverID),
|
||||
zap.String("username", username),
|
||||
zap.String("clientIP", clientIP),
|
||||
)
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed)
|
||||
return
|
||||
}
|
||||
|
||||
profile, err := service.GetProfileByUUID(db, username)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取用户配置文件失败: %v - 用户名: %s",
|
||||
zap.Error(err), // 错误详情(zap 原生支持,保留错误链)
|
||||
zap.String("username", username), // 结构化存储用户名(便于检索)
|
||||
)
|
||||
standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回玩家配置文件
|
||||
loggerInstance.Info("[INFO] 会话验证成功 - 服务器ID: %s, 用户名: %s, UUID: %s",
|
||||
zap.String("serverID", serverID), // 结构化存储服务器ID
|
||||
zap.String("username", username), // 结构化存储用户名
|
||||
zap.String("UUID", profile.UUID), // 结构化存储UUID
|
||||
)
|
||||
c.JSON(200, service.SerializeProfile(db, loggerInstance, redisClient, *profile))
|
||||
}
|
||||
|
||||
func GetProfilesByName(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
|
||||
var names []string
|
||||
|
||||
// 解析请求参数
|
||||
if err := c.ShouldBindJSON(&names); err != nil {
|
||||
loggerInstance.Error("[ERROR] 解析名称数组请求失败: ",
|
||||
zap.Error(err),
|
||||
)
|
||||
standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams)
|
||||
return
|
||||
}
|
||||
loggerInstance.Info("[INFO] 接收到批量获取配置文件请求",
|
||||
zap.Int("名称数量:", len(names)), // 结构化存储名称数量
|
||||
)
|
||||
|
||||
// 批量获取配置文件
|
||||
profiles, err := service.GetProfilesDataByNames(db, names)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取配置文件失败: ",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
// 改造:zap 兼容原有 INFO 日志格式
|
||||
loggerInstance.Info("[INFO] 成功获取配置文件",
|
||||
zap.Int("请求名称数:", len(names)),
|
||||
zap.Int("返回结果数: ", len(profiles)),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusOK, profiles)
|
||||
}
|
||||
|
||||
func GetMetaData(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
meta := gin.H{
|
||||
"implementationName": "CellAuth",
|
||||
"implementationVersion": "0.0.1",
|
||||
"serverName": "LittleLan's Yggdrasil Server Implementation.",
|
||||
"links": gin.H{
|
||||
"homepage": "https://skin.littlelan.cn",
|
||||
"register": "https://skin.littlelan.cn/auth",
|
||||
},
|
||||
"feature.non_email_login": true,
|
||||
"feature.enable_profile_key": true,
|
||||
}
|
||||
skinDomains := []string{".hitwh.games", ".littlelan.cn"}
|
||||
signature, err := service.GetPublicKeyFromRedisFunc(loggerInstance, redisClient)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 获取公钥失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 提供元数据")
|
||||
c.JSON(http.StatusOK, gin.H{"meta": meta,
|
||||
"skinDomains": skinDomains,
|
||||
"signaturePublickey": signature})
|
||||
}
|
||||
|
||||
func GetPlayerCertificates(c *gin.Context) {
|
||||
loggerInstance := logger.MustGetLogger()
|
||||
db := database.MustGetDB()
|
||||
redisClient := redis.MustGetClient()
|
||||
|
||||
var uuid string
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否以 Bearer 开头并提取 sessionID
|
||||
bearerPrefix := "Bearer "
|
||||
if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
tokenID := authHeader[len(bearerPrefix):]
|
||||
if tokenID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
var err error
|
||||
uuid, err = service.GetUUIDByAccessToken(db, tokenID)
|
||||
|
||||
if uuid == "" {
|
||||
loggerInstance.Error("[ERROR] 获取玩家UUID失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化UUID
|
||||
uuid = utils.FormatUUID(uuid)
|
||||
|
||||
// 生成玩家证书
|
||||
certificate, err := service.GeneratePlayerCertificate(db, loggerInstance, redisClient, uuid)
|
||||
if err != nil {
|
||||
loggerInstance.Error("[ERROR] 生成玩家证书失败: ", zap.Error(err))
|
||||
standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer)
|
||||
return
|
||||
}
|
||||
|
||||
loggerInstance.Info("[INFO] 成功生成玩家证书")
|
||||
c.JSON(http.StatusOK, certificate)
|
||||
}
|
||||
157
internal/handler/yggdrasil_handler_test.go
Normal file
157
internal/handler/yggdrasil_handler_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestYggdrasilHandler_EmailValidation 测试邮箱验证逻辑
|
||||
func TestYggdrasilHandler_EmailValidation(t *testing.T) {
|
||||
// 使用简单的邮箱正则表达式
|
||||
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的邮箱",
|
||||
email: "test@example.com",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效的邮箱格式",
|
||||
email: "invalid-email",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少@符号",
|
||||
email: "testexample.com",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少域名",
|
||||
email: "test@",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空邮箱",
|
||||
email: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := emailRegex.MatchString(tt.email)
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Email validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_RequestValidation 测试请求验证逻辑
|
||||
func TestYggdrasilHandler_RequestValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
serverID string
|
||||
username string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的请求",
|
||||
accessToken: "token-123",
|
||||
serverID: "server-456",
|
||||
username: "player",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "accessToken为空",
|
||||
accessToken: "",
|
||||
serverID: "server-456",
|
||||
username: "player",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverID为空",
|
||||
accessToken: "token-123",
|
||||
serverID: "",
|
||||
username: "player",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "username为空",
|
||||
accessToken: "token-123",
|
||||
serverID: "server-456",
|
||||
username: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.accessToken != "" && tt.serverID != "" && tt.username != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Request validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_ErrorHandling 测试错误处理逻辑
|
||||
func TestYggdrasilHandler_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errType string
|
||||
wantCode int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "参数错误",
|
||||
errType: "bad_request",
|
||||
wantCode: 400,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "未授权",
|
||||
errType: "forbidden",
|
||||
wantCode: 403,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
errType: "server_error",
|
||||
wantCode: 500,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑
|
||||
if !tt.wantError {
|
||||
t.Error("Error handling test should expect error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestYggdrasilHandler_Constants 测试常量定义
|
||||
func TestYggdrasilHandler_Constants(t *testing.T) {
|
||||
// 验证常量定义
|
||||
if MaxMultipartMemory != 32<<20 {
|
||||
t.Errorf("MaxMultipartMemory = %d, want %d", MaxMultipartMemory, 32<<20)
|
||||
}
|
||||
|
||||
if TextureTypeSkin != "SKIN" {
|
||||
t.Errorf("TextureTypeSkin = %q, want 'SKIN'", TextureTypeSkin)
|
||||
}
|
||||
|
||||
if TextureTypeCape != "CAPE" {
|
||||
t.Errorf("TextureTypeCape = %q, want 'CAPE'", TextureTypeCape)
|
||||
}
|
||||
}
|
||||
|
||||
78
internal/middleware/auth.go
Normal file
78
internal/middleware/auth.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"carrotskin/pkg/auth"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "缺少Authorization头",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Bearer token格式
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的Authorization头格式",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token := tokenParts[1]
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的token",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息存储到上下文中
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("role", claims.Role)
|
||||
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选的JWT认证中间件
|
||||
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(tokenParts) == 2 && tokenParts[0] == "Bearer" {
|
||||
token := tokenParts[1]
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err == nil {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("role", claims.Role)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
158
internal/middleware/auth_test.go
Normal file
158
internal/middleware/auth_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"carrotskin/pkg/auth"
|
||||
)
|
||||
|
||||
// TestAuthMiddleware_MissingHeader 测试缺少Authorization头的情况
|
||||
// 注意:这个测试需要auth服务初始化,暂时跳过实际执行
|
||||
func TestAuthMiddleware_MissingHeader(t *testing.T) {
|
||||
// 测试逻辑:缺少Authorization头应该返回401
|
||||
// 由于需要auth服务初始化,这里只测试逻辑部分
|
||||
hasHeader := false
|
||||
if hasHeader {
|
||||
t.Error("测试场景应该没有Authorization头")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_InvalidFormat 测试无效的Authorization头格式
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试解析逻辑
|
||||
func TestAuthMiddleware_InvalidFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "缺少Bearer前缀",
|
||||
header: "token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只有Bearer没有token",
|
||||
header: "Bearer",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
header: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "错误的格式",
|
||||
header: "Token token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "标准格式",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试header解析逻辑
|
||||
tokenParts := strings.SplitN(tt.header, " ", 2)
|
||||
isValid := len(tokenParts) == 2 && tokenParts[0] == "Bearer"
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Header validation: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_ValidToken 测试有效token的情况
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试token格式
|
||||
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
||||
// 创建JWT服务并生成token
|
||||
jwtService := auth.NewJWTService("test-secret-key", 24)
|
||||
token, err := jwtService.GenerateToken(1, "testuser", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证token格式
|
||||
if token == "" {
|
||||
t.Error("生成的token不应为空")
|
||||
}
|
||||
|
||||
// 验证可以解析token
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err != nil {
|
||||
t.Fatalf("验证token失败: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != 1 {
|
||||
t.Errorf("UserID = %d, want 1", claims.UserID)
|
||||
}
|
||||
if claims.Username != "testuser" {
|
||||
t.Errorf("Username = %q, want 'testuser'", claims.Username)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptionalAuthMiddleware_NoHeader 测试可选认证中间件无header的情况
|
||||
// 注意:这个测试需要auth服务初始化,这里只测试逻辑
|
||||
func TestOptionalAuthMiddleware_NoHeader(t *testing.T) {
|
||||
// 测试逻辑:可选认证中间件在没有header时应该允许请求继续
|
||||
hasHeader := false
|
||||
shouldContinue := true // 可选认证应该允许继续
|
||||
|
||||
if hasHeader && !shouldContinue {
|
||||
t.Error("可选认证逻辑错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_HeaderParsing 测试Authorization头解析逻辑
|
||||
func TestAuthMiddleware_HeaderParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantValid bool
|
||||
wantToken string
|
||||
}{
|
||||
{
|
||||
name: "标准Bearer格式",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
wantToken: "token123",
|
||||
},
|
||||
{
|
||||
name: "Bearer后多个空格",
|
||||
header: "Bearer token123",
|
||||
wantValid: true,
|
||||
wantToken: " token123", // SplitN只分割一次
|
||||
},
|
||||
{
|
||||
name: "缺少Bearer",
|
||||
header: "token123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "只有Bearer",
|
||||
header: "Bearer",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenParts := strings.SplitN(tt.header, " ", 2)
|
||||
if len(tokenParts) == 2 && tokenParts[0] == "Bearer" {
|
||||
if !tt.wantValid {
|
||||
t.Errorf("应该无效但被识别为有效")
|
||||
}
|
||||
if tokenParts[1] != tt.wantToken {
|
||||
t.Errorf("Token = %q, want %q", tokenParts[1], tt.wantToken)
|
||||
}
|
||||
} else {
|
||||
if tt.wantValid {
|
||||
t.Errorf("应该有效但被识别为无效")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
22
internal/middleware/cors.go
Normal file
22
internal/middleware/cors.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
})
|
||||
}
|
||||
134
internal/middleware/cors_test.go
Normal file
134
internal/middleware/cors_test.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TestCORS_Headers 测试CORS中间件设置的响应头
|
||||
func TestCORS_Headers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证CORS响应头
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
}
|
||||
|
||||
for header, expectedValue := range expectedHeaders {
|
||||
actualValue := w.Header().Get(header)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证Access-Control-Allow-Headers包含必要字段
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
if allowHeaders == "" {
|
||||
t.Error("Access-Control-Allow-Headers 不应为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_OPTIONS 测试OPTIONS请求处理
|
||||
func TestCORS_OPTIONS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("OPTIONS", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// OPTIONS请求应该返回204状态码
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_AllowMethods 测试允许的HTTP方法
|
||||
func TestCORS_AllowMethods(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE"}
|
||||
for _, method := range methods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(method, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证允许的方法头包含该方法
|
||||
allowMethods := w.Header().Get("Access-Control-Allow-Methods")
|
||||
if allowMethods == "" {
|
||||
t.Error("Access-Control-Allow-Methods 不应为空")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_AllowHeaders 测试允许的请求头
|
||||
func TestCORS_AllowHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
expectedHeaders := []string{"Content-Type", "Authorization", "Accept"}
|
||||
|
||||
for _, expectedHeader := range expectedHeaders {
|
||||
if !contains(allowHeaders, expectedHeader) {
|
||||
t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:检查字符串是否包含子字符串(简单实现)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(s) < len(substr) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
39
internal/middleware/logger.go
Normal file
39
internal/middleware/logger.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logger 日志中间件
|
||||
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录日志
|
||||
latency := time.Since(start)
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
if raw != "" {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
logger.Info("HTTP请求",
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.Int("status", statusCode),
|
||||
zap.String("ip", clientIP),
|
||||
zap.Duration("latency", latency),
|
||||
zap.String("user_agent", c.Request.UserAgent()),
|
||||
)
|
||||
})
|
||||
}
|
||||
185
internal/middleware/logger_test.go
Normal file
185
internal/middleware/logger_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestLogger_Middleware 测试日志中间件基本功能
|
||||
func TestLogger_Middleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
router.ServeHTTP(w, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// 验证请求成功处理
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// 验证处理时间合理(应该很短)
|
||||
if duration > 1*time.Second {
|
||||
t.Errorf("处理时间过长: %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_RequestInfo 测试日志中间件记录的请求信息
|
||||
func TestLogger_RequestInfo(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{
|
||||
name: "GET请求",
|
||||
method: "GET",
|
||||
path: "/test",
|
||||
},
|
||||
{
|
||||
name: "POST请求",
|
||||
method: "POST",
|
||||
path: "/test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(tt.method, tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证请求被正确处理
|
||||
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
||||
t.Errorf("状态码 = %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_QueryParams 测试带查询参数的请求
|
||||
func TestLogger_QueryParams(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test?page=1&size=20", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证请求成功处理
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_StatusCodes 测试不同状态码的日志记录
|
||||
func TestLogger_StatusCodes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
|
||||
router.GET("/success", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
router.GET("/notfound", func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"message": "not found"})
|
||||
})
|
||||
router.GET("/error", func(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"message": "error"})
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "成功请求",
|
||||
path: "/success",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "404请求",
|
||||
path: "/notfound",
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "500请求",
|
||||
path: "/error",
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLogger_Latency 测试延迟计算
|
||||
func TestLogger_Latency(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Logger(logger))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
// 模拟一些处理时间
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
router.ServeHTTP(w, req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// 验证延迟计算合理(应该包含处理时间)
|
||||
if duration < 10*time.Millisecond {
|
||||
t.Errorf("延迟计算可能不正确: %v", duration)
|
||||
}
|
||||
}
|
||||
29
internal/middleware/recovery.go
Normal file
29
internal/middleware/recovery.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(string); ok {
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", err),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "服务器内部错误",
|
||||
})
|
||||
})
|
||||
}
|
||||
153
internal/middleware/recovery_test.go
Normal file
153
internal/middleware/recovery_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestRecovery_PanicHandling 测试恢复中间件处理panic
|
||||
func TestRecovery_PanicHandling(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
// 创建一个会panic的路由
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 应该不会导致测试panic,而是返回500错误
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_StringPanic 测试字符串类型的panic
|
||||
func TestRecovery_StringPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("string panic message")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_ErrorPanic 测试error类型的panic
|
||||
func TestRecovery_ErrorPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic(http.ErrBodyReadAfterClose)
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 应该不会导致测试panic
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_NilPanic 测试nil panic
|
||||
func TestRecovery_NilPanic(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
// 直接panic模拟nil pointer错误,避免linter警告
|
||||
panic("runtime error: invalid memory address or nil pointer dereference")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证返回500状态码
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_ResponseFormat 测试恢复后的响应格式
|
||||
func TestRecovery_ResponseFormat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/panic", func(c *gin.Context) {
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/panic", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证响应体包含错误信息
|
||||
body := w.Body.String()
|
||||
if body == "" {
|
||||
t.Error("响应体不应为空")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecovery_NormalRequest 测试正常请求不受影响
|
||||
func TestRecovery_NormalRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
router := gin.New()
|
||||
router.Use(Recovery(logger))
|
||||
|
||||
router.GET("/normal", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/normal", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 正常请求应该不受影响
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
45
internal/model/audit_log.go
Normal file
45
internal/model/audit_log.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog 审计日志模型
|
||||
type AuditLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID *int64 `gorm:"column:user_id;type:bigint;index" json:"user_id,omitempty"`
|
||||
Action string `gorm:"column:action;type:varchar(100);not null;index" json:"action"`
|
||||
ResourceType string `gorm:"column:resource_type;type:varchar(50);not null;index:idx_audit_logs_resource" json:"resource_type"`
|
||||
ResourceID string `gorm:"column:resource_id;type:varchar(50);index:idx_audit_logs_resource" json:"resource_id,omitempty"`
|
||||
OldValues string `gorm:"column:old_values;type:jsonb" json:"old_values,omitempty"` // JSONB 格式
|
||||
NewValues string `gorm:"column:new_values;type:jsonb" json:"new_values,omitempty"` // JSONB 格式
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_audit_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
// CasbinRule Casbin 权限规则模型
|
||||
type CasbinRule struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
PType string `gorm:"column:ptype;type:varchar(100);not null;index;uniqueIndex:uk_casbin_rule" json:"ptype"`
|
||||
V0 string `gorm:"column:v0;type:varchar(100);not null;default:'';index;uniqueIndex:uk_casbin_rule" json:"v0"`
|
||||
V1 string `gorm:"column:v1;type:varchar(100);not null;default:'';index;uniqueIndex:uk_casbin_rule" json:"v1"`
|
||||
V2 string `gorm:"column:v2;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v2"`
|
||||
V3 string `gorm:"column:v3;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v3"`
|
||||
V4 string `gorm:"column:v4;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v4"`
|
||||
V5 string `gorm:"column:v5;type:varchar(100);not null;default:'';uniqueIndex:uk_casbin_rule" json:"v5"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (CasbinRule) TableName() string {
|
||||
return "casbin_rule"
|
||||
}
|
||||
63
internal/model/profile.go
Normal file
63
internal/model/profile.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Profile Minecraft 档案模型
|
||||
type Profile struct {
|
||||
UUID string `gorm:"column:uuid;type:varchar(36);primaryKey" json:"uuid"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
Name string `gorm:"column:name;type:varchar(16);not null;uniqueIndex" json:"name"` // Minecraft 角色名
|
||||
SkinID *int64 `gorm:"column:skin_id;type:bigint" json:"skin_id,omitempty"`
|
||||
CapeID *int64 `gorm:"column:cape_id;type:bigint" json:"cape_id,omitempty"`
|
||||
RSAPrivateKey string `gorm:"column:rsa_private_key;type:text;not null" json:"-"` // RSA 私钥不返回给前端
|
||||
IsActive bool `gorm:"column:is_active;not null;default:true;index" json:"is_active"`
|
||||
LastUsedAt *time.Time `gorm:"column:last_used_at;type:timestamp" json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Skin *Texture `gorm:"foreignKey:SkinID" json:"skin,omitempty"`
|
||||
Cape *Texture `gorm:"foreignKey:CapeID" json:"cape,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Profile) TableName() string {
|
||||
return "profiles"
|
||||
}
|
||||
|
||||
// ProfileResponse 档案响应(包含完整的皮肤/披风信息)
|
||||
type ProfileResponse struct {
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
Textures ProfileTexturesData `json:"textures"`
|
||||
IsActive bool `json:"is_active"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ProfileTexturesData Minecraft 材质数据结构
|
||||
type ProfileTexturesData struct {
|
||||
Skin *ProfileTexture `json:"SKIN,omitempty"`
|
||||
Cape *ProfileTexture `json:"CAPE,omitempty"`
|
||||
}
|
||||
|
||||
// ProfileTexture 单个材质信息
|
||||
type ProfileTexture struct {
|
||||
URL string `json:"url"`
|
||||
Metadata *ProfileTextureMetadata `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ProfileTextureMetadata 材质元数据
|
||||
type ProfileTextureMetadata struct {
|
||||
Model string `json:"model,omitempty"` // "slim" or "classic"
|
||||
}
|
||||
|
||||
type KeyPair struct {
|
||||
PrivateKey string `json:"private_key" bson:"private_key"`
|
||||
PublicKey string `json:"public_key" bson:"public_key"`
|
||||
Expiration time.Time `json:"expiration" bson:"expiration"`
|
||||
Refresh time.Time `json:"refresh" bson:"refresh"`
|
||||
}
|
||||
85
internal/model/response.go
Normal file
85
internal/model/response.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package model
|
||||
|
||||
// Response 通用API响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"` // 业务状态码
|
||||
Message string `json:"message"` // 响应消息
|
||||
Data interface{} `json:"data,omitempty"` // 响应数据
|
||||
}
|
||||
|
||||
// PaginationResponse 分页响应结构
|
||||
type PaginationResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PerPage int `json:"per_page"` // 每页数量
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应
|
||||
type ErrorResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Error string `json:"error,omitempty"` // 详细错误信息(仅开发环境)
|
||||
}
|
||||
|
||||
// 常用状态码
|
||||
const (
|
||||
CodeSuccess = 200 // 成功
|
||||
CodeCreated = 201 // 创建成功
|
||||
CodeBadRequest = 400 // 请求参数错误
|
||||
CodeUnauthorized = 401 // 未授权
|
||||
CodeForbidden = 403 // 禁止访问
|
||||
CodeNotFound = 404 // 资源不存在
|
||||
CodeConflict = 409 // 资源冲突
|
||||
CodeServerError = 500 // 服务器错误
|
||||
)
|
||||
|
||||
// 常用响应消息
|
||||
const (
|
||||
MsgSuccess = "操作成功"
|
||||
MsgCreated = "创建成功"
|
||||
MsgBadRequest = "请求参数错误"
|
||||
MsgUnauthorized = "未授权,请先登录"
|
||||
MsgForbidden = "权限不足"
|
||||
MsgNotFound = "资源不存在"
|
||||
MsgConflict = "资源已存在"
|
||||
MsgServerError = "服务器内部错误"
|
||||
MsgInvalidToken = "无效的令牌"
|
||||
MsgTokenExpired = "令牌已过期"
|
||||
MsgInvalidCredentials = "用户名或密码错误"
|
||||
)
|
||||
|
||||
// NewSuccessResponse 创建成功响应
|
||||
func NewSuccessResponse(data interface{}) *Response {
|
||||
return &Response{
|
||||
Code: CodeSuccess,
|
||||
Message: MsgSuccess,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorResponse 创建错误响应
|
||||
func NewErrorResponse(code int, message string, err error) *ErrorResponse {
|
||||
resp := &ErrorResponse{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewPaginationResponse 创建分页响应
|
||||
func NewPaginationResponse(data interface{}, total int64, page, perPage int) *PaginationResponse {
|
||||
return &PaginationResponse{
|
||||
Code: CodeSuccess,
|
||||
Message: MsgSuccess,
|
||||
Data: data,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PerPage: perPage,
|
||||
}
|
||||
}
|
||||
257
internal/model/response_test.go
Normal file
257
internal/model/response_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewSuccessResponse 测试创建成功响应
|
||||
func TestNewSuccessResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
}{
|
||||
{
|
||||
name: "字符串数据",
|
||||
data: "success",
|
||||
},
|
||||
{
|
||||
name: "map数据",
|
||||
data: map[string]string{
|
||||
"id": "1",
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil数据",
|
||||
data: nil,
|
||||
},
|
||||
{
|
||||
name: "数组数据",
|
||||
data: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewSuccessResponse(tt.data)
|
||||
if resp == nil {
|
||||
t.Fatal("NewSuccessResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != CodeSuccess {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, CodeSuccess)
|
||||
}
|
||||
if resp.Message != MsgSuccess {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, MsgSuccess)
|
||||
}
|
||||
// 对于可比较类型直接比较,对于不可比较类型只验证不为nil
|
||||
switch v := tt.data.(type) {
|
||||
case string, nil:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
if tt.data == nil && resp.Data != nil {
|
||||
t.Error("Data 应为nil")
|
||||
}
|
||||
case []string:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
default:
|
||||
// 对于map等不可比较类型,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
_ = v
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewErrorResponse 测试创建错误响应
|
||||
func TestNewErrorResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
message string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "带错误信息",
|
||||
code: CodeBadRequest,
|
||||
message: "请求参数错误",
|
||||
err: errors.New("具体错误信息"),
|
||||
},
|
||||
{
|
||||
name: "无错误信息",
|
||||
code: CodeUnauthorized,
|
||||
message: "未授权",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "服务器错误",
|
||||
code: CodeServerError,
|
||||
message: "服务器内部错误",
|
||||
err: errors.New("数据库连接失败"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewErrorResponse(tt.code, tt.message, tt.err)
|
||||
if resp == nil {
|
||||
t.Fatal("NewErrorResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != tt.code {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, tt.code)
|
||||
}
|
||||
if resp.Message != tt.message {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, tt.message)
|
||||
}
|
||||
if tt.err != nil {
|
||||
if resp.Error != tt.err.Error() {
|
||||
t.Errorf("Error = %q, want %q", resp.Error, tt.err.Error())
|
||||
}
|
||||
} else {
|
||||
if resp.Error != "" {
|
||||
t.Errorf("Error 应为空,实际为 %q", resp.Error)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewPaginationResponse 测试创建分页响应
|
||||
func TestNewPaginationResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data interface{}
|
||||
total int64
|
||||
page int
|
||||
perPage int
|
||||
}{
|
||||
{
|
||||
name: "正常分页",
|
||||
data: []string{"a", "b", "c"},
|
||||
total: 100,
|
||||
page: 1,
|
||||
perPage: 20,
|
||||
},
|
||||
{
|
||||
name: "空数据",
|
||||
data: []string{},
|
||||
total: 0,
|
||||
page: 1,
|
||||
perPage: 20,
|
||||
},
|
||||
{
|
||||
name: "最后一页",
|
||||
data: []string{"a", "b"},
|
||||
total: 22,
|
||||
page: 3,
|
||||
perPage: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := NewPaginationResponse(tt.data, tt.total, tt.page, tt.perPage)
|
||||
if resp == nil {
|
||||
t.Fatal("NewPaginationResponse() 返回nil")
|
||||
}
|
||||
if resp.Code != CodeSuccess {
|
||||
t.Errorf("Code = %d, want %d", resp.Code, CodeSuccess)
|
||||
}
|
||||
if resp.Message != MsgSuccess {
|
||||
t.Errorf("Message = %q, want %q", resp.Message, MsgSuccess)
|
||||
}
|
||||
// 对于可比较类型直接比较,对于不可比较类型只验证不为nil
|
||||
switch v := tt.data.(type) {
|
||||
case string, nil:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
if tt.data == nil && resp.Data != nil {
|
||||
t.Error("Data 应为nil")
|
||||
}
|
||||
case []string:
|
||||
// 数组不能直接比较,只验证不为nil
|
||||
if resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
default:
|
||||
// 对于map等不可比较类型,只验证不为nil
|
||||
if tt.data != nil && resp.Data == nil {
|
||||
t.Error("Data 不应为nil")
|
||||
}
|
||||
_ = v
|
||||
}
|
||||
if resp.Total != tt.total {
|
||||
t.Errorf("Total = %d, want %d", resp.Total, tt.total)
|
||||
}
|
||||
if resp.Page != tt.page {
|
||||
t.Errorf("Page = %d, want %d", resp.Page, tt.page)
|
||||
}
|
||||
if resp.PerPage != tt.perPage {
|
||||
t.Errorf("PerPage = %d, want %d", resp.PerPage, tt.perPage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResponseConstants 测试响应常量
|
||||
func TestResponseConstants(t *testing.T) {
|
||||
// 测试状态码常量
|
||||
statusCodes := map[string]int{
|
||||
"CodeSuccess": CodeSuccess,
|
||||
"CodeCreated": CodeCreated,
|
||||
"CodeBadRequest": CodeBadRequest,
|
||||
"CodeUnauthorized": CodeUnauthorized,
|
||||
"CodeForbidden": CodeForbidden,
|
||||
"CodeNotFound": CodeNotFound,
|
||||
"CodeConflict": CodeConflict,
|
||||
"CodeServerError": CodeServerError,
|
||||
}
|
||||
|
||||
expectedCodes := map[string]int{
|
||||
"CodeSuccess": 200,
|
||||
"CodeCreated": 201,
|
||||
"CodeBadRequest": 400,
|
||||
"CodeUnauthorized": 401,
|
||||
"CodeForbidden": 403,
|
||||
"CodeNotFound": 404,
|
||||
"CodeConflict": 409,
|
||||
"CodeServerError": 500,
|
||||
}
|
||||
|
||||
for name, code := range statusCodes {
|
||||
expected := expectedCodes[name]
|
||||
if code != expected {
|
||||
t.Errorf("%s = %d, want %d", name, code, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试消息常量不为空
|
||||
messages := []string{
|
||||
MsgSuccess,
|
||||
MsgCreated,
|
||||
MsgBadRequest,
|
||||
MsgUnauthorized,
|
||||
MsgForbidden,
|
||||
MsgNotFound,
|
||||
MsgConflict,
|
||||
MsgServerError,
|
||||
MsgInvalidToken,
|
||||
MsgTokenExpired,
|
||||
MsgInvalidCredentials,
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg == "" {
|
||||
t.Error("响应消息常量不应为空")
|
||||
}
|
||||
}
|
||||
}
|
||||
41
internal/model/system_config.go
Normal file
41
internal/model/system_config.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConfigType 配置类型
|
||||
type ConfigType string
|
||||
|
||||
const (
|
||||
ConfigTypeString ConfigType = "STRING"
|
||||
ConfigTypeInteger ConfigType = "INTEGER"
|
||||
ConfigTypeBoolean ConfigType = "BOOLEAN"
|
||||
ConfigTypeJSON ConfigType = "JSON"
|
||||
)
|
||||
|
||||
// SystemConfig 系统配置模型
|
||||
type SystemConfig struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Key string `gorm:"column:key;type:varchar(100);not null;uniqueIndex" json:"key"`
|
||||
Value string `gorm:"column:value;type:text;not null" json:"value"`
|
||||
Description string `gorm:"column:description;type:varchar(255);not null;default:''" json:"description"`
|
||||
Type ConfigType `gorm:"column:type;type:varchar(50);not null;default:'STRING'" json:"type"` // STRING, INTEGER, BOOLEAN, JSON
|
||||
IsPublic bool `gorm:"column:is_public;not null;default:false;index" json:"is_public"` // 是否可被前端获取
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (SystemConfig) TableName() string {
|
||||
return "system_config"
|
||||
}
|
||||
|
||||
// SystemConfigPublicResponse 公开配置响应
|
||||
type SystemConfigPublicResponse struct {
|
||||
SiteName string `json:"site_name"`
|
||||
SiteDescription string `json:"site_description"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
MaintenanceMode bool `json:"maintenance_mode"`
|
||||
Announcement string `json:"announcement"`
|
||||
}
|
||||
76
internal/model/texture.go
Normal file
76
internal/model/texture.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// TextureType 材质类型
|
||||
type TextureType string
|
||||
|
||||
const (
|
||||
TextureTypeSkin TextureType = "SKIN"
|
||||
TextureTypeCape TextureType = "CAPE"
|
||||
)
|
||||
|
||||
// Texture 材质模型
|
||||
type Texture struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UploaderID int64 `gorm:"column:uploader_id;not null;index" json:"uploader_id"`
|
||||
Name string `gorm:"column:name;type:varchar(100);not null;default:''" json:"name"`
|
||||
Description string `gorm:"column:description;type:text" json:"description,omitempty"`
|
||||
Type TextureType `gorm:"column:type;type:varchar(50);not null" json:"type"` // SKIN, CAPE
|
||||
URL string `gorm:"column:url;type:varchar(255);not null" json:"url"`
|
||||
Hash string `gorm:"column:hash;type:varchar(64);not null;uniqueIndex" json:"hash"` // SHA-256
|
||||
Size int `gorm:"column:size;type:integer;not null;default:0" json:"size"`
|
||||
IsPublic bool `gorm:"column:is_public;not null;default:false;index:idx_textures_public_type_status" json:"is_public"`
|
||||
DownloadCount int `gorm:"column:download_count;type:integer;not null;default:0;index:idx_textures_download_count,sort:desc" json:"download_count"`
|
||||
FavoriteCount int `gorm:"column:favorite_count;type:integer;not null;default:0;index:idx_textures_favorite_count,sort:desc" json:"favorite_count"`
|
||||
IsSlim bool `gorm:"column:is_slim;not null;default:false" json:"is_slim"` // Alex(细) or Steve(粗)
|
||||
Status int16 `gorm:"column:status;type:smallint;not null;default:1;index:idx_textures_public_type_status" json:"status"` // 1:正常, 0:审核中, -1:已删除
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
|
||||
// 关联
|
||||
Uploader *User `gorm:"foreignKey:UploaderID" json:"uploader,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (Texture) TableName() string {
|
||||
return "textures"
|
||||
}
|
||||
|
||||
// UserTextureFavorite 用户材质收藏
|
||||
type UserTextureFavorite struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index;uniqueIndex:uk_user_texture" json:"user_id"`
|
||||
TextureID int64 `gorm:"column:texture_id;not null;index;uniqueIndex:uk_user_texture" json:"texture_id"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Texture *Texture `gorm:"foreignKey:TextureID" json:"texture,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserTextureFavorite) TableName() string {
|
||||
return "user_texture_favorites"
|
||||
}
|
||||
|
||||
// TextureDownloadLog 材质下载记录
|
||||
type TextureDownloadLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
TextureID int64 `gorm:"column:texture_id;not null;index" json:"texture_id"`
|
||||
UserID *int64 `gorm:"column:user_id;type:bigint;index" json:"user_id,omitempty"`
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null;index" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_download_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
Texture *Texture `gorm:"foreignKey:TextureID" json:"texture,omitempty"`
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (TextureDownloadLog) TableName() string {
|
||||
return "texture_download_logs"
|
||||
}
|
||||
14
internal/model/token.go
Normal file
14
internal/model/token.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
type Token struct {
|
||||
AccessToken string `json:"_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ClientToken string `json:"client_token"`
|
||||
ProfileId string `json:"profile_id"`
|
||||
Usable bool `json:"usable"`
|
||||
IssueDate time.Time `json:"issue_date"`
|
||||
}
|
||||
|
||||
func (Token) TableName() string { return "token" }
|
||||
70
internal/model/user.go
Normal file
70
internal/model/user.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// User 用户模型
|
||||
type User struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
Username string `gorm:"column:username;type:varchar(255);not null;uniqueIndex" json:"username"`
|
||||
Password string `gorm:"column:password;type:varchar(255);not null" json:"-"` // 密码不返回给前端
|
||||
Email string `gorm:"column:email;type:varchar(255);not null;uniqueIndex" json:"email"`
|
||||
Avatar string `gorm:"column:avatar;type:varchar(255);not null;default:''" json:"avatar"`
|
||||
Points int `gorm:"column:points;type:integer;not null;default:0" json:"points"`
|
||||
Role string `gorm:"column:role;type:varchar(50);not null;default:'user'" json:"role"`
|
||||
Status int16 `gorm:"column:status;type:smallint;not null;default:1" json:"status"` // 1:正常, 0:禁用, -1:删除
|
||||
Properties string `gorm:"column:properties;type:jsonb" json:"properties"` // JSON字符串,存储为PostgreSQL的JSONB类型
|
||||
LastLoginAt *time.Time `gorm:"column:last_login_at;type:timestamp" json:"last_login_at,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;type:timestamp;not null;default:CURRENT_TIMESTAMP" json:"updated_at"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (User) TableName() string {
|
||||
return "user"
|
||||
}
|
||||
|
||||
// UserPointLog 用户积分变更记录
|
||||
type UserPointLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
ChangeType string `gorm:"column:change_type;type:varchar(50);not null" json:"change_type"` // EARN, SPEND, ADMIN_ADJUST
|
||||
Amount int `gorm:"column:amount;type:integer;not null" json:"amount"`
|
||||
BalanceBefore int `gorm:"column:balance_before;type:integer;not null" json:"balance_before"`
|
||||
BalanceAfter int `gorm:"column:balance_after;type:integer;not null" json:"balance_after"`
|
||||
Reason string `gorm:"column:reason;type:varchar(255);not null" json:"reason"`
|
||||
ReferenceType string `gorm:"column:reference_type;type:varchar(50)" json:"reference_type,omitempty"`
|
||||
ReferenceID *int64 `gorm:"column:reference_id;type:bigint" json:"reference_id,omitempty"`
|
||||
OperatorID *int64 `gorm:"column:operator_id;type:bigint" json:"operator_id,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_point_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Operator *User `gorm:"foreignKey:OperatorID" json:"operator,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserPointLog) TableName() string {
|
||||
return "user_point_logs"
|
||||
}
|
||||
|
||||
// UserLoginLog 用户登录日志
|
||||
type UserLoginLog struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;autoIncrement" json:"id"`
|
||||
UserID int64 `gorm:"column:user_id;not null;index" json:"user_id"`
|
||||
IPAddress string `gorm:"column:ip_address;type:inet;not null;index" json:"ip_address"`
|
||||
UserAgent string `gorm:"column:user_agent;type:text" json:"user_agent,omitempty"`
|
||||
LoginMethod string `gorm:"column:login_method;type:varchar(50);not null;default:'PASSWORD'" json:"login_method"`
|
||||
IsSuccess bool `gorm:"column:is_success;not null;index" json:"is_success"`
|
||||
FailureReason string `gorm:"column:failure_reason;type:varchar(255)" json:"failure_reason,omitempty"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;type:timestamp;not null;default:CURRENT_TIMESTAMP;index:idx_login_logs_created_at,sort:desc" json:"created_at"`
|
||||
|
||||
// 关联
|
||||
User *User `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (UserLoginLog) TableName() string {
|
||||
return "user_login_logs"
|
||||
}
|
||||
48
internal/model/yggdrasil.go
Normal file
48
internal/model/yggdrasil.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 定义随机字符集
|
||||
const passwordChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
// Yggdrasil ygg密码与用户id绑定
|
||||
type Yggdrasil struct {
|
||||
ID int64 `gorm:"column:id;primaryKey;not null" json:"id"`
|
||||
Password string `gorm:"column:password;not null" json:"password"`
|
||||
// 关联 - Yggdrasil的ID引用User的ID,但不自动创建外键约束(避免循环依赖)
|
||||
User *User `gorm:"foreignKey:ID;references:ID;constraint:OnDelete:CASCADE,OnUpdate:CASCADE" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (Yggdrasil) TableName() string { return "Yggdrasil" }
|
||||
|
||||
// AfterCreate User创建后自动同步生成GeneratePassword记录
|
||||
func (u *User) AfterCreate(tx *gorm.DB) error {
|
||||
randomPwd := GenerateRandomPassword(16)
|
||||
|
||||
// 创建GeneratePassword记录
|
||||
gp := Yggdrasil{
|
||||
ID: u.ID, // 关联User的ID
|
||||
Password: randomPwd, // 16位随机密码
|
||||
}
|
||||
|
||||
if err := tx.Create(&gp).Error; err != nil {
|
||||
// 若同步失败,可记录日志或回滚事务(根据业务需求处理)
|
||||
return fmt.Errorf("同步生成密码失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomPassword 生成指定长度的随机字符串
|
||||
func GenerateRandomPassword(length int) string {
|
||||
rand.Seed(time.Now().UnixNano()) // 初始化随机数种子
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = passwordChars[rand.Intn(len(passwordChars))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
199
internal/repository/profile_repository.go
Normal file
199
internal/repository/profile_repository.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(profile).Error
|
||||
}
|
||||
|
||||
// FindProfileByUUID 根据UUID查找档案
|
||||
func FindProfileByUUID(uuid string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// FindProfileByName 根据角色名查找档案
|
||||
func FindProfileByName(name string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("name = ?", name).First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// FindProfilesByUserID 获取用户的所有档案
|
||||
func FindProfilesByUserID(userID int64) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(profile).Error
|
||||
}
|
||||
|
||||
// UpdateProfileFields 更新指定字段
|
||||
func UpdateProfileFields(uuid string, updates map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
// CountProfilesByUserID 统计用户的档案数量
|
||||
func CountProfilesByUserID(userID int64) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置档案为活跃状态(同时将用户的其他档案设置为非活跃)
|
||||
func SetActiveProfile(uuid string, userID int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 将用户的所有档案设置为非活跃
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将指定档案设置为活跃
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("uuid = ? AND user_id = ?", uuid, userID).
|
||||
Update("is_active", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateProfileLastUsedAt 更新最后使用时间
|
||||
func UpdateProfileLastUsedAt(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
|
||||
// FindOneProfileByUserID 根据id找一个角色
|
||||
func FindOneProfileByUserID(userID int64) (*model.Profile, error) {
|
||||
profiles, err := FindProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile := profiles[0]
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func GetProfilesByNames(names []string) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("name in (?)", names).Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func GetProfileKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
db := database.MustGetDB()
|
||||
// 1. 参数校验(保持原逻辑)
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 2. GORM 查询:只查询 key_pair 字段(对应原 mongo 投影)
|
||||
var profile *model.Profile
|
||||
// 条件:id = profileId(PostgreSQL 主键),只选择 key_pair 字段
|
||||
result := db.WithContext(context.Background()).
|
||||
Select("key_pair"). // 只查询需要的字段(投影)
|
||||
Where("id = ?", profileId). // 查询条件(GORM 自动处理占位符,避免 SQL 注入)
|
||||
First(&profile) // 查单条记录
|
||||
|
||||
// 3. 错误处理(适配 GORM 错误类型)
|
||||
if result.Error != nil {
|
||||
// 空结果判断(对应原 mongo.ErrNoDocuments / pgx.ErrNoRows)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("key pair未找到")
|
||||
}
|
||||
// 保持原错误封装格式
|
||||
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
|
||||
}
|
||||
|
||||
// 4. JSONB 反序列化为 model.KeyPair
|
||||
keyPair := &model.KeyPair{}
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
func UpdateProfileKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
db := database.MustGetDB()
|
||||
// 仅保留最必要的入参校验(避免无效数据库请求)
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
if keyPair == nil {
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
// 事务内执行核心更新(保证原子性,出错自动回滚)
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 核心更新逻辑:按 profileId 匹配,直接更新 key_pair 相关字段
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles"). // 目标表名(与 PostgreSQL 表一致)
|
||||
Where("id = ?", profileId). // 更新条件:profileId 匹配
|
||||
// 直接映射字段(无需序列化,依赖 GORM 自动字段匹配)
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey, // 数据库 private_key 字段
|
||||
"public_key": keyPair.PublicKey, // 数据库 public_key 字段
|
||||
// 若 key_pair 是单个字段(非拆分),替换为:"key_pair": keyPair
|
||||
})
|
||||
|
||||
// 仅处理数据库层面的致命错误
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
184
internal/repository/profile_repository_test.go
Normal file
184
internal/repository/profile_repository_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileRepository_QueryConditions 测试档案查询条件逻辑
|
||||
func TestProfileRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
userID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UUID",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
userID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "UUID为空",
|
||||
uuid: "",
|
||||
userID: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
userID: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.uuid != "" && tt.userID > 0
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_SetActiveLogic 测试设置活跃档案的逻辑
|
||||
func TestProfileRepository_SetActiveLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
userID int64
|
||||
otherProfiles int
|
||||
wantAllInactive bool
|
||||
}{
|
||||
{
|
||||
name: "设置一个档案为活跃,其他应该变为非活跃",
|
||||
uuid: "profile-1",
|
||||
userID: 1,
|
||||
otherProfiles: 2,
|
||||
wantAllInactive: true,
|
||||
},
|
||||
{
|
||||
name: "只有一个档案时",
|
||||
uuid: "profile-1",
|
||||
userID: 1,
|
||||
otherProfiles: 0,
|
||||
wantAllInactive: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:设置一个档案为活跃时,应该先将所有档案设为非活跃
|
||||
if !tt.wantAllInactive {
|
||||
t.Error("Setting active profile should first set all profiles to inactive")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_CountLogic 测试统计逻辑
|
||||
func TestProfileRepository_CountLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
wantCount int64
|
||||
}{
|
||||
{
|
||||
name: "有效用户ID",
|
||||
userID: 1,
|
||||
wantCount: 0, // 实际值取决于数据库
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
userID: 0,
|
||||
wantCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证统计逻辑:用户ID应该大于0
|
||||
if tt.userID <= 0 && tt.wantCount != 0 {
|
||||
t.Error("Invalid userID should not count profiles")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_UpdateFieldsLogic 测试更新字段逻辑
|
||||
func TestProfileRepository_UpdateFieldsLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uuid string
|
||||
updates map[string]interface{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的更新",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
updates: map[string]interface{}{
|
||||
"name": "NewName",
|
||||
"skin_id": int64(1),
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "UUID为空",
|
||||
uuid: "",
|
||||
updates: map[string]interface{}{"name": "NewName"},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "更新字段为空",
|
||||
uuid: "123e4567-e89b-12d3-a456-426614174000",
|
||||
updates: map[string]interface{}{},
|
||||
wantValid: true, // 空更新也是有效的,只是不会更新任何字段
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.uuid != "" && tt.updates != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update fields validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileRepository_FindOneProfileLogic 测试查找单个档案的逻辑
|
||||
func TestProfileRepository_FindOneProfileLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "有档案时返回第一个",
|
||||
profileCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "多个档案时返回第一个",
|
||||
profileCount: 3,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "没有档案时应该错误",
|
||||
profileCount: 0,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:如果没有档案,访问索引0会panic或返回错误
|
||||
hasError := tt.profileCount == 0
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("FindOneProfile logic failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
57
internal/repository/system_config_repository.go
Normal file
57
internal/repository/system_config_repository.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GetSystemConfigByKey 根据键获取配置
|
||||
func GetSystemConfigByKey(key string) (*model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var config model.SystemConfig
|
||||
err := db.Where("key = ?", key).First(&config).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// GetPublicSystemConfigs 获取所有公开配置
|
||||
func GetPublicSystemConfigs() ([]model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var configs []model.SystemConfig
|
||||
err := db.Where("is_public = ?", true).Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// GetAllSystemConfigs 获取所有配置(管理员用)
|
||||
func GetAllSystemConfigs() ([]model.SystemConfig, error) {
|
||||
db := database.MustGetDB()
|
||||
var configs []model.SystemConfig
|
||||
err := db.Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
// UpdateSystemConfig 更新配置
|
||||
func UpdateSystemConfig(config *model.SystemConfig) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(config).Error
|
||||
}
|
||||
|
||||
// UpdateSystemConfigValue 更新配置值
|
||||
func UpdateSystemConfigValue(key, value string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
}
|
||||
146
internal/repository/system_config_repository_test.go
Normal file
146
internal/repository/system_config_repository_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSystemConfigRepository_QueryConditions 测试系统配置查询条件逻辑
|
||||
func TestSystemConfigRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
isPublic bool
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的配置键",
|
||||
key: "site_name",
|
||||
isPublic: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "配置键为空",
|
||||
key: "",
|
||||
isPublic: true,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "公开配置查询",
|
||||
key: "site_name",
|
||||
isPublic: true,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "私有配置查询",
|
||||
key: "secret_key",
|
||||
isPublic: false,
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.key != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_PublicConfigLogic 测试公开配置逻辑
|
||||
func TestSystemConfigRepository_PublicConfigLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isPublic bool
|
||||
wantInclude bool
|
||||
}{
|
||||
{
|
||||
name: "只获取公开配置",
|
||||
isPublic: true,
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "私有配置不应包含",
|
||||
isPublic: false,
|
||||
wantInclude: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:GetPublicSystemConfigs应该只返回is_public=true的配置
|
||||
if tt.isPublic != tt.wantInclude {
|
||||
t.Errorf("Public config logic failed: isPublic=%v, wantInclude=%v", tt.isPublic, tt.wantInclude)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_UpdateValueLogic 测试更新配置值逻辑
|
||||
func TestSystemConfigRepository_UpdateValueLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的键值对",
|
||||
key: "site_name",
|
||||
value: "CarrotSkin",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "键为空",
|
||||
key: "",
|
||||
value: "CarrotSkin",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "值为空(可能有效)",
|
||||
key: "site_name",
|
||||
value: "",
|
||||
wantValid: true, // 空值也可能是有效的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.key != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update value validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSystemConfigRepository_ErrorHandling 测试错误处理逻辑
|
||||
func TestSystemConfigRepository_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isNotFound bool
|
||||
wantNilConfig bool
|
||||
}{
|
||||
{
|
||||
name: "记录未找到应该返回nil配置",
|
||||
isNotFound: true,
|
||||
wantNilConfig: true,
|
||||
},
|
||||
{
|
||||
name: "找到记录应该返回配置",
|
||||
isNotFound: false,
|
||||
wantNilConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证错误处理逻辑:如果是RecordNotFound,返回nil配置
|
||||
if tt.isNotFound != tt.wantNilConfig {
|
||||
t.Errorf("Error handling logic failed: isNotFound=%v, wantNilConfig=%v", tt.isNotFound, tt.wantNilConfig)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
231
internal/repository/texture_repository.go
Normal file
231
internal/repository/texture_repository.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateTexture 创建材质
|
||||
func CreateTexture(texture *model.Texture) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(texture).Error
|
||||
}
|
||||
|
||||
// FindTextureByID 根据ID查找材质
|
||||
func FindTextureByID(id int64) (*model.Texture, error) {
|
||||
db := database.MustGetDB()
|
||||
var texture model.Texture
|
||||
err := db.Preload("Uploader").First(&texture, id).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &texture, nil
|
||||
}
|
||||
|
||||
// FindTextureByHash 根据Hash查找材质
|
||||
func FindTextureByHash(hash string) (*model.Texture, error) {
|
||||
db := database.MustGetDB()
|
||||
var texture model.Texture
|
||||
err := db.Where("hash = ?", hash).First(&texture).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &texture, nil
|
||||
}
|
||||
|
||||
// FindTexturesByUploaderID 根据上传者ID查找材质列表
|
||||
func FindTexturesByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
func SearchTextures(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := db.Model(&model.Texture{}).Where("status = 1")
|
||||
|
||||
// 公开筛选
|
||||
if publicOnly {
|
||||
query = query.Where("is_public = ?", true)
|
||||
}
|
||||
|
||||
// 类型筛选
|
||||
if textureType != "" {
|
||||
query = query.Where("type = ?", textureType)
|
||||
}
|
||||
|
||||
// 关键词搜索
|
||||
if keyword != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
func UpdateTexture(texture *model.Texture) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(texture).Error
|
||||
}
|
||||
|
||||
// UpdateTextureFields 更新材质指定字段
|
||||
func UpdateTextureFields(id int64, fields map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质(软删除)
|
||||
func DeleteTexture(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
// IncrementTextureDownloadCount 增加下载次数
|
||||
func IncrementTextureDownloadCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
// IncrementTextureFavoriteCount 增加收藏次数
|
||||
func IncrementTextureFavoriteCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
// DecrementTextureFavoriteCount 减少收藏次数
|
||||
func DecrementTextureFavoriteCount(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
|
||||
}
|
||||
|
||||
// CreateTextureDownloadLog 创建下载日志
|
||||
func CreateTextureDownloadLog(log *model.TextureDownloadLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// IsTextureFavorited 检查是否已收藏
|
||||
func IsTextureFavorited(userID, textureID int64) (bool, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.UserTextureFavorite{}).
|
||||
Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// AddTextureFavorite 添加收藏
|
||||
func AddTextureFavorite(userID, textureID int64) error {
|
||||
db := database.MustGetDB()
|
||||
favorite := &model.UserTextureFavorite{
|
||||
UserID: userID,
|
||||
TextureID: textureID,
|
||||
}
|
||||
return db.Create(favorite).Error
|
||||
}
|
||||
|
||||
// RemoveTextureFavorite 取消收藏
|
||||
func RemoveTextureFavorite(userID, textureID int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Delete(&model.UserTextureFavorite{}).Error
|
||||
}
|
||||
|
||||
// GetUserTextureFavorites 获取用户收藏的材质列表
|
||||
func GetUserTextureFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
// 子查询获取收藏的材质ID
|
||||
subQuery := db.Model(&model.UserTextureFavorite{}).
|
||||
Select("texture_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := db.Model(&model.Texture{}).
|
||||
Where("id IN (?) AND status = 1", subQuery)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
err := query.Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
// CountTexturesByUploaderID 统计用户上传的材质数量
|
||||
func CountTexturesByUploaderID(uploaderID int64) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.Texture{}).
|
||||
Where("uploader_id = ? AND status != -1", uploaderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
89
internal/repository/token_repository.go
Normal file
89
internal/repository/token_repository.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
)
|
||||
|
||||
func CreateToken(token *model.Token) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(token).Error
|
||||
}
|
||||
|
||||
func GetTokensByUserId(userId int64) ([]*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
tokens := make([]*model.Token, 0)
|
||||
err := db.Where("user_id = ?", userId).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func BatchDeleteTokens(tokensToDelete []string) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
if len(tokensToDelete) == 0 {
|
||||
return 0, nil // 无需要删除的令牌,直接返回
|
||||
}
|
||||
result := db.Where("access_token IN ?", tokensToDelete).Delete(&model.Token{})
|
||||
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func FindTokenByID(accessToken string) (*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
var tokens []*model.Token
|
||||
err := db.Where("_id = ?", accessToken).Find(&tokens).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tokens[0], nil
|
||||
}
|
||||
|
||||
func GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func GetTokenByAccessToken(accessToken string) (*model.Token, error) {
|
||||
db := database.MustGetDB()
|
||||
var token model.Token
|
||||
err := db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func DeleteTokenByAccessToken(accessToken string) error {
|
||||
db := database.MustGetDB()
|
||||
err := db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteTokenByUserId(userId int64) error {
|
||||
db := database.MustGetDB()
|
||||
err := db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
123
internal/repository/token_repository_test.go
Normal file
123
internal/repository/token_repository_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTokenRepository_BatchDeleteLogic 测试批量删除逻辑
|
||||
func TestTokenRepository_BatchDeleteLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokensToDelete []string
|
||||
wantCount int64
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "有效的token列表",
|
||||
tokensToDelete: []string{"token1", "token2", "token3"},
|
||||
wantCount: 3,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "空列表应该返回0",
|
||||
tokensToDelete: []string{},
|
||||
wantCount: 0,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "单个token",
|
||||
tokensToDelete: []string{"token1"},
|
||||
wantCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证批量删除逻辑:空列表应该直接返回0
|
||||
if len(tt.tokensToDelete) == 0 {
|
||||
if tt.wantCount != 0 {
|
||||
t.Errorf("Empty list should return count 0, got %d", tt.wantCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_QueryConditions 测试token查询条件逻辑
|
||||
func TestTokenRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
userID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的access token",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "access token为空",
|
||||
accessToken: "",
|
||||
userID: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0",
|
||||
accessToken: "valid-token-123",
|
||||
userID: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.accessToken != "" && tt.userID > 0
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenRepository_FindTokenByIDLogic 测试根据ID查找token的逻辑
|
||||
func TestTokenRepository_FindTokenByIDLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
resultCount int
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 1,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "未找到token",
|
||||
accessToken: "token-123",
|
||||
resultCount: 0,
|
||||
wantError: true, // 访问索引0会panic
|
||||
},
|
||||
{
|
||||
name: "找到多个token(异常情况)",
|
||||
accessToken: "token-123",
|
||||
resultCount: 2,
|
||||
wantError: false, // 返回第一个
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:如果结果为空,访问索引0会出错
|
||||
hasError := tt.resultCount == 0
|
||||
if hasError != tt.wantError {
|
||||
t.Errorf("FindTokenByID logic failed: got error=%v, want error=%v", hasError, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
136
internal/repository/user_repository.go
Normal file
136
internal/repository/user_repository.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateUser 创建用户
|
||||
func CreateUser(user *model.User) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindUserByID 根据ID查找用户
|
||||
func FindUserByID(id int64) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("id = ? AND status != -1", id).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByUsername 根据用户名查找用户
|
||||
func FindUserByUsername(username string) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("username = ? AND status != -1", username).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindUserByEmail 根据邮箱查找用户
|
||||
func FindUserByEmail(email string) (*model.User, error) {
|
||||
db := database.MustGetDB()
|
||||
var user model.User
|
||||
err := db.Where("email = ? AND status != -1", email).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUser 更新用户
|
||||
func UpdateUser(user *model.User) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(user).Error
|
||||
}
|
||||
|
||||
// UpdateUserFields 更新指定字段
|
||||
func UpdateUserFields(id int64, fields map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
// DeleteUser 软删除用户
|
||||
func DeleteUser(id int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
// CreateLoginLog 创建登录日志
|
||||
func CreateLoginLog(log *model.UserLoginLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// CreatePointLog 创建积分日志
|
||||
func CreatePointLog(log *model.UserPointLog) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(log).Error
|
||||
}
|
||||
|
||||
// UpdateUserPoints 更新用户积分(事务)
|
||||
func UpdateUserPoints(userID int64, amount int, changeType, reason string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 获取当前用户积分
|
||||
var user model.User
|
||||
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
balanceBefore := user.Points
|
||||
balanceAfter := balanceBefore + amount
|
||||
|
||||
// 检查积分是否足够
|
||||
if balanceAfter < 0 {
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
// 更新用户积分
|
||||
if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建积分日志
|
||||
log := &model.UserPointLog{
|
||||
UserID: userID,
|
||||
ChangeType: changeType,
|
||||
Amount: amount,
|
||||
BalanceBefore: balanceBefore,
|
||||
BalanceAfter: balanceAfter,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
return tx.Create(log).Error
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", userID).Update("avatar", avatarURL).Error
|
||||
}
|
||||
|
||||
// UpdateUserEmail 更新用户邮箱
|
||||
func UpdateUserEmail(userID int64, email string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.User{}).Where("id = ?", userID).Update("email", email).Error
|
||||
}
|
||||
155
internal/repository/user_repository_test.go
Normal file
155
internal/repository/user_repository_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestUserRepository_QueryConditions 测试用户查询条件逻辑
|
||||
func TestUserRepository_QueryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id int64
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID和状态",
|
||||
id: 1,
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0时无效",
|
||||
id: 0,
|
||||
status: 1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(已删除)应该被排除",
|
||||
id: 1,
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为0(禁用)可能有效",
|
||||
id: 1,
|
||||
status: 0,
|
||||
wantValid: true, // 查询条件中只排除-1
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试查询条件逻辑:status != -1
|
||||
isValid := tt.id > 0 && tt.status != -1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Query condition validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_DeleteLogic 测试软删除逻辑
|
||||
func TestUserRepository_DeleteLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oldStatus int16
|
||||
newStatus int16
|
||||
}{
|
||||
{
|
||||
name: "软删除应该将状态设置为-1",
|
||||
oldStatus: 1,
|
||||
newStatus: -1,
|
||||
},
|
||||
{
|
||||
name: "从禁用状态删除",
|
||||
oldStatus: 0,
|
||||
newStatus: -1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证软删除逻辑:状态应该变为-1
|
||||
if tt.newStatus != -1 {
|
||||
t.Errorf("Delete should set status to -1, got %d", tt.newStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_UpdateFieldsLogic 测试更新字段逻辑
|
||||
func TestUserRepository_UpdateFieldsLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields map[string]interface{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的更新字段",
|
||||
fields: map[string]interface{}{
|
||||
"email": "new@example.com",
|
||||
"avatar": "https://example.com/avatar.png",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "空字段映射",
|
||||
fields: map[string]interface{}{},
|
||||
wantValid: true, // 空映射也是有效的,只是不会更新任何字段
|
||||
},
|
||||
{
|
||||
name: "包含nil值的字段",
|
||||
fields: map[string]interface{}{
|
||||
"email": "new@example.com",
|
||||
"avatar": nil,
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证字段映射逻辑
|
||||
isValid := tt.fields != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Update fields validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_ErrorHandling 测试错误处理逻辑
|
||||
func TestUserRepository_ErrorHandling(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
isNotFound bool
|
||||
wantNilUser bool
|
||||
}{
|
||||
{
|
||||
name: "记录未找到应该返回nil用户",
|
||||
err: nil, // 模拟gorm.ErrRecordNotFound
|
||||
isNotFound: true,
|
||||
wantNilUser: true,
|
||||
},
|
||||
{
|
||||
name: "其他错误应该返回错误",
|
||||
err: nil,
|
||||
isNotFound: false,
|
||||
wantNilUser: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试错误处理逻辑:如果是RecordNotFound,返回nil用户;否则返回错误
|
||||
if tt.isNotFound {
|
||||
if !tt.wantNilUser {
|
||||
t.Error("RecordNotFound should return nil user")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
16
internal/repository/yggdrasil_repository.go
Normal file
16
internal/repository/yggdrasil_repository.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
)
|
||||
|
||||
func GetYggdrasilPasswordById(Id int64) (string, error) {
|
||||
db := database.MustGetDB()
|
||||
var yggdrasil model.Yggdrasil
|
||||
err := db.Where("id = ?", Id).First(&yggdrasil).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return yggdrasil.Password, nil
|
||||
}
|
||||
165
internal/service/captcha_service.go
Normal file
165
internal/service/captcha_service.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/wenlng/go-captcha-assets/resources/imagesv2"
|
||||
"github.com/wenlng/go-captcha-assets/resources/tiles"
|
||||
"github.com/wenlng/go-captcha/v2/slide"
|
||||
)
|
||||
|
||||
var (
|
||||
slideTileCapt slide.Captcha
|
||||
cfg *config.Config
|
||||
)
|
||||
|
||||
// 常量定义(业务相关配置,与Redis连接配置分离)
|
||||
const (
|
||||
redisKeyPrefix = "captcha:" // Redis键前缀(便于区分业务)
|
||||
paddingValue = 3 // 验证允许的误差像素(±3px)
|
||||
)
|
||||
|
||||
// Init 验证码图初始化
|
||||
func init() {
|
||||
cfg, _ = config.Load()
|
||||
// 从默认仓库中获取主图
|
||||
builder := slide.NewBuilder()
|
||||
bgImage, err := imagesv2.GetImages()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
// 滑块形状获取
|
||||
graphs := getSlideTileGraphArr()
|
||||
|
||||
builder.SetResources(
|
||||
slide.WithGraphImages(graphs),
|
||||
slide.WithBackgrounds(bgImage),
|
||||
)
|
||||
slideTileCapt = builder.Make()
|
||||
if slideTileCapt == nil {
|
||||
log.Fatalln("验证码实例初始化失败")
|
||||
}
|
||||
}
|
||||
|
||||
// getSlideTileGraphArr 滑块选择
|
||||
func getSlideTileGraphArr() []*slide.GraphImage {
|
||||
graphs, err := tiles.GetTiles()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
var newGraphs = make([]*slide.GraphImage, 0, len(graphs))
|
||||
for i := 0; i < len(graphs); i++ {
|
||||
graph := graphs[i]
|
||||
newGraphs = append(newGraphs, &slide.GraphImage{
|
||||
OverlayImage: graph.OverlayImage,
|
||||
MaskImage: graph.MaskImage,
|
||||
ShadowImage: graph.ShadowImage,
|
||||
})
|
||||
}
|
||||
return newGraphs
|
||||
}
|
||||
|
||||
// RedisData 存储到Redis的验证信息(仅包含校验必需字段)
|
||||
type RedisData struct {
|
||||
Tx int `json:"tx"` // 滑块目标X坐标
|
||||
Ty int `json:"ty"` // 滑块目标Y坐标
|
||||
}
|
||||
|
||||
// GenerateCaptchaData 提取生成验证码的相关信息
|
||||
func GenerateCaptchaData(ctx context.Context, redisClient *redis.Client) (string, string, string, int, error) {
|
||||
// 生成uuid作为验证码进程唯一标识
|
||||
captchaID := uuid.NewString()
|
||||
if captchaID == "" {
|
||||
return "", "", "", 0, errors.New("生成验证码唯一标识失败")
|
||||
}
|
||||
|
||||
captData, err := slideTileCapt.Generate()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
blockData := captData.GetData()
|
||||
if blockData == nil {
|
||||
return "", "", "", 0, errors.New("获取验证码数据失败")
|
||||
}
|
||||
block, _ := json.Marshal(blockData)
|
||||
var blockMap map[string]interface{}
|
||||
|
||||
if err := json.Unmarshal(block, &blockMap); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("反序列化为map失败: %w", err)
|
||||
}
|
||||
// 提取x和y并转换为int类型
|
||||
tx, ok := blockMap["x"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将x转换为float64")
|
||||
}
|
||||
var x = int(tx)
|
||||
ty, ok := blockMap["y"].(float64)
|
||||
if !ok {
|
||||
return "", "", "", 0, errors.New("无法将y转换为float64")
|
||||
}
|
||||
var y = int(ty)
|
||||
var mBase64, tBase64 string
|
||||
mBase64, err = captData.GetMasterImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("主图转换为base64失败: %w", err)
|
||||
}
|
||||
tBase64, err = captData.GetTileImage().ToBase64()
|
||||
if err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("滑块图转换为base64失败: %w", err)
|
||||
}
|
||||
redisData := RedisData{
|
||||
Tx: x,
|
||||
Ty: y,
|
||||
}
|
||||
redisDataJSON, _ := json.Marshal(redisData)
|
||||
redisKey := redisKeyPrefix + captchaID
|
||||
expireTime := 300 * time.Second
|
||||
|
||||
// 使用注入的Redis客户端
|
||||
if err := redisClient.Set(
|
||||
ctx,
|
||||
redisKey,
|
||||
redisDataJSON,
|
||||
expireTime,
|
||||
); err != nil {
|
||||
return "", "", "", 0, fmt.Errorf("存储验证码到Redis失败: %w", err)
|
||||
}
|
||||
return mBase64, tBase64, captchaID, y - 10, nil
|
||||
}
|
||||
|
||||
// VerifyCaptchaData 验证用户验证码
|
||||
func VerifyCaptchaData(ctx context.Context, redisClient *redis.Client, dx int, id string) (bool, error) {
|
||||
redisKey := redisKeyPrefix + id
|
||||
|
||||
// 从Redis获取验证信息,使用注入的客户端
|
||||
dataJSON, err := redisClient.Get(ctx, redisKey)
|
||||
if err != nil {
|
||||
if redisClient.Nil(err) { // 使用封装客户端的Nil错误
|
||||
return false, errors.New("验证码已过期或无效")
|
||||
}
|
||||
return false, fmt.Errorf("Redis查询失败: %w", err)
|
||||
}
|
||||
var redisData RedisData
|
||||
if err := json.Unmarshal([]byte(dataJSON), &redisData); err != nil {
|
||||
return false, fmt.Errorf("解析Redis数据失败: %w", err)
|
||||
}
|
||||
tx := redisData.Tx
|
||||
ty := redisData.Ty
|
||||
ok := slide.Validate(dx, ty, tx, ty, paddingValue)
|
||||
|
||||
// 验证后立即删除Redis记录(防止重复使用)
|
||||
if ok {
|
||||
if err := redisClient.Del(ctx, redisKey); err != nil {
|
||||
// 记录警告但不影响验证结果
|
||||
log.Printf("删除验证码Redis记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
174
internal/service/captcha_service_test.go
Normal file
174
internal/service/captcha_service_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCaptchaService_Constants 测试验证码服务常量
|
||||
func TestCaptchaService_Constants(t *testing.T) {
|
||||
if redisKeyPrefix != "captcha:" {
|
||||
t.Errorf("redisKeyPrefix = %s, want 'captcha:'", redisKeyPrefix)
|
||||
}
|
||||
|
||||
if paddingValue != 3 {
|
||||
t.Errorf("paddingValue = %d, want 3", paddingValue)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisData_Structure 测试RedisData结构
|
||||
func TestRedisData_Structure(t *testing.T) {
|
||||
data := RedisData{
|
||||
Tx: 100,
|
||||
Ty: 200,
|
||||
}
|
||||
|
||||
if data.Tx != 100 {
|
||||
t.Errorf("RedisData.Tx = %d, want 100", data.Tx)
|
||||
}
|
||||
|
||||
if data.Ty != 200 {
|
||||
t.Errorf("RedisData.Ty = %d, want 200", data.Ty)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_Logic 测试生成验证码的逻辑部分
|
||||
func TestGenerateCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
captchaID string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的captchaID",
|
||||
captchaID: "test-uuid-123",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空的captchaID应该失败",
|
||||
captchaID: "",
|
||||
wantErr: true,
|
||||
errContains: "生成验证码唯一标识失败",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试UUID验证逻辑
|
||||
if tt.captchaID == "" {
|
||||
if !tt.wantErr {
|
||||
t.Error("空captchaID应该返回错误")
|
||||
}
|
||||
} else {
|
||||
if tt.wantErr {
|
||||
t.Error("非空captchaID不应该返回错误")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_Logic 测试验证验证码的逻辑部分
|
||||
func TestVerifyCaptchaData_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dx int
|
||||
tx int
|
||||
ty int
|
||||
padding int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "精确匹配",
|
||||
dx: 100,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(+3)",
|
||||
dx: 103,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "在误差范围内(-3)",
|
||||
dx: 97,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(+4)",
|
||||
dx: 104,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "超出误差范围(-4)",
|
||||
dx: 96,
|
||||
tx: 100,
|
||||
ty: 200,
|
||||
padding: 3,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 验证逻辑:dx应该在[tx-padding, tx+padding]范围内
|
||||
diff := tt.dx - tt.tx
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
isValid := diff <= tt.padding
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v (dx=%d, tx=%d, padding=%d)", isValid, tt.wantValid, tt.dx, tt.tx, tt.padding)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerifyCaptchaData_RedisKey 测试Redis键生成逻辑
|
||||
func TestVerifyCaptchaData_RedisKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的Redis键",
|
||||
id: "test-id-123",
|
||||
expected: "captcha:test-id-123",
|
||||
},
|
||||
{
|
||||
name: "空ID",
|
||||
id: "",
|
||||
expected: "captcha:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redisKey := redisKeyPrefix + tt.id
|
||||
if redisKey != tt.expected {
|
||||
t.Errorf("Redis key = %s, want %s", redisKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateCaptchaData_ExpireTime 测试过期时间
|
||||
func TestGenerateCaptchaData_ExpireTime(t *testing.T) {
|
||||
expectedExpireTime := 300 * time.Second
|
||||
if expectedExpireTime != 5*time.Minute {
|
||||
t.Errorf("Expire time should be 5 minutes")
|
||||
}
|
||||
}
|
||||
13
internal/service/common.go
Normal file
13
internal/service/common.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
jsoniter "github.com/json-iterator/go"
|
||||
)
|
||||
|
||||
// 统一的json变量,用于整个service包
|
||||
var json = jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
||||
// DefaultTimeout 默认超时时间
|
||||
const DefaultTimeout = 5 * time.Second
|
||||
48
internal/service/common_test.go
Normal file
48
internal/service/common_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCommon_Constants 测试common包的常量
|
||||
func TestCommon_Constants(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommon_JSON 测试JSON变量
|
||||
func TestCommon_JSON(t *testing.T) {
|
||||
// 验证json变量不为nil
|
||||
if json == nil {
|
||||
t.Error("json 变量不应为nil")
|
||||
}
|
||||
|
||||
// 测试JSON序列化
|
||||
testData := map[string]interface{}{
|
||||
"name": "test",
|
||||
"age": 25,
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if len(bytes) == 0 {
|
||||
t.Error("json.Marshal() 返回的字节不应为空")
|
||||
}
|
||||
|
||||
// 测试JSON反序列化
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(bytes, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() 失败: %v", err)
|
||||
}
|
||||
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("反序列化结果 name = %v, want 'test'", result["name"])
|
||||
}
|
||||
}
|
||||
|
||||
252
internal/service/profile_service.go
Normal file
252
internal/service/profile_service.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) {
|
||||
// 1. 验证用户存在
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("用户不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
|
||||
if user.Status != 1 {
|
||||
return nil, fmt.Errorf("用户状态异常")
|
||||
}
|
||||
|
||||
// 2. 检查角色名是否已存在
|
||||
existingName, err := repository.FindProfileByName(name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
|
||||
// 3. 生成UUID
|
||||
profileUUID := uuid.New().String()
|
||||
|
||||
// 4. 生成RSA密钥对
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA密钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 创建档案
|
||||
profile := &model.Profile{
|
||||
UUID: profileUUID,
|
||||
UserID: userID,
|
||||
Name: name,
|
||||
RSAPrivateKey: privateKey,
|
||||
IsActive: true, // 新创建的档案默认为活跃状态
|
||||
}
|
||||
|
||||
if err := repository.CreateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("创建档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 将用户的其他档案设置为非活跃
|
||||
if err := repository.SetActiveProfile(profileUUID, userID); err != nil {
|
||||
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetProfileByUUID 获取档案详情
|
||||
func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// GetUserProfiles 获取用户的所有档案
|
||||
func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("档案不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return nil, fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 检查角色名是否重复
|
||||
if name != nil && *name != profile.Name {
|
||||
existingName, err := repository.FindProfileByName(*name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
if existingName != nil {
|
||||
return nil, fmt.Errorf("角色名已被使用")
|
||||
}
|
||||
profile.Name = *name
|
||||
}
|
||||
|
||||
// 4. 更新皮肤和披风
|
||||
if skinID != nil {
|
||||
profile.SkinID = skinID
|
||||
}
|
||||
if capeID != nil {
|
||||
profile.CapeID = capeID
|
||||
}
|
||||
|
||||
// 5. 保存更新
|
||||
if err := repository.UpdateProfile(profile); err != nil {
|
||||
return nil, fmt.Errorf("更新档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 6. 重新加载关联数据
|
||||
return repository.FindProfileByUUID(uuid)
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 删除档案
|
||||
if err := repository.DeleteProfile(uuid); err != nil {
|
||||
return fmt.Errorf("删除档案失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetActiveProfile 设置活跃档案
|
||||
func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error {
|
||||
// 1. 查询档案
|
||||
profile, err := repository.FindProfileByUUID(uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("档案不存在")
|
||||
}
|
||||
return fmt.Errorf("查询档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 验证权限
|
||||
if profile.UserID != userID {
|
||||
return fmt.Errorf("无权操作此档案")
|
||||
}
|
||||
|
||||
// 3. 设置活跃状态
|
||||
if err := repository.SetActiveProfile(uuid, userID); err != nil {
|
||||
return fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 更新最后使用时间
|
||||
if err := repository.UpdateProfileLastUsedAt(uuid); err != nil {
|
||||
return fmt.Errorf("更新使用时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckProfileLimit 检查用户档案数量限制
|
||||
func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error {
|
||||
count, err := repository.CountProfilesByUserID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询档案数量失败: %w", err)
|
||||
}
|
||||
|
||||
if int(count) >= maxProfiles {
|
||||
return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateRSAPrivateKey 生成RSA-2048私钥(PEM格式)
|
||||
func generateRSAPrivateKey() (string, error) {
|
||||
// 生成2048位RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将私钥编码为PEM格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privateKeyBytes,
|
||||
})
|
||||
|
||||
return string(privateKeyPEM), nil
|
||||
}
|
||||
|
||||
func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) {
|
||||
if userId == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
}
|
||||
return false, fmt.Errorf("验证配置文件失败: %w", err)
|
||||
}
|
||||
return profile.UserID == userId, nil
|
||||
}
|
||||
|
||||
func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) {
|
||||
profiles, err := repository.GetProfilesByNames(names)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// GetProfileKeyPair 从 PostgreSQL 获取密钥对(GORM 实现,无手动 SQL)
|
||||
func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) {
|
||||
keyPair, err := repository.GetProfileKeyPair(profileId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
406
internal/service/profile_service_test.go
Normal file
406
internal/service/profile_service_test.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestProfileService_Validation 测试Profile服务验证逻辑
|
||||
func TestProfileService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户ID和角色名",
|
||||
userID: 1,
|
||||
profileName: "TestProfile",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID为0时无效",
|
||||
userID: 0,
|
||||
profileName: "TestProfile",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "角色名为空时无效",
|
||||
userID: 1,
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0 && tt.profileName != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_StatusValidation 测试用户状态验证
|
||||
func TestProfileService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为0(禁用)时无效",
|
||||
status: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.status == 1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProfileService_IsActiveDefault 测试Profile默认活跃状态
|
||||
func TestProfileService_IsActiveDefault(t *testing.T) {
|
||||
// 新创建的档案默认为活跃状态
|
||||
isActive := true
|
||||
if !isActive {
|
||||
t.Error("新创建的Profile应该默认为活跃状态")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_PermissionCheck 测试更新Profile的权限检查逻辑
|
||||
func TestUpdateProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝操作",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateProfile_NameValidation 测试更新Profile时名称验证逻辑
|
||||
func TestUpdateProfile_NameValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
currentName string
|
||||
newName *string
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "名称未改变,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("TestProfile"),
|
||||
shouldCheck: false,
|
||||
},
|
||||
{
|
||||
name: "名称改变,需要检查",
|
||||
currentName: "TestProfile",
|
||||
newName: stringPtr("NewProfile"),
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "名称为nil,不检查",
|
||||
currentName: "TestProfile",
|
||||
newName: nil,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCheck := tt.newName != nil && *tt.newName != tt.currentName
|
||||
if shouldCheck != tt.shouldCheck {
|
||||
t.Errorf("Name validation check failed: got %v, want %v", shouldCheck, tt.shouldCheck)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteProfile_PermissionCheck 测试删除Profile的权限检查
|
||||
func TestDeleteProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝删除",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetActiveProfile_PermissionCheck 测试设置活跃Profile的权限检查
|
||||
func TestSetActiveProfile_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配,允许设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配,拒绝设置",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.profileUserID != tt.requestUserID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckProfileLimit_Logic 测试Profile数量限制检查逻辑
|
||||
func TestCheckProfileLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int
|
||||
maxProfiles int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxProfiles: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxProfiles: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= tt.maxProfiles
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_InputValidation 测试ValidateProfileByUserID输入验证
|
||||
func TestValidateProfileByUserID_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
uuid string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效输入",
|
||||
userID: 1,
|
||||
uuid: "test-uuid",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "userID为0",
|
||||
userID: 0,
|
||||
uuid: "test-uuid",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "uuid为空",
|
||||
userID: 1,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都无效",
|
||||
userID: 0,
|
||||
uuid: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.userID == 0 || tt.uuid == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateProfileByUserID_UserIDMatching 测试用户ID匹配逻辑
|
||||
func TestValidateProfileByUserID_UserIDMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileUserID int64
|
||||
requestUserID int64
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户ID匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户ID不匹配",
|
||||
profileUserID: 1,
|
||||
requestUserID: 2,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.profileUserID == tt.requestUserID
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("UserID matching failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey 测试RSA私钥生成
|
||||
func TestGenerateRSAPrivateKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "生成RSA私钥",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
privateKey, err := generateRSAPrivateKey()
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if privateKey == "" {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥不应为空")
|
||||
}
|
||||
// 验证PEM格式
|
||||
if len(privateKey) < 100 {
|
||||
t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey))
|
||||
}
|
||||
// 验证包含PEM头部
|
||||
if !contains(privateKey, "BEGIN RSA PRIVATE KEY") {
|
||||
t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateRSAPrivateKey_Uniqueness 测试RSA私钥唯一性
|
||||
func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) {
|
||||
keys := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
key, err := generateRSAPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("generateRSAPrivateKey() 失败: %v", err)
|
||||
}
|
||||
if keys[key] {
|
||||
t.Errorf("第%d次生成的密钥与之前重复", i+1)
|
||||
}
|
||||
keys[key] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr ||
|
||||
(len(s) > len(substr) && (s[:len(substr)] == substr ||
|
||||
s[len(s)-len(substr):] == substr ||
|
||||
containsMiddle(s, substr))))
|
||||
}
|
||||
|
||||
func containsMiddle(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
97
internal/service/serialize_service.go
Normal file
97
internal/service/serialize_service.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/redis"
|
||||
"encoding/base64"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Property struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
}
|
||||
|
||||
func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, p model.Profile) map[string]interface{} {
|
||||
var err error
|
||||
|
||||
// 创建基本材质数据
|
||||
texturesMap := make(map[string]interface{})
|
||||
textures := map[string]interface{}{
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
"profileId": p.UUID,
|
||||
"profileName": p.Name,
|
||||
"textures": texturesMap,
|
||||
}
|
||||
|
||||
// 处理皮肤
|
||||
if p.SkinID != nil {
|
||||
skin, err := GetTextureByID(db, *p.SkinID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID))
|
||||
} else {
|
||||
texturesMap["SKIN"] = map[string]interface{}{
|
||||
"url": skin.URL,
|
||||
"metadata": skin.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理披风
|
||||
if p.CapeID != nil {
|
||||
cape, err := GetTextureByID(db, *p.CapeID)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID))
|
||||
} else {
|
||||
texturesMap["CAPE"] = map[string]interface{}{
|
||||
"url": cape.URL,
|
||||
"metadata": cape.Size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将textures编码为base64
|
||||
bytes, err := json.Marshal(textures)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
textureData := base64.StdEncoding.EncodeToString(bytes)
|
||||
signature, err := SignStringWithSHA1withRSA(logger, redisClient, textureData)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名textures失败: ", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
data := map[string]interface{}{
|
||||
"id": p.UUID,
|
||||
"name": p.Name,
|
||||
"properties": []Property{
|
||||
{
|
||||
Name: "textures",
|
||||
Value: textureData,
|
||||
Signature: signature,
|
||||
},
|
||||
},
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func SerializeUser(logger *zap.Logger, u *model.User, UUID string) map[string]interface{} {
|
||||
if u == nil {
|
||||
logger.Error("[ERROR] 尝试序列化空用户")
|
||||
return nil
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"id": UUID,
|
||||
"properties": u.Properties,
|
||||
}
|
||||
return data
|
||||
}
|
||||
172
internal/service/serialize_service_test.go
Normal file
172
internal/service/serialize_service_test.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户
|
||||
func TestSerializeUser_NilUser(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
result := SerializeUser(logger, nil, "test-uuid")
|
||||
if result != nil {
|
||||
t.Error("SerializeUser() 对于nil用户应返回nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_ActualCall 实际调用SerializeUser函数
|
||||
func TestSerializeUser_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
user := &model.User{
|
||||
ID: 1,
|
||||
Username: "testuser",
|
||||
Email: "test@example.com",
|
||||
Properties: "{}",
|
||||
}
|
||||
|
||||
result := SerializeUser(logger, user, "test-uuid-123")
|
||||
if result == nil {
|
||||
t.Fatal("SerializeUser() 返回的结果不应为nil")
|
||||
}
|
||||
|
||||
if result["id"] != "test-uuid-123" {
|
||||
t.Errorf("id = %v, want 'test-uuid-123'", result["id"])
|
||||
}
|
||||
|
||||
if result["properties"] == nil {
|
||||
t.Error("properties 不应为nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProperty_Structure 测试Property结构
|
||||
func TestProperty_Structure(t *testing.T) {
|
||||
prop := Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
}
|
||||
|
||||
if prop.Name == "" {
|
||||
t.Error("Property name should not be empty")
|
||||
}
|
||||
|
||||
if prop.Value == "" {
|
||||
t.Error("Property value should not be empty")
|
||||
}
|
||||
|
||||
// Signature是可选的
|
||||
if prop.Signature == "" {
|
||||
t.Log("Property signature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeService_PropertyFields 测试Property字段
|
||||
func TestSerializeService_PropertyFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
property Property
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "缺少Name的Property",
|
||||
property: Property{
|
||||
Name: "",
|
||||
Value: "base64value",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "缺少Value的Property",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "",
|
||||
Signature: "signature",
|
||||
},
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "没有Signature的Property(有效)",
|
||||
property: Property{
|
||||
Name: "textures",
|
||||
Value: "base64value",
|
||||
Signature: "",
|
||||
},
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.property.Name != "" && tt.property.Value != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Property validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeUser_InputValidation 测试SerializeUser输入验证
|
||||
func TestSerializeUser_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user *struct{}
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "用户不为nil",
|
||||
user: &struct{}{},
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户为nil",
|
||||
user: nil,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.user != nil
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Input validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSerializeProfile_Structure 测试SerializeProfile返回结构
|
||||
func TestSerializeProfile_Structure(t *testing.T) {
|
||||
// 测试返回的数据结构应该包含的字段
|
||||
expectedFields := []string{"id", "name", "properties"}
|
||||
|
||||
// 验证字段名称
|
||||
for _, field := range expectedFields {
|
||||
if field == "" {
|
||||
t.Error("Field name should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证properties应该是数组
|
||||
// 注意:这里只测试逻辑,不测试实际序列化
|
||||
}
|
||||
|
||||
// TestSerializeProfile_PropertyName 测试Property名称
|
||||
func TestSerializeProfile_PropertyName(t *testing.T) {
|
||||
// textures是固定的属性名
|
||||
propertyName := "textures"
|
||||
if propertyName != "textures" {
|
||||
t.Errorf("Property name = %s, want 'textures'", propertyName)
|
||||
}
|
||||
}
|
||||
605
internal/service/signature_service.go
Normal file
605
internal/service/signature_service.go
Normal file
@@ -0,0 +1,605 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
// RSA密钥长度
|
||||
RSAKeySize = 4096
|
||||
|
||||
// Redis密钥名称
|
||||
PrivateKeyRedisKey = "private_key"
|
||||
PublicKeyRedisKey = "public_key"
|
||||
|
||||
// 密钥过期时间
|
||||
KeyExpirationTime = time.Hour * 24 * 7
|
||||
|
||||
// 证书相关
|
||||
CertificateRefreshInterval = time.Hour * 24 // 证书刷新时间间隔
|
||||
CertificateExpirationPeriod = time.Hour * 24 * 7 // 证书过期时间
|
||||
)
|
||||
|
||||
// PlayerCertificate 表示玩家证书信息
|
||||
type PlayerCertificate struct {
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
RefreshedAfter string `json:"refreshedAfter"`
|
||||
PublicKeySignature string `json:"publicKeySignature,omitempty"`
|
||||
PublicKeySignatureV2 string `json:"publicKeySignatureV2,omitempty"`
|
||||
KeyPair struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
} `json:"keyPair"`
|
||||
}
|
||||
// SignatureService 保留结构体以保持向后兼容,但推荐使用函数式版本
|
||||
type SignatureService struct {
|
||||
logger *zap.Logger
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
func NewSignatureService(logger *zap.Logger, redisClient *redis.Client) *SignatureService {
|
||||
return &SignatureService{
|
||||
logger: logger,
|
||||
redisClient: redisClient,
|
||||
}
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串并返回Base64编码的签名(函数式版本)
|
||||
func SignStringWithSHA1withRSA(logger *zap.Logger, redisClient *redis.Client, data string) (string, error) {
|
||||
if data == "" {
|
||||
return "", fmt.Errorf("签名数据不能为空")
|
||||
}
|
||||
|
||||
// 获取私钥
|
||||
privateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解码私钥失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("解码私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 计算SHA1哈希
|
||||
hashed := sha1.Sum([]byte(data))
|
||||
|
||||
// 使用RSA-PKCS1v15算法签名
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA1, hashed[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] RSA签名失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("RSA签名失败: %w", err)
|
||||
}
|
||||
|
||||
// Base64编码签名
|
||||
encodedSignature := base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
logger.Info("[INFO] 成功使用SHA1withRSA生成签名,", zap.Any("数据长度:", len(data)))
|
||||
return encodedSignature, nil
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSAService 使用SHA1withRSA签名字符串并返回Base64编码的签名(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
|
||||
return SignStringWithSHA1withRSA(s.logger, s.redisClient, data)
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEM 从Redis获取并解码PEM格式的私钥(函数式版本)
|
||||
func DecodePrivateKeyFromPEM(logger *zap.Logger, redisClient *redis.Client) (*rsa.PrivateKey, error) {
|
||||
// 从Redis获取私钥
|
||||
privateKeyString, err := GetPrivateKeyFromRedis(logger, redisClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 解码PEM格式
|
||||
privateKeyBlock, rest := pem.Decode([]byte(privateKeyString))
|
||||
if privateKeyBlock == nil || len(rest) > 0 {
|
||||
logger.Error("[ERROR] 无效的PEM格式私钥")
|
||||
return nil, fmt.Errorf("无效的PEM格式私钥")
|
||||
}
|
||||
|
||||
// 解析PKCS1格式的私钥
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 解析私钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("解析私钥失败: %w", err)
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedis 从Redis获取私钥(PEM格式)(函数式版本)
|
||||
func GetPrivateKeyFromRedis(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PrivateKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取私钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPrivateKeyFromRedis(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// DecodePrivateKeyFromPEMService 从Redis获取并解码PEM格式的私钥(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) DecodePrivateKeyFromPEM() (*rsa.PrivateKey, error) {
|
||||
return DecodePrivateKeyFromPEM(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GetPrivateKeyFromRedisService 从Redis获取私钥(PEM格式)(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GetPrivateKeyFromRedis() (string, error) {
|
||||
return GetPrivateKeyFromRedis(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPair 生成新的RSA密钥对(函数式版本)
|
||||
func GenerateRSAKeyPair(logger *zap.Logger, redisClient *redis.Client) error {
|
||||
logger.Info("[INFO] 开始生成RSA密钥对", zap.Int("keySize", RSAKeySize))
|
||||
|
||||
// 生成私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("生成RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 编码私钥为PEM格式
|
||||
pemPrivateKey, err := EncodePrivateKeyToPEM(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA私钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取公钥并编码为PEM格式
|
||||
pubKey := privateKey.PublicKey
|
||||
pemPublicKey, err := EncodePublicKeyToPEM(logger, &pubKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码RSA公钥失败: ", zap.Error(err))
|
||||
return fmt.Errorf("编码RSA公钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存密钥对到Redis
|
||||
return SaveKeyPairToRedis(logger, redisClient, string(pemPrivateKey), string(pemPublicKey))
|
||||
}
|
||||
|
||||
// GenerateRSAKeyPairService 生成新的RSA密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GenerateRSAKeyPair() error {
|
||||
return GenerateRSAKeyPair(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEM 将私钥编码为PEM格式(函数式版本)
|
||||
func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
if privateKey == nil {
|
||||
return nil, fmt.Errorf("私钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PRIVATE KEY" 类型
|
||||
pemType := "PRIVATE KEY"
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PRIVATE KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PRIVATE KEY"
|
||||
}
|
||||
|
||||
// 将私钥转换为PKCS1格式
|
||||
privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: privateKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEM 将公钥编码为PEM格式(函数式版本)
|
||||
func EncodePublicKeyToPEM(logger *zap.Logger, publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
if publicKey == nil {
|
||||
return nil, fmt.Errorf("公钥不能为空")
|
||||
}
|
||||
|
||||
// 默认使用 "PUBLIC KEY" 类型
|
||||
pemType := "PUBLIC KEY"
|
||||
var publicKeyBytes []byte
|
||||
var err error
|
||||
|
||||
// 如果指定了类型参数且为 "RSA",则使用 "RSA PUBLIC KEY"
|
||||
if len(keyType) > 0 && keyType[0] == "RSA" {
|
||||
pemType = "RSA PUBLIC KEY"
|
||||
publicKeyBytes = x509.MarshalPKCS1PublicKey(publicKey)
|
||||
} else {
|
||||
// 默认将公钥转换为PKIX格式
|
||||
publicKeyBytes, err = x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 序列化公钥失败: ", zap.Error(err))
|
||||
return nil, fmt.Errorf("序列化公钥失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 编码为PEM格式
|
||||
pemBlock := &pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: publicKeyBytes,
|
||||
}
|
||||
|
||||
return pem.EncodeToMemory(pemBlock), nil
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedis 将RSA密钥对保存到Redis(函数式版本)
|
||||
func SaveKeyPairToRedis(logger *zap.Logger, redisClient *redis.Client, privateKey, publicKey string) error {
|
||||
// 创建上下文并设置超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 使用事务确保两个操作的原子性
|
||||
tx := redisClient.TxPipeline()
|
||||
|
||||
tx.Set(ctx, PrivateKeyRedisKey, privateKey, KeyExpirationTime)
|
||||
tx.Set(ctx, PublicKeyRedisKey, publicKey, KeyExpirationTime)
|
||||
|
||||
// 执行事务
|
||||
_, err := tx.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 保存RSA密钥对到Redis失败: ", zap.Error(err))
|
||||
return fmt.Errorf("保存RSA密钥对到Redis失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功保存RSA密钥对到Redis")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEMService 将私钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePrivateKeyToPEM(privateKey, keyType...)
|
||||
}
|
||||
|
||||
// EncodePublicKeyToPEMService 将公钥编码为PEM格式(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) EncodePublicKeyToPEM(publicKey *rsa.PublicKey, keyType ...string) ([]byte, error) {
|
||||
return EncodePublicKeyToPEM(s.logger, publicKey, keyType...)
|
||||
}
|
||||
|
||||
// SaveKeyPairToRedisService 将RSA密钥对保存到Redis(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) SaveKeyPairToRedis(privateKey, publicKey string) error {
|
||||
return SaveKeyPairToRedis(s.logger, s.redisClient, privateKey, publicKey)
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedisFunc 从Redis获取公钥(PEM格式,函数式版本)
|
||||
func GetPublicKeyFromRedisFunc(logger *zap.Logger, redisClient *redis.Client) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pemBytes, err := redisClient.GetBytes(ctx, PublicKeyRedisKey)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 从Redis获取公钥失败,尝试生成新的密钥对: ", zap.Error(err))
|
||||
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
// 检查获取到的公钥是否为空(key不存在时GetBytes返回nil, nil)
|
||||
if len(pemBytes) == 0 {
|
||||
logger.Info("[INFO] Redis中公钥为空,尝试生成新的密钥对")
|
||||
// 生成新的密钥对
|
||||
err = GenerateRSAKeyPair(logger, redisClient)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成RSA密钥对失败: ", zap.Error(err))
|
||||
return "", fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
}
|
||||
// 递归获取生成的密钥
|
||||
return GetPublicKeyFromRedisFunc(logger, redisClient)
|
||||
}
|
||||
|
||||
return string(pemBytes), nil
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedis 从Redis获取公钥(PEM格式,结构体方法版本)
|
||||
func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
return GetPublicKeyFromRedisFunc(s.logger, s.redisClient)
|
||||
}
|
||||
|
||||
|
||||
// GeneratePlayerCertificate 生成玩家证书(函数式版本)
|
||||
func GeneratePlayerCertificate(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, uuid string) (*PlayerCertificate, error) {
|
||||
if uuid == "" {
|
||||
return nil, fmt.Errorf("UUID不能为空")
|
||||
}
|
||||
logger.Info("[INFO] 开始生成玩家证书,用户UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
|
||||
keyPair, err := repository.GetProfileKeyPair(uuid)
|
||||
if err != nil {
|
||||
logger.Info("[INFO] 获取用户密钥对失败,将创建新密钥对: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair = nil
|
||||
}
|
||||
|
||||
// 如果没有找到密钥对或密钥对已过期,创建一个新的
|
||||
now := time.Now().UTC()
|
||||
if keyPair == nil || keyPair.Refresh.Before(now) || keyPair.PrivateKey == "" || keyPair.PublicKey == "" {
|
||||
logger.Info("[INFO] 为用户创建新的密钥对: %s",
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
keyPair, err = NewKeyPair(logger)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书密钥对失败: %w", err)
|
||||
}
|
||||
// 保存密钥对到数据库
|
||||
err = repository.UpdateProfileKeyPair(uuid, keyPair)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Warn("[WARN] 更新用户密钥对失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
// 继续执行,即使保存失败
|
||||
}
|
||||
}
|
||||
|
||||
// 计算expiresAt的毫秒时间戳
|
||||
expiresAtMillis := keyPair.Expiration.UnixMilli()
|
||||
|
||||
// 准备签名
|
||||
publicKeySignature := ""
|
||||
publicKeySignatureV2 := ""
|
||||
|
||||
// 获取服务器私钥用于签名
|
||||
serverPrivateKey, err := DecodePrivateKeyFromPEM(logger, redisClient)
|
||||
if err != nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 获取服务器私钥失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
)
|
||||
return nil, fmt.Errorf("获取服务器私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 提取公钥DER编码
|
||||
pubPEMBlock, _ := pem.Decode([]byte(keyPair.PublicKey))
|
||||
if pubPEMBlock == nil {
|
||||
// 日志修改:logger → s.logger,zap结构化字段
|
||||
logger.Error("[ERROR] 解码公钥PEM失败",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("publicKey", keyPair.PublicKey),
|
||||
)
|
||||
return nil, fmt.Errorf("解码公钥PEM失败")
|
||||
}
|
||||
pubDER := pubPEMBlock.Bytes
|
||||
|
||||
// 准备publicKeySignature(用于MC 1.19)
|
||||
// Base64编码公钥,不包含换行
|
||||
pubBase64 := strings.ReplaceAll(base64.StdEncoding.EncodeToString(pubDER), "\n", "")
|
||||
|
||||
// 按76字符一行进行包装
|
||||
pubBase64Wrapped := WrapString(pubBase64, 76)
|
||||
|
||||
// 放入PEM格式
|
||||
pubMojangPEM := "-----BEGIN RSA PUBLIC KEY-----\n" +
|
||||
pubBase64Wrapped +
|
||||
"\n-----END RSA PUBLIC KEY-----\n"
|
||||
|
||||
// 签名数据: expiresAt毫秒时间戳 + 公钥PEM格式
|
||||
signedData := []byte(fmt.Sprintf("%d%s", expiresAtMillis, pubMojangPEM))
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash1 := sha1.Sum(signedData)
|
||||
signature, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash1[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名失败: %w", err)
|
||||
}
|
||||
publicKeySignature = base64.StdEncoding.EncodeToString(signature)
|
||||
|
||||
// 准备publicKeySignatureV2(用于MC 1.19.1+)
|
||||
var uuidBytes []byte
|
||||
|
||||
// 如果提供了UUID,则使用它
|
||||
// 移除UUID中的连字符
|
||||
uuidStr := strings.ReplaceAll(uuid, "-", "")
|
||||
|
||||
// 将UUID转换为字节数组(16字节)
|
||||
if len(uuidStr) < 32 {
|
||||
logger.Warn("[WARN] UUID长度不足32字符,使用空UUID: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("processedUuidStr", uuidStr),
|
||||
)
|
||||
uuidBytes = make([]byte, 16)
|
||||
} else {
|
||||
// 解析UUID字符串为字节
|
||||
uuidBytes = make([]byte, 16)
|
||||
parseErr := error(nil)
|
||||
for i := 0; i < 16; i++ {
|
||||
// 每两个字符转换为一个字节
|
||||
byteStr := uuidStr[i*2 : i*2+2]
|
||||
byteVal, err := strconv.ParseUint(byteStr, 16, 8)
|
||||
if err != nil {
|
||||
parseErr = err
|
||||
logger.Error("[ERROR] 解析UUID字节失败: %v, byteStr: %s",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("byteStr", byteStr),
|
||||
zap.Int("index", i),
|
||||
)
|
||||
uuidBytes = make([]byte, 16) // 出错时使用空UUID
|
||||
break
|
||||
}
|
||||
uuidBytes[i] = byte(byteVal)
|
||||
}
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("解析UUID字节失败: %w", parseErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 准备签名数据:UUID + expiresAt时间戳 + DER编码的公钥
|
||||
signedDataV2 := make([]byte, 0, 24+len(pubDER)) // 预分配缓冲区
|
||||
|
||||
// 添加UUID(16字节)
|
||||
signedDataV2 = append(signedDataV2, uuidBytes...)
|
||||
|
||||
// 添加expiresAt毫秒时间戳(8字节,大端序)
|
||||
expiresAtBytes := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(expiresAtBytes, uint64(expiresAtMillis))
|
||||
signedDataV2 = append(signedDataV2, expiresAtBytes...)
|
||||
|
||||
// 添加DER编码的公钥
|
||||
signedDataV2 = append(signedDataV2, pubDER...)
|
||||
|
||||
// 计算SHA1哈希并签名
|
||||
hash2 := sha1.Sum(signedDataV2)
|
||||
signatureV2, err := rsa.SignPKCS1v15(rand.Reader, serverPrivateKey, crypto.SHA1, hash2[:])
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 签名V2失败: %v",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", uuid),
|
||||
zap.Int64("expiresAtMillis", expiresAtMillis),
|
||||
)
|
||||
return nil, fmt.Errorf("签名V2失败: %w", err)
|
||||
}
|
||||
publicKeySignatureV2 = base64.StdEncoding.EncodeToString(signatureV2)
|
||||
|
||||
// 创建玩家证书结构
|
||||
certificate := &PlayerCertificate{
|
||||
KeyPair: struct {
|
||||
PrivateKey string `json:"privateKey"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}{
|
||||
PrivateKey: keyPair.PrivateKey,
|
||||
PublicKey: keyPair.PublicKey,
|
||||
},
|
||||
PublicKeySignature: publicKeySignature,
|
||||
PublicKeySignatureV2: publicKeySignatureV2,
|
||||
ExpiresAt: keyPair.Expiration.Format(time.RFC3339Nano),
|
||||
RefreshedAfter: keyPair.Refresh.Format(time.RFC3339Nano),
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功生成玩家证书,过期时间: %s",
|
||||
zap.String("uuid", uuid),
|
||||
zap.String("expiresAt", certificate.ExpiresAt),
|
||||
zap.String("refreshedAfter", certificate.RefreshedAfter),
|
||||
)
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
// GeneratePlayerCertificateService 生成玩家证书(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) GeneratePlayerCertificate(uuid string) (*PlayerCertificate, error) {
|
||||
return GeneratePlayerCertificate(nil, s.logger, s.redisClient, uuid) // TODO: 需要传入db参数
|
||||
}
|
||||
|
||||
// NewKeyPair 生成新的密钥对(函数式版本)
|
||||
func NewKeyPair(logger *zap.Logger) (*model.KeyPair, error) {
|
||||
// 生成新的RSA密钥对(用于玩家证书)
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048) // 对玩家证书使用更小的密钥以提高性能
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 生成玩家证书私钥失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("生成玩家证书私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取DER编码的密钥
|
||||
keyDER, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码私钥为PKCS8格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码私钥为PKCS8格式失败: %w", err)
|
||||
}
|
||||
|
||||
pubDER, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 编码公钥为PKIX格式失败: %v",
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("编码公钥为PKIX格式失败: %w", err)
|
||||
}
|
||||
|
||||
// 将密钥编码为PEM格式
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: keyDER,
|
||||
})
|
||||
|
||||
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: pubDER,
|
||||
})
|
||||
|
||||
// 创建证书过期和刷新时间
|
||||
now := time.Now().UTC()
|
||||
expiresAtTime := now.Add(CertificateExpirationPeriod)
|
||||
refreshedAfter := now.Add(CertificateRefreshInterval)
|
||||
keyPair := &model.KeyPair{
|
||||
Expiration: expiresAtTime,
|
||||
PrivateKey: string(keyPEM),
|
||||
PublicKey: string(pubPEM),
|
||||
Refresh: refreshedAfter,
|
||||
}
|
||||
return keyPair, nil
|
||||
}
|
||||
|
||||
// WrapString 将字符串按指定宽度进行换行(函数式版本)
|
||||
func WrapString(str string, width int) string {
|
||||
if width <= 0 {
|
||||
return str
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(str); i += width {
|
||||
end := i + width
|
||||
if end > len(str) {
|
||||
end = len(str)
|
||||
}
|
||||
b.WriteString(str[i:end])
|
||||
if end < len(str) {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// NewKeyPairService 生成新的密钥对(结构体方法版本,保持向后兼容)
|
||||
func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
return NewKeyPair(s.logger)
|
||||
}
|
||||
358
internal/service/signature_service_test.go
Normal file
358
internal/service/signature_service_test.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestSignatureService_Constants 测试签名服务相关常量
|
||||
func TestSignatureService_Constants(t *testing.T) {
|
||||
if RSAKeySize != 4096 {
|
||||
t.Errorf("RSAKeySize = %d, want 4096", RSAKeySize)
|
||||
}
|
||||
|
||||
if PrivateKeyRedisKey == "" {
|
||||
t.Error("PrivateKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if PublicKeyRedisKey == "" {
|
||||
t.Error("PublicKeyRedisKey should not be empty")
|
||||
}
|
||||
|
||||
if KeyExpirationTime != 24*7*time.Hour {
|
||||
t.Errorf("KeyExpirationTime = %v, want 7 days", KeyExpirationTime)
|
||||
}
|
||||
|
||||
if CertificateRefreshInterval != 24*time.Hour {
|
||||
t.Errorf("CertificateRefreshInterval = %v, want 24 hours", CertificateRefreshInterval)
|
||||
}
|
||||
|
||||
if CertificateExpirationPeriod != 24*7*time.Hour {
|
||||
t.Errorf("CertificateExpirationPeriod = %v, want 7 days", CertificateExpirationPeriod)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSignatureService_DataValidation 测试签名数据验证逻辑
|
||||
func TestSignatureService_DataValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "非空数据有效",
|
||||
data: "test data",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "空数据无效",
|
||||
data: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.data != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Data validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPlayerCertificate_Structure 测试PlayerCertificate结构
|
||||
func TestPlayerCertificate_Structure(t *testing.T) {
|
||||
cert := PlayerCertificate{
|
||||
ExpiresAt: "2025-01-01T00:00:00Z",
|
||||
RefreshedAfter: "2025-01-01T00:00:00Z",
|
||||
PublicKeySignature: "signature",
|
||||
PublicKeySignatureV2: "signaturev2",
|
||||
}
|
||||
|
||||
// 验证结构体字段
|
||||
if cert.ExpiresAt == "" {
|
||||
t.Error("ExpiresAt should not be empty")
|
||||
}
|
||||
|
||||
if cert.RefreshedAfter == "" {
|
||||
t.Error("RefreshedAfter should not be empty")
|
||||
}
|
||||
|
||||
// PublicKeySignature是可选的
|
||||
if cert.PublicKeySignature == "" {
|
||||
t.Log("PublicKeySignature is optional")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString 测试字符串换行函数
|
||||
func TestWrapString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "正常换行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
expected: "12345\n67890",
|
||||
},
|
||||
{
|
||||
name: "字符串长度等于width",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
expected: "12345",
|
||||
},
|
||||
{
|
||||
name: "字符串长度小于width",
|
||||
str: "123",
|
||||
width: 5,
|
||||
expected: "123",
|
||||
},
|
||||
{
|
||||
name: "width为0,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "width为负数,返回原字符串",
|
||||
str: "1234567890",
|
||||
width: -1,
|
||||
expected: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
str: "",
|
||||
width: 5,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "width为1",
|
||||
str: "12345",
|
||||
width: 1,
|
||||
expected: "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
name: "长字符串多次换行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
expected: "12345\n67890\n12345",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
if result != tt.expected {
|
||||
t.Errorf("WrapString(%q, %d) = %q, want %q", tt.str, tt.width, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_LineCount 测试换行后的行数
|
||||
func TestWrapString_LineCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
str string
|
||||
width int
|
||||
wantLines int
|
||||
}{
|
||||
{
|
||||
name: "10个字符,width=5,应该2行",
|
||||
str: "1234567890",
|
||||
width: 5,
|
||||
wantLines: 2,
|
||||
},
|
||||
{
|
||||
name: "15个字符,width=5,应该3行",
|
||||
str: "123456789012345",
|
||||
width: 5,
|
||||
wantLines: 3,
|
||||
},
|
||||
{
|
||||
name: "5个字符,width=5,应该1行",
|
||||
str: "12345",
|
||||
width: 5,
|
||||
wantLines: 1,
|
||||
},
|
||||
{
|
||||
name: "width为0,应该1行",
|
||||
str: "1234567890",
|
||||
width: 0,
|
||||
wantLines: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := WrapString(tt.str, tt.width)
|
||||
lines := strings.Count(result, "\n") + 1
|
||||
if lines != tt.wantLines {
|
||||
t.Errorf("Line count = %d, want %d (result: %q)", lines, tt.wantLines, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWrapString_NoTrailingNewline 测试末尾不换行
|
||||
func TestWrapString_NoTrailingNewline(t *testing.T) {
|
||||
str := "1234567890"
|
||||
result := WrapString(str, 5)
|
||||
|
||||
// 验证末尾没有换行符
|
||||
if strings.HasSuffix(result, "\n") {
|
||||
t.Error("Result should not end with newline")
|
||||
}
|
||||
|
||||
// 验证包含换行符(除了最后一行)
|
||||
if !strings.Contains(result, "\n") {
|
||||
t.Error("Result should contain newline for multi-line output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePrivateKeyToPEM_ActualCall 实际调用EncodePrivateKeyToPEM函数
|
||||
func TestEncodePrivateKeyToPEM_ActualCall(t *testing.T) {
|
||||
// 生成测试用的RSA私钥
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA私钥失败: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePrivateKeyToPEM(privateKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePrivateKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePrivateKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'RSA PRIVATE KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PRIVATE KEY") {
|
||||
t.Error("EncodePrivateKeyToPEM() 应包含 'PRIVATE KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_ActualCall 实际调用EncodePublicKeyToPEM函数
|
||||
func TestEncodePublicKeyToPEM_ActualCall(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 生成测试用的RSA密钥对
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("生成RSA密钥对失败: %v", err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "默认类型",
|
||||
keyType: []string{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "RSA类型",
|
||||
keyType: []string{"RSA"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pemBytes, err := EncodePublicKeyToPEM(logger, publicKey, tt.keyType...)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("EncodePublicKeyToPEM() error = %v, wantError %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
if !tt.wantError {
|
||||
if len(pemBytes) == 0 {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM字节不应为空")
|
||||
}
|
||||
pemStr := string(pemBytes)
|
||||
// 验证PEM格式
|
||||
if !strings.Contains(pemStr, "BEGIN") || !strings.Contains(pemStr, "END") {
|
||||
t.Error("EncodePublicKeyToPEM() 返回的PEM格式不正确")
|
||||
}
|
||||
// 验证类型
|
||||
if len(tt.keyType) > 0 && tt.keyType[0] == "RSA" {
|
||||
if !strings.Contains(pemStr, "RSA PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'RSA PUBLIC KEY'")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(pemStr, "PUBLIC KEY") {
|
||||
t.Error("EncodePublicKeyToPEM() 应包含 'PUBLIC KEY'")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodePublicKeyToPEM_NilKey 测试nil公钥
|
||||
func TestEncodePublicKeyToPEM_NilKey(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
_, err := EncodePublicKeyToPEM(logger, nil)
|
||||
if err == nil {
|
||||
t.Error("EncodePublicKeyToPEM() 对于nil公钥应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSignatureService 测试创建SignatureService
|
||||
func TestNewSignatureService(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
// 注意:这里需要实际的redis client,但我们只测试结构体创建
|
||||
// 在实际测试中,可以使用mock redis client
|
||||
service := NewSignatureService(logger, nil)
|
||||
if service == nil {
|
||||
t.Error("NewSignatureService() 不应返回nil")
|
||||
}
|
||||
if service.logger != logger {
|
||||
t.Error("NewSignatureService() logger 设置不正确")
|
||||
}
|
||||
}
|
||||
251
internal/service/texture_service.go
Normal file
251
internal/service/texture_service.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// CreateTexture 创建材质
|
||||
func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
|
||||
// 验证用户存在
|
||||
user, err := repository.FindUserByID(uploaderID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查Hash是否已存在
|
||||
existingTexture, err := repository.FindTextureByHash(hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingTexture != nil {
|
||||
return nil, errors.New("该材质已存在")
|
||||
}
|
||||
|
||||
// 转换材质类型
|
||||
var textureTypeEnum model.TextureType
|
||||
switch textureType {
|
||||
case "SKIN":
|
||||
textureTypeEnum = model.TextureTypeSkin
|
||||
case "CAPE":
|
||||
textureTypeEnum = model.TextureTypeCape
|
||||
default:
|
||||
return nil, errors.New("无效的材质类型")
|
||||
}
|
||||
|
||||
// 创建材质
|
||||
texture := &model.Texture{
|
||||
UploaderID: uploaderID,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Type: textureTypeEnum,
|
||||
URL: url,
|
||||
Hash: hash,
|
||||
Size: size,
|
||||
IsPublic: isPublic,
|
||||
IsSlim: isSlim,
|
||||
Status: 1,
|
||||
DownloadCount: 0,
|
||||
FavoriteCount: 0,
|
||||
}
|
||||
|
||||
if err := repository.CreateTexture(texture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetTextureByID 根据ID获取材质
|
||||
func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) {
|
||||
texture, err := repository.FindTextureByID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
if texture.Status == -1 {
|
||||
return nil, errors.New("材质已删除")
|
||||
}
|
||||
return texture, nil
|
||||
}
|
||||
|
||||
// GetUserTextures 获取用户上传的材质列表
|
||||
func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.FindTexturesByUploaderID(uploaderID, page, pageSize)
|
||||
}
|
||||
|
||||
// SearchTextures 搜索材质
|
||||
func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize)
|
||||
}
|
||||
|
||||
// UpdateTexture 更新材质
|
||||
func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if texture == nil {
|
||||
return nil, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以修改
|
||||
if texture.UploaderID != uploaderID {
|
||||
return nil, errors.New("无权修改此材质")
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
updates := make(map[string]interface{})
|
||||
if name != "" {
|
||||
updates["name"] = name
|
||||
}
|
||||
if description != "" {
|
||||
updates["description"] = description
|
||||
}
|
||||
if isPublic != nil {
|
||||
updates["is_public"] = *isPublic
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := repository.UpdateTextureFields(textureID, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 返回更新后的材质
|
||||
return repository.FindTextureByID(textureID)
|
||||
}
|
||||
|
||||
// DeleteTexture 删除材质
|
||||
func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error {
|
||||
// 获取材质
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查权限:只有上传者可以删除
|
||||
if texture.UploaderID != uploaderID {
|
||||
return errors.New("无权删除此材质")
|
||||
}
|
||||
|
||||
return repository.DeleteTexture(textureID)
|
||||
}
|
||||
|
||||
// RecordTextureDownload 记录下载
|
||||
func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if texture == nil {
|
||||
return errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
if err := repository.IncrementTextureDownloadCount(textureID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建下载日志
|
||||
log := &model.TextureDownloadLog{
|
||||
TextureID: textureID,
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
|
||||
return repository.CreateTextureDownloadLog(log)
|
||||
}
|
||||
|
||||
// ToggleTextureFavorite 切换收藏状态
|
||||
func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) {
|
||||
// 检查材质是否存在
|
||||
texture, err := repository.FindTextureByID(textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if texture == nil {
|
||||
return false, errors.New("材质不存在")
|
||||
}
|
||||
|
||||
// 检查是否已收藏
|
||||
isFavorited, err := repository.IsTextureFavorited(userID, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isFavorited {
|
||||
// 取消收藏
|
||||
if err := repository.RemoveTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.DecrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
} else {
|
||||
// 添加收藏
|
||||
if err := repository.AddTextureFavorite(userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := repository.IncrementTextureFavoriteCount(textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetUserTextureFavorites 获取用户收藏的材质列表
|
||||
func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return repository.GetUserTextureFavorites(userID, page, pageSize)
|
||||
}
|
||||
|
||||
// CheckTextureUploadLimit 检查用户上传材质数量限制
|
||||
func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error {
|
||||
count, err := repository.CountTexturesByUploaderID(uploaderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count >= int64(maxTextures) {
|
||||
return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
471
internal/service/texture_service_test.go
Normal file
471
internal/service/texture_service_test.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestTextureService_TypeValidation 测试材质类型验证
|
||||
func TestTextureService_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
textureType string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "SKIN类型有效",
|
||||
textureType: "SKIN",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "CAPE类型有效",
|
||||
textureType: "CAPE",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型",
|
||||
textureType: "INVALID",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空类型无效",
|
||||
textureType: "",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.textureType == "SKIN" || tt.textureType == "CAPE"
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Texture type validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_DefaultValues 测试材质默认值
|
||||
func TestTextureService_DefaultValues(t *testing.T) {
|
||||
// 测试默认状态
|
||||
defaultStatus := 1
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认状态应为1,实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试默认下载数
|
||||
defaultDownloadCount := 0
|
||||
if defaultDownloadCount != 0 {
|
||||
t.Errorf("默认下载数应为0,实际为%d", defaultDownloadCount)
|
||||
}
|
||||
|
||||
// 测试默认收藏数
|
||||
defaultFavoriteCount := 0
|
||||
if defaultFavoriteCount != 0 {
|
||||
t.Errorf("默认收藏数应为0,实际为%d", defaultFavoriteCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureService_StatusValidation 测试材质状态验证
|
||||
func TestTextureService_StatusValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status int16
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "状态为1(正常)时有效",
|
||||
status: 1,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "状态为-1(删除)时无效",
|
||||
status: -1,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "状态为0时可能有效(取决于业务逻辑)",
|
||||
status: 0,
|
||||
wantValid: true, // 状态为0(禁用)时,材质仍然存在,只是不可用,但查询时不会返回错误
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 材质状态为-1时表示已删除,无效
|
||||
isValid := tt.status != -1
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Status validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextures_Pagination 测试分页逻辑
|
||||
func TestGetUserTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 2,
|
||||
pageSize: 20,
|
||||
wantPage: 2,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize小于1,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 0,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSearchTextures_Pagination 测试搜索分页逻辑
|
||||
func TestSearchTextures_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 10,
|
||||
wantPage: 1,
|
||||
wantSize: 10,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: -1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 150,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_PermissionCheck 测试更新材质的权限检查
|
||||
func TestUpdateTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许更新",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝更新",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateTexture_FieldUpdates 测试更新字段逻辑
|
||||
func TestUpdateTexture_FieldUpdates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
nameValue string
|
||||
descValue string
|
||||
isPublic *bool
|
||||
wantUpdates int
|
||||
}{
|
||||
{
|
||||
name: "更新所有字段",
|
||||
nameValue: "NewName",
|
||||
descValue: "NewDesc",
|
||||
isPublic: boolPtr(true),
|
||||
wantUpdates: 3,
|
||||
},
|
||||
{
|
||||
name: "只更新名称",
|
||||
nameValue: "NewName",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新描述",
|
||||
nameValue: "",
|
||||
descValue: "NewDesc",
|
||||
isPublic: nil,
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "只更新公开状态",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: boolPtr(false),
|
||||
wantUpdates: 1,
|
||||
},
|
||||
{
|
||||
name: "没有更新",
|
||||
nameValue: "",
|
||||
descValue: "",
|
||||
isPublic: nil,
|
||||
wantUpdates: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
updates := 0
|
||||
if tt.nameValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.descValue != "" {
|
||||
updates++
|
||||
}
|
||||
if tt.isPublic != nil {
|
||||
updates++
|
||||
}
|
||||
|
||||
if updates != tt.wantUpdates {
|
||||
t.Errorf("Updates count = %d, want %d", updates, tt.wantUpdates)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeleteTexture_PermissionCheck 测试删除材质的权限检查
|
||||
func TestDeleteTexture_PermissionCheck(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uploaderID int64
|
||||
requestID int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "上传者ID匹配,允许删除",
|
||||
uploaderID: 1,
|
||||
requestID: 1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上传者ID不匹配,拒绝删除",
|
||||
uploaderID: 1,
|
||||
requestID: 2,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.uploaderID != tt.requestID
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Permission check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestToggleTextureFavorite_Logic 测试切换收藏状态的逻辑
|
||||
func TestToggleTextureFavorite_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isFavorited bool
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "已收藏,取消收藏",
|
||||
isFavorited: true,
|
||||
wantResult: false,
|
||||
},
|
||||
{
|
||||
name: "未收藏,添加收藏",
|
||||
isFavorited: false,
|
||||
wantResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := !tt.isFavorited
|
||||
if result != tt.wantResult {
|
||||
t.Errorf("Toggle favorite failed: got %v, want %v", result, tt.wantResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUserTextureFavorites_Pagination 测试收藏列表分页
|
||||
func TestGetUserTextureFavorites_Pagination(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantPage int
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page小于1,应该设为1",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100,应该设为20",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantPage: 1,
|
||||
wantSize: 20,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := tt.page
|
||||
pageSize := tt.pageSize
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
if page != tt.wantPage {
|
||||
t.Errorf("Page = %d, want %d", page, tt.wantPage)
|
||||
}
|
||||
if pageSize != tt.wantSize {
|
||||
t.Errorf("PageSize = %d, want %d", pageSize, tt.wantSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCheckTextureUploadLimit_Logic 测试上传限制检查逻辑
|
||||
func TestCheckTextureUploadLimit_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
count int64
|
||||
maxTextures int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "未达到上限",
|
||||
count: 5,
|
||||
maxTextures: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "达到上限",
|
||||
count: 10,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超过上限",
|
||||
count: 15,
|
||||
maxTextures: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.count >= int64(tt.maxTextures)
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Limit check failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func boolPtr(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
277
internal/service/token_service.go
Normal file
277
internal/service/token_service.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ExtendedTimeout = 10 * time.Second
|
||||
TokensMaxCount = 10 // 用户最多保留的token数量
|
||||
)
|
||||
|
||||
// NewToken 创建新令牌
|
||||
func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) {
|
||||
var (
|
||||
selectedProfileID *model.Profile
|
||||
availableProfiles []*model.Profile
|
||||
)
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
_, err := repository.FindProfileByUUID(UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成令牌
|
||||
if clientToken == "" {
|
||||
clientToken = uuid.New().String()
|
||||
}
|
||||
|
||||
accessToken := uuid.New().String()
|
||||
token := model.Token{
|
||||
AccessToken: accessToken,
|
||||
ClientToken: clientToken,
|
||||
UserID: userId,
|
||||
Usable: true,
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果用户只有一个配置文件,自动选择
|
||||
if len(profiles) == 1 {
|
||||
selectedProfileID = profiles[0]
|
||||
token.ProfileId = selectedProfileID.UUID
|
||||
}
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌到tokens集合
|
||||
_, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
defer insertCancel()
|
||||
|
||||
err = repository.CreateToken(&token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
// 清理多余的令牌
|
||||
go CheckAndCleanupExcessTokens(db, logger, userId)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌,只保留最新的10个
|
||||
func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
// 获取用户所有令牌,按发行日期降序排序
|
||||
tokens, err := repository.GetTokensByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
// 如果令牌数量不超过上限,无需清理
|
||||
if len(tokens) <= TokensMaxCount {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取需要删除的令牌ID列表
|
||||
tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount)
|
||||
for i := TokensMaxCount; i < len(tokens); i++ {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
// 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数)
|
||||
DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10)))
|
||||
return
|
||||
}
|
||||
|
||||
if DeletedCount > 0 {
|
||||
logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount))
|
||||
}
|
||||
}
|
||||
|
||||
// ValidToken 验证令牌有效性
|
||||
func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool {
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 使用投影只获取需要的字段
|
||||
var token *model.Token
|
||||
token, err := repository.FindTokenByID(accessToken)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !token.Usable {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果客户端令牌为空,只验证访问令牌
|
||||
if clientToken == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// 否则验证客户端令牌是否匹配
|
||||
return token.ClientToken == clientToken
|
||||
}
|
||||
|
||||
func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) {
|
||||
return repository.GetUUIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) {
|
||||
return repository.GetUserIDByAccessToken(accessToken)
|
||||
}
|
||||
|
||||
// RefreshToken 刷新令牌
|
||||
func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) {
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
}
|
||||
logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken))
|
||||
return "", "", fmt.Errorf("查询令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
logger.Error(
|
||||
"验证Profile失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("profileId", selectedProfileID),
|
||||
)
|
||||
return "", "", fmt.Errorf("验证角色失败: %w", err)
|
||||
}
|
||||
if !valid {
|
||||
return "", "", errors.New("角色与用户不匹配")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 clientToken 是否有效
|
||||
if clientToken != "" && clientToken != oldToken.ClientToken {
|
||||
return "", "", errors.New("clientToken无效")
|
||||
}
|
||||
|
||||
// 检查 selectedProfileID 的逻辑
|
||||
if selectedProfileID != "" {
|
||||
if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID {
|
||||
return "", "", errors.New("原令牌已绑定角色,无法选择新角色")
|
||||
}
|
||||
} else {
|
||||
selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := uuid.New().String()
|
||||
newToken := model.Token{
|
||||
AccessToken: newAccessToken,
|
||||
ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同
|
||||
UserID: oldToken.UserID,
|
||||
Usable: true,
|
||||
ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色
|
||||
IssueDate: time.Now(),
|
||||
}
|
||||
|
||||
// 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌
|
||||
|
||||
err = repository.CreateToken(&newToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"创建新Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
// 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建
|
||||
logger.Warn(
|
||||
"删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
zap.String("oldToken", oldToken.AccessToken),
|
||||
zap.String("newToken", newAccessToken),
|
||||
)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"成功刷新Token",
|
||||
zap.Any("userId", oldToken.UserID),
|
||||
zap.String("accessToken", newAccessToken),
|
||||
)
|
||||
return newAccessToken, oldToken.ClientToken, nil
|
||||
}
|
||||
|
||||
// InvalidToken 使令牌失效
|
||||
func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) {
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"删除Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return
|
||||
}
|
||||
logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken))
|
||||
|
||||
}
|
||||
|
||||
// InvalidUserTokens 使用户所有令牌失效
|
||||
func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) {
|
||||
if userId == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := repository.DeleteTokenByUserId(userId)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]删除用户Token失败",
|
||||
zap.Error(err),
|
||||
zap.Any("userId", userId),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId))
|
||||
|
||||
}
|
||||
204
internal/service/token_service_test.go
Normal file
204
internal/service/token_service_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestTokenService_Constants 测试Token服务相关常量
|
||||
func TestTokenService_Constants(t *testing.T) {
|
||||
if ExtendedTimeout != 10*time.Second {
|
||||
t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout)
|
||||
}
|
||||
|
||||
if TokensMaxCount != 10 {
|
||||
t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Timeout 测试超时常量
|
||||
func TestTokenService_Timeout(t *testing.T) {
|
||||
if DefaultTimeout != 5*time.Second {
|
||||
t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout)
|
||||
}
|
||||
|
||||
if ExtendedTimeout <= DefaultTimeout {
|
||||
t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_Validation 测试Token验证逻辑
|
||||
func TestTokenService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accessToken string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "空token无效",
|
||||
accessToken: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "非空token可能有效",
|
||||
accessToken: "valid-token-string",
|
||||
wantValid: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 测试空token检查逻辑
|
||||
isValid := tt.accessToken != ""
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Token validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ClientTokenLogic 测试ClientToken逻辑
|
||||
func TestTokenService_ClientTokenLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientToken string
|
||||
shouldGenerate bool
|
||||
}{
|
||||
{
|
||||
name: "空的clientToken应该生成新的",
|
||||
clientToken: "",
|
||||
shouldGenerate: true,
|
||||
},
|
||||
{
|
||||
name: "非空的clientToken应该使用提供的",
|
||||
clientToken: "existing-client-token",
|
||||
shouldGenerate: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldGenerate := tt.clientToken == ""
|
||||
if shouldGenerate != tt.shouldGenerate {
|
||||
t.Errorf("ClientToken logic failed: got %v, want %v", shouldGenerate, tt.shouldGenerate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_ProfileSelection 测试Profile选择逻辑
|
||||
func TestTokenService_ProfileSelection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileCount int
|
||||
shouldAutoSelect bool
|
||||
}{
|
||||
{
|
||||
name: "只有一个profile时自动选择",
|
||||
profileCount: 1,
|
||||
shouldAutoSelect: true,
|
||||
},
|
||||
{
|
||||
name: "多个profile时不自动选择",
|
||||
profileCount: 2,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
{
|
||||
name: "没有profile时不自动选择",
|
||||
profileCount: 0,
|
||||
shouldAutoSelect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldAutoSelect := tt.profileCount == 1
|
||||
if shouldAutoSelect != tt.shouldAutoSelect {
|
||||
t.Errorf("Profile selection logic failed: got %v, want %v", shouldAutoSelect, tt.shouldAutoSelect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_CleanupLogic 测试清理逻辑
|
||||
func TestTokenService_CleanupLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenCount int
|
||||
maxCount int
|
||||
shouldCleanup bool
|
||||
cleanupCount int
|
||||
}{
|
||||
{
|
||||
name: "token数量未超过上限,不需要清理",
|
||||
tokenCount: 5,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
{
|
||||
name: "token数量超过上限,需要清理",
|
||||
tokenCount: 15,
|
||||
maxCount: 10,
|
||||
shouldCleanup: true,
|
||||
cleanupCount: 5,
|
||||
},
|
||||
{
|
||||
name: "token数量等于上限,不需要清理",
|
||||
tokenCount: 10,
|
||||
maxCount: 10,
|
||||
shouldCleanup: false,
|
||||
cleanupCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
shouldCleanup := tt.tokenCount > tt.maxCount
|
||||
if shouldCleanup != tt.shouldCleanup {
|
||||
t.Errorf("Cleanup decision failed: got %v, want %v", shouldCleanup, tt.shouldCleanup)
|
||||
}
|
||||
|
||||
if shouldCleanup {
|
||||
expectedCleanupCount := tt.tokenCount - tt.maxCount
|
||||
if expectedCleanupCount != tt.cleanupCount {
|
||||
t.Errorf("Cleanup count failed: got %d, want %d", expectedCleanupCount, tt.cleanupCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenService_UserIDValidation 测试UserID验证
|
||||
func TestTokenService_UserIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
isValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的UserID",
|
||||
userID: 1,
|
||||
isValid: true,
|
||||
},
|
||||
{
|
||||
name: "UserID为0时无效",
|
||||
userID: 0,
|
||||
isValid: false,
|
||||
},
|
||||
{
|
||||
name: "负数UserID无效",
|
||||
userID: -1,
|
||||
isValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.userID > 0
|
||||
if isValid != tt.isValid {
|
||||
t.Errorf("UserID validation failed: got %v, want %v", isValid, tt.isValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
160
internal/service/upload_service.go
Normal file
160
internal/service/upload_service.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/storage"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileType 文件类型枚举
|
||||
type FileType string
|
||||
|
||||
const (
|
||||
FileTypeAvatar FileType = "avatar"
|
||||
FileTypeTexture FileType = "texture"
|
||||
)
|
||||
|
||||
// UploadConfig 上传配置
|
||||
type UploadConfig struct {
|
||||
AllowedExts map[string]bool // 允许的文件扩展名
|
||||
MinSize int64 // 最小文件大小(字节)
|
||||
MaxSize int64 // 最大文件大小(字节)
|
||||
Expires time.Duration // URL过期时间
|
||||
}
|
||||
|
||||
// GetUploadConfig 根据文件类型获取上传配置
|
||||
func GetUploadConfig(fileType FileType) *UploadConfig {
|
||||
switch fileType {
|
||||
case FileTypeAvatar:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".jpg": true,
|
||||
".jpeg": true,
|
||||
".png": true,
|
||||
".gif": true,
|
||||
".webp": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 5 * 1024 * 1024, // 5MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
case FileTypeTexture:
|
||||
return &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024, // 1KB
|
||||
MaxSize: 10 * 1024 * 1024, // 10MB
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateFileName 验证文件名
|
||||
func ValidateFileName(fileName string, fileType FileType) error {
|
||||
if fileName == "" {
|
||||
return fmt.Errorf("文件名不能为空")
|
||||
}
|
||||
|
||||
uploadConfig := GetUploadConfig(fileType)
|
||||
if uploadConfig == nil {
|
||||
return fmt.Errorf("不支持的文件类型")
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
if !uploadConfig.AllowedExts[ext] {
|
||||
return fmt.Errorf("不支持的文件格式: %s", ext)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURL 生成头像上传URL
|
||||
func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeAvatar); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeAvatar)
|
||||
|
||||
// 3. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("avatars")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName)
|
||||
|
||||
// 5. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURL 生成材质上传URL
|
||||
func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, cfg config.RustFSConfig, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) {
|
||||
// 1. 验证文件名
|
||||
if err := ValidateFileName(fileName, FileTypeTexture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 验证材质类型
|
||||
if textureType != "SKIN" && textureType != "CAPE" {
|
||||
return nil, fmt.Errorf("无效的材质类型: %s", textureType)
|
||||
}
|
||||
|
||||
// 3. 获取上传配置
|
||||
uploadConfig := GetUploadConfig(FileTypeTexture)
|
||||
|
||||
// 4. 获取存储桶名称
|
||||
bucketName, err := storageClient.GetBucket("textures")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取存储桶失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 生成对象名称(路径)
|
||||
// 格式: user_{userId}/{textureType}/timestamp_{originalFileName}
|
||||
timestamp := time.Now().Format("20060102150405")
|
||||
textureTypeFolder := strings.ToLower(textureType)
|
||||
objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName)
|
||||
|
||||
// 6. 生成预签名POST URL
|
||||
result, err := storageClient.GeneratePresignedPostURL(
|
||||
ctx,
|
||||
bucketName,
|
||||
objectName,
|
||||
uploadConfig.MinSize,
|
||||
uploadConfig.MaxSize,
|
||||
uploadConfig.Expires,
|
||||
cfg.UseSSL,
|
||||
cfg.Endpoint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成上传URL失败: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
279
internal/service/upload_service_test.go
Normal file
279
internal/service/upload_service_test.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestUploadService_FileTypes 测试文件类型常量
|
||||
func TestUploadService_FileTypes(t *testing.T) {
|
||||
if FileTypeAvatar == "" {
|
||||
t.Error("FileTypeAvatar should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeTexture == "" {
|
||||
t.Error("FileTypeTexture should not be empty")
|
||||
}
|
||||
|
||||
if FileTypeAvatar == FileTypeTexture {
|
||||
t.Error("FileTypeAvatar and FileTypeTexture should be different")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig 测试获取上传配置
|
||||
func TestGetUploadConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileType FileType
|
||||
wantConfig bool
|
||||
}{
|
||||
{
|
||||
name: "头像类型返回配置",
|
||||
fileType: FileTypeAvatar,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "材质类型返回配置",
|
||||
fileType: FileTypeTexture,
|
||||
wantConfig: true,
|
||||
},
|
||||
{
|
||||
name: "无效类型返回nil",
|
||||
fileType: FileType("invalid"),
|
||||
wantConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := GetUploadConfig(tt.fileType)
|
||||
hasConfig := config != nil
|
||||
if hasConfig != tt.wantConfig {
|
||||
t.Errorf("GetUploadConfig() = %v, want %v", hasConfig, tt.wantConfig)
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
// 验证配置字段
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize <= 0 {
|
||||
t.Error("MaxSize should be greater than 0")
|
||||
}
|
||||
if config.MaxSize < config.MinSize {
|
||||
t.Error("MaxSize should be greater than or equal to MinSize")
|
||||
}
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
if len(config.AllowedExts) == 0 {
|
||||
t.Error("AllowedExts should not be empty")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_AvatarConfig 测试头像配置详情
|
||||
func TestGetUploadConfig_AvatarConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeAvatar)
|
||||
if config == nil {
|
||||
t.Fatal("Avatar config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名
|
||||
expectedExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range expectedExts {
|
||||
if !config.AllowedExts[ext] {
|
||||
t.Errorf("Avatar config should allow %s extension", ext)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Avatar MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 5*1024*1024 {
|
||||
t.Errorf("Avatar MaxSize = %d, want 5MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Avatar Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetUploadConfig_TextureConfig 测试材质配置详情
|
||||
func TestGetUploadConfig_TextureConfig(t *testing.T) {
|
||||
config := GetUploadConfig(FileTypeTexture)
|
||||
if config == nil {
|
||||
t.Fatal("Texture config should not be nil")
|
||||
}
|
||||
|
||||
// 验证允许的扩展名(材质只允许PNG)
|
||||
if !config.AllowedExts[".png"] {
|
||||
t.Error("Texture config should allow .png extension")
|
||||
}
|
||||
|
||||
// 验证文件大小限制
|
||||
if config.MinSize != 1024 {
|
||||
t.Errorf("Texture MinSize = %d, want 1024", config.MinSize)
|
||||
}
|
||||
|
||||
if config.MaxSize != 10*1024*1024 {
|
||||
t.Errorf("Texture MaxSize = %d, want 10MB", config.MaxSize)
|
||||
}
|
||||
|
||||
// 验证过期时间
|
||||
if config.Expires != 15*time.Minute {
|
||||
t.Errorf("Texture Expires = %v, want 15 minutes", config.Expires)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName 测试文件名验证
|
||||
func TestValidateFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "有效的头像文件名",
|
||||
fileName: "avatar.png",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效的材质文件名",
|
||||
fileName: "texture.png",
|
||||
fileType: FileTypeTexture,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "文件名为空",
|
||||
fileName: "",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "文件名不能为空",
|
||||
},
|
||||
{
|
||||
name: "不支持的文件扩展名",
|
||||
fileName: "file.txt",
|
||||
fileType: FileTypeAvatar,
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件格式",
|
||||
},
|
||||
{
|
||||
name: "无效的文件类型",
|
||||
fileName: "file.png",
|
||||
fileType: FileType("invalid"),
|
||||
wantErr: true,
|
||||
errContains: "不支持的文件类型",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateFileName(tt.fileName, tt.fileType)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ValidateFileName() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && tt.errContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("ValidateFileName() error = %v, should contain %s", err, tt.errContains)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_Extensions 测试各种扩展名
|
||||
func TestValidateFileName_Extensions(t *testing.T) {
|
||||
avatarExts := []string{".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
for _, ext := range avatarExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err != nil {
|
||||
t.Errorf("Avatar file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 材质只支持PNG
|
||||
textureExts := []string{".png"}
|
||||
for _, ext := range textureExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeTexture)
|
||||
if err != nil {
|
||||
t.Errorf("Texture file with %s extension should be valid, got error: %v", ext, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试不支持的扩展名
|
||||
invalidExts := []string{".txt", ".pdf", ".doc"}
|
||||
for _, ext := range invalidExts {
|
||||
fileName := "test" + ext
|
||||
err := ValidateFileName(fileName, FileTypeAvatar)
|
||||
if err == nil {
|
||||
t.Errorf("Avatar file with %s extension should be invalid", ext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateFileName_CaseInsensitive 测试扩展名大小写不敏感
|
||||
func TestValidateFileName_CaseInsensitive(t *testing.T) {
|
||||
testCases := []struct {
|
||||
fileName string
|
||||
fileType FileType
|
||||
wantErr bool
|
||||
}{
|
||||
{"test.PNG", FileTypeAvatar, false},
|
||||
{"test.JPG", FileTypeAvatar, false},
|
||||
{"test.JPEG", FileTypeAvatar, false},
|
||||
{"test.GIF", FileTypeAvatar, false},
|
||||
{"test.WEBP", FileTypeAvatar, false},
|
||||
{"test.PnG", FileTypeTexture, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.fileName, func(t *testing.T) {
|
||||
err := ValidateFileName(tc.fileName, tc.fileType)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("ValidateFileName(%s, %s) error = %v, wantErr %v", tc.fileName, tc.fileType, err, tc.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUploadConfig_Structure 测试UploadConfig结构
|
||||
func TestUploadConfig_Structure(t *testing.T) {
|
||||
config := &UploadConfig{
|
||||
AllowedExts: map[string]bool{
|
||||
".png": true,
|
||||
},
|
||||
MinSize: 1024,
|
||||
MaxSize: 5 * 1024 * 1024,
|
||||
Expires: 15 * time.Minute,
|
||||
}
|
||||
|
||||
if config.AllowedExts == nil {
|
||||
t.Error("AllowedExts should not be nil")
|
||||
}
|
||||
|
||||
if config.MinSize <= 0 {
|
||||
t.Error("MinSize should be greater than 0")
|
||||
}
|
||||
|
||||
if config.MaxSize <= config.MinSize {
|
||||
t.Error("MaxSize should be greater than MinSize")
|
||||
}
|
||||
|
||||
if config.Expires <= 0 {
|
||||
t.Error("Expires should be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
248
internal/service/user_service.go
Normal file
248
internal/service/user_service.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RegisterUser 用户注册
|
||||
func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := repository.FindUserByUsername(username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return nil, "", errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingEmail != nil {
|
||||
return nil, "", errors.New("邮箱已被注册")
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := auth.HashPassword(password)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 确定头像URL:优先使用用户提供的头像,否则使用默认头像
|
||||
avatarURL := avatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = getDefaultAvatar()
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &model.User{
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
Email: email,
|
||||
Avatar: avatarURL,
|
||||
Role: "user",
|
||||
Status: 1,
|
||||
Points: 0, // 初始积分可以从配置读取
|
||||
}
|
||||
|
||||
if err := repository.CreateUser(user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// TODO: 添加注册奖励积分
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// LoginUser 用户登录(支持用户名或邮箱登录)
|
||||
func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) {
|
||||
// 查找用户:判断是用户名还是邮箱
|
||||
var user *model.User
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
// 包含@符号,认为是邮箱
|
||||
user, err = repository.FindUserByEmail(usernameOrEmail)
|
||||
} else {
|
||||
// 否则认为是用户名
|
||||
user, err = repository.FindUserByUsername(usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if user == nil {
|
||||
// 记录失败日志
|
||||
logFailedLogin(0, ipAddress, userAgent, "用户不存在")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if user.Status != 1 {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用")
|
||||
return nil, "", errors.New("账号已被禁用")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !auth.CheckPassword(user.Password, password) {
|
||||
logFailedLogin(user.ID, ipAddress, userAgent, "密码错误")
|
||||
return nil, "", errors.New("用户名/邮箱或密码错误")
|
||||
}
|
||||
|
||||
// 生成JWT Token
|
||||
token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role)
|
||||
if err != nil {
|
||||
return nil, "", errors.New("生成Token失败")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
// GetUserByID 根据ID获取用户
|
||||
func GetUserByID(id int64) (*model.User, error) {
|
||||
return repository.FindUserByID(id)
|
||||
}
|
||||
|
||||
// UpdateUserInfo 更新用户信息
|
||||
func UpdateUserInfo(user *model.User) error {
|
||||
return repository.UpdateUser(user)
|
||||
}
|
||||
|
||||
// UpdateUserAvatar 更新用户头像
|
||||
func UpdateUserAvatar(userID int64, avatarURL string) error {
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserPassword 修改密码
|
||||
func ChangeUserPassword(userID int64, oldPassword, newPassword string) error {
|
||||
// 获取用户
|
||||
user, err := repository.FindUserByID(userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !auth.CheckPassword(user.Password, oldPassword) {
|
||||
return errors.New("原密码错误")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ResetUserPassword 重置密码(通过邮箱)
|
||||
func ResetUserPassword(email, newPassword string) error {
|
||||
// 查找用户
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
return repository.UpdateUserFields(user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
}
|
||||
|
||||
// ChangeUserEmail 更换邮箱
|
||||
func ChangeUserEmail(userID int64, newEmail string) error {
|
||||
// 检查新邮箱是否已被使用
|
||||
existingUser, err := repository.FindUserByEmail(newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil && existingUser.ID != userID {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
}
|
||||
|
||||
// 更新邮箱
|
||||
return repository.UpdateUserFields(userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
}
|
||||
|
||||
// logSuccessLogin 记录成功登录
|
||||
func logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// logFailedLogin 记录失败登录
|
||||
func logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = repository.CreateLoginLog(log)
|
||||
}
|
||||
|
||||
// getDefaultAvatar 获取默认头像URL
|
||||
func getDefaultAvatar() string {
|
||||
// 如果数据库中不存在默认头像配置,返回错误信息
|
||||
const log = "数据库中不存在默认头像配置"
|
||||
|
||||
// 尝试从数据库读取配置
|
||||
config, err := repository.GetSystemConfigByKey("default_avatar")
|
||||
if err != nil || config == nil {
|
||||
return log
|
||||
}
|
||||
|
||||
return config.Value
|
||||
}
|
||||
|
||||
func GetUserByEmail(email string) (*model.User, error) {
|
||||
user, err := repository.FindUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, errors.New("邮箱查找失败")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
199
internal/service/user_service_test.go
Normal file
199
internal/service/user_service_test.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestGetDefaultAvatar 测试获取默认头像的逻辑
|
||||
// 注意:这个测试需要mock repository,但由于repository是函数式的,
|
||||
// 我们只测试逻辑部分
|
||||
func TestGetDefaultAvatar_Logic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configExists bool
|
||||
configValue string
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
name: "配置存在时返回配置值",
|
||||
configExists: true,
|
||||
configValue: "https://example.com/avatar.png",
|
||||
expectedResult: "https://example.com/avatar.png",
|
||||
},
|
||||
{
|
||||
name: "配置不存在时返回错误信息",
|
||||
configExists: false,
|
||||
configValue: "",
|
||||
expectedResult: "数据库中不存在默认头像配置",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 这个测试只验证逻辑,不实际调用repository
|
||||
// 实际的repository调用测试需要集成测试或mock
|
||||
if tt.configExists {
|
||||
if tt.expectedResult != tt.configValue {
|
||||
t.Errorf("当配置存在时,应该返回配置值")
|
||||
}
|
||||
} else {
|
||||
if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") {
|
||||
t.Errorf("当配置不存在时,应该返回错误信息")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑
|
||||
func TestLoginUser_EmailDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
usernameOrEmail string
|
||||
isEmail bool
|
||||
}{
|
||||
{
|
||||
name: "包含@符号,识别为邮箱",
|
||||
usernameOrEmail: "user@example.com",
|
||||
isEmail: true,
|
||||
},
|
||||
{
|
||||
name: "不包含@符号,识别为用户名",
|
||||
usernameOrEmail: "username",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "空字符串",
|
||||
usernameOrEmail: "",
|
||||
isEmail: false,
|
||||
},
|
||||
{
|
||||
name: "只有@符号",
|
||||
usernameOrEmail: "@",
|
||||
isEmail: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isEmail := strings.Contains(tt.usernameOrEmail, "@")
|
||||
if isEmail != tt.isEmail {
|
||||
t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Constants 测试用户服务相关常量
|
||||
func TestUserService_Constants(t *testing.T) {
|
||||
// 测试默认用户角色
|
||||
defaultRole := "user"
|
||||
if defaultRole == "" {
|
||||
t.Error("默认用户角色不能为空")
|
||||
}
|
||||
|
||||
// 测试默认用户状态
|
||||
defaultStatus := int16(1)
|
||||
if defaultStatus != 1 {
|
||||
t.Errorf("默认用户状态应为1(正常),实际为%d", defaultStatus)
|
||||
}
|
||||
|
||||
// 测试初始积分
|
||||
initialPoints := 0
|
||||
if initialPoints < 0 {
|
||||
t.Errorf("初始积分不应为负数,实际为%d", initialPoints)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_Validation 测试用户数据验证逻辑
|
||||
func TestUserService_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的用户名和邮箱",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "test@example.com",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱为空",
|
||||
username: "testuser",
|
||||
email: "",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
email: "test@example.com",
|
||||
password: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱格式无效(缺少@)",
|
||||
username: "testuser",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 简单的验证逻辑测试
|
||||
isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserService_AvatarLogic 测试头像逻辑
|
||||
func TestUserService_AvatarLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providedAvatar string
|
||||
defaultAvatar string
|
||||
expectedAvatar string
|
||||
}{
|
||||
{
|
||||
name: "提供头像时使用提供的头像",
|
||||
providedAvatar: "https://example.com/custom.png",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/custom.png",
|
||||
},
|
||||
{
|
||||
name: "未提供头像时使用默认头像",
|
||||
providedAvatar: "",
|
||||
defaultAvatar: "https://example.com/default.png",
|
||||
expectedAvatar: "https://example.com/default.png",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
avatarURL := tt.providedAvatar
|
||||
if avatarURL == "" {
|
||||
avatarURL = tt.defaultAvatar
|
||||
}
|
||||
if avatarURL != tt.expectedAvatar {
|
||||
t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
118
internal/service/verification_service.go
Normal file
118
internal/service/verification_service.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/email"
|
||||
"carrotskin/pkg/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
// 验证码类型
|
||||
VerificationTypeRegister = "register"
|
||||
VerificationTypeResetPassword = "reset_password"
|
||||
VerificationTypeChangeEmail = "change_email"
|
||||
|
||||
// 验证码配置
|
||||
CodeLength = 6 // 验证码长度
|
||||
CodeExpiration = 10 * time.Minute // 验证码有效期
|
||||
CodeRateLimit = 1 * time.Minute // 发送频率限制
|
||||
)
|
||||
|
||||
// GenerateVerificationCode 生成6位数字验证码
|
||||
func GenerateVerificationCode() (string, error) {
|
||||
const digits = "0123456789"
|
||||
code := make([]byte, CodeLength)
|
||||
for i := range code {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(digits))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code[i] = digits[num.Int64()]
|
||||
}
|
||||
return string(code), nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func SendVerificationCode(ctx context.Context, redisClient *redis.Client, emailService *email.Service, email, codeType string) error {
|
||||
// 检查发送频率限制
|
||||
rateLimitKey := fmt.Sprintf("verification:rate_limit:%s:%s", codeType, email)
|
||||
exists, err := redisClient.Exists(ctx, rateLimitKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("检查发送频率失败: %w", err)
|
||||
}
|
||||
if exists > 0 {
|
||||
return fmt.Errorf("发送过于频繁,请稍后再试")
|
||||
}
|
||||
|
||||
// 生成验证码
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
return fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储验证码到Redis
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
if err := redisClient.Set(ctx, codeKey, code, CodeExpiration); err != nil {
|
||||
return fmt.Errorf("存储验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置发送频率限制
|
||||
if err := redisClient.Set(ctx, rateLimitKey, "1", CodeRateLimit); err != nil {
|
||||
return fmt.Errorf("设置发送频率限制失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送邮件
|
||||
if err := sendVerificationEmail(emailService, email, code, codeType); err != nil {
|
||||
// 发送失败,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode 验证验证码
|
||||
func VerifyCode(ctx context.Context, redisClient *redis.Client, email, code, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
|
||||
// 从Redis获取验证码
|
||||
storedCode, err := redisClient.Get(ctx, codeKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("验证码已过期或不存在")
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
if storedCode != code {
|
||||
return fmt.Errorf("验证码错误")
|
||||
}
|
||||
|
||||
// 验证成功,删除验证码
|
||||
_ = redisClient.Del(ctx, codeKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteVerificationCode 删除验证码
|
||||
func DeleteVerificationCode(ctx context.Context, redisClient *redis.Client, email, codeType string) error {
|
||||
codeKey := fmt.Sprintf("verification:code:%s:%s", codeType, email)
|
||||
return redisClient.Del(ctx, codeKey)
|
||||
}
|
||||
|
||||
// sendVerificationEmail 根据类型发送邮件
|
||||
func sendVerificationEmail(emailService *email.Service, to, code, codeType string) error {
|
||||
switch codeType {
|
||||
case VerificationTypeRegister:
|
||||
return emailService.SendEmailVerification(to, code)
|
||||
case VerificationTypeResetPassword:
|
||||
return emailService.SendResetPassword(to, code)
|
||||
case VerificationTypeChangeEmail:
|
||||
return emailService.SendChangeEmail(to, code)
|
||||
default:
|
||||
return emailService.SendVerificationCode(to, code, codeType)
|
||||
}
|
||||
}
|
||||
119
internal/service/verification_service_test.go
Normal file
119
internal/service/verification_service_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestGenerateVerificationCode 测试生成验证码函数
|
||||
func TestGenerateVerificationCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLen int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "生成6位验证码",
|
||||
wantLen: CodeLength,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateVerificationCode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && len(code) != tt.wantLen {
|
||||
t.Errorf("GenerateVerificationCode() code length = %v, want %v", len(code), tt.wantLen)
|
||||
}
|
||||
// 验证验证码只包含数字
|
||||
for _, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("GenerateVerificationCode() code contains non-digit: %c", c)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 测试多次生成,验证码应该不同(概率上)
|
||||
codes := make(map[string]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
if codes[code] {
|
||||
t.Logf("发现重复验证码(这是正常的,因为只有6位数字): %s", code)
|
||||
}
|
||||
codes[code] = true
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationConstants 测试验证码相关常量
|
||||
func TestVerificationConstants(t *testing.T) {
|
||||
if CodeLength != 6 {
|
||||
t.Errorf("CodeLength = %d, want 6", CodeLength)
|
||||
}
|
||||
|
||||
if CodeExpiration != 10*time.Minute {
|
||||
t.Errorf("CodeExpiration = %v, want 10 minutes", CodeExpiration)
|
||||
}
|
||||
|
||||
if CodeRateLimit != 1*time.Minute {
|
||||
t.Errorf("CodeRateLimit = %v, want 1 minute", CodeRateLimit)
|
||||
}
|
||||
|
||||
// 验证验证码类型常量
|
||||
types := []string{
|
||||
VerificationTypeRegister,
|
||||
VerificationTypeResetPassword,
|
||||
VerificationTypeChangeEmail,
|
||||
}
|
||||
|
||||
for _, vType := range types {
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationCodeFormat 测试验证码格式
|
||||
func TestVerificationCodeFormat(t *testing.T) {
|
||||
code, err := GenerateVerificationCode()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateVerificationCode() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证长度
|
||||
if len(code) != 6 {
|
||||
t.Errorf("验证码长度应为6位,实际为%d位", len(code))
|
||||
}
|
||||
|
||||
// 验证只包含数字
|
||||
for i, c := range code {
|
||||
if c < '0' || c > '9' {
|
||||
t.Errorf("验证码第%d位包含非数字字符: %c", i+1, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestVerificationTypes 测试验证码类型
|
||||
func TestVerificationTypes(t *testing.T) {
|
||||
validTypes := map[string]bool{
|
||||
VerificationTypeRegister: true,
|
||||
VerificationTypeResetPassword: true,
|
||||
VerificationTypeChangeEmail: true,
|
||||
}
|
||||
|
||||
for vType, isValid := range validTypes {
|
||||
if !isValid {
|
||||
t.Errorf("验证码类型 %s 应该是有效的", vType)
|
||||
}
|
||||
if vType == "" {
|
||||
t.Error("验证码类型不能为空字符串")
|
||||
}
|
||||
}
|
||||
}
|
||||
201
internal/service/yggdrasil_service.go
Normal file
201
internal/service/yggdrasil_service.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/redis"
|
||||
"carrotskin/pkg/utils"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SessionKeyPrefix Redis会话键前缀
|
||||
const SessionKeyPrefix = "Join_"
|
||||
|
||||
// SessionTTL 会话超时时间 - 增加到15分钟
|
||||
const SessionTTL = 15 * time.Minute
|
||||
|
||||
type SessionData struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
UserName string `json:"userName"`
|
||||
SelectedProfile string `json:"selectedProfile"`
|
||||
IP string `json:"ip"`
|
||||
}
|
||||
|
||||
// GetUserIDByEmail 根据邮箱返回用户id
|
||||
func GetUserIDByEmail(db *gorm.DB, Identifier string) (int64, error) {
|
||||
user, err := repository.FindUserByEmail(Identifier)
|
||||
if err != nil {
|
||||
return 0, errors.New("用户不存在")
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// GetProfileByProfileName 根据用户名返回用户id
|
||||
func GetProfileByProfileName(db *gorm.DB, Identifier string) (*model.Profile, error) {
|
||||
profile, err := repository.FindProfileByName(Identifier)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户角色未创建")
|
||||
}
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
// VerifyPassword 验证密码是否一致
|
||||
func VerifyPassword(db *gorm.DB, password string, Id int64) error {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(Id)
|
||||
if err != nil {
|
||||
return errors.New("未生成密码")
|
||||
}
|
||||
if passwordStore != password {
|
||||
return errors.New("密码错误")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetProfileByUserId(db *gorm.DB, userId int64) (*model.Profile, error) {
|
||||
profiles, err := repository.FindProfilesByUserID(userId)
|
||||
if err != nil {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
if len(profiles) == 0 {
|
||||
return nil, errors.New("角色查找失败")
|
||||
}
|
||||
return profiles[0], nil
|
||||
}
|
||||
|
||||
func GetPasswordByUserId(db *gorm.DB, userId int64) (string, error) {
|
||||
passwordStore, err := repository.GetYggdrasilPasswordById(userId)
|
||||
if err != nil {
|
||||
return "", errors.New("yggdrasil密码查找失败")
|
||||
}
|
||||
return passwordStore, nil
|
||||
}
|
||||
|
||||
// JoinServer 记录玩家加入服务器的会话信息
|
||||
func JoinServer(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client, serverId, accessToken, selectedProfile, ip string) error {
|
||||
// 输入验证
|
||||
if serverId == "" || accessToken == "" || selectedProfile == "" {
|
||||
return errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 验证serverId格式,防止注入攻击
|
||||
if len(serverId) > 100 || strings.ContainsAny(serverId, "<>\"'&") {
|
||||
return errors.New("服务器ID格式无效")
|
||||
}
|
||||
|
||||
// 验证IP格式
|
||||
if ip != "" {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("IP地址格式无效")
|
||||
}
|
||||
}
|
||||
|
||||
// 获取和验证Token
|
||||
token, err := repository.GetTokenByAccessToken(accessToken)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"验证Token失败",
|
||||
zap.Error(err),
|
||||
zap.String("accessToken", accessToken),
|
||||
)
|
||||
return fmt.Errorf("验证Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 格式化UUID并验证与Token关联的配置文件
|
||||
formattedProfile := utils.FormatUUID(selectedProfile)
|
||||
if token.ProfileId != formattedProfile {
|
||||
return errors.New("selectedProfile与Token不匹配")
|
||||
}
|
||||
|
||||
profile, err := repository.FindProfileByUUID(formattedProfile)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"获取Profile失败",
|
||||
zap.Error(err),
|
||||
zap.String("uuid", formattedProfile),
|
||||
)
|
||||
return fmt.Errorf("获取Profile失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建会话数据
|
||||
data := SessionData{
|
||||
AccessToken: accessToken,
|
||||
UserName: profile.Name,
|
||||
SelectedProfile: formattedProfile,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
// 序列化会话数据
|
||||
marshaledData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"[ERROR]序列化会话数据失败",
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储会话数据到Redis
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
ctx := context.Background()
|
||||
if err = redisClient.Set(ctx, sessionKey, marshaledData, SessionTTL); err != nil {
|
||||
logger.Error(
|
||||
"保存会话数据失败",
|
||||
zap.Error(err),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return fmt.Errorf("保存会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"玩家成功加入服务器",
|
||||
zap.String("username", profile.Name),
|
||||
zap.String("serverId", serverId),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasJoinedServer 验证玩家是否已经加入了服务器
|
||||
func HasJoinedServer(logger *zap.Logger, redisClient *redis.Client, serverId, username, ip string) error {
|
||||
if serverId == "" || username == "" {
|
||||
return errors.New("服务器ID和用户名不能为空")
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 从Redis获取会话数据
|
||||
sessionKey := SessionKeyPrefix + serverId
|
||||
data, err := redisClient.GetBytes(ctx, sessionKey)
|
||||
if err != nil {
|
||||
logger.Error("[ERROR] 获取会话数据失败:", zap.Error(err), zap.Any("serverId:", serverId))
|
||||
return fmt.Errorf("获取会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 反序列化会话数据
|
||||
var sessionData SessionData
|
||||
if err = json.Unmarshal(data, &sessionData); err != nil {
|
||||
logger.Error("[ERROR] 解析会话数据失败: ", zap.Error(err))
|
||||
return fmt.Errorf("解析会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证用户名
|
||||
if sessionData.UserName != username {
|
||||
return errors.New("用户名不匹配")
|
||||
}
|
||||
|
||||
// 验证IP(如果提供)
|
||||
if ip != "" && sessionData.IP != ip {
|
||||
return errors.New("IP地址不匹配")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
350
internal/service/yggdrasil_service_test.go
Normal file
350
internal/service/yggdrasil_service_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestYggdrasilService_Constants 测试Yggdrasil服务常量
|
||||
func TestYggdrasilService_Constants(t *testing.T) {
|
||||
if SessionKeyPrefix != "Join_" {
|
||||
t.Errorf("SessionKeyPrefix = %s, want 'Join_'", SessionKeyPrefix)
|
||||
}
|
||||
|
||||
if SessionTTL != 15*time.Minute {
|
||||
t.Errorf("SessionTTL = %v, want 15 minutes", SessionTTL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionData_Structure 测试SessionData结构
|
||||
func TestSessionData_Structure(t *testing.T) {
|
||||
data := SessionData{
|
||||
AccessToken: "test-token",
|
||||
UserName: "TestUser",
|
||||
SelectedProfile: "test-profile-uuid",
|
||||
IP: "127.0.0.1",
|
||||
}
|
||||
|
||||
if data.AccessToken == "" {
|
||||
t.Error("AccessToken should not be empty")
|
||||
}
|
||||
|
||||
if data.UserName == "" {
|
||||
t.Error("UserName should not be empty")
|
||||
}
|
||||
|
||||
if data.SelectedProfile == "" {
|
||||
t.Error("SelectedProfile should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_InputValidation 测试JoinServer输入验证逻辑
|
||||
func TestJoinServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
accessToken string
|
||||
selectedProfile string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server-123",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "accessToken为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "",
|
||||
selectedProfile: "test-profile",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
{
|
||||
name: "selectedProfile为空",
|
||||
serverId: "test-server",
|
||||
accessToken: "test-token",
|
||||
selectedProfile: "",
|
||||
wantErr: true,
|
||||
errContains: "参数不能为空",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.accessToken == "" || tt.selectedProfile == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_ServerIDValidation 测试服务器ID格式验证
|
||||
func TestJoinServer_ServerIDValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的serverId",
|
||||
serverId: "test-server-123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "serverId过长",
|
||||
serverId: strings.Repeat("a", 101),
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符<",
|
||||
serverId: "test<server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符>",
|
||||
serverId: "test>server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符\"",
|
||||
serverId: "test\"server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符'",
|
||||
serverId: "test'server",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "serverId包含危险字符&",
|
||||
serverId: "test&server",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := len(tt.serverId) <= 100 && !strings.ContainsAny(tt.serverId, "<>\"'&")
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("ServerID validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_IPValidation 测试IP地址验证逻辑
|
||||
func TestJoinServer_IPValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的IPv4地址",
|
||||
ip: "127.0.0.1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "有效的IPv6地址",
|
||||
ip: "::1",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "无效的IP地址",
|
||||
ip: "invalid-ip",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "空IP地址(可选)",
|
||||
ip: "",
|
||||
wantValid: true, // 空IP是允许的
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var isValid bool
|
||||
if tt.ip == "" {
|
||||
isValid = true // 空IP是允许的
|
||||
} else {
|
||||
isValid = net.ParseIP(tt.ip) != nil
|
||||
}
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("IP validation failed: got %v, want %v (ip=%s)", isValid, tt.wantValid, tt.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_InputValidation 测试HasJoinedServer输入验证
|
||||
func TestHasJoinedServer_InputValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
username string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "所有参数有效",
|
||||
serverId: "test-server",
|
||||
username: "TestUser",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "serverId为空",
|
||||
serverId: "",
|
||||
username: "TestUser",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "username为空",
|
||||
serverId: "test-server",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "两者都为空",
|
||||
serverId: "",
|
||||
username: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hasError := tt.serverId == "" || tt.username == ""
|
||||
if hasError != tt.wantErr {
|
||||
t.Errorf("Input validation failed: got %v, want %v", hasError, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_UsernameMatching 测试用户名匹配逻辑
|
||||
func TestHasJoinedServer_UsernameMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionUser string
|
||||
requestUser string
|
||||
wantMatch bool
|
||||
}{
|
||||
{
|
||||
name: "用户名匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "TestUser",
|
||||
wantMatch: true,
|
||||
},
|
||||
{
|
||||
name: "用户名不匹配",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "OtherUser",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "大小写敏感",
|
||||
sessionUser: "TestUser",
|
||||
requestUser: "testuser",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := tt.sessionUser == tt.requestUser
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("Username matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHasJoinedServer_IPMatching 测试IP地址匹配逻辑
|
||||
func TestHasJoinedServer_IPMatching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sessionIP string
|
||||
requestIP string
|
||||
wantMatch bool
|
||||
shouldCheck bool
|
||||
}{
|
||||
{
|
||||
name: "IP匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "127.0.0.1",
|
||||
wantMatch: true,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "IP不匹配",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "192.168.1.1",
|
||||
wantMatch: false,
|
||||
shouldCheck: true,
|
||||
},
|
||||
{
|
||||
name: "请求IP为空时不检查",
|
||||
sessionIP: "127.0.0.1",
|
||||
requestIP: "",
|
||||
wantMatch: true,
|
||||
shouldCheck: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var matches bool
|
||||
if tt.requestIP == "" {
|
||||
matches = true // 空IP不检查
|
||||
} else {
|
||||
matches = tt.sessionIP == tt.requestIP
|
||||
}
|
||||
if matches != tt.wantMatch {
|
||||
t.Errorf("IP matching failed: got %v, want %v", matches, tt.wantMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJoinServer_SessionKey 测试会话键生成
|
||||
func TestJoinServer_SessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverId string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "生成正确的会话键",
|
||||
serverId: "test-server-123",
|
||||
expected: "Join_test-server-123",
|
||||
},
|
||||
{
|
||||
name: "空serverId",
|
||||
serverId: "",
|
||||
expected: "Join_",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sessionKey := SessionKeyPrefix + tt.serverId
|
||||
if sessionKey != tt.expected {
|
||||
t.Errorf("Session key = %s, want %s", sessionKey, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
215
internal/types/common.go
Normal file
215
internal/types/common.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package types
|
||||
|
||||
import "time"
|
||||
|
||||
// BaseResponse 基础响应结构
|
||||
type BaseResponse struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// PaginationRequest 分页请求
|
||||
type PaginationRequest struct {
|
||||
Page int `json:"page" form:"page" binding:"omitempty,min=1"`
|
||||
PageSize int `json:"page_size" form:"page_size" binding:"omitempty,min=1,max=100"`
|
||||
}
|
||||
|
||||
// PaginationResponse 分页响应
|
||||
type PaginationResponse struct {
|
||||
List interface{} `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
// LoginRequest 登录请求
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username" binding:"required" example:"testuser"` // 支持用户名或邮箱
|
||||
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
|
||||
}
|
||||
|
||||
// RegisterRequest 注册请求
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50" example:"newuser"`
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"` // 邮箱验证码
|
||||
Avatar string `json:"avatar" binding:"omitempty,url" example:"https://rustfs.example.com/avatars/user_1/avatar.png"` // 可选,用户自定义头像
|
||||
}
|
||||
|
||||
// UpdateUserRequest 更新用户请求
|
||||
type UpdateUserRequest struct {
|
||||
Avatar string `json:"avatar" binding:"omitempty,url" example:"https://example.com/new-avatar.png"`
|
||||
OldPassword string `json:"old_password" binding:"omitempty,min=6,max=128" example:"oldpassword123"` // 修改密码时必需
|
||||
NewPassword string `json:"new_password" binding:"omitempty,min=6,max=128" example:"newpassword123"` // 新密码
|
||||
}
|
||||
|
||||
// SendVerificationCodeRequest 发送验证码请求
|
||||
type SendVerificationCodeRequest struct {
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
Type string `json:"type" binding:"required,oneof=register reset_password change_email" example:"register"` // 类型: register/reset_password/change_email
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email" example:"user@example.com"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
|
||||
}
|
||||
|
||||
// ChangeEmailRequest 更换邮箱请求
|
||||
type ChangeEmailRequest struct {
|
||||
NewEmail string `json:"new_email" binding:"required,email" example:"newemail@example.com"`
|
||||
VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"`
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURLRequest 生成头像上传URL请求
|
||||
type GenerateAvatarUploadURLRequest struct {
|
||||
FileName string `json:"file_name" binding:"required" example:"avatar.png"`
|
||||
}
|
||||
|
||||
// GenerateAvatarUploadURLResponse 生成头像上传URL响应
|
||||
type GenerateAvatarUploadURLResponse struct {
|
||||
PostURL string `json:"post_url" example:"https://rustfs.example.com/avatars"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
AvatarURL string `json:"avatar_url" example:"https://rustfs.example.com/avatars/user_1/xxx.png"`
|
||||
ExpiresIn int `json:"expires_in" example:"900"` // 秒
|
||||
}
|
||||
|
||||
// CreateProfileRequest 创建档案请求
|
||||
type CreateProfileRequest struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=16" example:"PlayerName"`
|
||||
}
|
||||
|
||||
// UpdateTextureRequest 更新材质请求
|
||||
type UpdateTextureRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=100" example:"My Skin"`
|
||||
Description string `json:"description" binding:"omitempty,max=500" example:"A cool skin"`
|
||||
IsPublic *bool `json:"is_public" example:"true"`
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURLRequest 生成材质上传URL请求
|
||||
type GenerateTextureUploadURLRequest struct {
|
||||
FileName string `json:"file_name" binding:"required" example:"skin.png"`
|
||||
TextureType TextureType `json:"texture_type" binding:"required,oneof=SKIN CAPE" example:"SKIN"`
|
||||
}
|
||||
|
||||
// GenerateTextureUploadURLResponse 生成材质上传URL响应
|
||||
type GenerateTextureUploadURLResponse struct {
|
||||
PostURL string `json:"post_url" example:"https://rustfs.example.com/textures"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
TextureURL string `json:"texture_url" example:"https://rustfs.example.com/textures/user_1/skin/xxx.png"`
|
||||
ExpiresIn int `json:"expires_in" example:"900"` // 秒
|
||||
}
|
||||
|
||||
// LoginResponse 登录响应
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
UserInfo *UserInfo `json:"user_info"`
|
||||
}
|
||||
|
||||
// UserInfo 用户信息
|
||||
type UserInfo struct {
|
||||
ID int64 `json:"id" example:"1"`
|
||||
Username string `json:"username" example:"testuser"`
|
||||
Email string `json:"email" example:"test@example.com"`
|
||||
Avatar string `json:"avatar" example:"https://example.com/avatar.png"`
|
||||
Points int `json:"points" example:"100"`
|
||||
Role string `json:"role" example:"user"`
|
||||
Status int16 `json:"status" example:"1"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty" example:"2025-10-01T12:00:00Z"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// TextureType 材质类型
|
||||
type TextureType string
|
||||
|
||||
const (
|
||||
TextureTypeSkin TextureType = "SKIN"
|
||||
TextureTypeCape TextureType = "CAPE"
|
||||
)
|
||||
|
||||
// TextureInfo 材质信息
|
||||
type TextureInfo struct {
|
||||
ID int64 `json:"id" example:"1"`
|
||||
UploaderID int64 `json:"uploader_id" example:"1"`
|
||||
Name string `json:"name" example:"My Skin"`
|
||||
Description string `json:"description,omitempty" example:"A cool skin"`
|
||||
Type TextureType `json:"type" example:"SKIN"`
|
||||
URL string `json:"url" example:"https://rustfs.example.com/textures/xxx.png"`
|
||||
Hash string `json:"hash" example:"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"`
|
||||
Size int `json:"size" example:"2048"`
|
||||
IsPublic bool `json:"is_public" example:"true"`
|
||||
DownloadCount int `json:"download_count" example:"100"`
|
||||
FavoriteCount int `json:"favorite_count" example:"50"`
|
||||
IsSlim bool `json:"is_slim" example:"false"`
|
||||
Status int16 `json:"status" example:"1"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// ProfileInfo 角色信息
|
||||
type ProfileInfo struct {
|
||||
UUID string `json:"uuid" example:"550e8400-e29b-41d4-a716-446655440000"`
|
||||
UserID int64 `json:"user_id" example:"1"`
|
||||
Name string `json:"name" example:"PlayerName"`
|
||||
SkinID *int64 `json:"skin_id,omitempty" example:"1"`
|
||||
CapeID *int64 `json:"cape_id,omitempty" example:"2"`
|
||||
IsActive bool `json:"is_active" example:"true"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty" example:"2025-10-01T12:00:00Z"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2025-10-01T10:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2025-10-01T10:00:00Z"`
|
||||
}
|
||||
|
||||
// UploadURLRequest 上传URL请求
|
||||
type UploadURLRequest struct {
|
||||
Type TextureType `json:"type" binding:"required,oneof=SKIN CAPE"`
|
||||
Filename string `json:"filename" binding:"required"`
|
||||
}
|
||||
|
||||
// UploadURLResponse 上传URL响应
|
||||
type UploadURLResponse struct {
|
||||
PostURL string `json:"post_url"`
|
||||
FormData map[string]string `json:"form_data"`
|
||||
FileURL string `json:"file_url"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// CreateTextureRequest 创建材质请求
|
||||
type CreateTextureRequest struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100" example:"My Cool Skin"`
|
||||
Description string `json:"description" binding:"max=500" example:"A very cool skin"`
|
||||
Type TextureType `json:"type" binding:"required,oneof=SKIN CAPE" example:"SKIN"`
|
||||
URL string `json:"url" binding:"required,url" example:"https://rustfs.example.com/textures/user_1/skin/xxx.png"`
|
||||
Hash string `json:"hash" binding:"required,len=64" example:"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"`
|
||||
Size int `json:"size" binding:"required,min=1" example:"2048"`
|
||||
IsPublic bool `json:"is_public" example:"true"`
|
||||
IsSlim bool `json:"is_slim" example:"false"` // Alex模型(细臂)为true,Steve模型(粗臂)为false
|
||||
}
|
||||
|
||||
// SearchTextureRequest 搜索材质请求
|
||||
type SearchTextureRequest struct {
|
||||
PaginationRequest
|
||||
Keyword string `json:"keyword" form:"keyword"`
|
||||
Type TextureType `json:"type" form:"type" binding:"omitempty,oneof=SKIN CAPE"`
|
||||
PublicOnly bool `json:"public_only" form:"public_only"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest 更新角色请求
|
||||
type UpdateProfileRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=16" example:"NewPlayerName"`
|
||||
SkinID *int64 `json:"skin_id,omitempty" example:"1"`
|
||||
CapeID *int64 `json:"cape_id,omitempty" example:"2"`
|
||||
}
|
||||
|
||||
// SystemConfigResponse 基础系统配置响应
|
||||
type SystemConfigResponse struct {
|
||||
SiteName string `json:"site_name" example:"CarrotSkin"`
|
||||
SiteDescription string `json:"site_description" example:"A Minecraft Skin Station"`
|
||||
RegistrationEnabled bool `json:"registration_enabled" example:"true"`
|
||||
MaxTexturesPerUser int `json:"max_textures_per_user" example:"100"`
|
||||
MaxProfilesPerUser int `json:"max_profiles_per_user" example:"5"`
|
||||
}
|
||||
384
internal/types/common_test.go
Normal file
384
internal/types/common_test.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPaginationRequest_Validation 测试分页请求验证逻辑
|
||||
func TestPaginationRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的分页参数",
|
||||
page: 1,
|
||||
pageSize: 20,
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "page小于1应该无效",
|
||||
page: 0,
|
||||
pageSize: 20,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "pageSize小于1应该无效",
|
||||
page: 1,
|
||||
pageSize: 0,
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "pageSize超过100应该无效",
|
||||
page: 1,
|
||||
pageSize: 200,
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.page >= 1 && tt.pageSize >= 1 && tt.pageSize <= 100
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestTextureType_Constants 测试材质类型常量
|
||||
func TestTextureType_Constants(t *testing.T) {
|
||||
if TextureTypeSkin != "SKIN" {
|
||||
t.Errorf("TextureTypeSkin = %q, want 'SKIN'", TextureTypeSkin)
|
||||
}
|
||||
|
||||
if TextureTypeCape != "CAPE" {
|
||||
t.Errorf("TextureTypeCape = %q, want 'CAPE'", TextureTypeCape)
|
||||
}
|
||||
|
||||
if TextureTypeSkin == TextureTypeCape {
|
||||
t.Error("TextureTypeSkin 和 TextureTypeCape 应该不同")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaginationResponse_Structure 测试分页响应结构
|
||||
func TestPaginationResponse_Structure(t *testing.T) {
|
||||
resp := PaginationResponse{
|
||||
List: []string{"a", "b", "c"},
|
||||
Total: 100,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
TotalPages: 5,
|
||||
}
|
||||
|
||||
if resp.Total != 100 {
|
||||
t.Errorf("Total = %d, want 100", resp.Total)
|
||||
}
|
||||
|
||||
if resp.Page != 1 {
|
||||
t.Errorf("Page = %d, want 1", resp.Page)
|
||||
}
|
||||
|
||||
if resp.PageSize != 20 {
|
||||
t.Errorf("PageSize = %d, want 20", resp.PageSize)
|
||||
}
|
||||
|
||||
if resp.TotalPages != 5 {
|
||||
t.Errorf("TotalPages = %d, want 5", resp.TotalPages)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPaginationResponse_TotalPagesCalculation 测试总页数计算逻辑
|
||||
func TestPaginationResponse_TotalPagesCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
total int64
|
||||
pageSize int
|
||||
wantPages int
|
||||
}{
|
||||
{
|
||||
name: "正好整除",
|
||||
total: 100,
|
||||
pageSize: 20,
|
||||
wantPages: 5,
|
||||
},
|
||||
{
|
||||
name: "有余数",
|
||||
total: 101,
|
||||
pageSize: 20,
|
||||
wantPages: 6, // 向上取整
|
||||
},
|
||||
{
|
||||
name: "总数小于每页数量",
|
||||
total: 10,
|
||||
pageSize: 20,
|
||||
wantPages: 1,
|
||||
},
|
||||
{
|
||||
name: "总数为0",
|
||||
total: 0,
|
||||
pageSize: 20,
|
||||
wantPages: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 计算总页数:向上取整
|
||||
var totalPages int
|
||||
if tt.total == 0 {
|
||||
totalPages = 0
|
||||
} else {
|
||||
totalPages = int((tt.total + int64(tt.pageSize) - 1) / int64(tt.pageSize))
|
||||
}
|
||||
|
||||
if totalPages != tt.wantPages {
|
||||
t.Errorf("TotalPages = %d, want %d", totalPages, tt.wantPages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseResponse_Structure 测试基础响应结构
|
||||
func TestBaseResponse_Structure(t *testing.T) {
|
||||
resp := BaseResponse{
|
||||
Code: 200,
|
||||
Message: "success",
|
||||
Data: "test data",
|
||||
}
|
||||
|
||||
if resp.Code != 200 {
|
||||
t.Errorf("Code = %d, want 200", resp.Code)
|
||||
}
|
||||
|
||||
if resp.Message != "success" {
|
||||
t.Errorf("Message = %q, want 'success'", resp.Message)
|
||||
}
|
||||
|
||||
if resp.Data != "test data" {
|
||||
t.Errorf("Data = %v, want 'test data'", resp.Data)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoginRequest_Validation 测试登录请求验证逻辑
|
||||
func TestLoginRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的登录请求",
|
||||
username: "testuser",
|
||||
password: "password123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
password: "password123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码为空",
|
||||
username: "testuser",
|
||||
password: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码长度小于6",
|
||||
username: "testuser",
|
||||
password: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "密码长度超过128",
|
||||
username: "testuser",
|
||||
password: string(make([]byte, 129)),
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.username != "" && len(tt.password) >= 6 && len(tt.password) <= 128
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterRequest_Validation 测试注册请求验证逻辑
|
||||
func TestRegisterRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
email string
|
||||
password string
|
||||
verificationCode string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的注册请求",
|
||||
username: "newuser",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "用户名为空",
|
||||
username: "",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户名长度小于3",
|
||||
username: "ab",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "用户名长度超过50",
|
||||
username: string(make([]byte, 51)),
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "邮箱格式无效",
|
||||
username: "newuser",
|
||||
email: "invalid-email",
|
||||
password: "password123",
|
||||
verificationCode: "123456",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "验证码长度不是6",
|
||||
username: "newuser",
|
||||
email: "user@example.com",
|
||||
password: "password123",
|
||||
verificationCode: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.username != "" &&
|
||||
len(tt.username) >= 3 && len(tt.username) <= 50 &&
|
||||
tt.email != "" && contains(tt.email, "@") &&
|
||||
len(tt.password) >= 6 && len(tt.password) <= 128 &&
|
||||
len(tt.verificationCode) == 6
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestResetPasswordRequest_Validation 测试重置密码请求验证
|
||||
func TestResetPasswordRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
verificationCode string
|
||||
newPassword string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的重置密码请求",
|
||||
email: "user@example.com",
|
||||
verificationCode: "123456",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "邮箱为空",
|
||||
email: "",
|
||||
verificationCode: "123456",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "验证码长度不是6",
|
||||
email: "user@example.com",
|
||||
verificationCode: "12345",
|
||||
newPassword: "newpassword123",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "新密码长度小于6",
|
||||
email: "user@example.com",
|
||||
verificationCode: "123456",
|
||||
newPassword: "12345",
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.email != "" &&
|
||||
len(tt.verificationCode) == 6 &&
|
||||
len(tt.newPassword) >= 6 && len(tt.newPassword) <= 128
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateProfileRequest_Validation 测试创建档案请求验证
|
||||
func TestCreateProfileRequest_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileName string
|
||||
wantValid bool
|
||||
}{
|
||||
{
|
||||
name: "有效的档案名",
|
||||
profileName: "PlayerName",
|
||||
wantValid: true,
|
||||
},
|
||||
{
|
||||
name: "档案名为空",
|
||||
profileName: "",
|
||||
wantValid: false,
|
||||
},
|
||||
{
|
||||
name: "档案名长度超过16",
|
||||
profileName: string(make([]byte, 17)),
|
||||
wantValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid := tt.profileName != "" &&
|
||||
len(tt.profileName) >= 1 && len(tt.profileName) <= 16
|
||||
if isValid != tt.wantValid {
|
||||
t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user