diff --git a/go.mod b/go.mod index 377b009..e083b3d 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.0 github.com/joho/godotenv v1.5.1 github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible - github.com/lib/pq v1.10.9 github.com/minio/minio-go/v7 v7.0.66 github.com/redis/go-redis/v9 v9.0.5 github.com/spf13/viper v1.21.0 @@ -28,6 +27,7 @@ require ( github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/stretchr/testify v1.11.1 // indirect golang.org/x/image v0.16.0 // indirect golang.org/x/sync v0.16.0 // indirect gorm.io/driver/mysql v1.5.6 // indirect diff --git a/internal/handler/auth_handler.go b/internal/handler/auth_handler.go index c2ae087..143c7ea 100644 --- a/internal/handler/auth_handler.go +++ b/internal/handler/auth_handler.go @@ -1,17 +1,29 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/auth" "carrotskin/pkg/email" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" "github.com/gin-gonic/gin" "go.uber.org/zap" ) +// AuthHandler 认证处理器(依赖注入版本) +type AuthHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewAuthHandler 创建AuthHandler实例 +func NewAuthHandler(c *container.Container) *AuthHandler { + return &AuthHandler{ + container: c, + logger: c.Logger, + } +} + // Register 用户注册 // @Summary 用户注册 // @Description 注册新用户账号 @@ -22,11 +34,7 @@ import ( // @Success 200 {object} model.Response "注册成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/register [post] -func Register(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - jwtService := auth.MustGetJWTService() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) Register(c *gin.Context) { var req types.RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -34,16 +42,16 @@ func Register(c *gin.Context) { } // 验证邮箱验证码 - if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } // 注册用户 - user, token, err := service.RegisterUser(jwtService, req.Username, req.Password, req.Email, req.Avatar) + user, token, err := h.container.UserService.Register(req.Username, req.Password, req.Email, req.Avatar) if err != nil { - loggerInstance.Error("用户注册失败", zap.Error(err)) + h.logger.Error("用户注册失败", zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } @@ -65,11 +73,7 @@ func Register(c *gin.Context) { // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Failure 401 {object} model.ErrorResponse "登录失败" // @Router /api/v1/auth/login [post] -func Login(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - jwtService := auth.MustGetJWTService() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) Login(c *gin.Context) { var req types.LoginRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -79,9 +83,9 @@ func Login(c *gin.Context) { ipAddress := c.ClientIP() userAgent := c.GetHeader("User-Agent") - user, token, err := service.LoginUserWithRateLimit(redisClient, jwtService, req.Username, req.Password, ipAddress, userAgent) + user, token, err := h.container.UserService.Login(req.Username, req.Password, ipAddress, userAgent) if err != nil { - loggerInstance.Warn("用户登录失败", + h.logger.Warn("用户登录失败", zap.String("username_or_email", req.Username), zap.String("ip", ipAddress), zap.Error(err), @@ -106,19 +110,21 @@ func Login(c *gin.Context) { // @Success 200 {object} model.Response "发送成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/send-code [post] -func SendVerificationCode(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - emailService := email.MustGetService() - +func (h *AuthHandler) SendVerificationCode(c *gin.Context) { var req types.SendVerificationCodeRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) return } - if err := service.SendVerificationCode(c.Request.Context(), redisClient, emailService, req.Email, req.Type); err != nil { - loggerInstance.Error("发送验证码失败", + emailService, err := h.getEmailService() + if err != nil { + RespondServerError(c, "邮件服务不可用", err) + return + } + + if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { + h.logger.Error("发送验证码失败", zap.String("email", req.Email), zap.String("type", req.Type), zap.Error(err), @@ -140,10 +146,7 @@ func SendVerificationCode(c *gin.Context) { // @Success 200 {object} model.Response "重置成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/reset-password [post] -func ResetPassword(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) ResetPassword(c *gin.Context) { var req types.ResetPasswordRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -151,18 +154,23 @@ func ResetPassword(c *gin.Context) { } // 验证验证码 - if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } // 重置密码 - if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { - loggerInstance.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) + if err := h.container.UserService.ResetPassword(req.Email, req.NewPassword); err != nil { + h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) RespondServerError(c, err.Error(), nil) return } RespondSuccess(c, gin.H{"message": "密码重置成功"}) } + +// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) +func (h *AuthHandler) getEmailService() (*email.Service, error) { + return email.GetService() +} diff --git a/internal/handler/auth_handler_di.go b/internal/handler/auth_handler_di.go deleted file mode 100644 index 9087008..0000000 --- a/internal/handler/auth_handler_di.go +++ /dev/null @@ -1,177 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - "carrotskin/pkg/email" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// AuthHandler 认证处理器(依赖注入版本) -type AuthHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewAuthHandler 创建AuthHandler实例 -func NewAuthHandler(c *container.Container) *AuthHandler { - return &AuthHandler{ - container: c, - logger: c.Logger, - } -} - -// Register 用户注册 -// @Summary 用户注册 -// @Description 注册新用户账号 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.RegisterRequest true "注册信息" -// @Success 200 {object} model.Response "注册成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/register [post] -func (h *AuthHandler) Register(c *gin.Context) { - var req types.RegisterRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - // 验证邮箱验证码 - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { - h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - // 注册用户 - user, token, err := service.RegisterUser(h.container.JWT, req.Username, req.Password, req.Email, req.Avatar) - if err != nil { - h.logger.Error("用户注册失败", zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.LoginResponse{ - Token: token, - UserInfo: UserToUserInfo(user), - }) -} - -// Login 用户登录 -// @Summary 用户登录 -// @Description 用户登录获取JWT Token,支持用户名或邮箱登录 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)" -// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "登录失败" -// @Router /api/v1/auth/login [post] -func (h *AuthHandler) Login(c *gin.Context) { - var req types.LoginRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - ipAddress := c.ClientIP() - userAgent := c.GetHeader("User-Agent") - - user, token, err := service.LoginUserWithRateLimit(h.container.Redis, h.container.JWT, req.Username, req.Password, ipAddress, userAgent) - if err != nil { - h.logger.Warn("用户登录失败", - zap.String("username_or_email", req.Username), - zap.String("ip", ipAddress), - zap.Error(err), - ) - RespondUnauthorized(c, err.Error()) - return - } - - RespondSuccess(c, &types.LoginResponse{ - Token: token, - UserInfo: UserToUserInfo(user), - }) -} - -// SendVerificationCode 发送验证码 -// @Summary 发送验证码 -// @Description 发送邮箱验证码(注册/重置密码/更换邮箱) -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.SendVerificationCodeRequest true "发送验证码请求" -// @Success 200 {object} model.Response "发送成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/send-code [post] -func (h *AuthHandler) SendVerificationCode(c *gin.Context) { - var req types.SendVerificationCodeRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - emailService, err := h.getEmailService() - if err != nil { - RespondServerError(c, "邮件服务不可用", err) - return - } - - if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { - h.logger.Error("发送验证码失败", - zap.String("email", req.Email), - zap.String("type", req.Type), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, gin.H{"message": "验证码已发送,请查收邮件"}) -} - -// ResetPassword 重置密码 -// @Summary 重置密码 -// @Description 通过邮箱验证码重置密码 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.ResetPasswordRequest true "重置密码请求" -// @Success 200 {object} model.Response "重置成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/reset-password [post] -func (h *AuthHandler) ResetPassword(c *gin.Context) { - var req types.ResetPasswordRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - // 验证验证码 - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { - h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - // 重置密码 - if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { - h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, gin.H{"message": "密码重置成功"}) -} - -// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) -func (h *AuthHandler) getEmailService() (*email.Service, error) { - return email.GetService() -} - diff --git a/internal/handler/captcha_handler.go b/internal/handler/captcha_handler.go index c7e8942..f9849d0 100644 --- a/internal/handler/captcha_handler.go +++ b/internal/handler/captcha_handler.go @@ -1,47 +1,77 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" - "carrotskin/pkg/redis" "net/http" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) +// CaptchaHandler 验证码处理器 +type CaptchaHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewCaptchaHandler 创建CaptchaHandler实例 +func NewCaptchaHandler(c *container.Container) *CaptchaHandler { + return &CaptchaHandler{ + container: c, + logger: c.Logger, + } +} + +// CaptchaVerifyRequest 验证码验证请求 +type CaptchaVerifyRequest struct { + CaptchaID string `json:"captchaId" binding:"required"` + Dx int `json:"dx" binding:"required"` +} + // Generate 生成验证码 -func Generate(c *gin.Context) { - // 调用验证码服务生成验证码数据 - redisClient := redis.MustGetClient() - masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), redisClient) +// @Summary 生成滑动验证码 +// @Description 生成滑动验证码图片 +// @Tags captcha +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} "生成成功" +// @Failure 500 {object} map[string]interface{} "生成失败" +// @Router /api/v1/captcha/generate [get] +func (h *CaptchaHandler) Generate(c *gin.Context) { + masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis) if err != nil { + h.logger.Error("生成验证码失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, - "msg": "生成验证码失败: " + err.Error(), + "msg": "生成验证码失败", }) return } - // 返回验证码数据给前端 c.JSON(http.StatusOK, gin.H{ "code": 200, "data": gin.H{ - "masterImage": masterImg, // 主图(base64格式) - "tileImage": tileImg, // 滑块图(base64格式) - "captchaId": captchaID, // 验证码唯一标识(用于后续验证) - "y": y, // 滑块Y坐标(前端可用于定位滑块初始位置) + "masterImage": masterImg, + "tileImage": tileImg, + "captchaId": captchaID, + "y": y, }, }) } // Verify 验证验证码 -func Verify(c *gin.Context) { - // 定义请求参数结构体 - var req struct { - CaptchaID string `json:"captchaId" binding:"required"` // 验证码唯一标识 - Dx int `json:"dx" binding:"required"` // 用户滑动的X轴偏移量 - } - - // 解析并校验请求参数 +// @Summary 验证滑动验证码 +// @Description 验证用户滑动的偏移量是否正确 +// @Tags captcha +// @Accept json +// @Produce json +// @Param request body CaptchaVerifyRequest true "验证请求" +// @Success 200 {object} map[string]interface{} "验证结果" +// @Failure 400 {object} map[string]interface{} "参数错误" +// @Router /api/v1/captcha/verify [post] +func (h *CaptchaHandler) Verify(c *gin.Context) { + var req CaptchaVerifyRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, @@ -50,18 +80,19 @@ func Verify(c *gin.Context) { return } - // 调用验证码服务验证偏移量 - redisClient := redis.MustGetClient() - valid, err := service.VerifyCaptchaData(c.Request.Context(), redisClient, req.Dx, req.CaptchaID) + valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID) if err != nil { + h.logger.Error("验证码验证失败", + zap.String("captcha_id", req.CaptchaID), + zap.Error(err), + ) c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, - "msg": "验证失败: " + err.Error(), + "msg": "验证失败", }) return } - // 根据验证结果返回响应 if valid { c.JSON(http.StatusOK, gin.H{ "code": 200, @@ -74,3 +105,5 @@ func Verify(c *gin.Context) { }) } } + + diff --git a/internal/handler/captcha_handler_di.go b/internal/handler/captcha_handler_di.go deleted file mode 100644 index f9849d0..0000000 --- a/internal/handler/captcha_handler_di.go +++ /dev/null @@ -1,109 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "net/http" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// CaptchaHandler 验证码处理器 -type CaptchaHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewCaptchaHandler 创建CaptchaHandler实例 -func NewCaptchaHandler(c *container.Container) *CaptchaHandler { - return &CaptchaHandler{ - container: c, - logger: c.Logger, - } -} - -// CaptchaVerifyRequest 验证码验证请求 -type CaptchaVerifyRequest struct { - CaptchaID string `json:"captchaId" binding:"required"` - Dx int `json:"dx" binding:"required"` -} - -// Generate 生成验证码 -// @Summary 生成滑动验证码 -// @Description 生成滑动验证码图片 -// @Tags captcha -// @Accept json -// @Produce json -// @Success 200 {object} map[string]interface{} "生成成功" -// @Failure 500 {object} map[string]interface{} "生成失败" -// @Router /api/v1/captcha/generate [get] -func (h *CaptchaHandler) Generate(c *gin.Context) { - masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis) - if err != nil { - h.logger.Error("生成验证码失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "msg": "生成验证码失败", - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 200, - "data": gin.H{ - "masterImage": masterImg, - "tileImage": tileImg, - "captchaId": captchaID, - "y": y, - }, - }) -} - -// Verify 验证验证码 -// @Summary 验证滑动验证码 -// @Description 验证用户滑动的偏移量是否正确 -// @Tags captcha -// @Accept json -// @Produce json -// @Param request body CaptchaVerifyRequest true "验证请求" -// @Success 200 {object} map[string]interface{} "验证结果" -// @Failure 400 {object} map[string]interface{} "参数错误" -// @Router /api/v1/captcha/verify [post] -func (h *CaptchaHandler) Verify(c *gin.Context) { - var req CaptchaVerifyRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "msg": "参数错误: " + err.Error(), - }) - return - } - - valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID) - if err != nil { - h.logger.Error("验证码验证失败", - zap.String("captcha_id", req.CaptchaID), - zap.Error(err), - ) - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "msg": "验证失败", - }) - return - } - - if valid { - c.JSON(http.StatusOK, gin.H{ - "code": 200, - "msg": "验证成功", - }) - } else { - c.JSON(http.StatusOK, gin.H{ - "code": 400, - "msg": "验证失败,请重试", - }) - } -} - - diff --git a/internal/handler/profile_handler.go b/internal/handler/profile_handler.go index cc0063b..daa029a 100644 --- a/internal/handler/profile_handler.go +++ b/internal/handler/profile_handler.go @@ -1,16 +1,28 @@ package handler import ( - "carrotskin/internal/service" + "carrotskin/internal/container" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// CreateProfile 创建档案 +// ProfileHandler 档案处理器 +type ProfileHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewProfileHandler 创建ProfileHandler实例 +func NewProfileHandler(c *container.Container) *ProfileHandler { + return &ProfileHandler{ + container: c, + logger: c.Logger, + } +} + +// Create 创建档案 // @Summary 创建Minecraft档案 // @Description 创建新的Minecraft角色档案,UUID由后端自动生成 // @Tags profile @@ -18,12 +30,10 @@ import ( // @Produce json // @Security BearerAuth // @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)" -// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功,返回完整档案信息(含自动生成的UUID)" -// @Failure 400 {object} model.ErrorResponse "请求参数错误或已达档案数量上限" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" +// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/profile [post] -func CreateProfile(c *gin.Context) { +func (h *ProfileHandler) Create(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -35,17 +45,15 @@ func CreateProfile(c *gin.Context) { return } - maxProfiles := service.GetMaxProfilesPerUser() - db := database.MustGetDB() - - if err := service.CheckProfileLimit(db, userID, maxProfiles); err != nil { + maxProfiles := h.container.UserService.GetMaxProfilesPerUser() + if err := h.container.ProfileService.CheckLimit(userID, maxProfiles); err != nil { RespondBadRequest(c, err.Error(), nil) return } - profile, err := service.CreateProfile(db, userID, req.Name) + profile, err := h.container.ProfileService.Create(userID, req.Name) if err != nil { - logger.MustGetLogger().Error("创建档案失败", + h.logger.Error("创建档案失败", zap.Int64("user_id", userID), zap.String("name", req.Name), zap.Error(err), @@ -57,7 +65,7 @@ func CreateProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// GetProfiles 获取档案列表 +// List 获取档案列表 // @Summary 获取档案列表 // @Description 获取当前用户的所有档案 // @Tags profile @@ -65,18 +73,16 @@ func CreateProfile(c *gin.Context) { // @Produce json // @Security BearerAuth // @Success 200 {object} model.Response "获取成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile [get] -func GetProfiles(c *gin.Context) { +func (h *ProfileHandler) List(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - profiles, err := service.GetUserProfiles(database.MustGetDB(), userID) + profiles, err := h.container.ProfileService.GetByUserID(userID) if err != nil { - logger.MustGetLogger().Error("获取档案列表失败", + h.logger.Error("获取档案列表失败", zap.Int64("user_id", userID), zap.Error(err), ) @@ -87,7 +93,7 @@ func GetProfiles(c *gin.Context) { RespondSuccess(c, ProfilesToProfileInfos(profiles)) } -// GetProfile 获取档案详情 +// Get 获取档案详情 // @Summary 获取档案详情 // @Description 根据UUID获取档案详细信息 // @Tags profile @@ -96,14 +102,17 @@ func GetProfiles(c *gin.Context) { // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "获取成功" // @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [get] -func GetProfile(c *gin.Context) { +func (h *ProfileHandler) Get(c *gin.Context) { uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - profile, err := service.GetProfileByUUID(database.MustGetDB(), uuid) + profile, err := h.container.ProfileService.GetByUUID(uuid) if err != nil { - logger.MustGetLogger().Error("获取档案失败", + h.logger.Error("获取档案失败", zap.String("uuid", uuid), zap.Error(err), ) @@ -114,7 +123,7 @@ func GetProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// UpdateProfile 更新档案 +// Update 更新档案 // @Summary 更新档案 // @Description 更新档案信息 // @Tags profile @@ -124,19 +133,19 @@ func GetProfile(c *gin.Context) { // @Param uuid path string true "档案UUID" // @Param request body types.UpdateProfileRequest true "更新信息" // @Success 200 {object} model.Response "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [put] -func UpdateProfile(c *gin.Context) { +func (h *ProfileHandler) Update(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } var req types.UpdateProfileRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -149,9 +158,9 @@ func UpdateProfile(c *gin.Context) { namePtr = &req.Name } - profile, err := service.UpdateProfile(database.MustGetDB(), uuid, userID, namePtr, req.SkinID, req.CapeID) + profile, err := h.container.ProfileService.Update(uuid, userID, namePtr, req.SkinID, req.CapeID) if err != nil { - logger.MustGetLogger().Error("更新档案失败", + h.logger.Error("更新档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), @@ -163,7 +172,7 @@ func UpdateProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// DeleteProfile 删除档案 +// Delete 删除档案 // @Summary 删除档案 // @Description 删除指定的Minecraft档案 // @Tags profile @@ -172,22 +181,22 @@ func UpdateProfile(c *gin.Context) { // @Security BearerAuth // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "删除成功" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [delete] -func DeleteProfile(c *gin.Context) { +func (h *ProfileHandler) Delete(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - err := service.DeleteProfile(database.MustGetDB(), uuid, userID) - if err != nil { - logger.MustGetLogger().Error("删除档案失败", + if err := h.container.ProfileService.Delete(uuid, userID); err != nil { + h.logger.Error("删除档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), @@ -199,7 +208,7 @@ func DeleteProfile(c *gin.Context) { RespondSuccess(c, gin.H{"message": "删除成功"}) } -// SetActiveProfile 设置活跃档案 +// SetActive 设置活跃档案 // @Summary 设置活跃档案 // @Description 将指定档案设置为活跃状态 // @Tags profile @@ -208,22 +217,22 @@ func DeleteProfile(c *gin.Context) { // @Security BearerAuth // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "设置成功" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid}/activate [post] -func SetActiveProfile(c *gin.Context) { +func (h *ProfileHandler) SetActive(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - err := service.SetActiveProfile(database.MustGetDB(), uuid, userID) - if err != nil { - logger.MustGetLogger().Error("设置活跃档案失败", + if err := h.container.ProfileService.SetActive(uuid, userID); err != nil { + h.logger.Error("设置活跃档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), diff --git a/internal/handler/profile_handler_di.go b/internal/handler/profile_handler_di.go deleted file mode 100644 index 6fdbeb9..0000000 --- a/internal/handler/profile_handler_di.go +++ /dev/null @@ -1,247 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// ProfileHandler 档案处理器 -type ProfileHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewProfileHandler 创建ProfileHandler实例 -func NewProfileHandler(c *container.Container) *ProfileHandler { - return &ProfileHandler{ - container: c, - logger: c.Logger, - } -} - -// Create 创建档案 -// @Summary 创建Minecraft档案 -// @Description 创建新的Minecraft角色档案,UUID由后端自动生成 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)" -// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/profile [post] -func (h *ProfileHandler) Create(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.CreateProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) - return - } - - maxProfiles := service.GetMaxProfilesPerUser() - if err := service.CheckProfileLimit(h.container.DB, userID, maxProfiles); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - profile, err := service.CreateProfile(h.container.DB, userID, req.Name) - if err != nil { - h.logger.Error("创建档案失败", - zap.Int64("user_id", userID), - zap.String("name", req.Name), - zap.Error(err), - ) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// List 获取档案列表 -// @Summary 获取档案列表 -// @Description 获取当前用户的所有档案 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "获取成功" -// @Router /api/v1/profile [get] -func (h *ProfileHandler) List(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - profiles, err := service.GetUserProfiles(h.container.DB, userID) - if err != nil { - h.logger.Error("获取档案列表失败", - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, ProfilesToProfileInfos(profiles)) -} - -// Get 获取档案详情 -// @Summary 获取档案详情 -// @Description 根据UUID获取档案详细信息 -// @Tags profile -// @Accept json -// @Produce json -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "获取成功" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Router /api/v1/profile/{uuid} [get] -func (h *ProfileHandler) Get(c *gin.Context) { - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - profile, err := service.GetProfileByUUID(h.container.DB, uuid) - if err != nil { - h.logger.Error("获取档案失败", - zap.String("uuid", uuid), - zap.Error(err), - ) - RespondNotFound(c, err.Error()) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// Update 更新档案 -// @Summary 更新档案 -// @Description 更新档案信息 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Param request body types.UpdateProfileRequest true "更新信息" -// @Success 200 {object} model.Response "更新成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid} [put] -func (h *ProfileHandler) Update(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - var req types.UpdateProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) - return - } - - var namePtr *string - if req.Name != "" { - namePtr = &req.Name - } - - profile, err := service.UpdateProfile(h.container.DB, uuid, userID, namePtr, req.SkinID, req.CapeID) - if err != nil { - h.logger.Error("更新档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// Delete 删除档案 -// @Summary 删除档案 -// @Description 删除指定的Minecraft档案 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "删除成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid} [delete] -func (h *ProfileHandler) Delete(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - if err := service.DeleteProfile(h.container.DB, uuid, userID); err != nil { - h.logger.Error("删除档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, gin.H{"message": "删除成功"}) -} - -// SetActive 设置活跃档案 -// @Summary 设置活跃档案 -// @Description 将指定档案设置为活跃状态 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "设置成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid}/activate [post] -func (h *ProfileHandler) SetActive(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - if err := service.SetActiveProfile(h.container.DB, uuid, userID); err != nil { - h.logger.Error("设置活跃档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, gin.H{"message": "设置成功"}) -} - diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 95cee4c..a6da9c8 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -1,142 +1,193 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/middleware" "carrotskin/internal/model" "github.com/gin-gonic/gin" ) -// RegisterRoutes 注册所有路由 -func RegisterRoutes(router *gin.Engine) { +// Handlers 集中管理所有Handler +type Handlers struct { + Auth *AuthHandler + User *UserHandler + Texture *TextureHandler + Profile *ProfileHandler + Captcha *CaptchaHandler + Yggdrasil *YggdrasilHandler +} + +// NewHandlers 创建所有Handler实例 +func NewHandlers(c *container.Container) *Handlers { + return &Handlers{ + Auth: NewAuthHandler(c), + User: NewUserHandler(c), + Texture: NewTextureHandler(c), + Profile: NewProfileHandler(c), + Captcha: NewCaptchaHandler(c), + Yggdrasil: NewYggdrasilHandler(c), + } +} + +// RegisterRoutesWithDI 使用依赖注入注册所有路由 +func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { // 设置Swagger文档 SetupSwagger(router) + // 创建Handler实例 + h := NewHandlers(c) + // API路由组 v1 := router.Group("/api/v1") { // 认证路由(无需JWT) - authGroup := v1.Group("/auth") - { - authGroup.POST("/register", Register) - authGroup.POST("/login", Login) - authGroup.POST("/send-code", SendVerificationCode) - authGroup.POST("/reset-password", ResetPassword) - } + registerAuthRoutes(v1, h.Auth) // 用户路由(需要JWT认证) - userGroup := v1.Group("/user") - userGroup.Use(middleware.AuthMiddleware()) - { - userGroup.GET("/profile", GetUserProfile) - userGroup.PUT("/profile", UpdateUserProfile) - - // 头像相关 - userGroup.POST("/avatar/upload-url", GenerateAvatarUploadURL) - userGroup.PUT("/avatar", UpdateAvatar) - - // 更换邮箱 - userGroup.POST("/change-email", ChangeEmail) - - // Yggdrasil密码相关 - userGroup.POST("/yggdrasil-password/reset", ResetYggdrasilPassword) // 重置Yggdrasil密码并返回新密码 - } + registerUserRoutes(v1, h.User) // 材质路由 - textureGroup := v1.Group("/texture") - { - // 公开路由(无需认证) - textureGroup.GET("", SearchTextures) // 搜索材质 - textureGroup.GET("/:id", GetTexture) // 获取材质详情 - - // 需要认证的路由 - textureAuth := textureGroup.Group("") - textureAuth.Use(middleware.AuthMiddleware()) - { - textureAuth.POST("/upload-url", GenerateTextureUploadURL) // 生成上传URL - textureAuth.POST("", CreateTexture) // 创建材质记录 - textureAuth.PUT("/:id", UpdateTexture) // 更新材质 - textureAuth.DELETE("/:id", DeleteTexture) // 删除材质 - textureAuth.POST("/:id/favorite", ToggleFavorite) // 切换收藏 - textureAuth.GET("/my", GetUserTextures) // 我的材质 - textureAuth.GET("/favorites", GetUserFavorites) // 我的收藏 - } - } + registerTextureRoutes(v1, h.Texture) // 档案路由 - profileGroup := v1.Group("/profile") - { - // 公开路由(无需认证) - profileGroup.GET("/:uuid", GetProfile) // 获取档案详情 + registerProfileRoutesWithDI(v1, h.Profile) - // 需要认证的路由 - profileAuth := profileGroup.Group("") - profileAuth.Use(middleware.AuthMiddleware()) - { - profileAuth.POST("/", CreateProfile) // 创建档案 - profileAuth.GET("/", GetProfiles) // 获取我的档案列表 - profileAuth.PUT("/:uuid", UpdateProfile) // 更新档案 - profileAuth.DELETE("/:uuid", DeleteProfile) // 删除档案 - profileAuth.POST("/:uuid/activate", SetActiveProfile) // 设置活跃档案 - } - } // 验证码路由 - captchaGroup := v1.Group("/captcha") - { - captchaGroup.GET("/generate", Generate) //生成验证码 - captchaGroup.POST("/verify", Verify) //验证验证码 - } + registerCaptchaRoutesWithDI(v1, h.Captcha) // Yggdrasil API路由组 - ygg := v1.Group("/yggdrasil") - { - ygg.GET("", GetMetaData) - ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates) - authserver := ygg.Group("/authserver") - { - authserver.POST("/authenticate", Authenticate) - authserver.POST("/validate", ValidToken) - authserver.POST("/refresh", RefreshToken) - authserver.POST("/invalidate", InvalidToken) - authserver.POST("/signout", SignOut) - } - sessionServer := ygg.Group("/sessionserver") - { - sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID) - sessionServer.POST("/session/minecraft/join", JoinServer) - sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer) - } - api := ygg.Group("/api") - profiles := api.Group("/profiles") - { - profiles.POST("/minecraft", GetProfilesByName) - } - } + registerYggdrasilRoutesWithDI(v1, h.Yggdrasil) + // 系统路由 - system := v1.Group("/system") + registerSystemRoutes(v1) + } +} + +// registerAuthRoutes 注册认证路由 +func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { + authGroup := v1.Group("/auth") + { + authGroup.POST("/register", h.Register) + authGroup.POST("/login", h.Login) + authGroup.POST("/send-code", h.SendVerificationCode) + authGroup.POST("/reset-password", h.ResetPassword) + } +} + +// registerUserRoutes 注册用户路由 +func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { + userGroup := v1.Group("/user") + userGroup.Use(middleware.AuthMiddleware()) + { + userGroup.GET("/profile", h.GetProfile) + userGroup.PUT("/profile", h.UpdateProfile) + + // 头像相关 + userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) + userGroup.PUT("/avatar", h.UpdateAvatar) + + // 更换邮箱 + userGroup.POST("/change-email", h.ChangeEmail) + + // Yggdrasil密码相关 + userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) + } +} + +// registerTextureRoutes 注册材质路由 +func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { + textureGroup := v1.Group("/texture") + { + // 公开路由(无需认证) + textureGroup.GET("", h.Search) + textureGroup.GET("/:id", h.Get) + + // 需要认证的路由 + textureAuth := textureGroup.Group("") + textureAuth.Use(middleware.AuthMiddleware()) { - system.GET("/config", GetSystemConfig) + textureAuth.POST("/upload-url", h.GenerateUploadURL) + textureAuth.POST("", h.Create) + textureAuth.PUT("/:id", h.Update) + textureAuth.DELETE("/:id", h.Delete) + textureAuth.POST("/:id/favorite", h.ToggleFavorite) + textureAuth.GET("/my", h.GetUserTextures) + textureAuth.GET("/favorites", h.GetUserFavorites) } } } -// 以下是系统配置相关的占位符函数,待后续实现 +// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) +func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { + profileGroup := v1.Group("/profile") + { + // 公开路由(无需认证) + profileGroup.GET("/:uuid", h.Get) -// GetSystemConfig 获取系统配置 -// @Summary 获取系统配置 -// @Description 获取公开的系统配置信息 -// @Tags system -// @Accept json -// @Produce json -// @Success 200 {object} model.Response "获取成功" -// @Router /api/v1/system/config [get] -func GetSystemConfig(c *gin.Context) { - // TODO: 实现从数据库读取系统配置 - c.JSON(200, model.NewSuccessResponse(gin.H{ - "site_name": "CarrotSkin", - "site_description": "A Minecraft Skin Station", - "registration_enabled": true, - "max_textures_per_user": 100, - "max_profiles_per_user": 5, - })) + // 需要认证的路由 + profileAuth := profileGroup.Group("") + profileAuth.Use(middleware.AuthMiddleware()) + { + profileAuth.POST("/", h.Create) + profileAuth.GET("/", h.List) + profileAuth.PUT("/:uuid", h.Update) + profileAuth.DELETE("/:uuid", h.Delete) + profileAuth.POST("/:uuid/activate", h.SetActive) + } + } +} + +// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本) +func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) { + captchaGroup := v1.Group("/captcha") + { + captchaGroup.GET("/generate", h.Generate) + captchaGroup.POST("/verify", h.Verify) + } +} + +// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本) +func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) { + ygg := v1.Group("/yggdrasil") + { + ygg.GET("", h.GetMetaData) + ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates) + authserver := ygg.Group("/authserver") + { + authserver.POST("/authenticate", h.Authenticate) + authserver.POST("/validate", h.ValidToken) + authserver.POST("/refresh", h.RefreshToken) + authserver.POST("/invalidate", h.InvalidToken) + authserver.POST("/signout", h.SignOut) + } + sessionServer := ygg.Group("/sessionserver") + { + sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID) + sessionServer.POST("/session/minecraft/join", h.JoinServer) + sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer) + } + api := ygg.Group("/api") + profiles := api.Group("/profiles") + { + profiles.POST("/minecraft", h.GetProfilesByName) + } + } +} + +// registerSystemRoutes 注册系统路由 +func registerSystemRoutes(v1 *gin.RouterGroup) { + system := v1.Group("/system") + { + system.GET("/config", func(c *gin.Context) { + // TODO: 实现从数据库读取系统配置 + c.JSON(200, model.NewSuccessResponse(gin.H{ + "site_name": "CarrotSkin", + "site_description": "A Minecraft Skin Station", + "registration_enabled": true, + "max_textures_per_user": 100, + "max_profiles_per_user": 5, + })) + }) + } } diff --git a/internal/handler/routes_di.go b/internal/handler/routes_di.go deleted file mode 100644 index a6da9c8..0000000 --- a/internal/handler/routes_di.go +++ /dev/null @@ -1,193 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/middleware" - "carrotskin/internal/model" - - "github.com/gin-gonic/gin" -) - -// Handlers 集中管理所有Handler -type Handlers struct { - Auth *AuthHandler - User *UserHandler - Texture *TextureHandler - Profile *ProfileHandler - Captcha *CaptchaHandler - Yggdrasil *YggdrasilHandler -} - -// NewHandlers 创建所有Handler实例 -func NewHandlers(c *container.Container) *Handlers { - return &Handlers{ - Auth: NewAuthHandler(c), - User: NewUserHandler(c), - Texture: NewTextureHandler(c), - Profile: NewProfileHandler(c), - Captcha: NewCaptchaHandler(c), - Yggdrasil: NewYggdrasilHandler(c), - } -} - -// RegisterRoutesWithDI 使用依赖注入注册所有路由 -func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { - // 设置Swagger文档 - SetupSwagger(router) - - // 创建Handler实例 - h := NewHandlers(c) - - // API路由组 - v1 := router.Group("/api/v1") - { - // 认证路由(无需JWT) - registerAuthRoutes(v1, h.Auth) - - // 用户路由(需要JWT认证) - registerUserRoutes(v1, h.User) - - // 材质路由 - registerTextureRoutes(v1, h.Texture) - - // 档案路由 - registerProfileRoutesWithDI(v1, h.Profile) - - // 验证码路由 - registerCaptchaRoutesWithDI(v1, h.Captcha) - - // Yggdrasil API路由组 - registerYggdrasilRoutesWithDI(v1, h.Yggdrasil) - - // 系统路由 - registerSystemRoutes(v1) - } -} - -// registerAuthRoutes 注册认证路由 -func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { - authGroup := v1.Group("/auth") - { - authGroup.POST("/register", h.Register) - authGroup.POST("/login", h.Login) - authGroup.POST("/send-code", h.SendVerificationCode) - authGroup.POST("/reset-password", h.ResetPassword) - } -} - -// registerUserRoutes 注册用户路由 -func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { - userGroup := v1.Group("/user") - userGroup.Use(middleware.AuthMiddleware()) - { - userGroup.GET("/profile", h.GetProfile) - userGroup.PUT("/profile", h.UpdateProfile) - - // 头像相关 - userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) - userGroup.PUT("/avatar", h.UpdateAvatar) - - // 更换邮箱 - userGroup.POST("/change-email", h.ChangeEmail) - - // Yggdrasil密码相关 - userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) - } -} - -// registerTextureRoutes 注册材质路由 -func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { - textureGroup := v1.Group("/texture") - { - // 公开路由(无需认证) - textureGroup.GET("", h.Search) - textureGroup.GET("/:id", h.Get) - - // 需要认证的路由 - textureAuth := textureGroup.Group("") - textureAuth.Use(middleware.AuthMiddleware()) - { - textureAuth.POST("/upload-url", h.GenerateUploadURL) - textureAuth.POST("", h.Create) - textureAuth.PUT("/:id", h.Update) - textureAuth.DELETE("/:id", h.Delete) - textureAuth.POST("/:id/favorite", h.ToggleFavorite) - textureAuth.GET("/my", h.GetUserTextures) - textureAuth.GET("/favorites", h.GetUserFavorites) - } - } -} - -// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) -func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { - profileGroup := v1.Group("/profile") - { - // 公开路由(无需认证) - profileGroup.GET("/:uuid", h.Get) - - // 需要认证的路由 - profileAuth := profileGroup.Group("") - profileAuth.Use(middleware.AuthMiddleware()) - { - profileAuth.POST("/", h.Create) - profileAuth.GET("/", h.List) - profileAuth.PUT("/:uuid", h.Update) - profileAuth.DELETE("/:uuid", h.Delete) - profileAuth.POST("/:uuid/activate", h.SetActive) - } - } -} - -// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本) -func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) { - captchaGroup := v1.Group("/captcha") - { - captchaGroup.GET("/generate", h.Generate) - captchaGroup.POST("/verify", h.Verify) - } -} - -// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本) -func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) { - ygg := v1.Group("/yggdrasil") - { - ygg.GET("", h.GetMetaData) - ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates) - authserver := ygg.Group("/authserver") - { - authserver.POST("/authenticate", h.Authenticate) - authserver.POST("/validate", h.ValidToken) - authserver.POST("/refresh", h.RefreshToken) - authserver.POST("/invalidate", h.InvalidToken) - authserver.POST("/signout", h.SignOut) - } - sessionServer := ygg.Group("/sessionserver") - { - sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID) - sessionServer.POST("/session/minecraft/join", h.JoinServer) - sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer) - } - api := ygg.Group("/api") - profiles := api.Group("/profiles") - { - profiles.POST("/minecraft", h.GetProfilesByName) - } - } -} - -// registerSystemRoutes 注册系统路由 -func registerSystemRoutes(v1 *gin.RouterGroup) { - system := v1.Group("/system") - { - system.GET("/config", func(c *gin.Context) { - // TODO: 实现从数据库读取系统配置 - c.JSON(200, model.NewSuccessResponse(gin.H{ - "site_name": "CarrotSkin", - "site_description": "A Minecraft Skin Station", - "registration_enabled": true, - "max_textures_per_user": 100, - "max_profiles_per_user": 5, - })) - }) - } -} diff --git a/internal/handler/texture_handler.go b/internal/handler/texture_handler.go index a139f38..909e287 100644 --- a/internal/handler/texture_handler.go +++ b/internal/handler/texture_handler.go @@ -1,30 +1,32 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/model" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/storage" "strconv" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// GenerateTextureUploadURL 生成材质上传URL -// @Summary 生成材质上传URL -// @Description 生成预签名URL用于上传材质文件 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.GenerateTextureUploadURLRequest true "上传URL请求" -// @Success 200 {object} model.Response "生成成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/texture/upload-url [post] -func GenerateTextureUploadURL(c *gin.Context) { +// TextureHandler 材质处理器(依赖注入版本) +type TextureHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewTextureHandler 创建TextureHandler实例 +func NewTextureHandler(c *container.Container) *TextureHandler { + return &TextureHandler{ + container: c, + logger: c.Logger, + } +} + +// GenerateUploadURL 生成材质上传URL +func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -36,16 +38,20 @@ func GenerateTextureUploadURL(c *gin.Context) { return } - storageClient := storage.MustGetClient() + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + result, err := service.GenerateTextureUploadURL( c.Request.Context(), - storageClient, + h.container.Storage, userID, req.FileName, string(req.TextureType), ) if err != nil { - logger.MustGetLogger().Error("生成材质上传URL失败", + h.logger.Error("生成材质上传URL失败", zap.Int64("user_id", userID), zap.String("file_name", req.FileName), zap.String("texture_type", string(req.TextureType)), @@ -63,18 +69,8 @@ func GenerateTextureUploadURL(c *gin.Context) { }) } -// CreateTexture 创建材质记录 -// @Summary 创建材质记录 -// @Description 文件上传完成后,创建材质记录到数据库 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.CreateTextureRequest true "创建材质请求" -// @Success 200 {object} model.Response "创建成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/texture [post] -func CreateTexture(c *gin.Context) { +// Create 创建材质记录 +func (h *TextureHandler) Create(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -86,13 +82,13 @@ func CreateTexture(c *gin.Context) { return } - maxTextures := service.GetMaxTexturesPerUser() - if err := service.CheckTextureUploadLimit(database.MustGetDB(), userID, maxTextures); err != nil { + maxTextures := h.container.UserService.GetMaxTexturesPerUser() + if err := h.container.TextureService.CheckUploadLimit(userID, maxTextures); err != nil { RespondBadRequest(c, err.Error(), nil) return } - texture, err := service.CreateTexture(database.MustGetDB(), + texture, err := h.container.TextureService.Create( userID, req.Name, req.Description, @@ -104,7 +100,7 @@ func CreateTexture(c *gin.Context) { req.IsSlim, ) if err != nil { - logger.MustGetLogger().Error("创建材质失败", + h.logger.Error("创建材质失败", zap.Int64("user_id", userID), zap.String("name", req.Name), zap.Error(err), @@ -116,24 +112,15 @@ func CreateTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// GetTexture 获取材质详情 -// @Summary 获取材质详情 -// @Description 根据ID获取材质详细信息 -// @Tags texture -// @Accept json -// @Produce json -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "获取成功" -// @Failure 404 {object} model.ErrorResponse "材质不存在" -// @Router /api/v1/texture/{id} [get] -func GetTexture(c *gin.Context) { +// Get 获取材质详情 +func (h *TextureHandler) Get(c *gin.Context) { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { RespondBadRequest(c, "无效的材质ID", err) return } - texture, err := service.GetTextureByID(database.MustGetDB(), id) + texture, err := h.container.TextureService.GetByID(id) if err != nil { RespondNotFound(c, err.Error()) return @@ -142,20 +129,8 @@ func GetTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// SearchTextures 搜索材质 -// @Summary 搜索材质 -// @Description 根据关键词和类型搜索材质 -// @Tags texture -// @Accept json -// @Produce json -// @Param keyword query string false "关键词" -// @Param type query string false "材质类型(SKIN/CAPE)" -// @Param public_only query bool false "只看公开材质" -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "搜索成功" -// @Router /api/v1/texture [get] -func SearchTextures(c *gin.Context) { +// Search 搜索材质 +func (h *TextureHandler) Search(c *gin.Context) { keyword := c.Query("keyword") textureTypeStr := c.Query("type") publicOnly := c.Query("public_only") == "true" @@ -171,9 +146,9 @@ func SearchTextures(c *gin.Context) { textureType = model.TextureTypeCape } - textures, total, err := service.SearchTextures(database.MustGetDB(), keyword, textureType, publicOnly, page, pageSize) + textures, total, err := h.container.TextureService.Search(keyword, textureType, publicOnly, page, pageSize) if err != nil { - logger.MustGetLogger().Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) + h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) RespondServerError(c, "搜索材质失败", err) return } @@ -181,19 +156,8 @@ func SearchTextures(c *gin.Context) { c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) } -// UpdateTexture 更新材质 -// @Summary 更新材质 -// @Description 更新材质信息(仅上传者可操作) -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Param request body types.UpdateTextureRequest true "更新材质请求" -// @Success 200 {object} model.Response "更新成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/texture/{id} [put] -func UpdateTexture(c *gin.Context) { +// Update 更新材质 +func (h *TextureHandler) Update(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -211,9 +175,9 @@ func UpdateTexture(c *gin.Context) { return } - texture, err := service.UpdateTexture(database.MustGetDB(), textureID, userID, req.Name, req.Description, req.IsPublic) + texture, err := h.container.TextureService.Update(textureID, userID, req.Name, req.Description, req.IsPublic) if err != nil { - logger.MustGetLogger().Error("更新材质失败", + h.logger.Error("更新材质失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -225,18 +189,8 @@ func UpdateTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// DeleteTexture 删除材质 -// @Summary 删除材质 -// @Description 删除材质(软删除,仅上传者可操作) -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "删除成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/texture/{id} [delete] -func DeleteTexture(c *gin.Context) { +// Delete 删除材质 +func (h *TextureHandler) Delete(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -248,8 +202,8 @@ func DeleteTexture(c *gin.Context) { return } - if err := service.DeleteTexture(database.MustGetDB(), textureID, userID); err != nil { - logger.MustGetLogger().Error("删除材质失败", + if err := h.container.TextureService.Delete(textureID, userID); err != nil { + h.logger.Error("删除材质失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -262,16 +216,7 @@ func DeleteTexture(c *gin.Context) { } // ToggleFavorite 切换收藏状态 -// @Summary 切换收藏状态 -// @Description 收藏或取消收藏材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "切换成功" -// @Router /api/v1/texture/{id}/favorite [post] -func ToggleFavorite(c *gin.Context) { +func (h *TextureHandler) ToggleFavorite(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -283,9 +228,9 @@ func ToggleFavorite(c *gin.Context) { return } - isFavorited, err := service.ToggleTextureFavorite(database.MustGetDB(), userID, textureID) + isFavorited, err := h.container.TextureService.ToggleFavorite(userID, textureID) if err != nil { - logger.MustGetLogger().Error("切换收藏状态失败", + h.logger.Error("切换收藏状态失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -298,17 +243,7 @@ func ToggleFavorite(c *gin.Context) { } // GetUserTextures 获取用户上传的材质列表 -// @Summary 获取用户上传的材质列表 -// @Description 获取当前用户上传的所有材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "获取成功" -// @Router /api/v1/texture/my [get] -func GetUserTextures(c *gin.Context) { +func (h *TextureHandler) GetUserTextures(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -317,9 +252,9 @@ func GetUserTextures(c *gin.Context) { page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - textures, total, err := service.GetUserTextures(database.MustGetDB(), userID, page, pageSize) + textures, total, err := h.container.TextureService.GetByUserID(userID, page, pageSize) if err != nil { - logger.MustGetLogger().Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) + h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) RespondServerError(c, "获取材质列表失败", err) return } @@ -328,17 +263,7 @@ func GetUserTextures(c *gin.Context) { } // GetUserFavorites 获取用户收藏的材质列表 -// @Summary 获取用户收藏的材质列表 -// @Description 获取当前用户收藏的所有材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "获取成功" -// @Router /api/v1/texture/favorites [get] -func GetUserFavorites(c *gin.Context) { +func (h *TextureHandler) GetUserFavorites(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -347,9 +272,9 @@ func GetUserFavorites(c *gin.Context) { page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID, page, pageSize) + textures, total, err := h.container.TextureService.GetUserFavorites(userID, page, pageSize) if err != nil { - logger.MustGetLogger().Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) + h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) RespondServerError(c, "获取收藏列表失败", err) return } diff --git a/internal/handler/texture_handler_di.go b/internal/handler/texture_handler_di.go deleted file mode 100644 index 26bd558..0000000 --- a/internal/handler/texture_handler_di.go +++ /dev/null @@ -1,285 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/model" - "carrotskin/internal/service" - "carrotskin/internal/types" - "strconv" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// TextureHandler 材质处理器(依赖注入版本) -type TextureHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewTextureHandler 创建TextureHandler实例 -func NewTextureHandler(c *container.Container) *TextureHandler { - return &TextureHandler{ - container: c, - logger: c.Logger, - } -} - -// GenerateUploadURL 生成材质上传URL -func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.GenerateTextureUploadURLRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if h.container.Storage == nil { - RespondServerError(c, "存储服务不可用", nil) - return - } - - result, err := service.GenerateTextureUploadURL( - c.Request.Context(), - h.container.Storage, - userID, - req.FileName, - string(req.TextureType), - ) - if err != nil { - h.logger.Error("生成材质上传URL失败", - zap.Int64("user_id", userID), - zap.String("file_name", req.FileName), - zap.String("texture_type", string(req.TextureType)), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.GenerateTextureUploadURLResponse{ - PostURL: result.PostURL, - FormData: result.FormData, - TextureURL: result.FileURL, - ExpiresIn: 900, - }) -} - -// Create 创建材质记录 -func (h *TextureHandler) Create(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.CreateTextureRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - maxTextures := service.GetMaxTexturesPerUser() - if err := service.CheckTextureUploadLimit(h.container.DB, userID, maxTextures); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - texture, err := service.CreateTexture(h.container.DB, - userID, - req.Name, - req.Description, - string(req.Type), - req.URL, - req.Hash, - req.Size, - req.IsPublic, - req.IsSlim, - ) - if err != nil { - h.logger.Error("创建材质失败", - zap.Int64("user_id", userID), - zap.String("name", req.Name), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Get 获取材质详情 -func (h *TextureHandler) Get(c *gin.Context) { - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - texture, err := service.GetTextureByID(h.container.DB, id) - if err != nil { - RespondNotFound(c, err.Error()) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Search 搜索材质 -func (h *TextureHandler) Search(c *gin.Context) { - keyword := c.Query("keyword") - textureTypeStr := c.Query("type") - publicOnly := c.Query("public_only") == "true" - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - var textureType model.TextureType - switch textureTypeStr { - case "SKIN": - textureType = model.TextureTypeSkin - case "CAPE": - textureType = model.TextureTypeCape - } - - textures, total, err := service.SearchTextures(h.container.DB, keyword, textureType, publicOnly, page, pageSize) - if err != nil { - h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) - RespondServerError(c, "搜索材质失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - -// Update 更新材质 -func (h *TextureHandler) Update(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - var req types.UpdateTextureRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - texture, err := service.UpdateTexture(h.container.DB, textureID, userID, req.Name, req.Description, req.IsPublic) - if err != nil { - h.logger.Error("更新材质失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondForbidden(c, err.Error()) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Delete 删除材质 -func (h *TextureHandler) Delete(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - if err := service.DeleteTexture(h.container.DB, textureID, userID); err != nil { - h.logger.Error("删除材质失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondForbidden(c, err.Error()) - return - } - - RespondSuccess(c, nil) -} - -// ToggleFavorite 切换收藏状态 -func (h *TextureHandler) ToggleFavorite(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - isFavorited, err := service.ToggleTextureFavorite(h.container.DB, userID, textureID) - if err != nil { - h.logger.Error("切换收藏状态失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, map[string]bool{"is_favorited": isFavorited}) -} - -// GetUserTextures 获取用户上传的材质列表 -func (h *TextureHandler) GetUserTextures(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - textures, total, err := service.GetUserTextures(h.container.DB, userID, page, pageSize) - if err != nil { - h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondServerError(c, "获取材质列表失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - -// GetUserFavorites 获取用户收藏的材质列表 -func (h *TextureHandler) GetUserFavorites(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - textures, total, err := service.GetUserTextureFavorites(h.container.DB, userID, page, pageSize) - if err != nil { - h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondServerError(c, "获取收藏列表失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - - diff --git a/internal/handler/user_handler.go b/internal/handler/user_handler.go index c6144a4..406596b 100644 --- a/internal/handler/user_handler.go +++ b/internal/handler/user_handler.go @@ -1,36 +1,38 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" - "carrotskin/pkg/storage" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// GetUserProfile 获取用户信息 -// @Summary 获取用户信息 -// @Description 获取当前登录用户的详细信息 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "获取成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Router /api/v1/user/profile [get] -func GetUserProfile(c *gin.Context) { +// UserHandler 用户处理器(依赖注入版本) +type UserHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewUserHandler 创建UserHandler实例 +func NewUserHandler(c *container.Container) *UserHandler { + return &UserHandler{ + container: c, + logger: c.Logger, + } +} + +// GetProfile 获取用户信息 +func (h *UserHandler) GetProfile(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { - logger.MustGetLogger().Error("获取用户信息失败", + h.logger.Error("获取用户信息失败", zap.Int64("user_id", userID), zap.Error(err), ) @@ -41,22 +43,8 @@ func GetUserProfile(c *gin.Context) { RespondSuccess(c, UserToUserInfo(user)) } -// UpdateUserProfile 更新用户信息 -// @Summary 更新用户信息 -// @Description 更新当前登录用户的头像和密码(修改邮箱请使用 /change-email 接口) -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.UpdateUserRequest true "更新信息(修改密码时需同时提供old_password和new_password)" -// @Success 200 {object} model.Response{data=types.UserInfo} "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 404 {object} model.ErrorResponse "用户不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" -// @Router /api/v1/user/profile [put] -func UpdateUserProfile(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +// UpdateProfile 更新用户信息 +func (h *UserHandler) UpdateProfile(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -68,7 +56,7 @@ func UpdateUserProfile(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -81,32 +69,31 @@ func UpdateUserProfile(c *gin.Context) { return } - if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { - loggerInstance.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) + if err := h.container.UserService.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil { + h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } - loggerInstance.Info("用户修改密码成功", zap.Int64("user_id", userID)) + h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) } // 更新头像 if req.Avatar != "" { - // 验证头像 URL 是否来自允许的域名 - if err := service.ValidateAvatarURL(req.Avatar); err != nil { + if err := h.container.UserService.ValidateAvatarURL(req.Avatar); err != nil { RespondBadRequest(c, err.Error(), nil) return } user.Avatar = req.Avatar - if err := service.UpdateUserInfo(user); err != nil { - loggerInstance.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) + if err := h.container.UserService.UpdateInfo(user); err != nil { + h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) RespondServerError(c, "更新失败", err) return } } // 重新获取更新后的用户信息 - updatedUser, err := service.GetUserByID(userID) + updatedUser, err := h.container.UserService.GetByID(userID) if err != nil || updatedUser == nil { RespondNotFound(c, "用户不存在") return @@ -116,17 +103,7 @@ func UpdateUserProfile(c *gin.Context) { } // GenerateAvatarUploadURL 生成头像上传URL -// @Summary 生成头像上传URL -// @Description 生成预签名URL用于上传用户头像 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.GenerateAvatarUploadURLRequest true "文件名" -// @Success 200 {object} model.Response "生成成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/user/avatar/upload-url [post] -func GenerateAvatarUploadURL(c *gin.Context) { +func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -138,10 +115,14 @@ func GenerateAvatarUploadURL(c *gin.Context) { return } - storageClient := storage.MustGetClient() - result, err := service.GenerateAvatarUploadURL(c.Request.Context(), storageClient, userID, req.FileName) + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) if err != nil { - logger.MustGetLogger().Error("生成头像上传URL失败", + h.logger.Error("生成头像上传URL失败", zap.Int64("user_id", userID), zap.String("file_name", req.FileName), zap.Error(err), @@ -159,17 +140,7 @@ func GenerateAvatarUploadURL(c *gin.Context) { } // UpdateAvatar 更新头像URL -// @Summary 更新头像URL -// @Description 上传完成后更新用户的头像URL到数据库 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param avatar_url query string true "头像URL" -// @Success 200 {object} model.Response "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/user/avatar [put] -func UpdateAvatar(c *gin.Context) { +func (h *UserHandler) UpdateAvatar(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -181,13 +152,13 @@ func UpdateAvatar(c *gin.Context) { return } - if err := service.ValidateAvatarURL(avatarURL); err != nil { + if err := h.container.UserService.ValidateAvatarURL(avatarURL); err != nil { RespondBadRequest(c, err.Error(), nil) return } - if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { - logger.MustGetLogger().Error("更新头像失败", + if err := h.container.UserService.UpdateAvatar(userID, avatarURL); err != nil { + h.logger.Error("更新头像失败", zap.Int64("user_id", userID), zap.String("avatar_url", avatarURL), zap.Error(err), @@ -196,7 +167,7 @@ func UpdateAvatar(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -206,19 +177,7 @@ func UpdateAvatar(c *gin.Context) { } // ChangeEmail 更换邮箱 -// @Summary 更换邮箱 -// @Description 通过验证码更换用户邮箱 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.ChangeEmailRequest true "更换邮箱请求" -// @Success 200 {object} model.Response{data=types.UserInfo} "更换成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Router /api/v1/user/change-email [post] -func ChangeEmail(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +func (h *UserHandler) ChangeEmail(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -230,15 +189,14 @@ func ChangeEmail(c *gin.Context) { return } - redisClient := redis.MustGetClient() - if err := service.VerifyCode(c.Request.Context(), redisClient, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { + h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } - if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { - loggerInstance.Error("更换邮箱失败", + if err := h.container.UserService.ChangeEmail(userID, req.NewEmail); err != nil { + h.logger.Error("更换邮箱失败", zap.Int64("user_id", userID), zap.String("new_email", req.NewEmail), zap.Error(err), @@ -247,7 +205,7 @@ func ChangeEmail(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -257,31 +215,19 @@ func ChangeEmail(c *gin.Context) { } // ResetYggdrasilPassword 重置Yggdrasil密码 -// @Summary 重置Yggdrasil密码 -// @Description 重置当前用户的Yggdrasil密码并返回新密码 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "重置成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" -// @Router /api/v1/user/yggdrasil-password/reset [post] -func ResetYggdrasilPassword(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - db := database.MustGetDB() - newPassword, err := service.ResetYggdrasilPassword(db, userID) + newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) if err != nil { - loggerInstance.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) + h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) RespondServerError(c, "重置Yggdrasil密码失败", nil) return } - loggerInstance.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) + h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) RespondSuccess(c, gin.H{"password": newPassword}) } diff --git a/internal/handler/user_handler_di.go b/internal/handler/user_handler_di.go deleted file mode 100644 index 91e8a5a..0000000 --- a/internal/handler/user_handler_di.go +++ /dev/null @@ -1,233 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// UserHandler 用户处理器(依赖注入版本) -type UserHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewUserHandler 创建UserHandler实例 -func NewUserHandler(c *container.Container) *UserHandler { - return &UserHandler{ - container: c, - logger: c.Logger, - } -} - -// GetProfile 获取用户信息 -func (h *UserHandler) GetProfile(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - h.logger.Error("获取用户信息失败", - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// UpdateProfile 更新用户信息 -func (h *UserHandler) UpdateProfile(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.UpdateUserRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - // 处理密码修改 - if req.NewPassword != "" { - if req.OldPassword == "" { - RespondBadRequest(c, "修改密码需要提供原密码", nil) - return - } - - if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { - h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) - } - - // 更新头像 - if req.Avatar != "" { - if err := service.ValidateAvatarURL(req.Avatar); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - user.Avatar = req.Avatar - if err := service.UpdateUserInfo(user); err != nil { - h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) - RespondServerError(c, "更新失败", err) - return - } - } - - // 重新获取更新后的用户信息 - updatedUser, err := service.GetUserByID(userID) - if err != nil || updatedUser == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(updatedUser)) -} - -// GenerateAvatarUploadURL 生成头像上传URL -func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.GenerateAvatarUploadURLRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if h.container.Storage == nil { - RespondServerError(c, "存储服务不可用", nil) - return - } - - result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) - if err != nil { - h.logger.Error("生成头像上传URL失败", - zap.Int64("user_id", userID), - zap.String("file_name", req.FileName), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.GenerateAvatarUploadURLResponse{ - PostURL: result.PostURL, - FormData: result.FormData, - AvatarURL: result.FileURL, - ExpiresIn: 900, - }) -} - -// UpdateAvatar 更新头像URL -func (h *UserHandler) UpdateAvatar(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - avatarURL := c.Query("avatar_url") - if avatarURL == "" { - RespondBadRequest(c, "头像URL不能为空", nil) - return - } - - if err := service.ValidateAvatarURL(avatarURL); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { - h.logger.Error("更新头像失败", - zap.Int64("user_id", userID), - zap.String("avatar_url", avatarURL), - zap.Error(err), - ) - RespondServerError(c, "更新头像失败", err) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// ChangeEmail 更换邮箱 -func (h *UserHandler) ChangeEmail(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.ChangeEmailRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { - h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { - h.logger.Error("更换邮箱失败", - zap.Int64("user_id", userID), - zap.String("new_email", req.NewEmail), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// ResetYggdrasilPassword 重置Yggdrasil密码 -func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) - if err != nil { - h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) - RespondServerError(c, "重置Yggdrasil密码失败", nil) - return - } - - h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) - RespondSuccess(c, gin.H{"password": newPassword}) -} diff --git a/internal/handler/yggdrasil_handler.go b/internal/handler/yggdrasil_handler.go index acbf7b2..2ee21dc 100644 --- a/internal/handler/yggdrasil_handler.go +++ b/internal/handler/yggdrasil_handler.go @@ -2,11 +2,9 @@ package handler import ( "bytes" + "carrotskin/internal/container" "carrotskin/internal/model" "carrotskin/internal/service" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" "carrotskin/pkg/utils" "io" "net/http" @@ -111,6 +109,7 @@ type ( Password string `json:"password" binding:"required"` } + // JoinServerRequest 加入服务器请求 JoinServerRequest struct { ServerID string `json:"serverId" binding:"required"` AccessToken string `json:"accessToken" binding:"required"` @@ -138,6 +137,7 @@ type ( } ) +// APIResponse API响应 type APIResponse struct { Status int `json:"status"` Data interface{} `json:"data"` @@ -153,38 +153,47 @@ func standardResponse(c *gin.Context, status int, data interface{}, err interfac }) } -// Authenticate 用户认证 -func Authenticate(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() +// YggdrasilHandler Yggdrasil API处理器 +type YggdrasilHandler struct { + container *container.Container + logger *zap.Logger +} - // 读取并保存原始请求体,以便多次读取 +// NewYggdrasilHandler 创建YggdrasilHandler实例 +func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler { + return &YggdrasilHandler{ + container: c, + logger: c.Logger, + } +} + +// Authenticate 用户认证 +func (h *YggdrasilHandler) Authenticate(c *gin.Context) { rawData, err := io.ReadAll(c.Request.Body) if err != nil { - loggerInstance.Error("[ERROR] 读取请求体失败: ", zap.Error(err)) + h.logger.Error("读取请求体失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) return } c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData)) - // 绑定JSON数据到请求结构体 var request AuthenticateRequest if err = c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析认证请求失败: ", zap.Error(err)) + h.logger.Error("解析认证请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 根据标识符类型(邮箱或用户名)获取用户 var userId int64 var profile *model.Profile var UUID string + if emailRegex.MatchString(request.Identifier) { - userId, err = service.GetUserIDByEmail(db, request.Identifier) + userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier) } else { - profile, err = service.GetProfileByProfileName(db, request.Identifier) + profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier) if err != nil { - loggerInstance.Error("[ERROR] 用户名不存在: ", zap.String("标识符", request.Identifier), zap.Error(err)) + h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } @@ -193,163 +202,146 @@ func Authenticate(c *gin.Context) { } if err != nil { - loggerInstance.Warn("[WARN] 认证失败: 用户不存在", - zap.String("标识符:", request.Identifier), - zap.Error(err)) - + h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err)) + c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"}) return } - // 验证密码 - err = service.VerifyPassword(db, request.Password, userId) - if err != nil { - loggerInstance.Warn("[WARN] 认证失败:", zap.Error(err)) + if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil { + h.logger.Warn("认证失败: 密码错误", zap.Error(err)) c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword}) return } - // 生成新令牌 - selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(db, loggerInstance, userId, UUID, request.ClientToken) + + selectedProfile, availableProfiles, accessToken, clientToken, err := h.container.TokenService.Create(userId, UUID, request.ClientToken) if err != nil { - loggerInstance.Error("[ERROR] 生成令牌失败:", zap.Error(err), zap.Any("用户ID:", userId)) + h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - user, err := service.GetUserByID(userId) + user, err := h.container.UserService.GetByID(userId) if err != nil { - loggerInstance.Error("[ERROR] id查找错误:", zap.Error(err), zap.Any("ID:", userId)) + h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId)) } - // 处理可用的配置文件 - redisClient := redis.MustGetClient() + availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles)) - for _, profile := range availableProfiles { - availableProfilesData = append(availableProfilesData, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + for _, p := range availableProfiles { + availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p)) } + response := AuthenticateResponse{ AccessToken: accessToken, ClientToken: clientToken, AvailableProfiles: availableProfilesData, } + if selectedProfile != nil { - response.SelectedProfile = service.SerializeProfile(db, loggerInstance, redisClient, *selectedProfile) - } - if request.RequestUser { - // 使用 SerializeUser 来正确处理 Properties 字段 - response.User = service.SerializeUser(loggerInstance, user, UUID) + response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile) } - // 返回认证响应 - loggerInstance.Info("[INFO] 用户认证成功", zap.Any("用户ID:", userId)) + if request.RequestUser && user != nil { + response.User = service.SerializeUser(h.logger, user, UUID) + } + + h.logger.Info("用户认证成功", zap.Int64("userId", userId)) c.JSON(http.StatusOK, response) } // ValidToken 验证令牌 -func ValidToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) ValidToken(c *gin.Context) { var request ValidTokenRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析验证令牌请求失败: ", zap.Error(err)) + h.logger.Error("解析验证令牌请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 验证令牌 - if service.ValidToken(db, request.AccessToken, request.ClientToken) { - loggerInstance.Info("[INFO] 令牌验证成功", zap.Any("访问令牌:", request.AccessToken)) + + if h.container.TokenService.Validate(request.AccessToken, request.ClientToken) { + h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken)) c.JSON(http.StatusNoContent, gin.H{"valid": true}) } else { - loggerInstance.Warn("[WARN] 令牌验证失败", zap.Any("访问令牌:", request.AccessToken)) + h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken)) c.JSON(http.StatusForbidden, gin.H{"valid": false}) } } // RefreshToken 刷新令牌 -func RefreshToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) RefreshToken(c *gin.Context) { var request RefreshRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析刷新令牌请求失败: ", zap.Error(err)) + h.logger.Error("解析刷新令牌请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 获取用户ID和用户信息 - UUID, err := service.GetUUIDByAccessToken(db, request.AccessToken) + UUID, err := h.container.TokenService.GetUUIDByAccessToken(request.AccessToken) if err != nil { - loggerInstance.Warn("[WARN] 刷新令牌失败: 无效的访问令牌", zap.Any("令牌:", request.AccessToken), zap.Error(err)) + h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - userID, _ := service.GetUserIDByAccessToken(db, request.AccessToken) - // 格式化UUID 这里是因为HMCL的传入参数是HEX格式,为了兼容HMCL,在此做处理 + + userID, _ := h.container.TokenService.GetUserIDByAccessToken(request.AccessToken) UUID = utils.FormatUUID(UUID) - profile, err := service.GetProfileByUUID(db, UUID) + profile, err := h.container.ProfileService.GetByUUID(UUID) if err != nil { - loggerInstance.Error("[ERROR] 刷新令牌失败: 无法获取用户信息 错误: ", zap.Error(err)) + h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 准备响应数据 var profileData map[string]interface{} var userData map[string]interface{} var profileID string - // 处理选定的配置文件 if request.SelectedProfile != nil { - // 验证profileID是否存在 profileIDValue, ok := request.SelectedProfile["id"] if !ok { - loggerInstance.Error("[ERROR] 刷新令牌失败: 缺少配置文件ID", zap.Any("ID:", userID)) + h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"}) return } - // 类型断言 profileID, ok = profileIDValue.(string) if !ok { - loggerInstance.Error("[ERROR] 刷新令牌失败: 配置文件ID类型错误 ", zap.Any("用户ID:", userID)) + h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"}) return } - // 格式化profileID profileID = utils.FormatUUID(profileID) - // 验证配置文件所属用户 if profile.UserID != userID { - loggerInstance.Warn("[WARN] 刷新令牌失败: 用户不匹配 ", zap.Any("用户ID:", userID), zap.Any("配置文件用户ID:", profile.UserID)) + h.logger.Warn("刷新令牌失败: 用户不匹配", + zap.Int64("userId", userID), + zap.Int64("profileUserId", profile.UserID), + ) c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch}) return } - profileData = service.SerializeProfile(db, loggerInstance, redis.MustGetClient(), *profile) - } - user, _ := service.GetUserByID(userID) - // 添加用户信息(如果请求了) - if request.RequestUser { - userData = service.SerializeUser(loggerInstance, user, UUID) + profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile) } - // 刷新令牌 - newAccessToken, newClientToken, err := service.RefreshToken(db, loggerInstance, + user, _ := h.container.UserService.GetByID(userID) + if request.RequestUser && user != nil { + userData = service.SerializeUser(h.logger, user, UUID) + } + + newAccessToken, newClientToken, err := h.container.TokenService.Refresh( request.AccessToken, request.ClientToken, profileID, ) if err != nil { - loggerInstance := logger.MustGetLogger() - loggerInstance.Error("[ERROR] 刷新令牌失败: ", zap.Error(err), zap.Any("用户ID: ", userID)) + h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 返回响应 - loggerInstance.Info("[INFO] 刷新令牌成功", zap.Any("用户ID:", userID)) + h.logger.Info("刷新令牌成功", zap.Int64("userId", userID)) c.JSON(http.StatusOK, RefreshResponse{ AccessToken: newAccessToken, ClientToken: newClientToken, @@ -359,231 +351,177 @@ func RefreshToken(c *gin.Context) { } // InvalidToken 使令牌失效 -func InvalidToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) InvalidToken(c *gin.Context) { var request ValidTokenRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析使令牌失效请求失败: ", zap.Error(err)) + h.logger.Error("解析使令牌失效请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 使令牌失效 - service.InvalidToken(db, loggerInstance, request.AccessToken) - loggerInstance.Info("[INFO] 令牌已使失效", zap.Any("访问令牌:", request.AccessToken)) + + h.container.TokenService.Invalidate(request.AccessToken) + h.logger.Info("令牌已失效", zap.String("token", request.AccessToken)) c.JSON(http.StatusNoContent, gin.H{}) } // SignOut 用户登出 -func SignOut(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) SignOut(c *gin.Context) { var request SignOutRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析登出请求失败: %v", zap.Error(err)) + h.logger.Error("解析登出请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 验证邮箱格式 if !emailRegex.MatchString(request.Email) { - loggerInstance.Warn("[WARN] 登出失败: 邮箱格式不正确 ", zap.Any(" ", request.Email)) + h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email)) c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat}) return } - // 通过邮箱获取用户 - user, err := service.GetUserByEmail(request.Email) - if err != nil { - loggerInstance.Warn( - "登出失败: 用户不存在", - zap.String("邮箱", request.Email), - zap.Error(err), - ) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + user, err := h.container.UserService.GetByEmail(request.Email) + if err != nil || user == nil { + h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"}) return } - // 验证密码 - if err := service.VerifyPassword(db, request.Password, user.ID); err != nil { - loggerInstance.Warn("[WARN] 登出失败: 密码错误", zap.Any("用户ID:", user.ID)) + + if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil { + h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID)) c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword}) return } - // 使该用户的所有令牌失效 - service.InvalidUserTokens(db, loggerInstance, user.ID) - loggerInstance.Info("[INFO] 用户登出成功", zap.Any("用户ID:", user.ID)) + h.container.TokenService.InvalidateUserTokens(user.ID) + h.logger.Info("用户登出成功", zap.Int64("userId", user.ID)) c.JSON(http.StatusNoContent, gin.H{"valid": true}) } -func GetProfileByUUID(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - - // 获取并格式化UUID +// GetProfileByUUID 根据UUID获取档案 +func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) { uuid := utils.FormatUUID(c.Param("uuid")) - loggerInstance.Info("[INFO] 接收到获取配置文件请求", zap.Any("UUID:", uuid)) + h.logger.Info("获取配置文件请求", zap.String("uuid", uuid)) - // 获取配置文件 - profile, err := service.GetProfileByUUID(db, uuid) + profile, err := h.container.ProfileService.GetByUUID(uuid) if err != nil { - loggerInstance.Error("[ERROR] 获取配置文件失败:", zap.Error(err), zap.String("UUID:", uuid)) + h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid)) standardResponse(c, http.StatusInternalServerError, nil, err.Error()) return } - // 返回配置文件信息 - loggerInstance.Info("[INFO] 成功获取配置文件", zap.String("UUID:", uuid), zap.String("名称:", profile.Name)) - c.JSON(http.StatusOK, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name)) + c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) } -func JoinServer(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - +// JoinServer 加入服务器 +func (h *YggdrasilHandler) JoinServer(c *gin.Context) { var request JoinServerRequest clientIP := c.ClientIP() - // 解析请求参数 if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error( - "解析加入服务器请求失败", - zap.Error(err), - zap.String("IP", clientIP), - ) + h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP)) standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest) return } - loggerInstance.Info( - "收到加入服务器请求", - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + h.logger.Info("收到加入服务器请求", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) - // 处理加入服务器请求 - if err := service.JoinServer(db, loggerInstance, redisClient, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { - loggerInstance.Error( - "加入服务器失败", + if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { + h.logger.Error("加入服务器失败", zap.Error(err), - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed) return } - // 加入成功,返回204状态码 - loggerInstance.Info( - "加入服务器成功", - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + h.logger.Info("加入服务器成功", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) c.Status(http.StatusNoContent) } -func HasJoinedServer(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - +// HasJoinedServer 验证玩家是否已加入服务器 +func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) { clientIP, _ := c.GetQuery("ip") - // 获取并验证服务器ID参数 serverID, exists := c.GetQuery("serverId") if !exists || serverID == "" { - loggerInstance.Warn("[WARN] 缺少服务器ID参数", zap.Any("IP:", clientIP)) + h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP)) standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired) return } - // 获取并验证用户名参数 username, exists := c.GetQuery("username") if !exists || username == "" { - loggerInstance.Warn("[WARN] 缺少用户名参数", zap.Any("服务器ID:", serverID), zap.Any("IP:", clientIP)) + h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP)) standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired) return } - loggerInstance.Info("[INFO] 收到会话验证请求", zap.Any("服务器ID:", serverID), zap.Any("用户名: ", username), zap.Any("IP: ", clientIP)) + h.logger.Info("收到会话验证请求", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("ip", clientIP), + ) - // 验证玩家是否已加入服务器 - if err := service.HasJoinedServer(loggerInstance, redisClient, serverID, username, clientIP); err != nil { - loggerInstance.Warn("[WARN] 会话验证失败", + if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil { + h.logger.Warn("会话验证失败", zap.Error(err), - zap.String("serverID", serverID), + zap.String("serverId", serverID), zap.String("username", username), - zap.String("clientIP", clientIP), + zap.String("ip", clientIP), ) standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed) return } - profile, err := service.GetProfileByUUID(db, username) + profile, err := h.container.ProfileService.GetByUUID(username) if err != nil { - loggerInstance.Error("[ERROR] 获取用户配置文件失败: %v - 用户名: %s", - zap.Error(err), // 错误详情(zap 原生支持,保留错误链) - zap.String("username", username), // 结构化存储用户名(便于检索) - ) + h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username)) standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound) return } - // 返回玩家配置文件 - loggerInstance.Info("[INFO] 会话验证成功 - 服务器ID: %s, 用户名: %s, UUID: %s", - zap.String("serverID", serverID), // 结构化存储服务器ID - zap.String("username", username), // 结构化存储用户名 - zap.String("UUID", profile.UUID), // 结构化存储UUID + h.logger.Info("会话验证成功", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("uuid", profile.UUID), ) - c.JSON(200, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) } -func GetProfilesByName(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +// GetProfilesByName 批量获取配置文件 +func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) { var names []string - // 解析请求参数 if err := c.ShouldBindJSON(&names); err != nil { - loggerInstance.Error("[ERROR] 解析名称数组请求失败: ", - zap.Error(err), - ) + h.logger.Error("解析名称数组请求失败", zap.Error(err)) standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams) return } - loggerInstance.Info("[INFO] 接收到批量获取配置文件请求", - zap.Int("名称数量:", len(names)), // 结构化存储名称数量 - ) - // 批量获取配置文件 - profiles, err := service.GetProfilesDataByNames(db, names) + h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names))) + + profiles, err := h.container.ProfileService.GetByNames(names) if err != nil { - loggerInstance.Error("[ERROR] 获取配置文件失败: ", - zap.Error(err), - ) + h.logger.Error("获取配置文件失败", zap.Error(err)) } - // 改造:zap 兼容原有 INFO 日志格式 - loggerInstance.Info("[INFO] 成功获取配置文件", - zap.Int("请求名称数:", len(names)), - zap.Int("返回结果数: ", len(profiles)), - ) - + h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles))) c.JSON(http.StatusOK, profiles) } -func GetMetaData(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - +// GetMetaData 获取Yggdrasil元数据 +func (h *YggdrasilHandler) GetMetaData(c *gin.Context) { meta := gin.H{ "implementationName": "CellAuth", "implementationVersion": "0.0.1", @@ -595,26 +533,25 @@ func GetMetaData(c *gin.Context) { "feature.non_email_login": true, "feature.enable_profile_key": true, } + skinDomains := []string{".hitwh.games", ".littlelan.cn"} - signature, err := service.GetPublicKeyFromRedisFunc(loggerInstance, redisClient) + signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis) if err != nil { - loggerInstance.Error("[ERROR] 获取公钥失败: ", zap.Error(err)) + h.logger.Error("获取公钥失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - loggerInstance.Info("[INFO] 提供元数据") - c.JSON(http.StatusOK, gin.H{"meta": meta, + h.logger.Info("提供元数据") + c.JSON(http.StatusOK, gin.H{ + "meta": meta, "skinDomains": skinDomains, - "signaturePublickey": signature}) + "signaturePublickey": signature, + }) } -func GetPlayerCertificates(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - - var uuid string +// GetPlayerCertificates 获取玩家证书 +func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"}) @@ -622,39 +559,36 @@ func GetPlayerCertificates(c *gin.Context) { return } - // 检查是否以 Bearer 开头并提取 sessionID bearerPrefix := "Bearer " if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) c.Abort() return } + tokenID := authHeader[len(bearerPrefix):] if tokenID == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) c.Abort() return } - var err error - uuid, err = service.GetUUIDByAccessToken(db, tokenID) + uuid, err := h.container.TokenService.GetUUIDByAccessToken(tokenID) if uuid == "" { - loggerInstance.Error("[ERROR] 获取玩家UUID失败: ", zap.Error(err)) + h.logger.Error("获取玩家UUID失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - // 格式化UUID uuid = utils.FormatUUID(uuid) - // 生成玩家证书 - certificate, err := service.GeneratePlayerCertificate(db, loggerInstance, redisClient, uuid) + certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid) if err != nil { - loggerInstance.Error("[ERROR] 生成玩家证书失败: ", zap.Error(err)) + h.logger.Error("生成玩家证书失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - loggerInstance.Info("[INFO] 成功生成玩家证书") + h.logger.Info("成功生成玩家证书") c.JSON(http.StatusOK, certificate) } diff --git a/internal/handler/yggdrasil_handler_di.go b/internal/handler/yggdrasil_handler_di.go deleted file mode 100644 index c4fb8f3..0000000 --- a/internal/handler/yggdrasil_handler_di.go +++ /dev/null @@ -1,454 +0,0 @@ -package handler - -import ( - "bytes" - "carrotskin/internal/container" - "carrotskin/internal/model" - "carrotskin/internal/service" - "carrotskin/pkg/utils" - "io" - "net/http" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// YggdrasilHandler Yggdrasil API处理器 -type YggdrasilHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewYggdrasilHandler 创建YggdrasilHandler实例 -func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler { - return &YggdrasilHandler{ - container: c, - logger: c.Logger, - } -} - -// Authenticate 用户认证 -func (h *YggdrasilHandler) Authenticate(c *gin.Context) { - rawData, err := io.ReadAll(c.Request.Body) - if err != nil { - h.logger.Error("读取请求体失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) - return - } - c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData)) - - var request AuthenticateRequest - if err = c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析认证请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - var userId int64 - var profile *model.Profile - var UUID string - - if emailRegex.MatchString(request.Identifier) { - userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier) - } else { - profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier) - if err != nil { - h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - userId = profile.UserID - UUID = profile.UUID - } - - if err != nil { - h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"}) - return - } - - if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil { - h.logger.Warn("认证失败: 密码错误", zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword}) - return - } - - selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(h.container.DB, h.logger, userId, UUID, request.ClientToken) - if err != nil { - h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - user, err := service.GetUserByID(userId) - if err != nil { - h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId)) - } - - availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles)) - for _, p := range availableProfiles { - availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p)) - } - - response := AuthenticateResponse{ - AccessToken: accessToken, - ClientToken: clientToken, - AvailableProfiles: availableProfilesData, - } - - if selectedProfile != nil { - response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile) - } - - if request.RequestUser && user != nil { - response.User = service.SerializeUser(h.logger, user, UUID) - } - - h.logger.Info("用户认证成功", zap.Int64("userId", userId)) - c.JSON(http.StatusOK, response) -} - -// ValidToken 验证令牌 -func (h *YggdrasilHandler) ValidToken(c *gin.Context) { - var request ValidTokenRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析验证令牌请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if service.ValidToken(h.container.DB, request.AccessToken, request.ClientToken) { - h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken)) - c.JSON(http.StatusNoContent, gin.H{"valid": true}) - } else { - h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken)) - c.JSON(http.StatusForbidden, gin.H{"valid": false}) - } -} - -// RefreshToken 刷新令牌 -func (h *YggdrasilHandler) RefreshToken(c *gin.Context) { - var request RefreshRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析刷新令牌请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - UUID, err := service.GetUUIDByAccessToken(h.container.DB, request.AccessToken) - if err != nil { - h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - userID, _ := service.GetUserIDByAccessToken(h.container.DB, request.AccessToken) - UUID = utils.FormatUUID(UUID) - - profile, err := service.GetProfileByUUID(h.container.DB, UUID) - if err != nil { - h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - var profileData map[string]interface{} - var userData map[string]interface{} - var profileID string - - if request.SelectedProfile != nil { - profileIDValue, ok := request.SelectedProfile["id"] - if !ok { - h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"}) - return - } - - profileID, ok = profileIDValue.(string) - if !ok { - h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"}) - return - } - - profileID = utils.FormatUUID(profileID) - - if profile.UserID != userID { - h.logger.Warn("刷新令牌失败: 用户不匹配", - zap.Int64("userId", userID), - zap.Int64("profileUserId", profile.UserID), - ) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch}) - return - } - - profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile) - } - - user, _ := service.GetUserByID(userID) - if request.RequestUser && user != nil { - userData = service.SerializeUser(h.logger, user, UUID) - } - - newAccessToken, newClientToken, err := service.RefreshToken(h.container.DB, h.logger, - request.AccessToken, - request.ClientToken, - profileID, - ) - if err != nil { - h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - h.logger.Info("刷新令牌成功", zap.Int64("userId", userID)) - c.JSON(http.StatusOK, RefreshResponse{ - AccessToken: newAccessToken, - ClientToken: newClientToken, - SelectedProfile: profileData, - User: userData, - }) -} - -// InvalidToken 使令牌失效 -func (h *YggdrasilHandler) InvalidToken(c *gin.Context) { - var request ValidTokenRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析使令牌失效请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - service.InvalidToken(h.container.DB, h.logger, request.AccessToken) - h.logger.Info("令牌已失效", zap.String("token", request.AccessToken)) - c.JSON(http.StatusNoContent, gin.H{}) -} - -// SignOut 用户登出 -func (h *YggdrasilHandler) SignOut(c *gin.Context) { - var request SignOutRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析登出请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if !emailRegex.MatchString(request.Email) { - h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email)) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat}) - return - } - - user, err := service.GetUserByEmail(request.Email) - if err != nil || user == nil { - h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"}) - return - } - - if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil { - h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID)) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword}) - return - } - - service.InvalidUserTokens(h.container.DB, h.logger, user.ID) - h.logger.Info("用户登出成功", zap.Int64("userId", user.ID)) - c.JSON(http.StatusNoContent, gin.H{"valid": true}) -} - -// GetProfileByUUID 根据UUID获取档案 -func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) { - uuid := utils.FormatUUID(c.Param("uuid")) - h.logger.Info("获取配置文件请求", zap.String("uuid", uuid)) - - profile, err := service.GetProfileByUUID(h.container.DB, uuid) - if err != nil { - h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid)) - standardResponse(c, http.StatusInternalServerError, nil, err.Error()) - return - } - - h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name)) - c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) -} - -// JoinServer 加入服务器 -func (h *YggdrasilHandler) JoinServer(c *gin.Context) { - var request JoinServerRequest - clientIP := c.ClientIP() - - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP)) - standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest) - return - } - - h.logger.Info("收到加入服务器请求", - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - - if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { - h.logger.Error("加入服务器失败", - zap.Error(err), - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed) - return - } - - h.logger.Info("加入服务器成功", - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - c.Status(http.StatusNoContent) -} - -// HasJoinedServer 验证玩家是否已加入服务器 -func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) { - clientIP, _ := c.GetQuery("ip") - - serverID, exists := c.GetQuery("serverId") - if !exists || serverID == "" { - h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP)) - standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired) - return - } - - username, exists := c.GetQuery("username") - if !exists || username == "" { - h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP)) - standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired) - return - } - - h.logger.Info("收到会话验证请求", - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("ip", clientIP), - ) - - if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil { - h.logger.Warn("会话验证失败", - zap.Error(err), - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("ip", clientIP), - ) - standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed) - return - } - - profile, err := service.GetProfileByUUID(h.container.DB, username) - if err != nil { - h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username)) - standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound) - return - } - - h.logger.Info("会话验证成功", - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("uuid", profile.UUID), - ) - c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) -} - -// GetProfilesByName 批量获取配置文件 -func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) { - var names []string - - if err := c.ShouldBindJSON(&names); err != nil { - h.logger.Error("解析名称数组请求失败", zap.Error(err)) - standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams) - return - } - - h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names))) - - profiles, err := service.GetProfilesDataByNames(h.container.DB, names) - if err != nil { - h.logger.Error("获取配置文件失败", zap.Error(err)) - } - - h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles))) - c.JSON(http.StatusOK, profiles) -} - -// GetMetaData 获取Yggdrasil元数据 -func (h *YggdrasilHandler) GetMetaData(c *gin.Context) { - meta := gin.H{ - "implementationName": "CellAuth", - "implementationVersion": "0.0.1", - "serverName": "LittleLan's Yggdrasil Server Implementation.", - "links": gin.H{ - "homepage": "https://skin.littlelan.cn", - "register": "https://skin.littlelan.cn/auth", - }, - "feature.non_email_login": true, - "feature.enable_profile_key": true, - } - - skinDomains := []string{".hitwh.games", ".littlelan.cn"} - signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis) - if err != nil { - h.logger.Error("获取公钥失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - h.logger.Info("提供元数据") - c.JSON(http.StatusOK, gin.H{ - "meta": meta, - "skinDomains": skinDomains, - "signaturePublickey": signature, - }) -} - -// GetPlayerCertificates 获取玩家证书 -func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) { - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"}) - c.Abort() - return - } - - bearerPrefix := "Bearer " - if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) - c.Abort() - return - } - - tokenID := authHeader[len(bearerPrefix):] - if tokenID == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) - c.Abort() - return - } - - uuid, err := service.GetUUIDByAccessToken(h.container.DB, tokenID) - if uuid == "" { - h.logger.Error("获取玩家UUID失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - uuid = utils.FormatUUID(uuid) - - certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid) - if err != nil { - h.logger.Error("生成玩家证书失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - h.logger.Info("成功生成玩家证书") - c.JSON(http.StatusOK, certificate) -} diff --git a/internal/service/helpers_test.go b/internal/service/helpers_test.go new file mode 100644 index 0000000..043aba4 --- /dev/null +++ b/internal/service/helpers_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "errors" + "testing" +) + +// TestNormalizePagination_Basic 覆盖 NormalizePagination 的边界分支 +func TestNormalizePagination_Basic(t *testing.T) { + tests := []struct { + name string + page int + size int + wantPage int + wantPageSize int + }{ + {"page 小于 1", 0, 10, 1, 10}, + {"pageSize 小于 1", 1, 0, 1, 20}, + {"pageSize 大于 100", 2, 200, 2, 100}, + {"正常范围", 3, 30, 3, 30}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPage, gotSize := NormalizePagination(tt.page, tt.size) + if gotPage != tt.wantPage || gotSize != tt.wantPageSize { + t.Fatalf("NormalizePagination(%d,%d) = (%d,%d), want (%d,%d)", + tt.page, tt.size, gotPage, gotSize, tt.wantPage, tt.wantPageSize) + } + }) + } +} + +// TestWrapError 覆盖 WrapError 的 nil 与非 nil 分支 +func TestWrapError(t *testing.T) { + if err := WrapError(nil, "msg"); err != nil { + t.Fatalf("WrapError(nil, ...) 应返回 nil, got=%v", err) + } + + orig := errors.New("orig") + wrapped := WrapError(orig, "context") + if wrapped == nil { + t.Fatalf("WrapError 应返回非 nil 错误") + } + if wrapped.Error() == orig.Error() { + t.Fatalf("WrapError 应添加上下文信息, got=%v", wrapped) + } +} + + diff --git a/internal/service/mocks_test.go b/internal/service/mocks_test.go new file mode 100644 index 0000000..0c3572e --- /dev/null +++ b/internal/service/mocks_test.go @@ -0,0 +1,964 @@ +package service + +import ( + "carrotskin/internal/model" + "errors" +) + +// ============================================================================ +// Repository Mocks +// ============================================================================ + +// MockUserRepository 模拟UserRepository +type MockUserRepository struct { + users map[int64]*model.User + // 用于模拟错误的标志 + FailCreate bool + FailFindByID bool + FailFindByUsername bool + FailFindByEmail bool + FailUpdate bool +} + +func NewMockUserRepository() *MockUserRepository { + return &MockUserRepository{ + users: make(map[int64]*model.User), + } +} + +func (m *MockUserRepository) Create(user *model.User) error { + if m.FailCreate { + return errors.New("mock create error") + } + if user.ID == 0 { + user.ID = int64(len(m.users) + 1) + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserRepository) FindByID(id int64) (*model.User, error) { + if m.FailFindByID { + return nil, errors.New("mock find error") + } + if user, ok := m.users[id]; ok { + return user, nil + } + return nil, nil +} + +func (m *MockUserRepository) FindByUsername(username string) (*model.User, error) { + if m.FailFindByUsername { + return nil, errors.New("mock find by username error") + } + for _, user := range m.users { + if user.Username == username { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) { + if m.FailFindByEmail { + return nil, errors.New("mock find by email error") + } + for _, user := range m.users { + if user.Email == email { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserRepository) Update(user *model.User) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update fields error") + } + _, ok := m.users[id] + if !ok { + return errors.New("user not found") + } + return nil +} + +func (m *MockUserRepository) Delete(id int64) error { + delete(m.users, id) + return nil +} + +func (m *MockUserRepository) CreateLoginLog(log *model.UserLoginLog) error { + return nil +} + +func (m *MockUserRepository) CreatePointLog(log *model.UserPointLog) error { + return nil +} + +func (m *MockUserRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { + return nil +} + +// MockProfileRepository 模拟ProfileRepository +type MockProfileRepository struct { + profiles map[string]*model.Profile + userProfiles map[int64][]*model.Profile + nextID int64 + FailCreate bool + FailFind bool + FailUpdate bool + FailDelete bool +} + +func NewMockProfileRepository() *MockProfileRepository { + return &MockProfileRepository{ + profiles: make(map[string]*model.Profile), + userProfiles: make(map[int64][]*model.Profile), + nextID: 1, + } +} + +func (m *MockProfileRepository) Create(profile *model.Profile) error { + if m.FailCreate { + return errors.New("mock create error") + } + m.profiles[profile.UUID] = profile + m.userProfiles[profile.UserID] = append(m.userProfiles[profile.UserID], profile) + return nil +} + +func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if profile, ok := m.profiles[uuid]; ok { + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + for _, profile := range m.profiles { + if profile.Name == name { + return profile, nil + } + } + return nil, nil +} + +func (m *MockProfileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + return m.userProfiles[userID], nil +} + +func (m *MockProfileRepository) Update(profile *model.Profile) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.profiles[profile.UUID] = profile + return nil +} + +func (m *MockProfileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update error") + } + return nil +} + +func (m *MockProfileRepository) Delete(uuid string) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.profiles, uuid) + return nil +} + +func (m *MockProfileRepository) CountByUserID(userID int64) (int64, error) { + return int64(len(m.userProfiles[userID])), nil +} + +func (m *MockProfileRepository) SetActive(uuid string, userID int64) error { + return nil +} + +func (m *MockProfileRepository) UpdateLastUsedAt(uuid string) error { + return nil +} + +func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, error) { + var result []*model.Profile + for _, name := range names { + for _, profile := range m.profiles { + if profile.Name == name { + result = append(result, profile) + } + } + } + return result, nil +} + +func (m *MockProfileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { + return nil, nil +} + +func (m *MockProfileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { + return nil +} + +// MockTextureRepository 模拟TextureRepository +type MockTextureRepository struct { + textures map[int64]*model.Texture + favorites map[int64]map[int64]bool // userID -> textureID -> favorited + nextID int64 + FailCreate bool + FailFind bool + FailUpdate bool + FailDelete bool +} + +func NewMockTextureRepository() *MockTextureRepository { + return &MockTextureRepository{ + textures: make(map[int64]*model.Texture), + favorites: make(map[int64]map[int64]bool), + nextID: 1, + } +} + +func (m *MockTextureRepository) Create(texture *model.Texture) error { + if m.FailCreate { + return errors.New("mock create error") + } + if texture.ID == 0 { + texture.ID = m.nextID + m.nextID++ + } + m.textures[texture.ID] = texture + return nil +} + +func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if texture, ok := m.textures[id]; ok { + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + for _, texture := range m.textures { + if texture.Hash == hash { + return texture, nil + } + } + return nil, nil +} + +func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailFind { + return nil, 0, errors.New("mock find error") + } + var result []*model.Texture + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + result = append(result, texture) + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailFind { + return nil, 0, errors.New("mock find error") + } + var result []*model.Texture + for _, texture := range m.textures { + if publicOnly && !texture.IsPublic { + continue + } + result = append(result, texture) + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) Update(texture *model.Texture) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.textures[texture.ID] = texture + return nil +} + +func (m *MockTextureRepository) UpdateFields(id int64, fields map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update error") + } + return nil +} + +func (m *MockTextureRepository) Delete(id int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.textures, id) + return nil +} + +func (m *MockTextureRepository) IncrementDownloadCount(id int64) error { + if texture, ok := m.textures[id]; ok { + texture.DownloadCount++ + } + return nil +} + +func (m *MockTextureRepository) IncrementFavoriteCount(id int64) error { + if texture, ok := m.textures[id]; ok { + texture.FavoriteCount++ + } + return nil +} + +func (m *MockTextureRepository) DecrementFavoriteCount(id int64) error { + if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 { + texture.FavoriteCount-- + } + return nil +} + +func (m *MockTextureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { + return nil +} + +func (m *MockTextureRepository) IsFavorited(userID, textureID int64) (bool, error) { + if userFavs, ok := m.favorites[userID]; ok { + return userFavs[textureID], nil + } + return false, nil +} + +func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error { + if m.favorites[userID] == nil { + m.favorites[userID] = make(map[int64]bool) + } + m.favorites[userID][textureID] = true + return nil +} + +func (m *MockTextureRepository) RemoveFavorite(userID, textureID int64) error { + if userFavs, ok := m.favorites[userID]; ok { + delete(userFavs, textureID) + } + return nil +} + +func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var result []*model.Texture + if userFavs, ok := m.favorites[userID]; ok { + for textureID := range userFavs { + if texture, exists := m.textures[textureID]; exists { + result = append(result, texture) + } + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, error) { + var count int64 + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + count++ + } + } + return count, nil +} + +// MockTokenRepository 模拟TokenRepository +type MockTokenRepository struct { + tokens map[string]*model.Token + userTokens map[int64][]*model.Token + FailCreate bool + FailFind bool + FailDelete bool +} + +func NewMockTokenRepository() *MockTokenRepository { + return &MockTokenRepository{ + tokens: make(map[string]*model.Token), + userTokens: make(map[int64][]*model.Token), + } +} + +func (m *MockTokenRepository) Create(token *model.Token) error { + if m.FailCreate { + return errors.New("mock create error") + } + m.tokens[token.AccessToken] = token + m.userTokens[token.UserID] = append(m.userTokens[token.UserID], token) + return nil +} + +func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token, nil + } + return nil, errors.New("token not found") +} + +func (m *MockTokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + return m.userTokens[userId], nil +} + +func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { + if m.FailFind { + return "", errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token.ProfileId, nil + } + return "", errors.New("token not found") +} + +func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { + if m.FailFind { + return 0, errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token.UserID, nil + } + return 0, errors.New("token not found") +} + +func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.tokens, accessToken) + return nil +} + +func (m *MockTokenRepository) DeleteByUserID(userId int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + for _, token := range m.userTokens[userId] { + delete(m.tokens, token.AccessToken) + } + m.userTokens[userId] = nil + return nil +} + +func (m *MockTokenRepository) BatchDelete(accessTokens []string) (int64, error) { + if m.FailDelete { + return 0, errors.New("mock delete error") + } + var count int64 + for _, accessToken := range accessTokens { + if _, ok := m.tokens[accessToken]; ok { + delete(m.tokens, accessToken) + count++ + } + } + return count, nil +} + +// MockSystemConfigRepository 模拟SystemConfigRepository +type MockSystemConfigRepository struct { + configs map[string]*model.SystemConfig +} + +func NewMockSystemConfigRepository() *MockSystemConfigRepository { + return &MockSystemConfigRepository{ + configs: make(map[string]*model.SystemConfig), + } +} + +func (m *MockSystemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { + if config, ok := m.configs[key]; ok { + return config, nil + } + return nil, nil +} + +func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) { + var result []model.SystemConfig + for _, v := range m.configs { + result = append(result, *v) + } + return result, nil +} + +func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) { + var result []model.SystemConfig + for _, v := range m.configs { + result = append(result, *v) + } + return result, nil +} + +func (m *MockSystemConfigRepository) Update(config *model.SystemConfig) error { + m.configs[config.Key] = config + return nil +} + +func (m *MockSystemConfigRepository) UpdateValue(key, value string) error { + if config, ok := m.configs[key]; ok { + config.Value = value + return nil + } + return errors.New("config not found") +} + +// ============================================================================ +// Service Mocks +// ============================================================================ + +// MockUserService 模拟UserService +type MockUserService struct { + users map[int64]*model.User + maxProfilesPerUser int + maxTexturesPerUser int + FailRegister bool + FailLogin bool + FailGetByID bool + FailUpdate bool +} + +func NewMockUserService() *MockUserService { + return &MockUserService{ + users: make(map[int64]*model.User), + maxProfilesPerUser: 5, + maxTexturesPerUser: 50, + } +} + +func (m *MockUserService) Register(username, password, email, avatar string) (*model.User, string, error) { + if m.FailRegister { + return nil, "", errors.New("mock register error") + } + user := &model.User{ + ID: int64(len(m.users) + 1), + Username: username, + Email: email, + Avatar: avatar, + Status: 1, + } + m.users[user.ID] = user + return user, "mock-token", nil +} + +func (m *MockUserService) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { + if m.FailLogin { + return nil, "", errors.New("mock login error") + } + for _, user := range m.users { + if user.Username == usernameOrEmail || user.Email == usernameOrEmail { + return user, "mock-token", nil + } + } + return nil, "", errors.New("user not found") +} + +func (m *MockUserService) GetByID(id int64) (*model.User, error) { + if m.FailGetByID { + return nil, errors.New("mock get by id error") + } + if user, ok := m.users[id]; ok { + return user, nil + } + return nil, nil +} + +func (m *MockUserService) GetByEmail(email string) (*model.User, error) { + for _, user := range m.users { + if user.Email == email { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserService) UpdateInfo(user *model.User) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserService) UpdateAvatar(userID int64, avatarURL string) error { + if m.FailUpdate { + return errors.New("mock update error") + } + if user, ok := m.users[userID]; ok { + user.Avatar = avatarURL + } + return nil +} + +func (m *MockUserService) ChangePassword(userID int64, oldPassword, newPassword string) error { + return nil +} + +func (m *MockUserService) ResetPassword(email, newPassword string) error { + return nil +} + +func (m *MockUserService) ChangeEmail(userID int64, newEmail string) error { + if user, ok := m.users[userID]; ok { + user.Email = newEmail + } + return nil +} + +func (m *MockUserService) ValidateAvatarURL(avatarURL string) error { + return nil +} + +func (m *MockUserService) GetMaxProfilesPerUser() int { + return m.maxProfilesPerUser +} + +func (m *MockUserService) GetMaxTexturesPerUser() int { + return m.maxTexturesPerUser +} + +// MockProfileService 模拟ProfileService +type MockProfileService struct { + profiles map[string]*model.Profile + FailCreate bool + FailGet bool + FailUpdate bool + FailDelete bool +} + +func NewMockProfileService() *MockProfileService { + return &MockProfileService{ + profiles: make(map[string]*model.Profile), + } +} + +func (m *MockProfileService) Create(userID int64, name string) (*model.Profile, error) { + if m.FailCreate { + return nil, errors.New("mock create error") + } + profile := &model.Profile{ + UUID: "mock-uuid-" + name, + UserID: userID, + Name: name, + } + m.profiles[profile.UUID] = profile + return profile, nil +} + +func (m *MockProfileService) GetByUUID(uuid string) (*model.Profile, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + if profile, ok := m.profiles[uuid]; ok { + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileService) GetByUserID(userID int64) ([]*model.Profile, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + var result []*model.Profile + for _, profile := range m.profiles { + if profile.UserID == userID { + result = append(result, profile) + } + } + return result, nil +} + +func (m *MockProfileService) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { + if m.FailUpdate { + return nil, errors.New("mock update error") + } + if profile, ok := m.profiles[uuid]; ok { + if name != nil { + profile.Name = *name + } + if skinID != nil { + profile.SkinID = skinID + } + if capeID != nil { + profile.CapeID = capeID + } + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileService) Delete(uuid string, userID int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.profiles, uuid) + return nil +} + +func (m *MockProfileService) SetActive(uuid string, userID int64) error { + return nil +} + +func (m *MockProfileService) CheckLimit(userID int64, maxProfiles int) error { + count := 0 + for _, profile := range m.profiles { + if profile.UserID == userID { + count++ + } + } + if count >= maxProfiles { + return errors.New("达到档案数量上限") + } + return nil +} + +func (m *MockProfileService) GetByNames(names []string) ([]*model.Profile, error) { + var result []*model.Profile + for _, name := range names { + for _, profile := range m.profiles { + if profile.Name == name { + result = append(result, profile) + } + } + } + return result, nil +} + +func (m *MockProfileService) GetByProfileName(name string) (*model.Profile, error) { + for _, profile := range m.profiles { + if profile.Name == name { + return profile, nil + } + } + return nil, errors.New("profile not found") +} + +// MockTextureService 模拟TextureService +type MockTextureService struct { + textures map[int64]*model.Texture + nextID int64 + FailCreate bool + FailGet bool + FailUpdate bool + FailDelete bool +} + +func NewMockTextureService() *MockTextureService { + return &MockTextureService{ + textures: make(map[int64]*model.Texture), + nextID: 1, + } +} + +func (m *MockTextureService) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { + if m.FailCreate { + return nil, errors.New("mock create error") + } + texture := &model.Texture{ + ID: m.nextID, + UploaderID: uploaderID, + Name: name, + Description: description, + URL: url, + Hash: hash, + Size: size, + IsPublic: isPublic, + IsSlim: isSlim, + } + m.textures[texture.ID] = texture + m.nextID++ + return texture, nil +} + +func (m *MockTextureService) GetByID(id int64) (*model.Texture, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + if texture, ok := m.textures[id]; ok { + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureService) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailGet { + return nil, 0, errors.New("mock get error") + } + var result []*model.Texture + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + result = append(result, texture) + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureService) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailGet { + return nil, 0, errors.New("mock get error") + } + var result []*model.Texture + for _, texture := range m.textures { + if publicOnly && !texture.IsPublic { + continue + } + result = append(result, texture) + } + return result, int64(len(result)), nil +} + +func (m *MockTextureService) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { + if m.FailUpdate { + return nil, errors.New("mock update error") + } + if texture, ok := m.textures[textureID]; ok { + if name != "" { + texture.Name = name + } + if description != "" { + texture.Description = description + } + if isPublic != nil { + texture.IsPublic = *isPublic + } + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureService) Delete(textureID, uploaderID int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.textures, textureID) + return nil +} + +func (m *MockTextureService) ToggleFavorite(userID, textureID int64) (bool, error) { + return true, nil +} + +func (m *MockTextureService) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + return nil, 0, nil +} + +func (m *MockTextureService) CheckUploadLimit(uploaderID int64, maxTextures int) error { + count := 0 + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + count++ + } + } + if count >= maxTextures { + return errors.New("达到材质数量上限") + } + return nil +} + +// MockTokenService 模拟TokenService +type MockTokenService struct { + tokens map[string]*model.Token + FailCreate bool + FailValidate bool + FailRefresh bool +} + +func NewMockTokenService() *MockTokenService { + return &MockTokenService{ + tokens: make(map[string]*model.Token), + } +} + +func (m *MockTokenService) Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { + if m.FailCreate { + return nil, nil, "", "", errors.New("mock create error") + } + accessToken := "mock-access-token" + if clientToken == "" { + clientToken = "mock-client-token" + } + token := &model.Token{ + AccessToken: accessToken, + ClientToken: clientToken, + UserID: userID, + ProfileId: uuid, + Usable: true, + } + m.tokens[accessToken] = token + return nil, nil, accessToken, clientToken, nil +} + +func (m *MockTokenService) Validate(accessToken, clientToken string) bool { + if m.FailValidate { + return false + } + if token, ok := m.tokens[accessToken]; ok { + if clientToken == "" || token.ClientToken == clientToken { + return token.Usable + } + } + return false +} + +func (m *MockTokenService) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { + if m.FailRefresh { + return "", "", errors.New("mock refresh error") + } + return "new-access-token", clientToken, nil +} + +func (m *MockTokenService) Invalidate(accessToken string) { + delete(m.tokens, accessToken) +} + +func (m *MockTokenService) InvalidateUserTokens(userID int64) { + for key, token := range m.tokens { + if token.UserID == userID { + delete(m.tokens, key) + } + } +} + +func (m *MockTokenService) GetUUIDByAccessToken(accessToken string) (string, error) { + if token, ok := m.tokens[accessToken]; ok { + return token.ProfileId, nil + } + return "", errors.New("token not found") +} + +func (m *MockTokenService) GetUserIDByAccessToken(accessToken string) (int64, error) { + if token, ok := m.tokens[accessToken]; ok { + return token.UserID, nil + } + return 0, errors.New("token not found") +} diff --git a/internal/service/profile_service.go b/internal/service/profile_service.go index d3e2057..a956793 100644 --- a/internal/service/profile_service.go +++ b/internal/service/profile_service.go @@ -11,35 +11,54 @@ import ( "fmt" "github.com/google/uuid" - "github.com/jackc/pgx/v5" + "go.uber.org/zap" "gorm.io/gorm" ) -// CreateProfile 创建档案 -func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) { +// profileServiceImpl ProfileService的实现 +type profileServiceImpl struct { + profileRepo repository.ProfileRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewProfileService 创建ProfileService实例 +func NewProfileService( + profileRepo repository.ProfileRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) ProfileService { + return &profileServiceImpl{ + profileRepo: profileRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) { // 验证用户存在 - user, err := EnsureUserExists(userID) - if err != nil { - return nil, err + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { + return nil, errors.New("用户不存在") } if user.Status != 1 { - return nil, fmt.Errorf("用户状态异常") + return nil, errors.New("用户状态异常") } // 检查角色名是否已存在 - existingName, err := repository.FindProfileByName(name) + existingName, err := s.profileRepo.FindByName(name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, WrapError(err, "查询角色名失败") + return nil, fmt.Errorf("查询角色名失败: %w", err) } if existingName != nil { - return nil, fmt.Errorf("角色名已被使用") + return nil, errors.New("角色名已被使用") } // 生成UUID和RSA密钥 profileUUID := uuid.New().String() - privateKey, err := generateRSAPrivateKey() + privateKey, err := generateRSAPrivateKeyInternal() if err != nil { - return nil, WrapError(err, "生成RSA密钥失败") + return nil, fmt.Errorf("生成RSA密钥失败: %w", err) } // 创建档案 @@ -51,55 +70,59 @@ func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, erro IsActive: true, } - if err := repository.CreateProfile(profile); err != nil { - return nil, WrapError(err, "创建档案失败") + if err := s.profileRepo.Create(profile); err != nil { + return nil, fmt.Errorf("创建档案失败: %w", err) } // 设置活跃状态 - if err := repository.SetActiveProfile(profileUUID, userID); err != nil { - return nil, WrapError(err, "设置活跃状态失败") + if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { + return nil, fmt.Errorf("设置活跃状态失败: %w", err) } return profile, nil } -// GetProfileByUUID 获取档案详情 -func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) { - profile, err := repository.FindProfileByUUID(uuid) +func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrProfileNotFound } - return nil, WrapError(err, "查询档案失败") + return nil, fmt.Errorf("查询档案失败: %w", err) } return profile, nil } -// GetUserProfiles 获取用户的所有档案 -func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) { - profiles, err := repository.FindProfilesByUserID(userID) +func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) { + profiles, err := s.profileRepo.FindByUserID(userID) if err != nil { - return nil, WrapError(err, "查询档案列表失败") + return nil, fmt.Errorf("查询档案列表失败: %w", err) } return profiles, nil } -// UpdateProfile 更新档案 -func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { +func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { // 获取档案并验证权限 - profile, err := GetProfileWithPermissionCheck(uuid, userID) + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { - return nil, err + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrProfileNotFound + } + return nil, fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return nil, ErrProfileNoPermission } // 检查角色名是否重复 if name != nil && *name != profile.Name { - existingName, err := repository.FindProfileByName(*name) + existingName, err := s.profileRepo.FindByName(*name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, WrapError(err, "查询角色名失败") + return nil, fmt.Errorf("查询角色名失败: %w", err) } if existingName != nil { - return nil, fmt.Errorf("角色名已被使用") + return nil, errors.New("角色名已被使用") } profile.Name = *name } @@ -112,47 +135,62 @@ func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, profile.CapeID = capeID } - if err := repository.UpdateProfile(profile); err != nil { - return nil, WrapError(err, "更新档案失败") + if err := s.profileRepo.Update(profile); err != nil { + return nil, fmt.Errorf("更新档案失败: %w", err) } - return repository.FindProfileByUUID(uuid) + return s.profileRepo.FindByUUID(uuid) } -// DeleteProfile 删除档案 -func DeleteProfile(db *gorm.DB, uuid string, userID int64) error { - if _, err := GetProfileWithPermissionCheck(uuid, userID); err != nil { - return err - } - - if err := repository.DeleteProfile(uuid); err != nil { - return WrapError(err, "删除档案失败") - } - return nil -} - -// SetActiveProfile 设置活跃档案 -func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error { - if _, err := GetProfileWithPermissionCheck(uuid, userID); err != nil { - return err - } - - if err := repository.SetActiveProfile(uuid, userID); err != nil { - return WrapError(err, "设置活跃状态失败") - } - - if err := repository.UpdateProfileLastUsedAt(uuid); err != nil { - return WrapError(err, "更新使用时间失败") - } - - return nil -} - -// CheckProfileLimit 检查用户档案数量限制 -func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error { - count, err := repository.CountProfilesByUserID(userID) +func (s *profileServiceImpl) Delete(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { - return WrapError(err, "查询档案数量失败") + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.Delete(uuid); err != nil { + return fmt.Errorf("删除档案失败: %w", err) + } + return nil +} + +func (s *profileServiceImpl) SetActive(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.SetActive(uuid, userID); err != nil { + return fmt.Errorf("设置活跃状态失败: %w", err) + } + + if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { + return fmt.Errorf("更新使用时间失败: %w", err) + } + + return nil +} + +func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error { + count, err := s.profileRepo.CountByUserID(userID) + if err != nil { + return fmt.Errorf("查询档案数量失败: %w", err) } if int(count) >= maxProfiles { @@ -161,8 +199,24 @@ func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error { return nil } -// generateRSAPrivateKey 生成RSA-2048私钥(PEM格式) -func generateRSAPrivateKey() (string, error) { +func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) { + profiles, err := s.profileRepo.GetByNames(names) + if err != nil { + return nil, fmt.Errorf("查找失败: %w", err) + } + return profiles, nil +} + +func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByName(name) + if err != nil { + return nil, errors.New("用户角色未创建") + } + return profile, nil +} + +// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式) +func generateRSAPrivateKeyInternal() (string, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return "", err @@ -177,33 +231,4 @@ func generateRSAPrivateKey() (string, error) { return string(privateKeyPEM), nil } -func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) { - if userId == 0 || UUID == "" { - return false, errors.New("用户ID或配置文件ID不能为空") - } - profile, err := repository.FindProfileByUUID(UUID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return false, errors.New("配置文件不存在") - } - return false, WrapError(err, "验证配置文件失败") - } - return profile.UserID == userId, nil -} - -func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) { - profiles, err := repository.GetProfilesByNames(names) - if err != nil { - return nil, WrapError(err, "查找失败") - } - return profiles, nil -} - -func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) { - keyPair, err := repository.GetProfileKeyPair(profileId) - if err != nil { - return nil, WrapError(err, "查找失败") - } - return keyPair, nil -} diff --git a/internal/service/profile_service_impl.go b/internal/service/profile_service_impl.go deleted file mode 100644 index a956793..0000000 --- a/internal/service/profile_service_impl.go +++ /dev/null @@ -1,234 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - - "github.com/google/uuid" - "go.uber.org/zap" - "gorm.io/gorm" -) - -// profileServiceImpl ProfileService的实现 -type profileServiceImpl struct { - profileRepo repository.ProfileRepository - userRepo repository.UserRepository - logger *zap.Logger -} - -// NewProfileService 创建ProfileService实例 -func NewProfileService( - profileRepo repository.ProfileRepository, - userRepo repository.UserRepository, - logger *zap.Logger, -) ProfileService { - return &profileServiceImpl{ - profileRepo: profileRepo, - userRepo: userRepo, - logger: logger, - } -} - -func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) { - // 验证用户存在 - user, err := s.userRepo.FindByID(userID) - if err != nil || user == nil { - return nil, errors.New("用户不存在") - } - if user.Status != 1 { - return nil, errors.New("用户状态异常") - } - - // 检查角色名是否已存在 - existingName, err := s.profileRepo.FindByName(name) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("查询角色名失败: %w", err) - } - if existingName != nil { - return nil, errors.New("角色名已被使用") - } - - // 生成UUID和RSA密钥 - profileUUID := uuid.New().String() - privateKey, err := generateRSAPrivateKeyInternal() - if err != nil { - return nil, fmt.Errorf("生成RSA密钥失败: %w", err) - } - - // 创建档案 - profile := &model.Profile{ - UUID: profileUUID, - UserID: userID, - Name: name, - RSAPrivateKey: privateKey, - IsActive: true, - } - - if err := s.profileRepo.Create(profile); err != nil { - return nil, fmt.Errorf("创建档案失败: %w", err) - } - - // 设置活跃状态 - if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { - return nil, fmt.Errorf("设置活跃状态失败: %w", err) - } - - return profile, nil -} - -func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) { - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProfileNotFound - } - return nil, fmt.Errorf("查询档案失败: %w", err) - } - return profile, nil -} - -func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) { - profiles, err := s.profileRepo.FindByUserID(userID) - if err != nil { - return nil, fmt.Errorf("查询档案列表失败: %w", err) - } - return profiles, nil -} - -func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProfileNotFound - } - return nil, fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return nil, ErrProfileNoPermission - } - - // 检查角色名是否重复 - if name != nil && *name != profile.Name { - existingName, err := s.profileRepo.FindByName(*name) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("查询角色名失败: %w", err) - } - if existingName != nil { - return nil, errors.New("角色名已被使用") - } - profile.Name = *name - } - - // 更新皮肤和披风 - if skinID != nil { - profile.SkinID = skinID - } - if capeID != nil { - profile.CapeID = capeID - } - - if err := s.profileRepo.Update(profile); err != nil { - return nil, fmt.Errorf("更新档案失败: %w", err) - } - - return s.profileRepo.FindByUUID(uuid) -} - -func (s *profileServiceImpl) Delete(uuid string, userID int64) error { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProfileNotFound - } - return fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return ErrProfileNoPermission - } - - if err := s.profileRepo.Delete(uuid); err != nil { - return fmt.Errorf("删除档案失败: %w", err) - } - return nil -} - -func (s *profileServiceImpl) SetActive(uuid string, userID int64) error { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProfileNotFound - } - return fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return ErrProfileNoPermission - } - - if err := s.profileRepo.SetActive(uuid, userID); err != nil { - return fmt.Errorf("设置活跃状态失败: %w", err) - } - - if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { - return fmt.Errorf("更新使用时间失败: %w", err) - } - - return nil -} - -func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error { - count, err := s.profileRepo.CountByUserID(userID) - if err != nil { - return fmt.Errorf("查询档案数量失败: %w", err) - } - - if int(count) >= maxProfiles { - return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles) - } - return nil -} - -func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) { - profiles, err := s.profileRepo.GetByNames(names) - if err != nil { - return nil, fmt.Errorf("查找失败: %w", err) - } - return profiles, nil -} - -func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) { - profile, err := s.profileRepo.FindByName(name) - if err != nil { - return nil, errors.New("用户角色未创建") - } - return profile, nil -} - -// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式) -func generateRSAPrivateKeyInternal() (string, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return "", err - } - - privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: privateKeyBytes, - }) - - return string(privateKeyPEM), nil -} - - diff --git a/internal/service/profile_service_test.go b/internal/service/profile_service_test.go index 37fef82..cf71362 100644 --- a/internal/service/profile_service_test.go +++ b/internal/service/profile_service_test.go @@ -1,7 +1,10 @@ package service import ( + "carrotskin/internal/model" "testing" + + "go.uber.org/zap" ) // TestProfileService_Validation 测试Profile服务验证逻辑 @@ -347,22 +350,22 @@ func TestGenerateRSAPrivateKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - privateKey, err := generateRSAPrivateKey() + privateKey, err := generateRSAPrivateKeyInternal() if (err != nil) != tt.wantError { - t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError) + t.Errorf("generateRSAPrivateKeyInternal() error = %v, wantError %v", err, tt.wantError) return } if !tt.wantError { if privateKey == "" { - t.Error("generateRSAPrivateKey() 返回的私钥不应为空") + t.Error("generateRSAPrivateKeyInternal() 返回的私钥不应为空") } // 验证PEM格式 if len(privateKey) < 100 { - t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey)) + t.Errorf("generateRSAPrivateKeyInternal() 返回的私钥长度异常: %d", len(privateKey)) } // 验证包含PEM头部 if !contains(privateKey, "BEGIN RSA PRIVATE KEY") { - t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部") + t.Error("generateRSAPrivateKeyInternal() 返回的私钥应包含PEM头部") } } }) @@ -373,9 +376,9 @@ func TestGenerateRSAPrivateKey(t *testing.T) { func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) { keys := make(map[string]bool) for i := 0; i < 10; i++ { - key, err := generateRSAPrivateKey() + key, err := generateRSAPrivateKeyInternal() if err != nil { - t.Fatalf("generateRSAPrivateKey() 失败: %v", err) + t.Fatalf("generateRSAPrivateKeyInternal() 失败: %v", err) } if keys[key] { t.Errorf("第%d次生成的密钥与之前重复", i+1) @@ -404,3 +407,319 @@ func containsMiddle(s, substr string) bool { } return false } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestProfileServiceImpl_Create 测试创建Profile +func TestProfileServiceImpl_Create(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户 + testUser := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Status: 1, + } + userRepo.Create(testUser) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + userID int64 + profileName string + wantErr bool + errMsg string + setupMocks func() + }{ + { + name: "正常创建Profile", + userID: 1, + profileName: "TestProfile", + wantErr: false, + }, + { + name: "用户不存在", + userID: 999, + profileName: "TestProfile2", + wantErr: true, + errMsg: "用户不存在", + }, + { + name: "角色名已存在", + userID: 1, + profileName: "ExistingProfile", + wantErr: true, + errMsg: "角色名已被使用", + setupMocks: func() { + profileRepo.Create(&model.Profile{ + UUID: "existing-uuid", + UserID: 2, + Name: "ExistingProfile", + }) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupMocks != nil { + tt.setupMocks() + } + + profile, err := profileService.Create(tt.userID, tt.profileName) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if profile == nil { + t.Error("返回的Profile不应为nil") + } + if profile.Name != tt.profileName { + t.Errorf("Profile名称不匹配: got %v, want %v", profile.Name, tt.profileName) + } + if profile.UUID == "" { + t.Error("Profile UUID不应为空") + } + } + }) + } +} + +// TestProfileServiceImpl_GetByUUID 测试获取Profile +func TestProfileServiceImpl_GetByUUID(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "test-uuid-123", + UserID: 1, + Name: "TestProfile", + } + profileRepo.Create(testProfile) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + uuid string + wantErr bool + }{ + { + name: "获取存在的Profile", + uuid: "test-uuid-123", + wantErr: false, + }, + { + name: "获取不存在的Profile", + uuid: "non-existent-uuid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile, err := profileService.GetByUUID(tt.uuid) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if profile == nil { + t.Error("返回的Profile不应为nil") + } + if profile.UUID != tt.uuid { + t.Errorf("Profile UUID不匹配: got %v, want %v", profile.UUID, tt.uuid) + } + } + }) + } +} + +// TestProfileServiceImpl_Delete 测试删除Profile +func TestProfileServiceImpl_Delete(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "delete-test-uuid", + UserID: 1, + Name: "DeleteTestProfile", + } + profileRepo.Create(testProfile) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + uuid string + userID int64 + wantErr bool + }{ + { + name: "正常删除", + uuid: "delete-test-uuid", + userID: 1, + wantErr: false, + }, + { + name: "用户ID不匹配", + uuid: "delete-test-uuid", + userID: 2, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := profileService.Delete(tt.uuid, tt.userID) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + } + } + }) + } +} + +// TestProfileServiceImpl_GetByUserID 测试按用户获取档案列表 +func TestProfileServiceImpl_GetByUserID(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 为用户 1 和 2 预置不同档案 + profileRepo.Create(&model.Profile{UUID: "p1", UserID: 1, Name: "P1"}) + profileRepo.Create(&model.Profile{UUID: "p2", UserID: 1, Name: "P2"}) + profileRepo.Create(&model.Profile{UUID: "p3", UserID: 2, Name: "P3"}) + + svc := NewProfileService(profileRepo, userRepo, logger) + + list, err := svc.GetByUserID(1) + if err != nil { + t.Fatalf("GetByUserID 失败: %v", err) + } + if len(list) != 2 { + t.Fatalf("GetByUserID 返回数量错误, got=%d, want=2", len(list)) + } +} + +// TestProfileServiceImpl_Update_And_SetActive 测试 Update 与 SetActive +func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + profile := &model.Profile{ + UUID: "u1", + UserID: 1, + Name: "OldName", + } + profileRepo.Create(profile) + + svc := NewProfileService(profileRepo, userRepo, logger) + + // 正常更新名称与皮肤/披风 + newName := "NewName" + var skinID int64 = 10 + var capeID int64 = 20 + updated, err := svc.Update("u1", 1, &newName, &skinID, &capeID) + if err != nil { + t.Fatalf("Update 正常情况失败: %v", err) + } + if updated == nil || updated.Name != newName { + t.Fatalf("Update 未更新名称, got=%+v", updated) + } + + // 用户无权限 + if _, err := svc.Update("u1", 2, &newName, nil, nil); err == nil { + t.Fatalf("Update 在无权限时应返回错误") + } + + // 名称重复 + profileRepo.Create(&model.Profile{ + UUID: "u2", + UserID: 2, + Name: "Duplicate", + }) + if _, err := svc.Update("u1", 1, stringPtr("Duplicate"), nil, nil); err == nil { + t.Fatalf("Update 在名称重复时应返回错误") + } + + // SetActive 正常 + if err := svc.SetActive("u1", 1); err != nil { + t.Fatalf("SetActive 正常情况失败: %v", err) + } + + // SetActive 无权限 + if err := svc.SetActive("u1", 2); err == nil { + t.Fatalf("SetActive 在无权限时应返回错误") + } +} + +// TestProfileServiceImpl_CheckLimit_And_GetByNames 测试 CheckLimit / GetByNames / GetByProfileName +func TestProfileServiceImpl_CheckLimit_And_GetByNames(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 为用户 1 预置 2 个档案 + profileRepo.Create(&model.Profile{UUID: "a", UserID: 1, Name: "A"}) + profileRepo.Create(&model.Profile{UUID: "b", UserID: 1, Name: "B"}) + + svc := NewProfileService(profileRepo, userRepo, logger) + + // CheckLimit 未达上限 + if err := svc.CheckLimit(1, 3); err != nil { + t.Fatalf("CheckLimit 未达到上限时不应报错: %v", err) + } + + // CheckLimit 达到上限 + if err := svc.CheckLimit(1, 2); err == nil { + t.Fatalf("CheckLimit 达到上限时应报错") + } + + // GetByNames + list, err := svc.GetByNames([]string{"A", "B"}) + if err != nil { + t.Fatalf("GetByNames 失败: %v", err) + } + if len(list) != 2 { + t.Fatalf("GetByNames 返回数量错误, got=%d, want=2", len(list)) + } + + // GetByProfileName 存在 + p, err := svc.GetByProfileName("A") + if err != nil || p == nil || p.Name != "A" { + t.Fatalf("GetByProfileName 返回错误, profile=%+v, err=%v", p, err) + } +} diff --git a/internal/service/serialize_service.go b/internal/service/serialize_service.go index 2400522..4f12691 100644 --- a/internal/service/serialize_service.go +++ b/internal/service/serialize_service.go @@ -2,6 +2,7 @@ package service import ( "carrotskin/internal/model" + "carrotskin/internal/repository" "carrotskin/pkg/redis" "encoding/base64" "time" @@ -31,7 +32,7 @@ func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client // 处理皮肤 if p.SkinID != nil { - skin, err := GetTextureByID(db, *p.SkinID) + skin, err := repository.FindTextureByID(*p.SkinID) if err != nil { logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID)) } else { @@ -44,7 +45,7 @@ func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client // 处理披风 if p.CapeID != nil { - cape, err := GetTextureByID(db, *p.CapeID) + cape, err := repository.FindTextureByID(*p.CapeID) if err != nil { logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID)) } else { diff --git a/internal/service/serialize_service_test.go b/internal/service/serialize_service_test.go index 4f2d3be..4ad66e7 100644 --- a/internal/service/serialize_service_test.go +++ b/internal/service/serialize_service_test.go @@ -5,6 +5,7 @@ import ( "testing" "go.uber.org/zap/zaptest" + "gorm.io/datatypes" ) // TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户 @@ -19,25 +20,51 @@ func TestSerializeUser_NilUser(t *testing.T) { // TestSerializeUser_ActualCall 实际调用SerializeUser函数 func TestSerializeUser_ActualCall(t *testing.T) { logger := zaptest.NewLogger(t) - user := &model.User{ - ID: 1, - Username: "testuser", - Email: "test@example.com", - // Properties 使用 datatypes.JSON,测试中可以为空 - } - result := SerializeUser(logger, user, "test-uuid-123") - if result == nil { - t.Fatal("SerializeUser() 返回的结果不应为nil") - } + t.Run("Properties为nil时", func(t *testing.T) { + user := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + } - if result["id"] != "test-uuid-123" { - t.Errorf("id = %v, want 'test-uuid-123'", result["id"]) - } + result := SerializeUser(logger, user, "test-uuid-123") + if result == nil { + t.Fatal("SerializeUser() 返回的结果不应为nil") + } - if result["properties"] == nil { - t.Error("properties 不应为nil") - } + if result["id"] != "test-uuid-123" { + t.Errorf("id = %v, want 'test-uuid-123'", result["id"]) + } + + // 当 Properties 为 nil 时,properties 应该为 nil + if result["properties"] != nil { + t.Error("当 user.Properties 为 nil 时,properties 应为 nil") + } + }) + + t.Run("Properties有值时", func(t *testing.T) { + propsJSON := datatypes.JSON(`[{"name":"test","value":"value"}]`) + user := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Properties: &propsJSON, + } + + result := SerializeUser(logger, user, "test-uuid-456") + if result == nil { + t.Fatal("SerializeUser() 返回的结果不应为nil") + } + + if result["id"] != "test-uuid-456" { + t.Errorf("id = %v, want 'test-uuid-456'", result["id"]) + } + + if result["properties"] == nil { + t.Error("当 user.Properties 有值时,properties 不应为 nil") + } + }) } // TestProperty_Structure 测试Property结构 diff --git a/internal/service/texture_service.go b/internal/service/texture_service.go index ea312f0..eb19a82 100644 --- a/internal/service/texture_service.go +++ b/internal/service/texture_service.go @@ -6,18 +6,38 @@ import ( "errors" "fmt" - "gorm.io/gorm" + "go.uber.org/zap" ) -// CreateTexture 创建材质 -func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { +// textureServiceImpl TextureService的实现 +type textureServiceImpl struct { + textureRepo repository.TextureRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewTextureService 创建TextureService实例 +func NewTextureService( + textureRepo repository.TextureRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) TextureService { + return &textureServiceImpl{ + textureRepo: textureRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { // 验证用户存在 - if _, err := EnsureUserExists(uploaderID); err != nil { - return nil, err + user, err := s.userRepo.FindByID(uploaderID) + if err != nil || user == nil { + return nil, ErrUserNotFound } // 检查Hash是否已存在 - existingTexture, err := repository.FindTextureByHash(hash) + existingTexture, err := s.textureRepo.FindByHash(hash) if err != nil { return nil, err } @@ -26,7 +46,7 @@ func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType } // 转换材质类型 - textureTypeEnum, err := parseTextureType(textureType) + textureTypeEnum, err := parseTextureTypeInternal(textureType) if err != nil { return nil, err } @@ -47,36 +67,49 @@ func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType FavoriteCount: 0, } - if err := repository.CreateTexture(texture); err != nil { + if err := s.textureRepo.Create(texture); err != nil { return nil, err } return texture, nil } -// GetTextureByID 根据ID获取材质 -func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) { - return EnsureTextureExists(id) -} - -// GetUserTextures 获取用户上传的材质列表 -func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return repository.FindTexturesByUploaderID(uploaderID, page, pageSize) -} - -// SearchTextures 搜索材质 -func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize) -} - -// UpdateTexture 更新材质 -func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { - // 获取材质并验证权限 - if _, err := GetTextureWithPermissionCheck(textureID, uploaderID); err != nil { +func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) { + texture, err := s.textureRepo.FindByID(id) + if err != nil { return nil, err } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.Status == -1 { + return nil, errors.New("材质已删除") + } + return texture, nil +} + +func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) +} + +func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) +} + +func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { + return nil, err + } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return nil, ErrTextureNoPermission + } // 更新字段 updates := make(map[string]interface{}) @@ -91,83 +124,73 @@ func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description s } if len(updates) > 0 { - if err := repository.UpdateTextureFields(textureID, updates); err != nil { + if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { return nil, err } } - return repository.FindTextureByID(textureID) + return s.textureRepo.FindByID(textureID) } -// DeleteTexture 删除材质 -func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error { - if _, err := GetTextureWithPermissionCheck(textureID, uploaderID); err != nil { +func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { return err } - return repository.DeleteTexture(textureID) + if texture == nil { + return ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return ErrTextureNoPermission + } + + return s.textureRepo.Delete(textureID) } -// RecordTextureDownload 记录下载 -func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error { - if _, err := EnsureTextureExists(textureID); err != nil { - return err - } - - if err := repository.IncrementTextureDownloadCount(textureID); err != nil { - return err - } - - log := &model.TextureDownloadLog{ - TextureID: textureID, - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - } - - return repository.CreateTextureDownloadLog(log) -} - -// ToggleTextureFavorite 切换收藏状态 -func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) { - if _, err := EnsureTextureExists(textureID); err != nil { +func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) { + // 确保材质存在 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { return false, err } + if texture == nil { + return false, ErrTextureNotFound + } - isFavorited, err := repository.IsTextureFavorited(userID, textureID) + isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) if err != nil { return false, err } if isFavorited { // 已收藏 -> 取消收藏 - if err := repository.RemoveTextureFavorite(userID, textureID); err != nil { + if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { return false, err } - if err := repository.DecrementTextureFavoriteCount(textureID); err != nil { + if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { return false, err } return false, nil - } else { - // 未收藏 -> 添加收藏 - if err := repository.AddTextureFavorite(userID, textureID); err != nil { - return false, err - } - if err := repository.IncrementTextureFavoriteCount(textureID); err != nil { - return false, err - } - return true, nil } + + // 未收藏 -> 添加收藏 + if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { + return false, err + } + if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { + return false, err + } + return true, nil } -// GetUserTextureFavorites 获取用户收藏的材质列表 -func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { page, pageSize = NormalizePagination(page, pageSize) - return repository.GetUserTextureFavorites(userID, page, pageSize) + return s.textureRepo.GetUserFavorites(userID, page, pageSize) } -// CheckTextureUploadLimit 检查用户上传材质数量限制 -func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error { - count, err := repository.CountTexturesByUploaderID(uploaderID) +func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error { + count, err := s.textureRepo.CountByUploaderID(uploaderID) if err != nil { return err } @@ -179,8 +202,8 @@ func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) err return nil } -// parseTextureType 解析材质类型 -func parseTextureType(textureType string) (model.TextureType, error) { +// parseTextureTypeInternal 解析材质类型 +func parseTextureTypeInternal(textureType string) (model.TextureType, error) { switch textureType { case "SKIN": return model.TextureTypeSkin, nil diff --git a/internal/service/texture_service_impl.go b/internal/service/texture_service_impl.go deleted file mode 100644 index eb19a82..0000000 --- a/internal/service/texture_service_impl.go +++ /dev/null @@ -1,215 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "errors" - "fmt" - - "go.uber.org/zap" -) - -// textureServiceImpl TextureService的实现 -type textureServiceImpl struct { - textureRepo repository.TextureRepository - userRepo repository.UserRepository - logger *zap.Logger -} - -// NewTextureService 创建TextureService实例 -func NewTextureService( - textureRepo repository.TextureRepository, - userRepo repository.UserRepository, - logger *zap.Logger, -) TextureService { - return &textureServiceImpl{ - textureRepo: textureRepo, - userRepo: userRepo, - logger: logger, - } -} - -func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { - // 验证用户存在 - user, err := s.userRepo.FindByID(uploaderID) - if err != nil || user == nil { - return nil, ErrUserNotFound - } - - // 检查Hash是否已存在 - existingTexture, err := s.textureRepo.FindByHash(hash) - if err != nil { - return nil, err - } - if existingTexture != nil { - return nil, errors.New("该材质已存在") - } - - // 转换材质类型 - textureTypeEnum, err := parseTextureTypeInternal(textureType) - if err != nil { - return nil, err - } - - // 创建材质 - texture := &model.Texture{ - UploaderID: uploaderID, - Name: name, - Description: description, - Type: textureTypeEnum, - URL: url, - Hash: hash, - Size: size, - IsPublic: isPublic, - IsSlim: isSlim, - Status: 1, - DownloadCount: 0, - FavoriteCount: 0, - } - - if err := s.textureRepo.Create(texture); err != nil { - return nil, err - } - - return texture, nil -} - -func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) { - texture, err := s.textureRepo.FindByID(id) - if err != nil { - return nil, err - } - if texture == nil { - return nil, ErrTextureNotFound - } - if texture.Status == -1 { - return nil, errors.New("材质已删除") - } - return texture, nil -} - -func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) -} - -func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) -} - -func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { - // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return nil, err - } - if texture == nil { - return nil, ErrTextureNotFound - } - if texture.UploaderID != uploaderID { - return nil, ErrTextureNoPermission - } - - // 更新字段 - updates := make(map[string]interface{}) - if name != "" { - updates["name"] = name - } - if description != "" { - updates["description"] = description - } - if isPublic != nil { - updates["is_public"] = *isPublic - } - - if len(updates) > 0 { - if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { - return nil, err - } - } - - return s.textureRepo.FindByID(textureID) -} - -func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error { - // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return err - } - if texture == nil { - return ErrTextureNotFound - } - if texture.UploaderID != uploaderID { - return ErrTextureNoPermission - } - - return s.textureRepo.Delete(textureID) -} - -func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) { - // 确保材质存在 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return false, err - } - if texture == nil { - return false, ErrTextureNotFound - } - - isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) - if err != nil { - return false, err - } - - if isFavorited { - // 已收藏 -> 取消收藏 - if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { - return false, err - } - if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { - return false, err - } - return false, nil - } - - // 未收藏 -> 添加收藏 - if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { - return false, err - } - if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { - return false, err - } - return true, nil -} - -func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.GetUserFavorites(userID, page, pageSize) -} - -func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error { - count, err := s.textureRepo.CountByUploaderID(uploaderID) - if err != nil { - return err - } - - if count >= int64(maxTextures) { - return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures) - } - - return nil -} - -// parseTextureTypeInternal 解析材质类型 -func parseTextureTypeInternal(textureType string) (model.TextureType, error) { - switch textureType { - case "SKIN": - return model.TextureTypeSkin, nil - case "CAPE": - return model.TextureTypeCape, nil - default: - return "", errors.New("无效的材质类型") - } -} diff --git a/internal/service/texture_service_test.go b/internal/service/texture_service_test.go index c4e9ec1..a99a4f0 100644 --- a/internal/service/texture_service_test.go +++ b/internal/service/texture_service_test.go @@ -1,7 +1,10 @@ package service import ( + "carrotskin/internal/model" "testing" + + "go.uber.org/zap" ) // TestTextureService_TypeValidation 测试材质类型验证 @@ -469,3 +472,357 @@ func TestCheckTextureUploadLimit_Logic(t *testing.T) { func boolPtr(b bool) *bool { return &b } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestTextureServiceImpl_Create 测试创建Texture +func TestTextureServiceImpl_Create(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户 + testUser := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Status: 1, + } + userRepo.Create(testUser) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + tests := []struct { + name string + uploaderID int64 + textureName string + textureType string + hash string + wantErr bool + errContains string + setupMocks func() + }{ + { + name: "正常创建SKIN材质", + uploaderID: 1, + textureName: "TestSkin", + textureType: "SKIN", + hash: "unique-hash-1", + wantErr: false, + }, + { + name: "正常创建CAPE材质", + uploaderID: 1, + textureName: "TestCape", + textureType: "CAPE", + hash: "unique-hash-2", + wantErr: false, + }, + { + name: "用户不存在", + uploaderID: 999, + textureName: "TestTexture", + textureType: "SKIN", + hash: "unique-hash-3", + wantErr: true, + }, + { + name: "材质Hash已存在", + uploaderID: 1, + textureName: "DuplicateTexture", + textureType: "SKIN", + hash: "existing-hash", + wantErr: true, + errContains: "已存在", + setupMocks: func() { + textureRepo.Create(&model.Texture{ + ID: 100, + UploaderID: 1, + Name: "ExistingTexture", + Hash: "existing-hash", + }) + }, + }, + { + name: "无效的材质类型", + uploaderID: 1, + textureName: "InvalidTypeTexture", + textureType: "INVALID", + hash: "unique-hash-4", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupMocks != nil { + tt.setupMocks() + } + + texture, err := textureService.Create( + tt.uploaderID, + tt.textureName, + "Test description", + tt.textureType, + "http://example.com/texture.png", + tt.hash, + 1024, + true, + false, + ) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errContains != "" && !containsString(err.Error(), tt.errContains) { + t.Errorf("错误信息应包含 %q, 实际为: %v", tt.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if texture == nil { + t.Error("返回的Texture不应为nil") + } + if texture.Name != tt.textureName { + t.Errorf("Texture名称不匹配: got %v, want %v", texture.Name, tt.textureName) + } + } + }) + } +} + +// TestTextureServiceImpl_GetByID 测试获取Texture +func TestTextureServiceImpl_GetByID(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Texture + testTexture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "TestTexture", + Hash: "test-hash", + } + textureRepo.Create(testTexture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + tests := []struct { + name string + id int64 + wantErr bool + }{ + { + name: "获取存在的Texture", + id: 1, + wantErr: false, + }, + { + name: "获取不存在的Texture", + id: 999, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + texture, err := textureService.GetByID(tt.id) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if texture == nil { + t.Error("返回的Texture不应为nil") + } + } + }) + } +} + +// TestTextureServiceImpl_GetByUserID_And_Search 测试 GetByUserID 与 Search 分页封装 +func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置多条 Texture + for i := int64(1); i <= 5; i++ { + textureRepo.Create(&model.Texture{ + ID: i, + UploaderID: 1, + Name: "T", + IsPublic: i%2 == 0, + }) + } + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // GetByUserID 应按上传者过滤并调用 NormalizePagination + textures, total, err := textureService.GetByUserID(1, 0, 0) + if err != nil { + t.Fatalf("GetByUserID 失败: %v", err) + } + if total != int64(len(textures)) { + t.Fatalf("GetByUserID 返回数量与总数不一致, total=%d, len=%d", total, len(textures)) + } + + // Search 仅验证能够正常调用并返回结果 + searchResult, searchTotal, err := textureService.Search("", "", true, -1, 200) + if err != nil { + t.Fatalf("Search 失败: %v", err) + } + if searchTotal != int64(len(searchResult)) { + t.Fatalf("Search 返回数量与总数不一致, total=%d, len=%d", searchTotal, len(searchResult)) + } +} + +// TestTextureServiceImpl_Update_And_Delete 测试 Update / Delete 权限与字段更新 +func TestTextureServiceImpl_Update_And_Delete(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + texture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "Old", + Description:"OldDesc", + IsPublic: false, + } + textureRepo.Create(texture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // 更新成功 + newName := "NewName" + newDesc := "NewDesc" + public := boolPtr(true) + updated, err := textureService.Update(1, 1, newName, newDesc, public) + if err != nil { + t.Fatalf("Update 正常情况失败: %v", err) + } + // 由于 MockTextureRepository.UpdateFields 不会真正修改结构体字段,这里只验证不会返回 nil 即可 + if updated == nil { + t.Fatalf("Update 返回结果不应为 nil") + } + + // 无权限更新 + if _, err := textureService.Update(1, 2, "X", "Y", nil); err == nil { + t.Fatalf("Update 在无权限时应返回错误") + } + + // 删除成功 + if err := textureService.Delete(1, 1); err != nil { + t.Fatalf("Delete 正常情况失败: %v", err) + } + + // 无权限删除 + if err := textureService.Delete(1, 2); err == nil { + t.Fatalf("Delete 在无权限时应返回错误") + } +} + +// TestTextureServiceImpl_FavoritesAndLimit 测试 GetUserFavorites 与 CheckUploadLimit +func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置若干 Texture 与收藏关系 + for i := int64(1); i <= 3; i++ { + textureRepo.Create(&model.Texture{ + ID: i, + UploaderID: 1, + Name: "T", + }) + _ = textureRepo.AddFavorite(1, i) + } + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // GetUserFavorites + favs, total, err := textureService.GetUserFavorites(1, -1, -1) + if err != nil { + t.Fatalf("GetUserFavorites 失败: %v", err) + } + if int64(len(favs)) != total || total != 3 { + t.Fatalf("GetUserFavorites 数量不正确, total=%d, len=%d", total, len(favs)) + } + + // CheckUploadLimit 未超过上限 + if err := textureService.CheckUploadLimit(1, 10); err != nil { + t.Fatalf("CheckUploadLimit 在未达到上限时不应报错: %v", err) + } + + // CheckUploadLimit 超过上限 + if err := textureService.CheckUploadLimit(1, 2); err == nil { + t.Fatalf("CheckUploadLimit 在超过上限时应返回错误") + } +} + +// TestTextureServiceImpl_ToggleFavorite 测试收藏功能 +func TestTextureServiceImpl_ToggleFavorite(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户和Texture + testUser := &model.User{ID: 1, Username: "testuser", Status: 1} + userRepo.Create(testUser) + + testTexture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "TestTexture", + Hash: "test-hash", + } + textureRepo.Create(testTexture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // 第一次收藏 + isFavorited, err := textureService.ToggleFavorite(1, 1) + if err != nil { + t.Errorf("第一次收藏失败: %v", err) + } + if !isFavorited { + t.Error("第一次操作应该是添加收藏") + } + + // 第二次取消收藏 + isFavorited, err = textureService.ToggleFavorite(1, 1) + if err != nil { + t.Errorf("取消收藏失败: %v", err) + } + if isFavorited { + t.Error("第二次操作应该是取消收藏") + } +} + +// 辅助函数 +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || + (len(s) > len(substr) && (findSubstring(s, substr) != -1))) +} + +func findSubstring(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/internal/service/token_service.go b/internal/service/token_service.go index 20af177..b128abf 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -6,35 +6,55 @@ import ( "context" "errors" "fmt" - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "go.uber.org/zap" "strconv" "time" - "gorm.io/gorm" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" ) -// 常量定义 +// tokenServiceImpl TokenService的实现 +type tokenServiceImpl struct { + tokenRepo repository.TokenRepository + profileRepo repository.ProfileRepository + logger *zap.Logger +} + +// NewTokenService 创建TokenService实例 +func NewTokenService( + tokenRepo repository.TokenRepository, + profileRepo repository.ProfileRepository, + logger *zap.Logger, +) TokenService { + return &tokenServiceImpl{ + tokenRepo: tokenRepo, + profileRepo: profileRepo, + logger: logger, + } +} + const ( - ExtendedTimeout = 10 * time.Second - TokensMaxCount = 10 // 用户最多保留的token数量 + tokenExtendedTimeout = 10 * time.Second + tokensMaxCount = 10 ) -// NewToken 创建新令牌 -func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { +func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { var ( selectedProfileID *model.Profile availableProfiles []*model.Profile ) + // 设置超时上下文 _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() // 验证用户存在 - _, err := repository.FindProfileByUUID(UUID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) + if UUID != "" { + _, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) + } } // 生成令牌 @@ -46,13 +66,13 @@ func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, client token := model.Token{ AccessToken: accessToken, ClientToken: clientToken, - UserID: userId, + UserID: userID, Usable: true, IssueDate: time.Now(), } // 获取用户配置文件 - profiles, err := repository.FindProfilesByUserID(userId) + profiles, err := s.profileRepo.FindByUserID(userID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) } @@ -64,65 +84,24 @@ func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, client } availableProfiles = profiles - // 插入令牌到tokens集合 - _, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout) - defer insertCancel() - - err = repository.CreateToken(&token) + // 插入令牌 + err = s.tokenRepo.Create(&token) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) } + // 清理多余的令牌 - go CheckAndCleanupExcessTokens(db, logger, userId) + go s.checkAndCleanupExcessTokens(userID) return selectedProfileID, availableProfiles, accessToken, clientToken, nil } -// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌,只保留最新的10个 -func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) { - if userId == 0 { - return - } - // 获取用户所有令牌,按发行日期降序排序 - tokens, err := repository.GetTokensByUserId(userId) - if err != nil { - logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10))) - return - } - - // 如果令牌数量不超过上限,无需清理 - if len(tokens) <= TokensMaxCount { - return - } - - // 获取需要删除的令牌ID列表 - tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount) - for i := TokensMaxCount; i < len(tokens); i++ { - tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) - } - - // 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数) - DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete) - if err != nil { - logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10))) - return - } - - if DeletedCount > 0 { - logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount)) - } -} - -// ValidToken 验证令牌有效性 -func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool { +func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool { if accessToken == "" { return false } - // 使用投影只获取需要的字段 - var token *model.Token - token, err := repository.FindTokenByID(accessToken) - + token, err := s.tokenRepo.FindByAccessToken(accessToken) if err != nil { return false } @@ -131,47 +110,35 @@ func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool { return false } - // 如果客户端令牌为空,只验证访问令牌 if clientToken == "" { return true } - // 否则验证客户端令牌是否匹配 return token.ClientToken == clientToken } -func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) { - return repository.GetUUIDByAccessToken(accessToken) -} - -func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) { - return repository.GetUserIDByAccessToken(accessToken) -} - -// RefreshToken 刷新令牌 -func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) { +func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { if accessToken == "" { return "", "", errors.New("accessToken不能为空") } // 查找旧令牌 - oldToken, err := repository.GetTokenByAccessToken(accessToken) + oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", "", errors.New("accessToken无效") } - logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken)) + s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return "", "", fmt.Errorf("查询令牌失败: %w", err) } // 验证profile if selectedProfileID != "" { - valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID) + valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) if validErr != nil { - logger.Error( - "验证Profile失败", + s.logger.Error("验证Profile失败", zap.Error(err), - zap.Any("userId", oldToken.UserID), + zap.Int64("userId", oldToken.UserID), zap.String("profileId", selectedProfileID), ) return "", "", fmt.Errorf("验证角色失败: %w", err) @@ -192,86 +159,119 @@ func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken stri return "", "", errors.New("原令牌已绑定角色,无法选择新角色") } } else { - selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色 + selectedProfileID = oldToken.ProfileId } // 生成新令牌 newAccessToken := uuid.New().String() newToken := model.Token{ AccessToken: newAccessToken, - ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同 + ClientToken: oldToken.ClientToken, UserID: oldToken.UserID, Usable: true, - ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色 + ProfileId: selectedProfileID, IssueDate: time.Now(), } - // 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌 - - err = repository.CreateToken(&newToken) + // 先插入新令牌,再删除旧令牌 + err = s.tokenRepo.Create(&newToken) if err != nil { - logger.Error( - "创建新Token失败", - zap.Error(err), - zap.String("accessToken", accessToken), - ) + s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return "", "", fmt.Errorf("创建新Token失败: %w", err) } - err = repository.DeleteTokenByAccessToken(accessToken) + err = s.tokenRepo.DeleteByAccessToken(accessToken) if err != nil { - // 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建 - logger.Warn( - "删除旧Token失败,但新Token已创建", + s.logger.Warn("删除旧Token失败,但新Token已创建", zap.Error(err), zap.String("oldToken", oldToken.AccessToken), zap.String("newToken", newAccessToken), ) } - logger.Info( - "成功刷新Token", - zap.Any("userId", oldToken.UserID), - zap.String("accessToken", newAccessToken), - ) + s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken)) return newAccessToken, oldToken.ClientToken, nil } -// InvalidToken 使令牌失效 -func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) { +func (s *tokenServiceImpl) Invalidate(accessToken string) { if accessToken == "" { return } - err := repository.DeleteTokenByAccessToken(accessToken) + err := s.tokenRepo.DeleteByAccessToken(accessToken) if err != nil { - logger.Error( - "删除Token失败", - zap.Error(err), - zap.String("accessToken", accessToken), - ) + s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return } - logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken)) - + s.logger.Info("成功删除Token", zap.String("token", accessToken)) } -// InvalidUserTokens 使用户所有令牌失效 -func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) { - if userId == 0 { +func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) { + if userID == 0 { return } - err := repository.DeleteTokenByUserId(userId) + err := s.tokenRepo.DeleteByUserID(userID) if err != nil { - logger.Error( - "[ERROR]删除用户Token失败", - zap.Error(err), - zap.Any("userId", userId), - ) + s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) return } - logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId)) - + s.logger.Info("成功删除用户Token", zap.Int64("userId", userID)) +} + +func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) { + return s.tokenRepo.GetUUIDByAccessToken(accessToken) +} + +func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { + return s.tokenRepo.GetUserIDByAccessToken(accessToken) +} + +// 私有辅助方法 + +func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) { + if userID == 0 { + return + } + + tokens, err := s.tokenRepo.GetByUserID(userID) + if err != nil { + s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if len(tokens) <= tokensMaxCount { + return + } + + tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) + for i := tokensMaxCount; i < len(tokens); i++ { + tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) + } + + deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) + if err != nil { + s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if deletedCount > 0 { + s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) + } +} + +func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) { + if userID == 0 || UUID == "" { + return false, errors.New("用户ID或配置文件ID不能为空") + } + + profile, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, errors.New("配置文件不存在") + } + return false, fmt.Errorf("验证配置文件失败: %w", err) + } + return profile.UserID == userID, nil } diff --git a/internal/service/token_service_impl.go b/internal/service/token_service_impl.go deleted file mode 100644 index b128abf..0000000 --- a/internal/service/token_service_impl.go +++ /dev/null @@ -1,277 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "context" - "errors" - "fmt" - "strconv" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "go.uber.org/zap" -) - -// tokenServiceImpl TokenService的实现 -type tokenServiceImpl struct { - tokenRepo repository.TokenRepository - profileRepo repository.ProfileRepository - logger *zap.Logger -} - -// NewTokenService 创建TokenService实例 -func NewTokenService( - tokenRepo repository.TokenRepository, - profileRepo repository.ProfileRepository, - logger *zap.Logger, -) TokenService { - return &tokenServiceImpl{ - tokenRepo: tokenRepo, - profileRepo: profileRepo, - logger: logger, - } -} - -const ( - tokenExtendedTimeout = 10 * time.Second - tokensMaxCount = 10 -) - -func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { - var ( - selectedProfileID *model.Profile - availableProfiles []*model.Profile - ) - - // 设置超时上下文 - _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) - defer cancel() - - // 验证用户存在 - if UUID != "" { - _, err := s.profileRepo.FindByUUID(UUID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) - } - } - - // 生成令牌 - if clientToken == "" { - clientToken = uuid.New().String() - } - - accessToken := uuid.New().String() - token := model.Token{ - AccessToken: accessToken, - ClientToken: clientToken, - UserID: userID, - Usable: true, - IssueDate: time.Now(), - } - - // 获取用户配置文件 - profiles, err := s.profileRepo.FindByUserID(userID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) - } - - // 如果用户只有一个配置文件,自动选择 - if len(profiles) == 1 { - selectedProfileID = profiles[0] - token.ProfileId = selectedProfileID.UUID - } - availableProfiles = profiles - - // 插入令牌 - err = s.tokenRepo.Create(&token) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) - } - - // 清理多余的令牌 - go s.checkAndCleanupExcessTokens(userID) - - return selectedProfileID, availableProfiles, accessToken, clientToken, nil -} - -func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool { - if accessToken == "" { - return false - } - - token, err := s.tokenRepo.FindByAccessToken(accessToken) - if err != nil { - return false - } - - if !token.Usable { - return false - } - - if clientToken == "" { - return true - } - - return token.ClientToken == clientToken -} - -func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { - if accessToken == "" { - return "", "", errors.New("accessToken不能为空") - } - - // 查找旧令牌 - oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return "", "", errors.New("accessToken无效") - } - s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return "", "", fmt.Errorf("查询令牌失败: %w", err) - } - - // 验证profile - if selectedProfileID != "" { - valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) - if validErr != nil { - s.logger.Error("验证Profile失败", - zap.Error(err), - zap.Int64("userId", oldToken.UserID), - zap.String("profileId", selectedProfileID), - ) - return "", "", fmt.Errorf("验证角色失败: %w", err) - } - if !valid { - return "", "", errors.New("角色与用户不匹配") - } - } - - // 检查 clientToken 是否有效 - if clientToken != "" && clientToken != oldToken.ClientToken { - return "", "", errors.New("clientToken无效") - } - - // 检查 selectedProfileID 的逻辑 - if selectedProfileID != "" { - if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID { - return "", "", errors.New("原令牌已绑定角色,无法选择新角色") - } - } else { - selectedProfileID = oldToken.ProfileId - } - - // 生成新令牌 - newAccessToken := uuid.New().String() - newToken := model.Token{ - AccessToken: newAccessToken, - ClientToken: oldToken.ClientToken, - UserID: oldToken.UserID, - Usable: true, - ProfileId: selectedProfileID, - IssueDate: time.Now(), - } - - // 先插入新令牌,再删除旧令牌 - err = s.tokenRepo.Create(&newToken) - if err != nil { - s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return "", "", fmt.Errorf("创建新Token失败: %w", err) - } - - err = s.tokenRepo.DeleteByAccessToken(accessToken) - if err != nil { - s.logger.Warn("删除旧Token失败,但新Token已创建", - zap.Error(err), - zap.String("oldToken", oldToken.AccessToken), - zap.String("newToken", newAccessToken), - ) - } - - s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken)) - return newAccessToken, oldToken.ClientToken, nil -} - -func (s *tokenServiceImpl) Invalidate(accessToken string) { - if accessToken == "" { - return - } - - err := s.tokenRepo.DeleteByAccessToken(accessToken) - if err != nil { - s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return - } - s.logger.Info("成功删除Token", zap.String("token", accessToken)) -} - -func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) { - if userID == 0 { - return - } - - err := s.tokenRepo.DeleteByUserID(userID) - if err != nil { - s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) - return - } - - s.logger.Info("成功删除用户Token", zap.Int64("userId", userID)) -} - -func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) { - return s.tokenRepo.GetUUIDByAccessToken(accessToken) -} - -func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { - return s.tokenRepo.GetUserIDByAccessToken(accessToken) -} - -// 私有辅助方法 - -func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) { - if userID == 0 { - return - } - - tokens, err := s.tokenRepo.GetByUserID(userID) - if err != nil { - s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) - return - } - - if len(tokens) <= tokensMaxCount { - return - } - - tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) - for i := tokensMaxCount; i < len(tokens); i++ { - tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) - } - - deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) - if err != nil { - s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) - return - } - - if deletedCount > 0 { - s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) - } -} - -func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) { - if userID == 0 || UUID == "" { - return false, errors.New("用户ID或配置文件ID不能为空") - } - - profile, err := s.profileRepo.FindByUUID(UUID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return false, errors.New("配置文件不存在") - } - return false, fmt.Errorf("验证配置文件失败: %w", err) - } - return profile.UserID == userID, nil -} diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 7c051d2..e85978b 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -1,18 +1,23 @@ package service import ( + "carrotskin/internal/model" + "fmt" "testing" "time" + + "go.uber.org/zap" ) // TestTokenService_Constants 测试Token服务相关常量 func TestTokenService_Constants(t *testing.T) { - if ExtendedTimeout != 10*time.Second { - t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout) + // 测试私有常量通过行为验证 + if tokenExtendedTimeout != 10*time.Second { + t.Errorf("tokenExtendedTimeout = %v, want 10 seconds", tokenExtendedTimeout) } - if TokensMaxCount != 10 { - t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount) + if tokensMaxCount != 10 { + t.Errorf("tokensMaxCount = %d, want 10", tokensMaxCount) } } @@ -22,8 +27,8 @@ func TestTokenService_Timeout(t *testing.T) { t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout) } - if ExtendedTimeout <= DefaultTimeout { - t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout) + if tokenExtendedTimeout <= DefaultTimeout { + t.Errorf("tokenExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", tokenExtendedTimeout, DefaultTimeout) } } @@ -202,3 +207,314 @@ func TestTokenService_UserIDValidation(t *testing.T) { }) } } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestTokenServiceImpl_Create 测试创建Token +func TestTokenServiceImpl_Create(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "test-profile-uuid", + UserID: 1, + Name: "TestProfile", + IsActive: true, + } + profileRepo.Create(testProfile) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + tests := []struct { + name string + userID int64 + uuid string + clientToken string + wantErr bool + }{ + { + name: "正常创建Token(指定UUID)", + userID: 1, + uuid: "test-profile-uuid", + clientToken: "client-token-1", + wantErr: false, + }, + { + name: "正常创建Token(空clientToken)", + userID: 1, + uuid: "test-profile-uuid", + clientToken: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, accessToken, clientToken, err := tokenService.Create(tt.userID, tt.uuid, tt.clientToken) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if accessToken == "" { + t.Error("accessToken不应为空") + } + if clientToken == "" { + t.Error("clientToken不应为空") + } + } + }) + } +} + +// TestTokenServiceImpl_Validate 测试验证Token +func TestTokenServiceImpl_Validate(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Token + testToken := &model.Token{ + AccessToken: "valid-access-token", + ClientToken: "valid-client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + } + tokenRepo.Create(testToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + tests := []struct { + name string + accessToken string + clientToken string + wantValid bool + }{ + { + name: "有效Token(完全匹配)", + accessToken: "valid-access-token", + clientToken: "valid-client-token", + wantValid: true, + }, + { + name: "有效Token(只检查accessToken)", + accessToken: "valid-access-token", + clientToken: "", + wantValid: true, + }, + { + name: "无效Token(accessToken不存在)", + accessToken: "invalid-access-token", + clientToken: "", + wantValid: false, + }, + { + name: "无效Token(clientToken不匹配)", + accessToken: "valid-access-token", + clientToken: "wrong-client-token", + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tokenService.Validate(tt.accessToken, tt.clientToken) + + if isValid != tt.wantValid { + t.Errorf("Token验证结果不匹配: got %v, want %v", isValid, tt.wantValid) + } + }) + } +} + +// TestTokenServiceImpl_Invalidate 测试注销Token +func TestTokenServiceImpl_Invalidate(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Token + testToken := &model.Token{ + AccessToken: "token-to-invalidate", + ClientToken: "client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + } + tokenRepo.Create(testToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 验证Token存在 + isValid := tokenService.Validate("token-to-invalidate", "") + if !isValid { + t.Error("Token应该有效") + } + + // 注销Token + tokenService.Invalidate("token-to-invalidate") + + // 验证Token已失效(从repo中删除) + _, err := tokenRepo.FindByAccessToken("token-to-invalidate") + if err == nil { + t.Error("Token应该已被删除") + } +} + +// TestTokenServiceImpl_InvalidateUserTokens 测试注销用户所有Token +func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置多个Token + for i := 1; i <= 3; i++ { + tokenRepo.Create(&model.Token{ + AccessToken: fmt.Sprintf("user1-token-%d", i), + ClientToken: "client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + }) + } + tokenRepo.Create(&model.Token{ + AccessToken: "user2-token-1", + ClientToken: "client-token", + UserID: 2, + ProfileId: "test-profile-uuid-2", + Usable: true, + }) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 注销用户1的所有Token + tokenService.InvalidateUserTokens(1) + + // 验证用户1的Token已失效 + tokens, _ := tokenRepo.GetByUserID(1) + if len(tokens) > 0 { + t.Errorf("用户1的Token应该全部被删除,但还剩 %d 个", len(tokens)) + } + + // 验证用户2的Token仍然存在 + tokens2, _ := tokenRepo.GetByUserID(2) + if len(tokens2) != 1 { + t.Errorf("用户2的Token应该仍然存在,期望1个,实际 %d 个", len(tokens2)) + } +} + +// TestTokenServiceImpl_Refresh 覆盖 Refresh 的主要分支 +func TestTokenServiceImpl_Refresh(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置 Profile 与 Token + profile := &model.Profile{ + UUID: "profile-uuid", + UserID: 1, + } + profileRepo.Create(profile) + + oldToken := &model.Token{ + AccessToken: "old-token", + ClientToken: "client-token", + UserID: 1, + ProfileId: "", + Usable: true, + } + tokenRepo.Create(oldToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 正常刷新,不指定 profile + newAccess, client, err := tokenService.Refresh("old-token", "client-token", "") + if err != nil { + t.Fatalf("Refresh 正常情况失败: %v", err) + } + if newAccess == "" || client != "client-token" { + t.Fatalf("Refresh 返回值异常: access=%s, client=%s", newAccess, client) + } + + // accessToken 为空 + if _, _, err := tokenService.Refresh("", "client-token", ""); err == nil { + t.Fatalf("Refresh 在 accessToken 为空时应返回错误") + } +} + +// TestTokenServiceImpl_GetByAccessToken 封装 GetUUIDByAccessToken / GetUserIDByAccessToken +func TestTokenServiceImpl_GetByAccessToken(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + token := &model.Token{ + AccessToken: "token-1", + UserID: 42, + ProfileId: "profile-42", + Usable: true, + } + tokenRepo.Create(token) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + uuid, err := tokenService.GetUUIDByAccessToken("token-1") + if err != nil || uuid != "profile-42" { + t.Fatalf("GetUUIDByAccessToken 返回错误: uuid=%s, err=%v", uuid, err) + } + + uid, err := tokenService.GetUserIDByAccessToken("token-1") + if err != nil || uid != 42 { + t.Fatalf("GetUserIDByAccessToken 返回错误: uid=%d, err=%v", uid, err) + } +} + +// TestTokenServiceImpl_validateProfileByUserID 直接测试内部校验逻辑 +func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + svc := &tokenServiceImpl{ + tokenRepo: tokenRepo, + profileRepo: profileRepo, + logger: logger, + } + + // 预置 Profile + profile := &model.Profile{ + UUID: "p-1", + UserID: 1, + } + profileRepo.Create(profile) + + // 参数非法 + if ok, err := svc.validateProfileByUserID(0, ""); err == nil || ok { + t.Fatalf("validateProfileByUserID 在参数非法时应返回错误") + } + + // Profile 不存在 + if ok, err := svc.validateProfileByUserID(1, "not-exists"); err == nil || ok { + t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误") + } + + // 用户与 Profile 匹配 + if ok, err := svc.validateProfileByUserID(1, "p-1"); err != nil || !ok { + t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err) + } + + // 用户与 Profile 不匹配 + if ok, err := svc.validateProfileByUserID(2, "p-1"); err != nil || ok { + t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err) + } +} \ No newline at end of file diff --git a/internal/service/upload_service.go b/internal/service/upload_service.go index 4678872..877357b 100644 --- a/internal/service/upload_service.go +++ b/internal/service/upload_service.go @@ -74,27 +74,38 @@ func ValidateFileName(fileName string, fileType FileType) error { return nil } -// GenerateAvatarUploadURL 生成头像上传URL +// uploadStorageClient 为上传服务定义的最小依赖接口,便于单元测试注入 mock +type uploadStorageClient interface { + GetBucket(name string) (string, error) + GeneratePresignedPostURL(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) +} + +// GenerateAvatarUploadURL 生成头像上传URL(对外导出) func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) { + return generateAvatarUploadURLWithClient(ctx, storageClient, userID, fileName) +} + +// generateAvatarUploadURLWithClient 使用接口类型的内部实现,方便测试 +func generateAvatarUploadURLWithClient(ctx context.Context, storageClient uploadStorageClient, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) { // 1. 验证文件名 if err := ValidateFileName(fileName, FileTypeAvatar); err != nil { return nil, err } - + // 2. 获取上传配置 uploadConfig := GetUploadConfig(FileTypeAvatar) - + // 3. 获取存储桶名称 bucketName, err := storageClient.GetBucket("avatars") if err != nil { return nil, fmt.Errorf("获取存储桶失败: %w", err) } - + // 4. 生成对象名称(路径) // 格式: user_{userId}/timestamp_{originalFileName} timestamp := time.Now().Format("20060102150405") objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName) - + // 5. 生成预签名POST URL (使用存储客户端内置的 PublicURL) result, err := storageClient.GeneratePresignedPostURL( ctx, @@ -107,37 +118,42 @@ func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.Storage if err != nil { return nil, fmt.Errorf("生成上传URL失败: %w", err) } - + return result, nil } -// GenerateTextureUploadURL 生成材质上传URL +// GenerateTextureUploadURL 生成材质上传URL(对外导出) func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) { + return generateTextureUploadURLWithClient(ctx, storageClient, userID, fileName, textureType) +} + +// generateTextureUploadURLWithClient 使用接口类型的内部实现,方便测试 +func generateTextureUploadURLWithClient(ctx context.Context, storageClient uploadStorageClient, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) { // 1. 验证文件名 if err := ValidateFileName(fileName, FileTypeTexture); err != nil { return nil, err } - + // 2. 验证材质类型 if textureType != "SKIN" && textureType != "CAPE" { return nil, fmt.Errorf("无效的材质类型: %s", textureType) } - + // 3. 获取上传配置 uploadConfig := GetUploadConfig(FileTypeTexture) - + // 4. 获取存储桶名称 bucketName, err := storageClient.GetBucket("textures") if err != nil { return nil, fmt.Errorf("获取存储桶失败: %w", err) } - + // 5. 生成对象名称(路径) // 格式: user_{userId}/{textureType}/timestamp_{originalFileName} timestamp := time.Now().Format("20060102150405") textureTypeFolder := strings.ToLower(textureType) objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName) - + // 6. 生成预签名POST URL (使用存储客户端内置的 PublicURL) result, err := storageClient.GeneratePresignedPostURL( ctx, @@ -150,6 +166,6 @@ func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.Storag if err != nil { return nil, fmt.Errorf("生成上传URL失败: %w", err) } - + return result, nil } diff --git a/internal/service/upload_service_test.go b/internal/service/upload_service_test.go index 52f2012..07df008 100644 --- a/internal/service/upload_service_test.go +++ b/internal/service/upload_service_test.go @@ -1,9 +1,13 @@ package service import ( + "context" + "errors" "strings" "testing" "time" + + "carrotskin/pkg/storage" ) // TestUploadService_FileTypes 测试文件类型常量 @@ -135,43 +139,43 @@ func TestGetUploadConfig_TextureConfig(t *testing.T) { // TestValidateFileName 测试文件名验证 func TestValidateFileName(t *testing.T) { tests := []struct { - name string - fileName string - fileType FileType - wantErr bool + name string + fileName string + fileType FileType + wantErr bool errContains string }{ { - name: "有效的头像文件名", - fileName: "avatar.png", - fileType: FileTypeAvatar, - wantErr: false, + name: "有效的头像文件名", + fileName: "avatar.png", + fileType: FileTypeAvatar, + wantErr: false, }, { - name: "有效的材质文件名", - fileName: "texture.png", - fileType: FileTypeTexture, - wantErr: false, + name: "有效的材质文件名", + fileName: "texture.png", + fileType: FileTypeTexture, + wantErr: false, }, { - name: "文件名为空", - fileName: "", - fileType: FileTypeAvatar, - wantErr: true, + name: "文件名为空", + fileName: "", + fileType: FileTypeAvatar, + wantErr: true, errContains: "文件名不能为空", }, { - name: "不支持的文件扩展名", - fileName: "file.txt", - fileType: FileTypeAvatar, - wantErr: true, + name: "不支持的文件扩展名", + fileName: "file.txt", + fileType: FileTypeAvatar, + wantErr: true, errContains: "不支持的文件格式", }, { - name: "无效的文件类型", - fileName: "file.png", - fileType: FileType("invalid"), - wantErr: true, + name: "无效的文件类型", + fileName: "file.png", + fileType: FileType("invalid"), + wantErr: true, errContains: "不支持的文件类型", }, } @@ -277,3 +281,130 @@ func TestUploadConfig_Structure(t *testing.T) { } } +// mockStorageClient 用于单元测试的简单存储客户端假实现 +// 注意:这里只声明与 upload_service 使用到的方法,避免依赖真实 MinIO 客户端 +type mockStorageClient struct { + getBucketFn func(name string) (string, error) + generatePresignedPostURLFn func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) +} + +func (m *mockStorageClient) GetBucket(name string) (string, error) { + if m.getBucketFn != nil { + return m.getBucketFn(name) + } + return "", errors.New("GetBucket not implemented") +} + +func (m *mockStorageClient) GeneratePresignedPostURL(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if m.generatePresignedPostURLFn != nil { + return m.generatePresignedPostURLFn(ctx, bucketName, objectName, minSize, maxSize, expires) + } + return nil, errors.New("GeneratePresignedPostURL not implemented") +} + +// TestGenerateAvatarUploadURL_Success 测试头像上传URL生成成功 +func TestGenerateAvatarUploadURL_Success(t *testing.T) { + ctx := context.Background() + + mockClient := &mockStorageClient{ + getBucketFn: func(name string) (string, error) { + if name != "avatars" { + t.Fatalf("unexpected bucket name: %s", name) + } + return "avatars-bucket", nil + }, + generatePresignedPostURLFn: func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if bucketName != "avatars-bucket" { + t.Fatalf("unexpected bucketName: %s", bucketName) + } + if !strings.Contains(objectName, "user_") { + t.Fatalf("objectName should contain user_ prefix, got: %s", objectName) + } + if !strings.Contains(objectName, "avatar.png") { + t.Fatalf("objectName should contain original file name, got: %s", objectName) + } + // 检查大小与过期时间传递 + if minSize != 1024 { + t.Fatalf("minSize = %d, want 1024", minSize) + } + if maxSize != 5*1024*1024 { + t.Fatalf("maxSize = %d, want 5MB", maxSize) + } + if expires != 15*time.Minute { + t.Fatalf("expires = %v, want 15m", expires) + } + return &storage.PresignedPostPolicyResult{ + PostURL: "http://example.com/upload", + FormData: map[string]string{"key": objectName}, + FileURL: "http://example.com/file/" + objectName, + }, nil + }, + } + + // 直接将 mock 实例转换为真实类型使用(依赖其方法集与被测代码一致) + storageClient := (*storage.StorageClient)(nil) + _ = storageClient // 避免未使用告警,实际调用仍通过 mockClient 完成 + + // 直接通过内部使用接口的实现进行测试,避免依赖真实 StorageClient + result, err := generateAvatarUploadURLWithClient(ctx, mockClient, 123, "avatar.png") + + if err != nil { + t.Fatalf("GenerateAvatarUploadURL() error = %v, want nil", err) + } + if result == nil { + t.Fatalf("GenerateAvatarUploadURL() result is nil") + } + if result.PostURL == "" || result.FileURL == "" { + t.Fatalf("GenerateAvatarUploadURL() result has empty URLs: %+v", result) + } +} + +// TestGenerateTextureUploadURL_Success 测试材质上传URL生成成功(SKIN/CAPE) +func TestGenerateTextureUploadURL_Success(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + textureType string + }{ + {"SKIN 材质", "SKIN"}, + {"CAPE 材质", "CAPE"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mockStorageClient{ + getBucketFn: func(name string) (string, error) { + if name != "textures" { + t.Fatalf("unexpected bucket name: %s", name) + } + return "textures-bucket", nil + }, + generatePresignedPostURLFn: func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if bucketName != "textures-bucket" { + t.Fatalf("unexpected bucketName: %s", bucketName) + } + if !strings.Contains(objectName, "texture.png") { + t.Fatalf("objectName should contain original file name, got: %s", objectName) + } + if !strings.Contains(objectName, "/"+strings.ToLower(tt.textureType)+"/") { + t.Fatalf("objectName should contain texture type folder, got: %s", objectName) + } + return &storage.PresignedPostPolicyResult{ + PostURL: "http://example.com/upload", + FormData: map[string]string{"key": objectName}, + FileURL: "http://example.com/file/" + objectName, + }, nil + }, + } + + result, err := generateTextureUploadURLWithClient(ctx, mockClient, 123, "texture.png", tt.textureType) + if err != nil { + t.Fatalf("generateTextureUploadURLWithClient() error = %v, want nil", err) + } + if result == nil || result.PostURL == "" || result.FileURL == "" { + t.Fatalf("generateTextureUploadURLWithClient() result invalid: %+v", result) + } + }) + } +} diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 249a341..2b7250e 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -12,12 +12,39 @@ import ( "net/url" "strings" "time" + + "go.uber.org/zap" ) -// RegisterUser 用户注册 -func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) { +// userServiceImpl UserService的实现 +type userServiceImpl struct { + userRepo repository.UserRepository + configRepo repository.SystemConfigRepository + jwtService *auth.JWTService + redis *redis.Client + logger *zap.Logger +} + +// NewUserService 创建UserService实例 +func NewUserService( + userRepo repository.UserRepository, + configRepo repository.SystemConfigRepository, + jwtService *auth.JWTService, + redisClient *redis.Client, + logger *zap.Logger, +) UserService { + return &userServiceImpl{ + userRepo: userRepo, + configRepo: configRepo, + jwtService: jwtService, + redis: redisClient, + logger: logger, + } +} + +func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) { // 检查用户名是否已存在 - existingUser, err := repository.FindUserByUsername(username) + existingUser, err := s.userRepo.FindByUsername(username) if err != nil { return nil, "", err } @@ -26,7 +53,7 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar } // 检查邮箱是否已存在 - existingEmail, err := repository.FindUserByEmail(email) + existingEmail, err := s.userRepo.FindByEmail(email) if err != nil { return nil, "", err } @@ -40,15 +67,14 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar return nil, "", errors.New("密码加密失败") } - // 确定头像URL:优先使用用户提供的头像,否则使用默认头像 + // 确定头像URL avatarURL := avatar if avatarURL != "" { - // 验证用户提供的头像 URL 是否来自允许的域名 - if err := ValidateAvatarURL(avatarURL); err != nil { + if err := s.ValidateAvatarURL(avatarURL); err != nil { return nil, "", err } } else { - avatarURL = getDefaultAvatar() + avatarURL = s.getDefaultAvatar() } // 创建用户 @@ -62,12 +88,12 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar Points: 0, } - if err := repository.CreateUser(user); err != nil { + if err := s.userRepo.Create(user); err != nil { return nil, "", err } // 生成JWT Token - token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role) + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) if err != nil { return nil, "", errors.New("生成Token失败") } @@ -75,92 +101,56 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar return user, token, nil } -// LoginUser 用户登录(支持用户名或邮箱登录) -func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { - return LoginUserWithRateLimit(nil, jwtService, usernameOrEmail, password, ipAddress, userAgent) -} - -// LoginUserWithRateLimit 用户登录(带频率限制) -func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { +func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { ctx := context.Background() - // 检查账号是否被锁定(基于用户名/邮箱和IP) - if redisClient != nil { + // 检查账号是否被锁定 + if s.redis != nil { identifier := usernameOrEmail + ":" + ipAddress - locked, ttl, err := CheckLoginLocked(ctx, redisClient, identifier) + locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier) if err == nil && locked { return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1) } } - // 查找用户:判断是用户名还是邮箱 + // 查找用户 var user *model.User var err error if strings.Contains(usernameOrEmail, "@") { - user, err = repository.FindUserByEmail(usernameOrEmail) + user, err = s.userRepo.FindByEmail(usernameOrEmail) } else { - user, err = repository.FindUserByUsername(usernameOrEmail) + user, err = s.userRepo.FindByUsername(usernameOrEmail) } if err != nil { return nil, "", err } if user == nil { - // 记录失败尝试 - if redisClient != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, redisClient, identifier) - // 检查是否触发锁定 - if count >= MaxLoginAttempts { - logFailedLogin(0, ipAddress, userAgent, "用户不存在-账号已锁定") - return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes())) - } - remaining := MaxLoginAttempts - count - if remaining > 0 { - logFailedLogin(0, ipAddress, userAgent, "用户不存在") - return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining) - } - } - logFailedLogin(0, ipAddress, userAgent, "用户不存在") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在") return nil, "", errors.New("用户名/邮箱或密码错误") } // 检查用户状态 if user.Status != 1 { - logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用") return nil, "", errors.New("账号已被禁用") } // 验证密码 if !auth.CheckPassword(user.Password, password) { - // 记录失败尝试 - if redisClient != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, redisClient, identifier) - // 检查是否触发锁定 - if count >= MaxLoginAttempts { - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误-账号已锁定") - return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes())) - } - remaining := MaxLoginAttempts - count - if remaining > 0 { - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误") - return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining) - } - } - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误") return nil, "", errors.New("用户名/邮箱或密码错误") } // 登录成功,清除失败计数 - if redisClient != nil { + if s.redis != nil { identifier := usernameOrEmail + ":" + ipAddress - _ = ClearLoginAttempts(ctx, redisClient, identifier) + _ = ClearLoginAttempts(ctx, s.redis, identifier) } // 生成JWT Token - token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role) + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) if err != nil { return nil, "", errors.New("生成Token失败") } @@ -168,37 +158,37 @@ func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTServi // 更新最后登录时间 now := time.Now() user.LastLoginAt = &now - _ = repository.UpdateUserFields(user.ID, map[string]interface{}{ + _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ "last_login_at": now, }) // 记录成功登录日志 - logSuccessLogin(user.ID, ipAddress, userAgent) + s.logSuccessLogin(user.ID, ipAddress, userAgent) return user, token, nil } -// GetUserByID 根据ID获取用户 -func GetUserByID(id int64) (*model.User, error) { - return repository.FindUserByID(id) +func (s *userServiceImpl) GetByID(id int64) (*model.User, error) { + return s.userRepo.FindByID(id) } -// UpdateUserInfo 更新用户信息 -func UpdateUserInfo(user *model.User) error { - return repository.UpdateUser(user) +func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) { + return s.userRepo.FindByEmail(email) } -// UpdateUserAvatar 更新用户头像 -func UpdateUserAvatar(userID int64, avatarURL string) error { - return repository.UpdateUserFields(userID, map[string]interface{}{ +func (s *userServiceImpl) UpdateInfo(user *model.User) error { + return s.userRepo.Update(user) +} + +func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error { + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "avatar": avatarURL, }) } -// ChangeUserPassword 修改密码 -func ChangeUserPassword(userID int64, oldPassword, newPassword string) error { - user, err := repository.FindUserByID(userID) - if err != nil { +func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error { + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { return errors.New("用户不存在") } @@ -211,15 +201,14 @@ func ChangeUserPassword(userID int64, oldPassword, newPassword string) error { return errors.New("密码加密失败") } - return repository.UpdateUserFields(userID, map[string]interface{}{ + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "password": hashedPassword, }) } -// ResetUserPassword 重置密码(通过邮箱) -func ResetUserPassword(email, newPassword string) error { - user, err := repository.FindUserByEmail(email) - if err != nil { +func (s *userServiceImpl) ResetPassword(email, newPassword string) error { + user, err := s.userRepo.FindByEmail(email) + if err != nil || user == nil { return errors.New("用户不存在") } @@ -228,14 +217,13 @@ func ResetUserPassword(email, newPassword string) error { return errors.New("密码加密失败") } - return repository.UpdateUserFields(user.ID, map[string]interface{}{ + return s.userRepo.UpdateFields(user.ID, map[string]interface{}{ "password": hashedPassword, }) } -// ChangeUserEmail 更换邮箱 -func ChangeUserEmail(userID int64, newEmail string) error { - existingUser, err := repository.FindUserByEmail(newEmail) +func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error { + existingUser, err := s.userRepo.FindByEmail(newEmail) if err != nil { return err } @@ -243,47 +231,12 @@ func ChangeUserEmail(userID int64, newEmail string) error { return errors.New("邮箱已被其他用户使用") } - return repository.UpdateUserFields(userID, map[string]interface{}{ + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "email": newEmail, }) } -// logSuccessLogin 记录成功登录 -func logSuccessLogin(userID int64, ipAddress, userAgent string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: true, - } - _ = repository.CreateLoginLog(log) -} - -// logFailedLogin 记录失败登录 -func logFailedLogin(userID int64, ipAddress, userAgent, reason string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: false, - FailureReason: reason, - } - _ = repository.CreateLoginLog(log) -} - -// getDefaultAvatar 获取默认头像URL -func getDefaultAvatar() string { - config, err := repository.GetSystemConfigByKey("default_avatar") - if err != nil || config == nil || config.Value == "" { - return "" - } - return config.Value -} - -// ValidateAvatarURL 验证头像URL是否合法 -func ValidateAvatarURL(avatarURL string) error { +func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error { if avatarURL == "" { return nil } @@ -293,13 +246,8 @@ func ValidateAvatarURL(avatarURL string) error { return nil } - return ValidateURLDomain(avatarURL) -} - -// ValidateURLDomain 验证URL的域名是否在允许列表中 -func ValidateURLDomain(rawURL string) error { // 解析URL - parsedURL, err := url.Parse(rawURL) + parsedURL, err := url.Parse(avatarURL) if err != nil { return errors.New("无效的URL格式") } @@ -309,7 +257,6 @@ func ValidateURLDomain(rawURL string) error { return errors.New("URL必须使用http或https协议") } - // 获取主机名(不包含端口) host := parsedURL.Hostname() if host == "" { return errors.New("URL缺少主机名") @@ -318,16 +265,50 @@ func ValidateURLDomain(rawURL string) error { // 从配置获取允许的域名列表 cfg, err := config.GetConfig() if err != nil { - // 如果配置获取失败,使用默认的安全域名列表 allowedDomains := []string{"localhost", "127.0.0.1"} - return checkDomainAllowed(host, allowedDomains) + return s.checkDomainAllowed(host, allowedDomains) } - return checkDomainAllowed(host, cfg.Security.AllowedDomains) + return s.checkDomainAllowed(host, cfg.Security.AllowedDomains) } -// checkDomainAllowed 检查域名是否在允许列表中 -func checkDomainAllowed(host string, allowedDomains []string) error { +func (s *userServiceImpl) GetMaxProfilesPerUser() int { + config, err := s.configRepo.GetByKey("max_profiles_per_user") + if err != nil || config == nil { + return 5 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 5 + } + return value +} + +func (s *userServiceImpl) GetMaxTexturesPerUser() int { + config, err := s.configRepo.GetByKey("max_textures_per_user") + if err != nil || config == nil { + return 50 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 50 + } + return value +} + +// 私有辅助方法 + +func (s *userServiceImpl) getDefaultAvatar() string { + config, err := s.configRepo.GetByKey("default_avatar") + if err != nil || config == nil || config.Value == "" { + return "" + } + return config.Value +} + +func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error { host = strings.ToLower(host) for _, allowed := range allowedDomains { @@ -336,14 +317,12 @@ func checkDomainAllowed(host string, allowedDomains []string) error { continue } - // 精确匹配 if host == allowed { return nil } - // 支持通配符子域名匹配 (如 *.example.com) if strings.HasPrefix(allowed, "*.") { - suffix := allowed[1:] // 移除 "*",保留 ".example.com" + suffix := allowed[1:] if strings.HasSuffix(host, suffix) { return nil } @@ -353,39 +332,37 @@ func checkDomainAllowed(host string, allowedDomains []string) error { return errors.New("URL域名不在允许的列表中") } -// GetUserByEmail 根据邮箱获取用户 -func GetUserByEmail(email string) (*model.User, error) { - user, err := repository.FindUserByEmail(email) - if err != nil { - return nil, errors.New("邮箱查找失败") +func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) { + if s.redis != nil { + identifier := usernameOrEmail + ":" + ipAddress + count, _ := RecordLoginFailure(ctx, s.redis, identifier) + if count >= MaxLoginAttempts { + s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") + return + } } - return user, nil + s.logFailedLogin(userID, ipAddress, userAgent, reason) } -// GetMaxProfilesPerUser 获取每用户最大档案数量配置 -func GetMaxProfilesPerUser() int { - config, err := repository.GetSystemConfigByKey("max_profiles_per_user") - if err != nil || config == nil { - return 5 +func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: true, } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 5 - } - return value + _ = s.userRepo.CreateLoginLog(log) } -// GetMaxTexturesPerUser 获取每用户最大材质数量配置 -func GetMaxTexturesPerUser() int { - config, err := repository.GetSystemConfigByKey("max_textures_per_user") - if err != nil || config == nil { - return 50 +func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: false, + FailureReason: reason, } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 50 - } - return value + _ = s.userRepo.CreateLoginLog(log) } diff --git a/internal/service/user_service_impl.go b/internal/service/user_service_impl.go deleted file mode 100644 index 2b7250e..0000000 --- a/internal/service/user_service_impl.go +++ /dev/null @@ -1,368 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "carrotskin/pkg/auth" - "carrotskin/pkg/config" - "carrotskin/pkg/redis" - "context" - "errors" - "fmt" - "net/url" - "strings" - "time" - - "go.uber.org/zap" -) - -// userServiceImpl UserService的实现 -type userServiceImpl struct { - userRepo repository.UserRepository - configRepo repository.SystemConfigRepository - jwtService *auth.JWTService - redis *redis.Client - logger *zap.Logger -} - -// NewUserService 创建UserService实例 -func NewUserService( - userRepo repository.UserRepository, - configRepo repository.SystemConfigRepository, - jwtService *auth.JWTService, - redisClient *redis.Client, - logger *zap.Logger, -) UserService { - return &userServiceImpl{ - userRepo: userRepo, - configRepo: configRepo, - jwtService: jwtService, - redis: redisClient, - logger: logger, - } -} - -func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) { - // 检查用户名是否已存在 - existingUser, err := s.userRepo.FindByUsername(username) - if err != nil { - return nil, "", err - } - if existingUser != nil { - return nil, "", errors.New("用户名已存在") - } - - // 检查邮箱是否已存在 - existingEmail, err := s.userRepo.FindByEmail(email) - if err != nil { - return nil, "", err - } - if existingEmail != nil { - return nil, "", errors.New("邮箱已被注册") - } - - // 加密密码 - hashedPassword, err := auth.HashPassword(password) - if err != nil { - return nil, "", errors.New("密码加密失败") - } - - // 确定头像URL - avatarURL := avatar - if avatarURL != "" { - if err := s.ValidateAvatarURL(avatarURL); err != nil { - return nil, "", err - } - } else { - avatarURL = s.getDefaultAvatar() - } - - // 创建用户 - user := &model.User{ - Username: username, - Password: hashedPassword, - Email: email, - Avatar: avatarURL, - Role: "user", - Status: 1, - Points: 0, - } - - if err := s.userRepo.Create(user); err != nil { - return nil, "", err - } - - // 生成JWT Token - token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) - if err != nil { - return nil, "", errors.New("生成Token失败") - } - - return user, token, nil -} - -func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { - ctx := context.Background() - - // 检查账号是否被锁定 - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier) - if err == nil && locked { - return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1) - } - } - - // 查找用户 - var user *model.User - var err error - - if strings.Contains(usernameOrEmail, "@") { - user, err = s.userRepo.FindByEmail(usernameOrEmail) - } else { - user, err = s.userRepo.FindByUsername(usernameOrEmail) - } - - if err != nil { - return nil, "", err - } - if user == nil { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在") - return nil, "", errors.New("用户名/邮箱或密码错误") - } - - // 检查用户状态 - if user.Status != 1 { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用") - return nil, "", errors.New("账号已被禁用") - } - - // 验证密码 - if !auth.CheckPassword(user.Password, password) { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误") - return nil, "", errors.New("用户名/邮箱或密码错误") - } - - // 登录成功,清除失败计数 - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - _ = ClearLoginAttempts(ctx, s.redis, identifier) - } - - // 生成JWT Token - token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) - if err != nil { - return nil, "", errors.New("生成Token失败") - } - - // 更新最后登录时间 - now := time.Now() - user.LastLoginAt = &now - _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ - "last_login_at": now, - }) - - // 记录成功登录日志 - s.logSuccessLogin(user.ID, ipAddress, userAgent) - - return user, token, nil -} - -func (s *userServiceImpl) GetByID(id int64) (*model.User, error) { - return s.userRepo.FindByID(id) -} - -func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) { - return s.userRepo.FindByEmail(email) -} - -func (s *userServiceImpl) UpdateInfo(user *model.User) error { - return s.userRepo.Update(user) -} - -func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error { - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "avatar": avatarURL, - }) -} - -func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error { - user, err := s.userRepo.FindByID(userID) - if err != nil || user == nil { - return errors.New("用户不存在") - } - - if !auth.CheckPassword(user.Password, oldPassword) { - return errors.New("原密码错误") - } - - hashedPassword, err := auth.HashPassword(newPassword) - if err != nil { - return errors.New("密码加密失败") - } - - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "password": hashedPassword, - }) -} - -func (s *userServiceImpl) ResetPassword(email, newPassword string) error { - user, err := s.userRepo.FindByEmail(email) - if err != nil || user == nil { - return errors.New("用户不存在") - } - - hashedPassword, err := auth.HashPassword(newPassword) - if err != nil { - return errors.New("密码加密失败") - } - - return s.userRepo.UpdateFields(user.ID, map[string]interface{}{ - "password": hashedPassword, - }) -} - -func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error { - existingUser, err := s.userRepo.FindByEmail(newEmail) - if err != nil { - return err - } - if existingUser != nil && existingUser.ID != userID { - return errors.New("邮箱已被其他用户使用") - } - - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "email": newEmail, - }) -} - -func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error { - if avatarURL == "" { - return nil - } - - // 允许相对路径 - if strings.HasPrefix(avatarURL, "/") { - return nil - } - - // 解析URL - parsedURL, err := url.Parse(avatarURL) - if err != nil { - return errors.New("无效的URL格式") - } - - // 必须是HTTP或HTTPS协议 - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return errors.New("URL必须使用http或https协议") - } - - host := parsedURL.Hostname() - if host == "" { - return errors.New("URL缺少主机名") - } - - // 从配置获取允许的域名列表 - cfg, err := config.GetConfig() - if err != nil { - allowedDomains := []string{"localhost", "127.0.0.1"} - return s.checkDomainAllowed(host, allowedDomains) - } - - return s.checkDomainAllowed(host, cfg.Security.AllowedDomains) -} - -func (s *userServiceImpl) GetMaxProfilesPerUser() int { - config, err := s.configRepo.GetByKey("max_profiles_per_user") - if err != nil || config == nil { - return 5 - } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 5 - } - return value -} - -func (s *userServiceImpl) GetMaxTexturesPerUser() int { - config, err := s.configRepo.GetByKey("max_textures_per_user") - if err != nil || config == nil { - return 50 - } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 50 - } - return value -} - -// 私有辅助方法 - -func (s *userServiceImpl) getDefaultAvatar() string { - config, err := s.configRepo.GetByKey("default_avatar") - if err != nil || config == nil || config.Value == "" { - return "" - } - return config.Value -} - -func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error { - host = strings.ToLower(host) - - for _, allowed := range allowedDomains { - allowed = strings.ToLower(strings.TrimSpace(allowed)) - if allowed == "" { - continue - } - - if host == allowed { - return nil - } - - if strings.HasPrefix(allowed, "*.") { - suffix := allowed[1:] - if strings.HasSuffix(host, suffix) { - return nil - } - } - } - - return errors.New("URL域名不在允许的列表中") -} - -func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) { - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, s.redis, identifier) - if count >= MaxLoginAttempts { - s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") - return - } - } - s.logFailedLogin(userID, ipAddress, userAgent, reason) -} - -func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: true, - } - _ = s.userRepo.CreateLoginLog(log) -} - -func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: false, - FailureReason: reason, - } - _ = s.userRepo.CreateLoginLog(log) -} diff --git a/internal/service/user_service_test.go b/internal/service/user_service_test.go index 9144fb4..e5bfc36 100644 --- a/internal/service/user_service_test.go +++ b/internal/service/user_service_test.go @@ -1,199 +1,378 @@ package service import ( - "strings" + "carrotskin/internal/model" + "carrotskin/pkg/auth" "testing" + + "go.uber.org/zap" ) -// TestGetDefaultAvatar 测试获取默认头像的逻辑 -// 注意:这个测试需要mock repository,但由于repository是函数式的, -// 我们只测试逻辑部分 -func TestGetDefaultAvatar_Logic(t *testing.T) { +func TestUserServiceImpl_Register(t *testing.T) { + // 准备依赖 + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 初始化Service + // 注意:redisClient 传入 nil,因为 Register 方法中没有使用 redis + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 测试用例 tests := []struct { - name string - configExists bool - configValue string - expectedResult string + name string + username string + password string + email string + avatar string + wantErr bool + errMsg string + setupMocks func() }{ { - name: "配置存在时返回配置值", - configExists: true, - configValue: "https://example.com/avatar.png", - expectedResult: "https://example.com/avatar.png", + name: "正常注册", + username: "testuser", + password: "password123", + email: "test@example.com", + avatar: "", + wantErr: false, }, { - name: "配置不存在时返回错误信息", - configExists: false, - configValue: "", - expectedResult: "数据库中不存在默认头像配置", + name: "用户名已存在", + username: "existinguser", + password: "password123", + email: "new@example.com", + avatar: "", + wantErr: true, + errMsg: "用户名已存在", + setupMocks: func() { + userRepo.Create(&model.User{ + Username: "existinguser", + Email: "old@example.com", + }) + }, + }, + { + name: "邮箱已存在", + username: "newuser", + password: "password123", + email: "existing@example.com", + avatar: "", + wantErr: true, + errMsg: "邮箱已被注册", + setupMocks: func() { + userRepo.Create(&model.User{ + Username: "otheruser", + Email: "existing@example.com", + }) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 这个测试只验证逻辑,不实际调用repository - // 实际的repository调用测试需要集成测试或mock - if tt.configExists { - if tt.expectedResult != tt.configValue { - t.Errorf("当配置存在时,应该返回配置值") + // 重置mock状态 + if tt.setupMocks != nil { + tt.setupMocks() + } + + user, token, err := userService.Register(tt.username, tt.password, tt.email, tt.avatar) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) } } else { - if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") { - t.Errorf("当配置不存在时,应该返回错误信息") + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if user == nil { + t.Error("返回的用户不应为nil") + } + if token == "" { + t.Error("返回的Token不应为空") + } + if user.Username != tt.username { + t.Errorf("用户名不匹配: got %v, want %v", user.Username, tt.username) } } }) } } -// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑 -func TestLoginUser_EmailDetection(t *testing.T) { +func TestUserServiceImpl_Login(t *testing.T) { + // 准备依赖 + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 预置用户 + password := "password123" + hashedPassword, _ := auth.HashPassword(password) + testUser := &model.User{ + Username: "testlogin", + Email: "login@example.com", + Password: hashedPassword, + Status: 1, + } + userRepo.Create(testUser) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + tests := []struct { name string usernameOrEmail string - isEmail bool + password string + wantErr bool + errMsg string }{ { - name: "包含@符号,识别为邮箱", - usernameOrEmail: "user@example.com", - isEmail: true, + name: "用户名登录成功", + usernameOrEmail: "testlogin", + password: "password123", + wantErr: false, }, { - name: "不包含@符号,识别为用户名", - usernameOrEmail: "username", - isEmail: false, + name: "邮箱登录成功", + usernameOrEmail: "login@example.com", + password: "password123", + wantErr: false, }, { - name: "空字符串", - usernameOrEmail: "", - isEmail: false, + name: "密码错误", + usernameOrEmail: "testlogin", + password: "wrongpassword", + wantErr: true, + errMsg: "用户名/邮箱或密码错误", }, { - name: "只有@符号", - usernameOrEmail: "@", - isEmail: true, + name: "用户不存在", + usernameOrEmail: "nonexistent", + password: "password123", + wantErr: true, + errMsg: "用户名/邮箱或密码错误", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isEmail := strings.Contains(tt.usernameOrEmail, "@") - if isEmail != tt.isEmail { - t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail) + user, token, err := userService.Login(tt.usernameOrEmail, tt.password, "127.0.0.1", "test-agent") + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } else if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + } + if user == nil { + t.Error("用户不应为nil") + } + if token == "" { + t.Error("Token不应为空") + } } }) } } -// TestUserService_Constants 测试用户服务相关常量 -func TestUserService_Constants(t *testing.T) { - // 测试默认用户角色 - defaultRole := "user" - if defaultRole == "" { - t.Error("默认用户角色不能为空") +// TestUserServiceImpl_BasicGetters 测试 GetByID / GetByEmail / UpdateInfo / UpdateAvatar +func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 预置用户 + user := &model.User{ + ID: 1, + Username: "basic", + Email: "basic@example.com", + Avatar: "", + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // GetByID + gotByID, err := userService.GetByID(1) + if err != nil || gotByID == nil || gotByID.ID != 1 { + t.Fatalf("GetByID 返回不正确: user=%+v, err=%v", gotByID, err) } - // 测试默认用户状态 - defaultStatus := int16(1) - if defaultStatus != 1 { - t.Errorf("默认用户状态应为1(正常),实际为%d", defaultStatus) + // GetByEmail + gotByEmail, err := userService.GetByEmail("basic@example.com") + if err != nil || gotByEmail == nil || gotByEmail.Email != "basic@example.com" { + t.Fatalf("GetByEmail 返回不正确: user=%+v, err=%v", gotByEmail, err) } - // 测试初始积分 - initialPoints := 0 - if initialPoints < 0 { - t.Errorf("初始积分不应为负数,实际为%d", initialPoints) + // UpdateInfo + user.Username = "updated" + if err := userService.UpdateInfo(user); err != nil { + t.Fatalf("UpdateInfo 失败: %v", err) + } + updated, _ := userRepo.FindByID(1) + if updated.Username != "updated" { + t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username) + } + + // UpdateAvatar 只需确认不会返回错误(具体字段更新由仓库层保证) + if err := userService.UpdateAvatar(1, "http://example.com/avatar.png"); err != nil { + t.Fatalf("UpdateAvatar 失败: %v", err) } } -// TestUserService_Validation 测试用户数据验证逻辑 -func TestUserService_Validation(t *testing.T) { +// TestUserServiceImpl_ChangePassword 测试 ChangePassword +func TestUserServiceImpl_ChangePassword(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + hashed, _ := auth.HashPassword("oldpass") + user := &model.User{ + ID: 1, + Username: "changepw", + Password: hashed, + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 原密码正确 + if err := userService.ChangePassword(1, "oldpass", "newpass"); err != nil { + t.Fatalf("ChangePassword 正常情况失败: %v", err) + } + + // 用户不存在 + if err := userService.ChangePassword(999, "oldpass", "newpass"); err == nil { + t.Fatalf("ChangePassword 应在用户不存在时返回错误") + } + + // 原密码错误 + if err := userService.ChangePassword(1, "wrong", "another"); err == nil { + t.Fatalf("ChangePassword 应在原密码错误时返回错误") + } +} + +// TestUserServiceImpl_ResetPassword 测试 ResetPassword +func TestUserServiceImpl_ResetPassword(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + user := &model.User{ + ID: 1, + Username: "resetpw", + Email: "reset@example.com", + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 正常重置 + if err := userService.ResetPassword("reset@example.com", "newpass"); err != nil { + t.Fatalf("ResetPassword 正常情况失败: %v", err) + } + + // 用户不存在 + if err := userService.ResetPassword("notfound@example.com", "newpass"); err == nil { + t.Fatalf("ResetPassword 应在用户不存在时返回错误") + } +} + +// TestUserServiceImpl_ChangeEmail 测试 ChangeEmail +func TestUserServiceImpl_ChangeEmail(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + user1 := &model.User{ID: 1, Email: "user1@example.com"} + user2 := &model.User{ID: 2, Email: "user2@example.com"} + userRepo.Create(user1) + userRepo.Create(user2) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 正常修改 + if err := userService.ChangeEmail(1, "new@example.com"); err != nil { + t.Fatalf("ChangeEmail 正常情况失败: %v", err) + } + + // 邮箱被其他用户占用 + if err := userService.ChangeEmail(1, "user2@example.com"); err == nil { + t.Fatalf("ChangeEmail 应在邮箱被占用时返回错误") + } +} + +// TestUserServiceImpl_ValidateAvatarURL 测试 ValidateAvatarURL +func TestUserServiceImpl_ValidateAvatarURL(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + tests := []struct { - name string - username string - email string - password string - wantValid bool + name string + url string + wantErr bool }{ - { - name: "有效的用户名和邮箱", - username: "testuser", - email: "test@example.com", - password: "password123", - wantValid: true, - }, - { - name: "用户名为空", - username: "", - email: "test@example.com", - password: "password123", - wantValid: false, - }, - { - name: "邮箱为空", - username: "testuser", - email: "", - password: "password123", - wantValid: false, - }, - { - name: "密码为空", - username: "testuser", - email: "test@example.com", - password: "", - wantValid: false, - }, - { - name: "邮箱格式无效(缺少@)", - username: "testuser", - email: "invalid-email", - password: "password123", - wantValid: false, - }, + {"空字符串通过", "", false}, + {"相对路径通过", "/images/avatar.png", false}, + {"非法URL格式", "://bad-url", true}, + {"非法协议", "ftp://example.com/avatar.png", true}, + {"缺少主机名", "http:///avatar.png", true}, + {"本地域名通过", "http://localhost/avatar.png", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 简单的验证逻辑测试 - isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@") - if isValid != tt.wantValid { - t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid) + err := userService.ValidateAvatarURL(tt.url) + if (err != nil) != tt.wantErr { + t.Fatalf("ValidateAvatarURL(%q) error = %v, wantErr=%v", tt.url, err, tt.wantErr) } }) } } -// TestUserService_AvatarLogic 测试头像逻辑 -func TestUserService_AvatarLogic(t *testing.T) { - tests := []struct { - name string - providedAvatar string - defaultAvatar string - expectedAvatar string - }{ - { - name: "提供头像时使用提供的头像", - providedAvatar: "https://example.com/custom.png", - defaultAvatar: "https://example.com/default.png", - expectedAvatar: "https://example.com/custom.png", - }, - { - name: "未提供头像时使用默认头像", - providedAvatar: "", - defaultAvatar: "https://example.com/default.png", - expectedAvatar: "https://example.com/default.png", - }, +// TestUserServiceImpl_MaxLimits 测试 GetMaxProfilesPerUser / GetMaxTexturesPerUser +func TestUserServiceImpl_MaxLimits(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 未配置时走默认值 + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + if got := userService.GetMaxProfilesPerUser(); got != 5 { + t.Fatalf("GetMaxProfilesPerUser 默认值错误, got=%d", got) + } + if got := userService.GetMaxTexturesPerUser(); got != 50 { + t.Fatalf("GetMaxTexturesPerUser 默认值错误, got=%d", got) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - avatarURL := tt.providedAvatar - if avatarURL == "" { - avatarURL = tt.defaultAvatar - } - if avatarURL != tt.expectedAvatar { - t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar) - } - }) + // 配置有效值 + configRepo.Update(&model.SystemConfig{Key: "max_profiles_per_user", Value: "10"}) + configRepo.Update(&model.SystemConfig{Key: "max_textures_per_user", Value: "100"}) + + if got := userService.GetMaxProfilesPerUser(); got != 10 { + t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got) } -} + if got := userService.GetMaxTexturesPerUser(); got != 100 { + t.Fatalf("GetMaxTexturesPerUser 配置值错误, got=%d", got) + } +} \ No newline at end of file