From f7589ebbb8a3fe8e144aabd7b870d04082e62e57 Mon Sep 17 00:00:00 2001 From: lan Date: Tue, 2 Dec 2025 17:40:39 +0800 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20=E5=BC=95=E5=85=A5=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=B3=A8=E5=85=A5=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等) - 创建Repository接口实现 - 创建依赖注入容器(container.Container) - 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler) - 创建新的路由注册方式(RegisterRoutesWithDI) - 提供main.go示例文件展示如何使用依赖注入 同时包含之前的安全修复: - CORS配置安全加固 - 头像URL验证安全修复 - JWT algorithm confusion漏洞修复 - Recovery中间件增强 - 敏感错误信息泄露修复 - 类型断言安全修复 --- .dockerignore | 2 + .gitea/workflows/docker.yml | 84 ------ Dockerfile | 2 + cmd/server/main_di_example.go.example | 146 +++++++++ internal/container/container.go | 138 +++++++++ internal/handler/auth_handler_di.go | 177 +++++++++++ internal/handler/helpers.go | 27 +- internal/handler/routes_di.go | 191 ++++++++++++ internal/handler/texture_handler.go | 12 +- internal/handler/texture_handler_di.go | 284 ++++++++++++++++++ internal/handler/user_handler_di.go | 233 ++++++++++++++ internal/middleware/cors.go | 36 ++- internal/middleware/cors_test.go | 36 ++- internal/middleware/recovery.go | 27 +- internal/model/response.go | 41 ++- internal/repository/interfaces.go | 85 ++++++ .../repository/profile_repository_impl.go | 149 +++++++++ .../system_config_repository_impl.go | 44 +++ .../repository/texture_repository_impl.go | 175 +++++++++++ internal/repository/token_repository_impl.go | 71 +++++ internal/repository/user_repository_impl.go | 103 +++++++ internal/service/user_service.go | 75 ++++- pkg/auth/jwt.go | 4 + pkg/config/config.go | 23 +- pkg/config/manager.go | 3 - 25 files changed, 2029 insertions(+), 139 deletions(-) delete mode 100644 .gitea/workflows/docker.yml create mode 100644 cmd/server/main_di_example.go.example create mode 100644 internal/container/container.go create mode 100644 internal/handler/auth_handler_di.go create mode 100644 internal/handler/routes_di.go create mode 100644 internal/handler/texture_handler_di.go create mode 100644 internal/handler/user_handler_di.go create mode 100644 internal/repository/interfaces.go create mode 100644 internal/repository/profile_repository_impl.go create mode 100644 internal/repository/system_config_repository_impl.go create mode 100644 internal/repository/texture_repository_impl.go create mode 100644 internal/repository/token_repository_impl.go create mode 100644 internal/repository/user_repository_impl.go diff --git a/.dockerignore b/.dockerignore index e375c38..6686339 100644 --- a/.dockerignore +++ b/.dockerignore @@ -74,3 +74,5 @@ local/ dev/ minio-data/ + + diff --git a/.gitea/workflows/docker.yml b/.gitea/workflows/docker.yml deleted file mode 100644 index 4595bfb..0000000 --- a/.gitea/workflows/docker.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: Build and Push Docker Image - -on: - push: - branches: - - main - - master - - dev - tags: - - 'v*' - workflow_dispatch: - -env: - REGISTRY: code.littlelan.cn - IMAGE_NAME: carrotskin/backend - -jobs: - build-and-push: - runs-on: ubuntu-latest - container: - image: quay.io/buildah/stable:latest - options: --privileged - - steps: - - name: Install dependencies - run: | - dnf install -y git nodejs - - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Login to registry - run: | - buildah login \ - -u "${{ secrets.REGISTRY_USERNAME }}" \ - -p "${{ secrets.REGISTRY_PASSWORD }}" \ - ${{ env.REGISTRY }} - echo "Registry 登录成功" - - - name: Build image - run: | - buildah bud \ - --format docker \ - --layers \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - -f Dockerfile \ - . - echo "镜像构建完成" - - - name: Tag and push image - run: | - SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7) - REF_NAME="${{ github.ref_name }}" - REF="${{ github.ref }}" - - # 推送分支/标签名 - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME} - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME} - echo "✓ 推送: ${REF_NAME}" - - # 推送 SHA 标签 - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA} - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA} - echo "✓ 推送: sha-${SHORT_SHA}" - - # main/master 推送 latest - if [ "$REF" = "refs/heads/main" ] || [ "$REF" = "refs/heads/master" ]; then - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest - echo "✓ 推送: latest" - fi - - - name: Build summary - run: | - echo "==============================" - echo "✅ 镜像构建完成!" - echo "仓库: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}" - echo "分支: ${{ github.ref_name }}" - echo "==============================" diff --git a/Dockerfile b/Dockerfile index 118c7a4..b5a00ab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,3 +59,5 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ # 启动应用 ENTRYPOINT ["./server"] + + diff --git a/cmd/server/main_di_example.go.example b/cmd/server/main_di_example.go.example new file mode 100644 index 0000000..d9168ef --- /dev/null +++ b/cmd/server/main_di_example.go.example @@ -0,0 +1,146 @@ +// +build ignore +// 此文件是依赖注入版本的main.go示例 +// 可以参考此文件改造原有的main.go + +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + _ "carrotskin/docs" // Swagger文档 + "carrotskin/internal/container" + "carrotskin/internal/handler" + "carrotskin/internal/middleware" + "carrotskin/pkg/auth" + "carrotskin/pkg/config" + "carrotskin/pkg/database" + "carrotskin/pkg/email" + "carrotskin/pkg/logger" + "carrotskin/pkg/redis" + "carrotskin/pkg/storage" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func main() { + // 初始化配置 + if err := config.Init(); err != nil { + log.Fatalf("配置加载失败: %v", err) + } + cfg := config.MustGetConfig() + + // 初始化日志 + if err := logger.Init(cfg.Log); err != nil { + log.Fatalf("日志初始化失败: %v", err) + } + loggerInstance := logger.MustGetLogger() + defer loggerInstance.Sync() + + // 初始化数据库 + if err := database.Init(cfg.Database, loggerInstance); err != nil { + loggerInstance.Fatal("数据库初始化失败", zap.Error(err)) + } + defer database.Close() + + // 执行数据库迁移 + if err := database.AutoMigrate(loggerInstance); err != nil { + loggerInstance.Fatal("数据库迁移失败", zap.Error(err)) + } + + // 初始化种子数据 + if err := database.Seed(loggerInstance); err != nil { + loggerInstance.Fatal("种子数据初始化失败", zap.Error(err)) + } + + // 初始化JWT服务 + if err := auth.Init(cfg.JWT); err != nil { + loggerInstance.Fatal("JWT服务初始化失败", zap.Error(err)) + } + + // 初始化Redis + if err := redis.Init(cfg.Redis, loggerInstance); err != nil { + loggerInstance.Fatal("Redis连接失败", zap.Error(err)) + } + defer redis.MustGetClient().Close() + + // 初始化对象存储 + var storageClient *storage.StorageClient + if err := storage.Init(cfg.RustFS); err != nil { + loggerInstance.Warn("对象存储连接失败,某些功能可能不可用", zap.Error(err)) + } else { + storageClient = storage.MustGetClient() + loggerInstance.Info("对象存储连接成功") + } + + // 初始化邮件服务 + if err := email.Init(cfg.Email, loggerInstance); err != nil { + loggerInstance.Fatal("邮件服务初始化失败", zap.Error(err)) + } + + // ============ 依赖注入改动部分 ============ + // 创建依赖注入容器 + c := container.NewContainer( + database.MustGetDB(), + redis.MustGetClient(), + loggerInstance, + auth.MustGetJWTService(), + storageClient, + ) + + // 设置Gin模式 + if cfg.Server.Mode == "production" { + gin.SetMode(gin.ReleaseMode) + } + + // 创建路由 + router := gin.New() + + // 添加中间件 + router.Use(middleware.Logger(loggerInstance)) + router.Use(middleware.Recovery(loggerInstance)) + router.Use(middleware.CORS()) + + // 使用依赖注入方式注册路由 + handler.RegisterRoutesWithDI(router, c) + // ============ 依赖注入改动结束 ============ + + // 创建HTTP服务器 + srv := &http.Server{ + Addr: cfg.Server.Port, + Handler: router, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + } + + // 启动服务器 + go func() { + loggerInstance.Info("服务器启动", zap.String("port", cfg.Server.Port)) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + loggerInstance.Fatal("服务器启动失败", zap.Error(err)) + } + }() + + // 等待中断信号优雅关闭 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + loggerInstance.Info("正在关闭服务器...") + + // 设置关闭超时 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + loggerInstance.Fatal("服务器强制关闭", zap.Error(err)) + } + + loggerInstance.Info("服务器已关闭") +} + diff --git a/internal/container/container.go b/internal/container/container.go new file mode 100644 index 0000000..230e68f --- /dev/null +++ b/internal/container/container.go @@ -0,0 +1,138 @@ +package container + +import ( + "carrotskin/internal/repository" + "carrotskin/pkg/auth" + "carrotskin/pkg/redis" + "carrotskin/pkg/storage" + + "go.uber.org/zap" + "gorm.io/gorm" +) + +// Container 依赖注入容器 +// 集中管理所有依赖,便于测试和维护 +type Container struct { + // 基础设施依赖 + DB *gorm.DB + Redis *redis.Client + Logger *zap.Logger + JWT *auth.JWTService + Storage *storage.StorageClient + + // Repository层 + UserRepo repository.UserRepository + ProfileRepo repository.ProfileRepository + TextureRepo repository.TextureRepository + TokenRepo repository.TokenRepository + ConfigRepo repository.SystemConfigRepository +} + +// NewContainer 创建依赖容器 +func NewContainer( + db *gorm.DB, + redisClient *redis.Client, + logger *zap.Logger, + jwtService *auth.JWTService, + storageClient *storage.StorageClient, +) *Container { + c := &Container{ + DB: db, + Redis: redisClient, + Logger: logger, + JWT: jwtService, + Storage: storageClient, + } + + // 初始化Repository + c.UserRepo = repository.NewUserRepository(db) + c.ProfileRepo = repository.NewProfileRepository(db) + c.TextureRepo = repository.NewTextureRepository(db) + c.TokenRepo = repository.NewTokenRepository(db) + c.ConfigRepo = repository.NewSystemConfigRepository(db) + + return c +} + +// NewTestContainer 创建测试用容器(可注入mock依赖) +func NewTestContainer(opts ...Option) *Container { + c := &Container{} + for _, opt := range opts { + opt(c) + } + return c +} + +// Option 容器配置选项 +type Option func(*Container) + +// WithDB 设置数据库连接 +func WithDB(db *gorm.DB) Option { + return func(c *Container) { + c.DB = db + } +} + +// WithRedis 设置Redis客户端 +func WithRedis(redis *redis.Client) Option { + return func(c *Container) { + c.Redis = redis + } +} + +// WithLogger 设置日志 +func WithLogger(logger *zap.Logger) Option { + return func(c *Container) { + c.Logger = logger + } +} + +// WithJWT 设置JWT服务 +func WithJWT(jwt *auth.JWTService) Option { + return func(c *Container) { + c.JWT = jwt + } +} + +// WithStorage 设置存储客户端 +func WithStorage(storage *storage.StorageClient) Option { + return func(c *Container) { + c.Storage = storage + } +} + +// WithUserRepo 设置用户仓储 +func WithUserRepo(repo repository.UserRepository) Option { + return func(c *Container) { + c.UserRepo = repo + } +} + +// WithProfileRepo 设置档案仓储 +func WithProfileRepo(repo repository.ProfileRepository) Option { + return func(c *Container) { + c.ProfileRepo = repo + } +} + +// WithTextureRepo 设置材质仓储 +func WithTextureRepo(repo repository.TextureRepository) Option { + return func(c *Container) { + c.TextureRepo = repo + } +} + +// WithTokenRepo 设置令牌仓储 +func WithTokenRepo(repo repository.TokenRepository) Option { + return func(c *Container) { + c.TokenRepo = repo + } +} + +// WithConfigRepo 设置系统配置仓储 +func WithConfigRepo(repo repository.SystemConfigRepository) Option { + return func(c *Container) { + c.ConfigRepo = repo + } +} + diff --git a/internal/handler/auth_handler_di.go b/internal/handler/auth_handler_di.go new file mode 100644 index 0000000..9087008 --- /dev/null +++ b/internal/handler/auth_handler_di.go @@ -0,0 +1,177 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "carrotskin/internal/types" + "carrotskin/pkg/email" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AuthHandler 认证处理器(依赖注入版本) +type AuthHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewAuthHandler 创建AuthHandler实例 +func NewAuthHandler(c *container.Container) *AuthHandler { + return &AuthHandler{ + container: c, + logger: c.Logger, + } +} + +// Register 用户注册 +// @Summary 用户注册 +// @Description 注册新用户账号 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.RegisterRequest true "注册信息" +// @Success 200 {object} model.Response "注册成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/register [post] +func (h *AuthHandler) Register(c *gin.Context) { + var req types.RegisterRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + // 验证邮箱验证码 + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + // 注册用户 + user, token, err := service.RegisterUser(h.container.JWT, req.Username, req.Password, req.Email, req.Avatar) + if err != nil { + h.logger.Error("用户注册失败", zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.LoginResponse{ + Token: token, + UserInfo: UserToUserInfo(user), + }) +} + +// Login 用户登录 +// @Summary 用户登录 +// @Description 用户登录获取JWT Token,支持用户名或邮箱登录 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)" +// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Failure 401 {object} model.ErrorResponse "登录失败" +// @Router /api/v1/auth/login [post] +func (h *AuthHandler) Login(c *gin.Context) { + var req types.LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + ipAddress := c.ClientIP() + userAgent := c.GetHeader("User-Agent") + + user, token, err := service.LoginUserWithRateLimit(h.container.Redis, h.container.JWT, req.Username, req.Password, ipAddress, userAgent) + if err != nil { + h.logger.Warn("用户登录失败", + zap.String("username_or_email", req.Username), + zap.String("ip", ipAddress), + zap.Error(err), + ) + RespondUnauthorized(c, err.Error()) + return + } + + RespondSuccess(c, &types.LoginResponse{ + Token: token, + UserInfo: UserToUserInfo(user), + }) +} + +// SendVerificationCode 发送验证码 +// @Summary 发送验证码 +// @Description 发送邮箱验证码(注册/重置密码/更换邮箱) +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.SendVerificationCodeRequest true "发送验证码请求" +// @Success 200 {object} model.Response "发送成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/send-code [post] +func (h *AuthHandler) SendVerificationCode(c *gin.Context) { + var req types.SendVerificationCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + emailService, err := h.getEmailService() + if err != nil { + RespondServerError(c, "邮件服务不可用", err) + return + } + + if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { + h.logger.Error("发送验证码失败", + zap.String("email", req.Email), + zap.String("type", req.Type), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, gin.H{"message": "验证码已发送,请查收邮件"}) +} + +// ResetPassword 重置密码 +// @Summary 重置密码 +// @Description 通过邮箱验证码重置密码 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.ResetPasswordRequest true "重置密码请求" +// @Success 200 {object} model.Response "重置成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/reset-password [post] +func (h *AuthHandler) ResetPassword(c *gin.Context) { + var req types.ResetPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + // 验证验证码 + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + // 重置密码 + if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { + h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) + RespondServerError(c, err.Error(), nil) + return + } + + RespondSuccess(c, gin.H{"message": "密码重置成功"}) +} + +// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) +func (h *AuthHandler) getEmailService() (*email.Service, error) { + return email.GetService() +} + diff --git a/internal/handler/helpers.go b/internal/handler/helpers.go index 3e4489a..390b162 100644 --- a/internal/handler/helpers.go +++ b/internal/handler/helpers.go @@ -4,14 +4,24 @@ import ( "carrotskin/internal/model" "carrotskin/internal/types" "net/http" + "strconv" "github.com/gin-gonic/gin" ) +// parseIntWithDefault 将字符串解析为整数,解析失败返回默认值 +func parseIntWithDefault(s string, defaultVal int) int { + val, err := strconv.Atoi(s) + if err != nil { + return defaultVal + } + return val +} + // GetUserIDFromContext 从上下文获取用户ID,如果不存在返回未授权响应 // 返回值: userID, ok (如果ok为false,已经发送了错误响应) func GetUserIDFromContext(c *gin.Context) (int64, bool) { - userID, exists := c.Get("user_id") + userIDValue, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, model.NewErrorResponse( model.CodeUnauthorized, @@ -20,7 +30,19 @@ func GetUserIDFromContext(c *gin.Context) (int64, bool) { )) return 0, false } - return userID.(int64), true + + // 安全的类型断言 + userID, ok := userIDValue.(int64) + if !ok { + c.JSON(http.StatusInternalServerError, model.NewErrorResponse( + model.CodeServerError, + "用户ID类型错误", + nil, + )) + return 0, false + } + + return userID, true } // UserToUserInfo 将 User 模型转换为 UserInfo 响应 @@ -157,4 +179,3 @@ func RespondWithError(c *gin.Context, err error) { RespondServerError(c, msg, nil) } } - diff --git a/internal/handler/routes_di.go b/internal/handler/routes_di.go new file mode 100644 index 0000000..d022cf6 --- /dev/null +++ b/internal/handler/routes_di.go @@ -0,0 +1,191 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/middleware" + "carrotskin/internal/model" + + "github.com/gin-gonic/gin" +) + +// Handlers 集中管理所有Handler +type Handlers struct { + Auth *AuthHandler + User *UserHandler + Texture *TextureHandler + // Profile *ProfileHandler // 后续添加 + // Captcha *CaptchaHandler // 后续添加 + // Yggdrasil *YggdrasilHandler // 后续添加 +} + +// NewHandlers 创建所有Handler实例 +func NewHandlers(c *container.Container) *Handlers { + return &Handlers{ + Auth: NewAuthHandler(c), + User: NewUserHandler(c), + Texture: NewTextureHandler(c), + } +} + +// RegisterRoutesWithDI 使用依赖注入注册所有路由 +func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { + // 设置Swagger文档 + SetupSwagger(router) + + // 创建Handler实例 + h := NewHandlers(c) + + // API路由组 + v1 := router.Group("/api/v1") + { + // 认证路由(无需JWT) + registerAuthRoutes(v1, h.Auth) + + // 用户路由(需要JWT认证) + registerUserRoutes(v1, h.User) + + // 材质路由 + registerTextureRoutes(v1, h.Texture) + + // 档案路由(暂时保持原有方式) + registerProfileRoutes(v1) + + // 验证码路由(暂时保持原有方式) + registerCaptchaRoutes(v1) + + // Yggdrasil API路由组(暂时保持原有方式) + registerYggdrasilRoutes(v1) + + // 系统路由 + registerSystemRoutes(v1) + } +} + +// registerAuthRoutes 注册认证路由 +func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { + authGroup := v1.Group("/auth") + { + authGroup.POST("/register", h.Register) + authGroup.POST("/login", h.Login) + authGroup.POST("/send-code", h.SendVerificationCode) + authGroup.POST("/reset-password", h.ResetPassword) + } +} + +// registerUserRoutes 注册用户路由 +func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { + userGroup := v1.Group("/user") + userGroup.Use(middleware.AuthMiddleware()) + { + userGroup.GET("/profile", h.GetProfile) + userGroup.PUT("/profile", h.UpdateProfile) + + // 头像相关 + userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) + userGroup.PUT("/avatar", h.UpdateAvatar) + + // 更换邮箱 + userGroup.POST("/change-email", h.ChangeEmail) + + // Yggdrasil密码相关 + userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) + } +} + +// registerTextureRoutes 注册材质路由 +func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { + textureGroup := v1.Group("/texture") + { + // 公开路由(无需认证) + textureGroup.GET("", h.Search) + textureGroup.GET("/:id", h.Get) + + // 需要认证的路由 + textureAuth := textureGroup.Group("") + textureAuth.Use(middleware.AuthMiddleware()) + { + textureAuth.POST("/upload-url", h.GenerateUploadURL) + textureAuth.POST("", h.Create) + textureAuth.PUT("/:id", h.Update) + textureAuth.DELETE("/:id", h.Delete) + textureAuth.POST("/:id/favorite", h.ToggleFavorite) + textureAuth.GET("/my", h.GetUserTextures) + textureAuth.GET("/favorites", h.GetUserFavorites) + } + } +} + +// registerProfileRoutes 注册档案路由(保持原有方式,后续改造) +func registerProfileRoutes(v1 *gin.RouterGroup) { + profileGroup := v1.Group("/profile") + { + // 公开路由(无需认证) + profileGroup.GET("/:uuid", GetProfile) + + // 需要认证的路由 + profileAuth := profileGroup.Group("") + profileAuth.Use(middleware.AuthMiddleware()) + { + profileAuth.POST("/", CreateProfile) + profileAuth.GET("/", GetProfiles) + profileAuth.PUT("/:uuid", UpdateProfile) + profileAuth.DELETE("/:uuid", DeleteProfile) + profileAuth.POST("/:uuid/activate", SetActiveProfile) + } + } +} + +// registerCaptchaRoutes 注册验证码路由(保持原有方式) +func registerCaptchaRoutes(v1 *gin.RouterGroup) { + captchaGroup := v1.Group("/captcha") + { + captchaGroup.GET("/generate", Generate) + captchaGroup.POST("/verify", Verify) + } +} + +// registerYggdrasilRoutes 注册Yggdrasil API路由(保持原有方式) +func registerYggdrasilRoutes(v1 *gin.RouterGroup) { + ygg := v1.Group("/yggdrasil") + { + ygg.GET("", GetMetaData) + ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates) + authserver := ygg.Group("/authserver") + { + authserver.POST("/authenticate", Authenticate) + authserver.POST("/validate", ValidToken) + authserver.POST("/refresh", RefreshToken) + authserver.POST("/invalidate", InvalidToken) + authserver.POST("/signout", SignOut) + } + sessionServer := ygg.Group("/sessionserver") + { + sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID) + sessionServer.POST("/session/minecraft/join", JoinServer) + sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer) + } + api := ygg.Group("/api") + profiles := api.Group("/profiles") + { + profiles.POST("/minecraft", GetProfilesByName) + } + } +} + +// registerSystemRoutes 注册系统路由 +func registerSystemRoutes(v1 *gin.RouterGroup) { + system := v1.Group("/system") + { + system.GET("/config", func(c *gin.Context) { + // TODO: 实现从数据库读取系统配置 + c.JSON(200, model.NewSuccessResponse(gin.H{ + "site_name": "CarrotSkin", + "site_description": "A Minecraft Skin Station", + "registration_enabled": true, + "max_textures_per_user": 100, + "max_profiles_per_user": 5, + })) + }) + } +} + diff --git a/internal/handler/texture_handler.go b/internal/handler/texture_handler.go index c7a5184..a139f38 100644 --- a/internal/handler/texture_handler.go +++ b/internal/handler/texture_handler.go @@ -160,8 +160,8 @@ func SearchTextures(c *gin.Context) { textureTypeStr := c.Query("type") publicOnly := c.Query("public_only") == "true" - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) var textureType model.TextureType switch textureTypeStr { @@ -314,8 +314,8 @@ func GetUserTextures(c *gin.Context) { return } - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) textures, total, err := service.GetUserTextures(database.MustGetDB(), userID, page, pageSize) if err != nil { @@ -344,8 +344,8 @@ func GetUserFavorites(c *gin.Context) { return } - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID, page, pageSize) if err != nil { diff --git a/internal/handler/texture_handler_di.go b/internal/handler/texture_handler_di.go new file mode 100644 index 0000000..8233184 --- /dev/null +++ b/internal/handler/texture_handler_di.go @@ -0,0 +1,284 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/model" + "carrotskin/internal/service" + "carrotskin/internal/types" + "strconv" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// TextureHandler 材质处理器(依赖注入版本) +type TextureHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewTextureHandler 创建TextureHandler实例 +func NewTextureHandler(c *container.Container) *TextureHandler { + return &TextureHandler{ + container: c, + logger: c.Logger, + } +} + +// GenerateUploadURL 生成材质上传URL +func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.GenerateTextureUploadURLRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateTextureUploadURL( + c.Request.Context(), + h.container.Storage, + userID, + req.FileName, + string(req.TextureType), + ) + if err != nil { + h.logger.Error("生成材质上传URL失败", + zap.Int64("user_id", userID), + zap.String("file_name", req.FileName), + zap.String("texture_type", string(req.TextureType)), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.GenerateTextureUploadURLResponse{ + PostURL: result.PostURL, + FormData: result.FormData, + TextureURL: result.FileURL, + ExpiresIn: 900, + }) +} + +// Create 创建材质记录 +func (h *TextureHandler) Create(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.CreateTextureRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + maxTextures := service.GetMaxTexturesPerUser() + if err := service.CheckTextureUploadLimit(h.container.DB, userID, maxTextures); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + + texture, err := service.CreateTexture(h.container.DB, + userID, + req.Name, + req.Description, + string(req.Type), + req.URL, + req.Hash, + req.Size, + req.IsPublic, + req.IsSlim, + ) + if err != nil { + h.logger.Error("创建材质失败", + zap.Int64("user_id", userID), + zap.String("name", req.Name), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Get 获取材质详情 +func (h *TextureHandler) Get(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + texture, err := service.GetTextureByID(h.container.DB, id) + if err != nil { + RespondNotFound(c, err.Error()) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Search 搜索材质 +func (h *TextureHandler) Search(c *gin.Context) { + keyword := c.Query("keyword") + textureTypeStr := c.Query("type") + publicOnly := c.Query("public_only") == "true" + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + var textureType model.TextureType + switch textureTypeStr { + case "SKIN": + textureType = model.TextureTypeSkin + case "CAPE": + textureType = model.TextureTypeCape + } + + textures, total, err := service.SearchTextures(h.container.DB, keyword, textureType, publicOnly, page, pageSize) + if err != nil { + h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) + RespondServerError(c, "搜索材质失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + +// Update 更新材质 +func (h *TextureHandler) Update(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + var req types.UpdateTextureRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + texture, err := service.UpdateTexture(h.container.DB, textureID, userID, req.Name, req.Description, req.IsPublic) + if err != nil { + h.logger.Error("更新材质失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondForbidden(c, err.Error()) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Delete 删除材质 +func (h *TextureHandler) Delete(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + if err := service.DeleteTexture(h.container.DB, textureID, userID); err != nil { + h.logger.Error("删除材质失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondForbidden(c, err.Error()) + return + } + + RespondSuccess(c, nil) +} + +// ToggleFavorite 切换收藏状态 +func (h *TextureHandler) ToggleFavorite(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + isFavorited, err := service.ToggleTextureFavorite(h.container.DB, userID, textureID) + if err != nil { + h.logger.Error("切换收藏状态失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, map[string]bool{"is_favorited": isFavorited}) +} + +// GetUserTextures 获取用户上传的材质列表 +func (h *TextureHandler) GetUserTextures(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + textures, total, err := service.GetUserTextures(h.container.DB, userID, page, pageSize) + if err != nil { + h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondServerError(c, "获取材质列表失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + +// GetUserFavorites 获取用户收藏的材质列表 +func (h *TextureHandler) GetUserFavorites(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + textures, total, err := service.GetUserTextureFavorites(h.container.DB, userID, page, pageSize) + if err != nil { + h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondServerError(c, "获取收藏列表失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + diff --git a/internal/handler/user_handler_di.go b/internal/handler/user_handler_di.go new file mode 100644 index 0000000..91e8a5a --- /dev/null +++ b/internal/handler/user_handler_di.go @@ -0,0 +1,233 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "carrotskin/internal/types" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserHandler 用户处理器(依赖注入版本) +type UserHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewUserHandler 创建UserHandler实例 +func NewUserHandler(c *container.Container) *UserHandler { + return &UserHandler{ + container: c, + logger: c.Logger, + } +} + +// GetProfile 获取用户信息 +func (h *UserHandler) GetProfile(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + h.logger.Error("获取用户信息失败", + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// UpdateProfile 更新用户信息 +func (h *UserHandler) UpdateProfile(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.UpdateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + // 处理密码修改 + if req.NewPassword != "" { + if req.OldPassword == "" { + RespondBadRequest(c, "修改密码需要提供原密码", nil) + return + } + + if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { + h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) + } + + // 更新头像 + if req.Avatar != "" { + if err := service.ValidateAvatarURL(req.Avatar); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + user.Avatar = req.Avatar + if err := service.UpdateUserInfo(user); err != nil { + h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) + RespondServerError(c, "更新失败", err) + return + } + } + + // 重新获取更新后的用户信息 + updatedUser, err := service.GetUserByID(userID) + if err != nil || updatedUser == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(updatedUser)) +} + +// GenerateAvatarUploadURL 生成头像上传URL +func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.GenerateAvatarUploadURLRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) + if err != nil { + h.logger.Error("生成头像上传URL失败", + zap.Int64("user_id", userID), + zap.String("file_name", req.FileName), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.GenerateAvatarUploadURLResponse{ + PostURL: result.PostURL, + FormData: result.FormData, + AvatarURL: result.FileURL, + ExpiresIn: 900, + }) +} + +// UpdateAvatar 更新头像URL +func (h *UserHandler) UpdateAvatar(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + avatarURL := c.Query("avatar_url") + if avatarURL == "" { + RespondBadRequest(c, "头像URL不能为空", nil) + return + } + + if err := service.ValidateAvatarURL(avatarURL); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + + if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { + h.logger.Error("更新头像失败", + zap.Int64("user_id", userID), + zap.String("avatar_url", avatarURL), + zap.Error(err), + ) + RespondServerError(c, "更新头像失败", err) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// ChangeEmail 更换邮箱 +func (h *UserHandler) ChangeEmail(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.ChangeEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { + h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { + h.logger.Error("更换邮箱失败", + zap.Int64("user_id", userID), + zap.String("new_email", req.NewEmail), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// ResetYggdrasilPassword 重置Yggdrasil密码 +func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) + if err != nil { + h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) + RespondServerError(c, "重置Yggdrasil密码失败", nil) + return + } + + h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) + RespondSuccess(c, gin.H{"password": newPassword}) +} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index aaf1847..a806368 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -1,16 +1,48 @@ package middleware import ( + "carrotskin/pkg/config" + "github.com/gin-gonic/gin" ) // CORS 跨域中间件 func CORS() gin.HandlerFunc { + // 获取配置,如果配置未初始化则使用默认值 + var allowedOrigins []string + if cfg, err := config.GetConfig(); err == nil { + allowedOrigins = cfg.Security.AllowedOrigins + } else { + // 默认允许所有来源(向后兼容) + allowedOrigins = []string{"*"} + } + return gin.HandlerFunc(func(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Credentials", "true") + origin := c.GetHeader("Origin") + + // 检查是否允许该来源 + allowOrigin := "*" + if len(allowedOrigins) > 0 && allowedOrigins[0] != "*" { + allowOrigin = "" + for _, allowed := range allowedOrigins { + if allowed == origin || allowed == "*" { + allowOrigin = origin + break + } + } + } + + if allowOrigin != "" { + c.Header("Access-Control-Allow-Origin", allowOrigin) + // 只有在非通配符模式下才允许credentials + if allowOrigin != "*" { + c.Header("Access-Control-Allow-Credentials", "true") + } + } + c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") + c.Header("Access-Control-Max-Age", "86400") // 缓存预检请求结果24小时 if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index a07f6c7..833e76d 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -24,10 +24,11 @@ func TestCORS_Headers(t *testing.T) { router.ServeHTTP(w, req) // 验证CORS响应头 + // 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范, + // 不应该设置 Access-Control-Allow-Credentials 为 "true" expectedHeaders := map[string]string{ - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Credentials": "true", - "Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE", } for header, expectedValue := range expectedHeaders { @@ -37,6 +38,11 @@ func TestCORS_Headers(t *testing.T) { } } + // 验证在通配符模式下不设置Credentials(这是正确的安全行为) + if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { + t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials) + } + // 验证Access-Control-Allow-Headers包含必要字段 allowHeaders := w.Header().Get("Access-Control-Allow-Headers") if allowHeaders == "" { @@ -117,6 +123,30 @@ func TestCORS_AllowHeaders(t *testing.T) { } } +// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为 +func TestCORS_WithSpecificOrigin(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 注意:此测试验证的是在配置了具体allowed origins时的行为 + // 在没有配置初始化的情况下,默认使用通配符模式 + router := gin.New() + router.Use(CORS()) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://example.com") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + // 默认配置下使用通配符,所以不应该设置credentials + if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { + t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials) + } +} + // 辅助函数:检查字符串是否包含子字符串(简单实现) func contains(s, substr string) bool { if len(substr) == 0 { diff --git a/internal/middleware/recovery.go b/internal/middleware/recovery.go index 8277182..3293f42 100644 --- a/internal/middleware/recovery.go +++ b/internal/middleware/recovery.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "runtime/debug" @@ -11,16 +12,26 @@ import ( // Recovery 恢复中间件 func Recovery(logger *zap.Logger) gin.HandlerFunc { return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { - if err, ok := recovered.(string); ok { - logger.Error("服务器恐慌", - zap.String("error", err), - zap.String("path", c.Request.URL.Path), - zap.String("method", c.Request.Method), - zap.String("ip", c.ClientIP()), - zap.String("stack", string(debug.Stack())), - ) + // 将任意类型的panic转换为字符串 + var errMsg string + switch v := recovered.(type) { + case string: + errMsg = v + case error: + errMsg = v.Error() + default: + errMsg = fmt.Sprintf("%v", v) } + logger.Error("服务器恐慌", + zap.String("error", errMsg), + zap.String("path", c.Request.URL.Path), + zap.String("method", c.Request.Method), + zap.String("ip", c.ClientIP()), + zap.String("user_agent", c.GetHeader("User-Agent")), + zap.String("stack", string(debug.Stack())), + ) + c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "服务器内部错误", diff --git a/internal/model/response.go b/internal/model/response.go index e76dac0..26865ed 100644 --- a/internal/model/response.go +++ b/internal/model/response.go @@ -1,10 +1,12 @@ package model +import "os" + // Response 通用API响应结构 type Response struct { - Code int `json:"code"` // 业务状态码 - Message string `json:"message"` // 响应消息 - Data interface{} `json:"data,omitempty"` // 响应数据 + Code int `json:"code"` // 业务状态码 + Message string `json:"message"` // 响应消息 + Data interface{} `json:"data,omitempty"` // 响应数据 } // PaginationResponse 分页响应结构 @@ -12,9 +14,9 @@ type PaginationResponse struct { Code int `json:"code"` Message string `json:"message"` Data interface{} `json:"data"` - Total int64 `json:"total"` // 总记录数 - Page int `json:"page"` // 当前页码 - PerPage int `json:"per_page"` // 每页数量 + Total int64 `json:"total"` // 总记录数 + Page int `json:"page"` // 当前页码 + PerPage int `json:"per_page"` // 每页数量 } // ErrorResponse 错误响应 @@ -26,14 +28,14 @@ type ErrorResponse struct { // 常用状态码 const ( - CodeSuccess = 200 // 成功 - CodeCreated = 201 // 创建成功 - CodeBadRequest = 400 // 请求参数错误 - CodeUnauthorized = 401 // 未授权 - CodeForbidden = 403 // 禁止访问 - CodeNotFound = 404 // 资源不存在 - CodeConflict = 409 // 资源冲突 - CodeServerError = 500 // 服务器错误 + CodeSuccess = 200 // 成功 + CodeCreated = 201 // 创建成功 + CodeBadRequest = 400 // 请求参数错误 + CodeUnauthorized = 401 // 未授权 + CodeForbidden = 403 // 禁止访问 + CodeNotFound = 404 // 资源不存在 + CodeConflict = 409 // 资源冲突 + CodeServerError = 500 // 服务器错误 ) // 常用响应消息 @@ -61,17 +63,26 @@ func NewSuccessResponse(data interface{}) *Response { } // NewErrorResponse 创建错误响应 +// 注意:err参数仅在开发环境下显示,生产环境不应暴露详细错误信息 func NewErrorResponse(code int, message string, err error) *ErrorResponse { resp := &ErrorResponse{ Code: code, Message: message, } - if err != nil { + // 仅在非生产环境下返回详细错误信息 + // 可以通过环境变量 ENVIRONMENT 控制 + if err != nil && !isProductionEnvironment() { resp.Error = err.Error() } return resp } +// isProductionEnvironment 检查是否为生产环境 +func isProductionEnvironment() bool { + env := os.Getenv("ENVIRONMENT") + return env == "production" || env == "prod" +} + // NewPaginationResponse 创建分页响应 func NewPaginationResponse(data interface{}, total int64, page, perPage int) *PaginationResponse { return &PaginationResponse{ diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go new file mode 100644 index 0000000..f72ca88 --- /dev/null +++ b/internal/repository/interfaces.go @@ -0,0 +1,85 @@ +package repository + +import ( + "carrotskin/internal/model" +) + +// UserRepository 用户仓储接口 +type UserRepository interface { + Create(user *model.User) error + FindByID(id int64) (*model.User, error) + FindByUsername(username string) (*model.User, error) + FindByEmail(email string) (*model.User, error) + Update(user *model.User) error + UpdateFields(id int64, fields map[string]interface{}) error + Delete(id int64) error + CreateLoginLog(log *model.UserLoginLog) error + CreatePointLog(log *model.UserPointLog) error + UpdatePoints(userID int64, amount int, changeType, reason string) error +} + +// ProfileRepository 档案仓储接口 +type ProfileRepository interface { + Create(profile *model.Profile) error + FindByUUID(uuid string) (*model.Profile, error) + FindByName(name string) (*model.Profile, error) + FindByUserID(userID int64) ([]*model.Profile, error) + Update(profile *model.Profile) error + UpdateFields(uuid string, updates map[string]interface{}) error + Delete(uuid string) error + CountByUserID(userID int64) (int64, error) + SetActive(uuid string, userID int64) error + UpdateLastUsedAt(uuid string) error + GetByNames(names []string) ([]*model.Profile, error) + GetKeyPair(profileId string) (*model.KeyPair, error) + UpdateKeyPair(profileId string, keyPair *model.KeyPair) error +} + +// TextureRepository 材质仓储接口 +type TextureRepository interface { + Create(texture *model.Texture) error + FindByID(id int64) (*model.Texture, error) + FindByHash(hash string) (*model.Texture, error) + FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) + Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) + Update(texture *model.Texture) error + UpdateFields(id int64, fields map[string]interface{}) error + Delete(id int64) error + IncrementDownloadCount(id int64) error + IncrementFavoriteCount(id int64) error + DecrementFavoriteCount(id int64) error + CreateDownloadLog(log *model.TextureDownloadLog) error + IsFavorited(userID, textureID int64) (bool, error) + AddFavorite(userID, textureID int64) error + RemoveFavorite(userID, textureID int64) error + GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) + CountByUploaderID(uploaderID int64) (int64, error) +} + +// TokenRepository 令牌仓储接口 +type TokenRepository interface { + Create(token *model.Token) error + FindByAccessToken(accessToken string) (*model.Token, error) + GetByUserID(userId int64) ([]*model.Token, error) + GetUUIDByAccessToken(accessToken string) (string, error) + GetUserIDByAccessToken(accessToken string) (int64, error) + DeleteByAccessToken(accessToken string) error + DeleteByUserID(userId int64) error + BatchDelete(accessTokens []string) (int64, error) +} + +// SystemConfigRepository 系统配置仓储接口 +type SystemConfigRepository interface { + GetByKey(key string) (*model.SystemConfig, error) + GetPublic() ([]model.SystemConfig, error) + GetAll() ([]model.SystemConfig, error) + Update(config *model.SystemConfig) error + UpdateValue(key, value string) error +} + +// YggdrasilRepository Yggdrasil仓储接口 +type YggdrasilRepository interface { + GetPasswordByID(id int64) (string, error) + ResetPassword(id int64, password string) error +} + diff --git a/internal/repository/profile_repository_impl.go b/internal/repository/profile_repository_impl.go new file mode 100644 index 0000000..ebe3fdb --- /dev/null +++ b/internal/repository/profile_repository_impl.go @@ -0,0 +1,149 @@ +package repository + +import ( + "carrotskin/internal/model" + "context" + "errors" + "fmt" + + "gorm.io/gorm" +) + +// profileRepositoryImpl ProfileRepository的实现 +type profileRepositoryImpl struct { + db *gorm.DB +} + +// NewProfileRepository 创建ProfileRepository实例 +func NewProfileRepository(db *gorm.DB) ProfileRepository { + return &profileRepositoryImpl{db: db} +} + +func (r *profileRepositoryImpl) Create(profile *model.Profile) error { + return r.db.Create(profile).Error +} + +func (r *profileRepositoryImpl) FindByUUID(uuid string) (*model.Profile, error) { + var profile model.Profile + err := r.db.Where("uuid = ?", uuid). + Preload("Skin"). + Preload("Cape"). + First(&profile).Error + if err != nil { + return nil, err + } + return &profile, nil +} + +func (r *profileRepositoryImpl) FindByName(name string) (*model.Profile, error) { + var profile model.Profile + err := r.db.Where("name = ?", name).First(&profile).Error + if err != nil { + return nil, err + } + return &profile, nil +} + +func (r *profileRepositoryImpl) FindByUserID(userID int64) ([]*model.Profile, error) { + var profiles []*model.Profile + err := r.db.Where("user_id = ?", userID). + Preload("Skin"). + Preload("Cape"). + Order("created_at DESC"). + Find(&profiles).Error + return profiles, err +} + +func (r *profileRepositoryImpl) Update(profile *model.Profile) error { + return r.db.Save(profile).Error +} + +func (r *profileRepositoryImpl) UpdateFields(uuid string, updates map[string]interface{}) error { + return r.db.Model(&model.Profile{}). + Where("uuid = ?", uuid). + Updates(updates).Error +} + +func (r *profileRepositoryImpl) Delete(uuid string) error { + return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error +} + +func (r *profileRepositoryImpl) CountByUserID(userID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.Profile{}). + Where("user_id = ?", userID). + Count(&count).Error + return count, err +} + +func (r *profileRepositoryImpl) SetActive(uuid string, userID int64) error { + return r.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&model.Profile{}). + Where("user_id = ?", userID). + Update("is_active", false).Error; err != nil { + return err + } + + return tx.Model(&model.Profile{}). + Where("uuid = ? AND user_id = ?", uuid, userID). + Update("is_active", true).Error + }) +} + +func (r *profileRepositoryImpl) UpdateLastUsedAt(uuid string) error { + return r.db.Model(&model.Profile{}). + Where("uuid = ?", uuid). + Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error +} + +func (r *profileRepositoryImpl) GetByNames(names []string) ([]*model.Profile, error) { + var profiles []*model.Profile + err := r.db.Where("name in (?)", names).Find(&profiles).Error + return profiles, err +} + +func (r *profileRepositoryImpl) GetKeyPair(profileId string) (*model.KeyPair, error) { + if profileId == "" { + return nil, errors.New("参数不能为空") + } + + var profile model.Profile + result := r.db.WithContext(context.Background()). + Select("key_pair"). + Where("id = ?", profileId). + First(&profile) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, errors.New("key pair未找到") + } + return nil, fmt.Errorf("获取key pair失败: %w", result.Error) + } + + return &model.KeyPair{}, nil +} + +func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { + if profileId == "" { + return errors.New("profileId 不能为空") + } + if keyPair == nil { + return errors.New("keyPair 不能为 nil") + } + + return r.db.Transaction(func(tx *gorm.DB) error { + result := tx.WithContext(context.Background()). + Table("profiles"). + Where("id = ?", profileId). + UpdateColumns(map[string]interface{}{ + "private_key": keyPair.PrivateKey, + "public_key": keyPair.PublicKey, + }) + + if result.Error != nil { + return fmt.Errorf("更新 keyPair 失败: %w", result.Error) + } + return nil + }) +} + diff --git a/internal/repository/system_config_repository_impl.go b/internal/repository/system_config_repository_impl.go new file mode 100644 index 0000000..2bb5844 --- /dev/null +++ b/internal/repository/system_config_repository_impl.go @@ -0,0 +1,44 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// systemConfigRepositoryImpl SystemConfigRepository的实现 +type systemConfigRepositoryImpl struct { + db *gorm.DB +} + +// NewSystemConfigRepository 创建SystemConfigRepository实例 +func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository { + return &systemConfigRepositoryImpl{db: db} +} + +func (r *systemConfigRepositoryImpl) GetByKey(key string) (*model.SystemConfig, error) { + var config model.SystemConfig + err := r.db.Where("key = ?", key).First(&config).Error + return handleNotFoundResult(&config, err) +} + +func (r *systemConfigRepositoryImpl) GetPublic() ([]model.SystemConfig, error) { + var configs []model.SystemConfig + err := r.db.Where("is_public = ?", true).Find(&configs).Error + return configs, err +} + +func (r *systemConfigRepositoryImpl) GetAll() ([]model.SystemConfig, error) { + var configs []model.SystemConfig + err := r.db.Find(&configs).Error + return configs, err +} + +func (r *systemConfigRepositoryImpl) Update(config *model.SystemConfig) error { + return r.db.Save(config).Error +} + +func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error { + return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error +} + diff --git a/internal/repository/texture_repository_impl.go b/internal/repository/texture_repository_impl.go new file mode 100644 index 0000000..c6a2971 --- /dev/null +++ b/internal/repository/texture_repository_impl.go @@ -0,0 +1,175 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// textureRepositoryImpl TextureRepository的实现 +type textureRepositoryImpl struct { + db *gorm.DB +} + +// NewTextureRepository 创建TextureRepository实例 +func NewTextureRepository(db *gorm.DB) TextureRepository { + return &textureRepositoryImpl{db: db} +} + +func (r *textureRepositoryImpl) Create(texture *model.Texture) error { + return r.db.Create(texture).Error +} + +func (r *textureRepositoryImpl) FindByID(id int64) (*model.Texture, error) { + var texture model.Texture + err := r.db.Preload("Uploader").First(&texture, id).Error + return handleNotFoundResult(&texture, err) +} + +func (r *textureRepositoryImpl) FindByHash(hash string) (*model.Texture, error) { + var texture model.Texture + err := r.db.Where("hash = ?", hash).First(&texture).Error + return handleNotFoundResult(&texture, err) +} + +func (r *textureRepositoryImpl) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + query := r.db.Model(&model.Texture{}).Where("status = 1") + + if publicOnly { + query = query.Where("is_public = ?", true) + } + if textureType != "" { + query = query.Where("type = ?", textureType) + } + if keyword != "" { + query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%") + } + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) Update(texture *model.Texture) error { + return r.db.Save(texture).Error +} + +func (r *textureRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error +} + +func (r *textureRepositoryImpl) Delete(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error +} + +func (r *textureRepositoryImpl) IncrementDownloadCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error +} + +func (r *textureRepositoryImpl) IncrementFavoriteCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error +} + +func (r *textureRepositoryImpl) DecrementFavoriteCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error +} + +func (r *textureRepositoryImpl) CreateDownloadLog(log *model.TextureDownloadLog) error { + return r.db.Create(log).Error +} + +func (r *textureRepositoryImpl) IsFavorited(userID, textureID int64) (bool, error) { + var count int64 + err := r.db.Model(&model.UserTextureFavorite{}). + Where("user_id = ? AND texture_id = ?", userID, textureID). + Count(&count).Error + return count > 0, err +} + +func (r *textureRepositoryImpl) AddFavorite(userID, textureID int64) error { + favorite := &model.UserTextureFavorite{ + UserID: userID, + TextureID: textureID, + } + return r.db.Create(favorite).Error +} + +func (r *textureRepositoryImpl) RemoveFavorite(userID, textureID int64) error { + return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID). + Delete(&model.UserTextureFavorite{}).Error +} + +func (r *textureRepositoryImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + subQuery := r.db.Model(&model.UserTextureFavorite{}). + Select("texture_id"). + Where("user_id = ?", userID) + + query := r.db.Model(&model.Texture{}). + Where("id IN (?) AND status = 1", subQuery) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.Texture{}). + Where("uploader_id = ? AND status != -1", uploaderID). + Count(&count).Error + return count, err +} + diff --git a/internal/repository/token_repository_impl.go b/internal/repository/token_repository_impl.go new file mode 100644 index 0000000..e4c94e1 --- /dev/null +++ b/internal/repository/token_repository_impl.go @@ -0,0 +1,71 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// tokenRepositoryImpl TokenRepository的实现 +type tokenRepositoryImpl struct { + db *gorm.DB +} + +// NewTokenRepository 创建TokenRepository实例 +func NewTokenRepository(db *gorm.DB) TokenRepository { + return &tokenRepositoryImpl{db: db} +} + +func (r *tokenRepositoryImpl) Create(token *model.Token) error { + return r.db.Create(token).Error +} + +func (r *tokenRepositoryImpl) FindByAccessToken(accessToken string) (*model.Token, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +func (r *tokenRepositoryImpl) GetByUserID(userId int64) ([]*model.Token, error) { + var tokens []*model.Token + err := r.db.Where("user_id = ?", userId).Find(&tokens).Error + return tokens, err +} + +func (r *tokenRepositoryImpl) GetUUIDByAccessToken(accessToken string) (string, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return "", err + } + return token.ProfileId, nil +} + +func (r *tokenRepositoryImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return 0, err + } + return token.UserID, nil +} + +func (r *tokenRepositoryImpl) DeleteByAccessToken(accessToken string) error { + return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error +} + +func (r *tokenRepositoryImpl) DeleteByUserID(userId int64) error { + return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error +} + +func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) { + if len(accessTokens) == 0 { + return 0, nil + } + result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) + return result.RowsAffected, result.Error +} + diff --git a/internal/repository/user_repository_impl.go b/internal/repository/user_repository_impl.go new file mode 100644 index 0000000..57ec4c8 --- /dev/null +++ b/internal/repository/user_repository_impl.go @@ -0,0 +1,103 @@ +package repository + +import ( + "carrotskin/internal/model" + "errors" + + "gorm.io/gorm" +) + +// userRepositoryImpl UserRepository的实现 +type userRepositoryImpl struct { + db *gorm.DB +} + +// NewUserRepository 创建UserRepository实例 +func NewUserRepository(db *gorm.DB) UserRepository { + return &userRepositoryImpl{db: db} +} + +func (r *userRepositoryImpl) Create(user *model.User) error { + return r.db.Create(user).Error +} + +func (r *userRepositoryImpl) FindByID(id int64) (*model.User, error) { + var user model.User + err := r.db.Where("id = ? AND status != -1", id).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) FindByUsername(username string) (*model.User, error) { + var user model.User + err := r.db.Where("username = ? AND status != -1", username).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) FindByEmail(email string) (*model.User, error) { + var user model.User + err := r.db.Where("email = ? AND status != -1", email).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) Update(user *model.User) error { + return r.db.Save(user).Error +} + +func (r *userRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error { + return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error +} + +func (r *userRepositoryImpl) Delete(id int64) error { + return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error +} + +func (r *userRepositoryImpl) CreateLoginLog(log *model.UserLoginLog) error { + return r.db.Create(log).Error +} + +func (r *userRepositoryImpl) CreatePointLog(log *model.UserPointLog) error { + return r.db.Create(log).Error +} + +func (r *userRepositoryImpl) UpdatePoints(userID int64, amount int, changeType, reason string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + var user model.User + if err := tx.Where("id = ?", userID).First(&user).Error; err != nil { + return err + } + + balanceBefore := user.Points + balanceAfter := balanceBefore + amount + + if balanceAfter < 0 { + return errors.New("积分不足") + } + + if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil { + return err + } + + log := &model.UserPointLog{ + UserID: userID, + ChangeType: changeType, + Amount: amount, + BalanceBefore: balanceBefore, + BalanceAfter: balanceAfter, + Reason: reason, + } + + return tx.Create(log).Error + }) +} + +// handleNotFoundResult 处理记录未找到的情况 +func handleNotFoundResult[T any](result *T, err error) (*T, error) { + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return result, nil +} + diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 4a98ca7..249a341 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -4,10 +4,12 @@ import ( "carrotskin/internal/model" "carrotskin/internal/repository" "carrotskin/pkg/auth" + "carrotskin/pkg/config" "carrotskin/pkg/redis" "context" "errors" "fmt" + "net/url" "strings" "time" ) @@ -286,24 +288,69 @@ func ValidateAvatarURL(avatarURL string) error { return nil } - // 允许的域名列表 - allowedDomains := []string{ - "rustfs.example.com", - "localhost", - "127.0.0.1", - } - - for _, domain := range allowedDomains { - if strings.Contains(avatarURL, domain) { - return nil - } - } - + // 允许相对路径 if strings.HasPrefix(avatarURL, "/") { return nil } - return errors.New("头像URL不在允许的域名列表中") + return ValidateURLDomain(avatarURL) +} + +// ValidateURLDomain 验证URL的域名是否在允许列表中 +func ValidateURLDomain(rawURL string) error { + // 解析URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return errors.New("无效的URL格式") + } + + // 必须是HTTP或HTTPS协议 + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return errors.New("URL必须使用http或https协议") + } + + // 获取主机名(不包含端口) + host := parsedURL.Hostname() + if host == "" { + return errors.New("URL缺少主机名") + } + + // 从配置获取允许的域名列表 + cfg, err := config.GetConfig() + if err != nil { + // 如果配置获取失败,使用默认的安全域名列表 + allowedDomains := []string{"localhost", "127.0.0.1"} + return checkDomainAllowed(host, allowedDomains) + } + + return checkDomainAllowed(host, cfg.Security.AllowedDomains) +} + +// checkDomainAllowed 检查域名是否在允许列表中 +func checkDomainAllowed(host string, allowedDomains []string) error { + host = strings.ToLower(host) + + for _, allowed := range allowedDomains { + allowed = strings.ToLower(strings.TrimSpace(allowed)) + if allowed == "" { + continue + } + + // 精确匹配 + if host == allowed { + return nil + } + + // 支持通配符子域名匹配 (如 *.example.com) + if strings.HasPrefix(allowed, "*.") { + suffix := allowed[1:] // 移除 "*",保留 ".example.com" + if strings.HasSuffix(host, suffix) { + return nil + } + } + } + + return errors.New("URL域名不在允许的列表中") } // GetUserByEmail 根据邮箱获取用户 diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 275ee86..b509c70 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -55,6 +55,10 @@ func (j *JWTService) GenerateToken(userID int64, username, role string) (string, // ValidateToken 验证JWT Token func (j *JWTService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + // 验证签名算法,防止algorithm confusion攻击 + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("不支持的签名算法") + } return []byte(j.secretKey), nil }) diff --git a/pkg/config/config.go b/pkg/config/config.go index 919e4c3..83c5188 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strconv" + "strings" "time" "github.com/joho/godotenv" @@ -22,6 +23,7 @@ type Config struct { Log LogConfig `mapstructure:"log"` Upload UploadConfig `mapstructure:"upload"` Email EmailConfig `mapstructure:"email"` + Security SecurityConfig `mapstructure:"security"` } // ServerConfig 服务器配置 @@ -107,6 +109,12 @@ type EmailConfig struct { FromName string `mapstructure:"from_name"` } +// SecurityConfig 安全配置 +type SecurityConfig struct { + AllowedOrigins []string `mapstructure:"allowed_origins"` // 允许的CORS来源 + AllowedDomains []string `mapstructure:"allowed_domains"` // 允许的头像/材质URL域名 +} + // Load 加载配置 - 完全从环境变量加载,不依赖YAML文件 func Load() (*Config, error) { // 加载.env文件(如果存在) @@ -160,7 +168,7 @@ func setDefaults() { // RustFS默认配置 viper.SetDefault("rustfs.endpoint", "127.0.0.1:9000") - viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL + viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL viper.SetDefault("rustfs.use_ssl", false) // JWT默认配置 @@ -188,6 +196,10 @@ func setDefaults() { // 邮件默认配置 viper.SetDefault("email.enabled", false) viper.SetDefault("email.smtp_port", 587) + + // 安全默认配置 + viper.SetDefault("security.allowed_origins", []string{"*"}) + viper.SetDefault("security.allowed_domains", []string{"localhost", "127.0.0.1"}) } // setupEnvMappings 设置环境变量映射 @@ -310,6 +322,15 @@ func overrideFromEnv(config *Config) { if env := os.Getenv("ENVIRONMENT"); env != "" { config.Environment = env } + + // 处理安全配置 + if allowedOrigins := os.Getenv("SECURITY_ALLOWED_ORIGINS"); allowedOrigins != "" { + config.Security.AllowedOrigins = strings.Split(allowedOrigins, ",") + } + + if allowedDomains := os.Getenv("SECURITY_ALLOWED_DOMAINS"); allowedDomains != "" { + config.Security.AllowedDomains = strings.Split(allowedDomains, ",") + } } // IsTestEnvironment 判断是否为测试环境 diff --git a/pkg/config/manager.go b/pkg/config/manager.go index 5c2d631..1ded256 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -62,6 +62,3 @@ func MustGetRustFSConfig() *RustFSConfig { return cfg } - - - From ffdc3e3e6b544eeec68718f6acecbe6007600e40 Mon Sep 17 00:00:00 2001 From: lan Date: Tue, 2 Dec 2025 17:46:00 +0800 Subject: [PATCH 2/5] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=B3=A8=E5=85=A5=E6=94=B9=E9=80=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完成所有Handler的依赖注入改造: - AuthHandler: 认证相关功能 - UserHandler: 用户管理功能 - TextureHandler: 材质管理功能 - ProfileHandler: 档案管理功能 - CaptchaHandler: 验证码功能 - YggdrasilHandler: Yggdrasil API功能 新增错误类型定义: - internal/errors/errors.go: 统一的错误类型和工厂函数 更新main.go: - 使用container.NewContainer创建依赖容器 - 使用handler.RegisterRoutesWithDI注册路由 代码遵循Go最佳实践: - 依赖通过构造函数注入 - Handler通过结构体方法实现 - 统一的错误处理模式 - 清晰的分层架构 --- cmd/server/main.go | 17 +- cmd/server/main_di_example.go.example | 146 ------ internal/container/container.go | 11 +- internal/errors/errors.go | 127 +++++ internal/handler/captcha_handler_di.go | 108 +++++ internal/handler/profile_handler_di.go | 247 ++++++++++ internal/handler/routes_di.go | 84 ++-- internal/handler/yggdrasil_handler_di.go | 454 ++++++++++++++++++ internal/repository/interfaces.go | 1 - .../repository/profile_repository_impl.go | 1 - .../repository/texture_repository_impl.go | 1 - internal/repository/token_repository_impl.go | 1 - internal/repository/user_repository_impl.go | 1 - 13 files changed, 998 insertions(+), 201 deletions(-) delete mode 100644 cmd/server/main_di_example.go.example create mode 100644 internal/errors/errors.go create mode 100644 internal/handler/captcha_handler_di.go create mode 100644 internal/handler/profile_handler_di.go create mode 100644 internal/handler/yggdrasil_handler_di.go diff --git a/cmd/server/main.go b/cmd/server/main.go index fb68942..ea29746 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,6 +10,7 @@ import ( "time" _ "carrotskin/docs" // Swagger文档 + "carrotskin/internal/container" "carrotskin/internal/handler" "carrotskin/internal/middleware" "carrotskin/pkg/auth" @@ -66,10 +67,11 @@ func main() { defer redis.MustGetClient().Close() // 初始化对象存储 (RustFS - S3兼容) - // 如果对象存储未配置或连接失败,记录警告但不退出(某些功能可能不可用) + var storageClient *storage.StorageClient if err := storage.Init(cfg.RustFS); err != nil { loggerInstance.Warn("对象存储连接失败,某些功能可能不可用", zap.Error(err)) } else { + storageClient = storage.MustGetClient() loggerInstance.Info("对象存储连接成功") } @@ -78,6 +80,15 @@ func main() { loggerInstance.Fatal("邮件服务初始化失败", zap.Error(err)) } + // 创建依赖注入容器 + c := container.NewContainer( + database.MustGetDB(), + redis.MustGetClient(), + loggerInstance, + auth.MustGetJWTService(), + storageClient, + ) + // 设置Gin模式 if cfg.Server.Mode == "production" { gin.SetMode(gin.ReleaseMode) @@ -91,8 +102,8 @@ func main() { router.Use(middleware.Recovery(loggerInstance)) router.Use(middleware.CORS()) - // 注册路由 - handler.RegisterRoutes(router) + // 使用依赖注入方式注册路由 + handler.RegisterRoutesWithDI(router, c) // 创建HTTP服务器 srv := &http.Server{ diff --git a/cmd/server/main_di_example.go.example b/cmd/server/main_di_example.go.example deleted file mode 100644 index d9168ef..0000000 --- a/cmd/server/main_di_example.go.example +++ /dev/null @@ -1,146 +0,0 @@ -// +build ignore -// 此文件是依赖注入版本的main.go示例 -// 可以参考此文件改造原有的main.go - -package main - -import ( - "context" - "log" - "net/http" - "os" - "os/signal" - "syscall" - "time" - - _ "carrotskin/docs" // Swagger文档 - "carrotskin/internal/container" - "carrotskin/internal/handler" - "carrotskin/internal/middleware" - "carrotskin/pkg/auth" - "carrotskin/pkg/config" - "carrotskin/pkg/database" - "carrotskin/pkg/email" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" - "carrotskin/pkg/storage" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -func main() { - // 初始化配置 - if err := config.Init(); err != nil { - log.Fatalf("配置加载失败: %v", err) - } - cfg := config.MustGetConfig() - - // 初始化日志 - if err := logger.Init(cfg.Log); err != nil { - log.Fatalf("日志初始化失败: %v", err) - } - loggerInstance := logger.MustGetLogger() - defer loggerInstance.Sync() - - // 初始化数据库 - if err := database.Init(cfg.Database, loggerInstance); err != nil { - loggerInstance.Fatal("数据库初始化失败", zap.Error(err)) - } - defer database.Close() - - // 执行数据库迁移 - if err := database.AutoMigrate(loggerInstance); err != nil { - loggerInstance.Fatal("数据库迁移失败", zap.Error(err)) - } - - // 初始化种子数据 - if err := database.Seed(loggerInstance); err != nil { - loggerInstance.Fatal("种子数据初始化失败", zap.Error(err)) - } - - // 初始化JWT服务 - if err := auth.Init(cfg.JWT); err != nil { - loggerInstance.Fatal("JWT服务初始化失败", zap.Error(err)) - } - - // 初始化Redis - if err := redis.Init(cfg.Redis, loggerInstance); err != nil { - loggerInstance.Fatal("Redis连接失败", zap.Error(err)) - } - defer redis.MustGetClient().Close() - - // 初始化对象存储 - var storageClient *storage.StorageClient - if err := storage.Init(cfg.RustFS); err != nil { - loggerInstance.Warn("对象存储连接失败,某些功能可能不可用", zap.Error(err)) - } else { - storageClient = storage.MustGetClient() - loggerInstance.Info("对象存储连接成功") - } - - // 初始化邮件服务 - if err := email.Init(cfg.Email, loggerInstance); err != nil { - loggerInstance.Fatal("邮件服务初始化失败", zap.Error(err)) - } - - // ============ 依赖注入改动部分 ============ - // 创建依赖注入容器 - c := container.NewContainer( - database.MustGetDB(), - redis.MustGetClient(), - loggerInstance, - auth.MustGetJWTService(), - storageClient, - ) - - // 设置Gin模式 - if cfg.Server.Mode == "production" { - gin.SetMode(gin.ReleaseMode) - } - - // 创建路由 - router := gin.New() - - // 添加中间件 - router.Use(middleware.Logger(loggerInstance)) - router.Use(middleware.Recovery(loggerInstance)) - router.Use(middleware.CORS()) - - // 使用依赖注入方式注册路由 - handler.RegisterRoutesWithDI(router, c) - // ============ 依赖注入改动结束 ============ - - // 创建HTTP服务器 - srv := &http.Server{ - Addr: cfg.Server.Port, - Handler: router, - ReadTimeout: cfg.Server.ReadTimeout, - WriteTimeout: cfg.Server.WriteTimeout, - } - - // 启动服务器 - go func() { - loggerInstance.Info("服务器启动", zap.String("port", cfg.Server.Port)) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - loggerInstance.Fatal("服务器启动失败", zap.Error(err)) - } - }() - - // 等待中断信号优雅关闭 - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - <-quit - loggerInstance.Info("正在关闭服务器...") - - // 设置关闭超时 - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := srv.Shutdown(ctx); err != nil { - loggerInstance.Fatal("服务器强制关闭", zap.Error(err)) - } - - loggerInstance.Info("服务器已关闭") -} - diff --git a/internal/container/container.go b/internal/container/container.go index 230e68f..cde146e 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -21,11 +21,11 @@ type Container struct { Storage *storage.StorageClient // Repository层 - UserRepo repository.UserRepository - ProfileRepo repository.ProfileRepository - TextureRepo repository.TextureRepository - TokenRepo repository.TokenRepository - ConfigRepo repository.SystemConfigRepository + UserRepo repository.UserRepository + ProfileRepo repository.ProfileRepository + TextureRepo repository.TextureRepository + TokenRepo repository.TokenRepository + ConfigRepo repository.SystemConfigRepository } // NewContainer 创建依赖容器 @@ -135,4 +135,3 @@ func WithConfigRepo(repo repository.SystemConfigRepository) Option { c.ConfigRepo = repo } } - diff --git a/internal/errors/errors.go b/internal/errors/errors.go new file mode 100644 index 0000000..7cf4e41 --- /dev/null +++ b/internal/errors/errors.go @@ -0,0 +1,127 @@ +// Package errors 定义应用程序的错误类型 +package errors + +import ( + "errors" + "fmt" +) + +// 预定义错误 +var ( + // 用户相关错误 + ErrUserNotFound = errors.New("用户不存在") + ErrUserAlreadyExists = errors.New("用户已存在") + ErrEmailAlreadyExists = errors.New("邮箱已被注册") + ErrInvalidPassword = errors.New("密码错误") + ErrAccountDisabled = errors.New("账号已被禁用") + + // 认证相关错误 + ErrUnauthorized = errors.New("未授权") + ErrInvalidToken = errors.New("无效的令牌") + ErrTokenExpired = errors.New("令牌已过期") + ErrInvalidSignature = errors.New("签名验证失败") + + // 档案相关错误 + ErrProfileNotFound = errors.New("档案不存在") + ErrProfileNameExists = errors.New("角色名已被使用") + ErrProfileLimitReached = errors.New("已达档案数量上限") + ErrProfileNoPermission = errors.New("无权操作此档案") + + // 材质相关错误 + ErrTextureNotFound = errors.New("材质不存在") + ErrTextureExists = errors.New("该材质已存在") + ErrTextureLimitReached = errors.New("已达材质数量上限") + ErrTextureNoPermission = errors.New("无权操作此材质") + ErrInvalidTextureType = errors.New("无效的材质类型") + + // 验证码相关错误 + ErrInvalidVerificationCode = errors.New("验证码错误或已过期") + ErrTooManyAttempts = errors.New("尝试次数过多") + ErrSendTooFrequent = errors.New("发送过于频繁") + + // URL验证相关错误 + ErrInvalidURL = errors.New("无效的URL格式") + ErrDomainNotAllowed = errors.New("URL域名不在允许的列表中") + + // 存储相关错误 + ErrStorageUnavailable = errors.New("存储服务不可用") + ErrUploadFailed = errors.New("上传失败") + + // 通用错误 + ErrBadRequest = errors.New("请求参数错误") + ErrInternalServer = errors.New("服务器内部错误") + ErrNotFound = errors.New("资源不存在") + ErrForbidden = errors.New("权限不足") +) + +// AppError 应用错误类型,包含错误码和消息 +type AppError struct { + Code int // HTTP状态码 + Message string // 用户可见的错误消息 + Err error // 原始错误(用于日志) +} + +// Error 实现error接口 +func (e *AppError) Error() string { + if e.Err != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Err) + } + return e.Message +} + +// Unwrap 支持errors.Is和errors.As +func (e *AppError) Unwrap() error { + return e.Err +} + +// NewAppError 创建新的应用错误 +func NewAppError(code int, message string, err error) *AppError { + return &AppError{ + Code: code, + Message: message, + Err: err, + } +} + +// NewBadRequest 创建400错误 +func NewBadRequest(message string, err error) *AppError { + return NewAppError(400, message, err) +} + +// NewUnauthorized 创建401错误 +func NewUnauthorized(message string) *AppError { + return NewAppError(401, message, nil) +} + +// NewForbidden 创建403错误 +func NewForbidden(message string) *AppError { + return NewAppError(403, message, nil) +} + +// NewNotFound 创建404错误 +func NewNotFound(message string) *AppError { + return NewAppError(404, message, nil) +} + +// NewInternalError 创建500错误 +func NewInternalError(message string, err error) *AppError { + return NewAppError(500, message, err) +} + +// Is 检查错误是否匹配 +func Is(err, target error) bool { + return errors.Is(err, target) +} + +// As 尝试将错误转换为指定类型 +func As(err error, target interface{}) bool { + return errors.As(err, target) +} + +// Wrap 包装错误 +func Wrap(err error, message string) error { + if err == nil { + return nil + } + return fmt.Errorf("%s: %w", message, err) +} diff --git a/internal/handler/captcha_handler_di.go b/internal/handler/captcha_handler_di.go new file mode 100644 index 0000000..8078aee --- /dev/null +++ b/internal/handler/captcha_handler_di.go @@ -0,0 +1,108 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "net/http" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// CaptchaHandler 验证码处理器 +type CaptchaHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewCaptchaHandler 创建CaptchaHandler实例 +func NewCaptchaHandler(c *container.Container) *CaptchaHandler { + return &CaptchaHandler{ + container: c, + logger: c.Logger, + } +} + +// CaptchaVerifyRequest 验证码验证请求 +type CaptchaVerifyRequest struct { + CaptchaID string `json:"captchaId" binding:"required"` + Dx int `json:"dx" binding:"required"` +} + +// Generate 生成验证码 +// @Summary 生成滑动验证码 +// @Description 生成滑动验证码图片 +// @Tags captcha +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} "生成成功" +// @Failure 500 {object} map[string]interface{} "生成失败" +// @Router /api/v1/captcha/generate [get] +func (h *CaptchaHandler) Generate(c *gin.Context) { + masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis) + if err != nil { + h.logger.Error("生成验证码失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "msg": "生成验证码失败", + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 200, + "data": gin.H{ + "masterImage": masterImg, + "tileImage": tileImg, + "captchaId": captchaID, + "y": y, + }, + }) +} + +// Verify 验证验证码 +// @Summary 验证滑动验证码 +// @Description 验证用户滑动的偏移量是否正确 +// @Tags captcha +// @Accept json +// @Produce json +// @Param request body CaptchaVerifyRequest true "验证请求" +// @Success 200 {object} map[string]interface{} "验证结果" +// @Failure 400 {object} map[string]interface{} "参数错误" +// @Router /api/v1/captcha/verify [post] +func (h *CaptchaHandler) Verify(c *gin.Context) { + var req CaptchaVerifyRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "msg": "参数错误: " + err.Error(), + }) + return + } + + valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID) + if err != nil { + h.logger.Error("验证码验证失败", + zap.String("captcha_id", req.CaptchaID), + zap.Error(err), + ) + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "msg": "验证失败", + }) + return + } + + if valid { + c.JSON(http.StatusOK, gin.H{ + "code": 200, + "msg": "验证成功", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "code": 400, + "msg": "验证失败,请重试", + }) + } +} + diff --git a/internal/handler/profile_handler_di.go b/internal/handler/profile_handler_di.go new file mode 100644 index 0000000..6fdbeb9 --- /dev/null +++ b/internal/handler/profile_handler_di.go @@ -0,0 +1,247 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "carrotskin/internal/types" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// ProfileHandler 档案处理器 +type ProfileHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewProfileHandler 创建ProfileHandler实例 +func NewProfileHandler(c *container.Container) *ProfileHandler { + return &ProfileHandler{ + container: c, + logger: c.Logger, + } +} + +// Create 创建档案 +// @Summary 创建Minecraft档案 +// @Description 创建新的Minecraft角色档案,UUID由后端自动生成 +// @Tags profile +// @Accept json +// @Produce json +// @Security BearerAuth +// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)" +// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/profile [post] +func (h *ProfileHandler) Create(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.CreateProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) + return + } + + maxProfiles := service.GetMaxProfilesPerUser() + if err := service.CheckProfileLimit(h.container.DB, userID, maxProfiles); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + + profile, err := service.CreateProfile(h.container.DB, userID, req.Name) + if err != nil { + h.logger.Error("创建档案失败", + zap.Int64("user_id", userID), + zap.String("name", req.Name), + zap.Error(err), + ) + RespondServerError(c, err.Error(), nil) + return + } + + RespondSuccess(c, ProfileToProfileInfo(profile)) +} + +// List 获取档案列表 +// @Summary 获取档案列表 +// @Description 获取当前用户的所有档案 +// @Tags profile +// @Accept json +// @Produce json +// @Security BearerAuth +// @Success 200 {object} model.Response "获取成功" +// @Router /api/v1/profile [get] +func (h *ProfileHandler) List(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + profiles, err := service.GetUserProfiles(h.container.DB, userID) + if err != nil { + h.logger.Error("获取档案列表失败", + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondServerError(c, err.Error(), nil) + return + } + + RespondSuccess(c, ProfilesToProfileInfos(profiles)) +} + +// Get 获取档案详情 +// @Summary 获取档案详情 +// @Description 根据UUID获取档案详细信息 +// @Tags profile +// @Accept json +// @Produce json +// @Param uuid path string true "档案UUID" +// @Success 200 {object} model.Response "获取成功" +// @Failure 404 {object} model.ErrorResponse "档案不存在" +// @Router /api/v1/profile/{uuid} [get] +func (h *ProfileHandler) Get(c *gin.Context) { + uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } + + profile, err := service.GetProfileByUUID(h.container.DB, uuid) + if err != nil { + h.logger.Error("获取档案失败", + zap.String("uuid", uuid), + zap.Error(err), + ) + RespondNotFound(c, err.Error()) + return + } + + RespondSuccess(c, ProfileToProfileInfo(profile)) +} + +// Update 更新档案 +// @Summary 更新档案 +// @Description 更新档案信息 +// @Tags profile +// @Accept json +// @Produce json +// @Security BearerAuth +// @Param uuid path string true "档案UUID" +// @Param request body types.UpdateProfileRequest true "更新信息" +// @Success 200 {object} model.Response "更新成功" +// @Failure 403 {object} model.ErrorResponse "无权操作" +// @Router /api/v1/profile/{uuid} [put] +func (h *ProfileHandler) Update(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } + + var req types.UpdateProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) + return + } + + var namePtr *string + if req.Name != "" { + namePtr = &req.Name + } + + profile, err := service.UpdateProfile(h.container.DB, uuid, userID, namePtr, req.SkinID, req.CapeID) + if err != nil { + h.logger.Error("更新档案失败", + zap.String("uuid", uuid), + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondWithError(c, err) + return + } + + RespondSuccess(c, ProfileToProfileInfo(profile)) +} + +// Delete 删除档案 +// @Summary 删除档案 +// @Description 删除指定的Minecraft档案 +// @Tags profile +// @Accept json +// @Produce json +// @Security BearerAuth +// @Param uuid path string true "档案UUID" +// @Success 200 {object} model.Response "删除成功" +// @Failure 403 {object} model.ErrorResponse "无权操作" +// @Router /api/v1/profile/{uuid} [delete] +func (h *ProfileHandler) Delete(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } + + if err := service.DeleteProfile(h.container.DB, uuid, userID); err != nil { + h.logger.Error("删除档案失败", + zap.String("uuid", uuid), + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondWithError(c, err) + return + } + + RespondSuccess(c, gin.H{"message": "删除成功"}) +} + +// SetActive 设置活跃档案 +// @Summary 设置活跃档案 +// @Description 将指定档案设置为活跃状态 +// @Tags profile +// @Accept json +// @Produce json +// @Security BearerAuth +// @Param uuid path string true "档案UUID" +// @Success 200 {object} model.Response "设置成功" +// @Failure 403 {object} model.ErrorResponse "无权操作" +// @Router /api/v1/profile/{uuid}/activate [post] +func (h *ProfileHandler) SetActive(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } + + if err := service.SetActiveProfile(h.container.DB, uuid, userID); err != nil { + h.logger.Error("设置活跃档案失败", + zap.String("uuid", uuid), + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondWithError(c, err) + return + } + + RespondSuccess(c, gin.H{"message": "设置成功"}) +} + diff --git a/internal/handler/routes_di.go b/internal/handler/routes_di.go index d022cf6..a6da9c8 100644 --- a/internal/handler/routes_di.go +++ b/internal/handler/routes_di.go @@ -10,20 +10,23 @@ import ( // Handlers 集中管理所有Handler type Handlers struct { - Auth *AuthHandler - User *UserHandler - Texture *TextureHandler - // Profile *ProfileHandler // 后续添加 - // Captcha *CaptchaHandler // 后续添加 - // Yggdrasil *YggdrasilHandler // 后续添加 + Auth *AuthHandler + User *UserHandler + Texture *TextureHandler + Profile *ProfileHandler + Captcha *CaptchaHandler + Yggdrasil *YggdrasilHandler } // NewHandlers 创建所有Handler实例 func NewHandlers(c *container.Container) *Handlers { return &Handlers{ - Auth: NewAuthHandler(c), - User: NewUserHandler(c), - Texture: NewTextureHandler(c), + Auth: NewAuthHandler(c), + User: NewUserHandler(c), + Texture: NewTextureHandler(c), + Profile: NewProfileHandler(c), + Captcha: NewCaptchaHandler(c), + Yggdrasil: NewYggdrasilHandler(c), } } @@ -47,14 +50,14 @@ func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { // 材质路由 registerTextureRoutes(v1, h.Texture) - // 档案路由(暂时保持原有方式) - registerProfileRoutes(v1) + // 档案路由 + registerProfileRoutesWithDI(v1, h.Profile) - // 验证码路由(暂时保持原有方式) - registerCaptchaRoutes(v1) + // 验证码路由 + registerCaptchaRoutesWithDI(v1, h.Captcha) - // Yggdrasil API路由组(暂时保持原有方式) - registerYggdrasilRoutes(v1) + // Yggdrasil API路由组 + registerYggdrasilRoutesWithDI(v1, h.Yggdrasil) // 系统路由 registerSystemRoutes(v1) @@ -115,59 +118,59 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { } } -// registerProfileRoutes 注册档案路由(保持原有方式,后续改造) -func registerProfileRoutes(v1 *gin.RouterGroup) { +// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) +func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { profileGroup := v1.Group("/profile") { // 公开路由(无需认证) - profileGroup.GET("/:uuid", GetProfile) + profileGroup.GET("/:uuid", h.Get) // 需要认证的路由 profileAuth := profileGroup.Group("") profileAuth.Use(middleware.AuthMiddleware()) { - profileAuth.POST("/", CreateProfile) - profileAuth.GET("/", GetProfiles) - profileAuth.PUT("/:uuid", UpdateProfile) - profileAuth.DELETE("/:uuid", DeleteProfile) - profileAuth.POST("/:uuid/activate", SetActiveProfile) + profileAuth.POST("/", h.Create) + profileAuth.GET("/", h.List) + profileAuth.PUT("/:uuid", h.Update) + profileAuth.DELETE("/:uuid", h.Delete) + profileAuth.POST("/:uuid/activate", h.SetActive) } } } -// registerCaptchaRoutes 注册验证码路由(保持原有方式) -func registerCaptchaRoutes(v1 *gin.RouterGroup) { +// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本) +func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) { captchaGroup := v1.Group("/captcha") { - captchaGroup.GET("/generate", Generate) - captchaGroup.POST("/verify", Verify) + captchaGroup.GET("/generate", h.Generate) + captchaGroup.POST("/verify", h.Verify) } } -// registerYggdrasilRoutes 注册Yggdrasil API路由(保持原有方式) -func registerYggdrasilRoutes(v1 *gin.RouterGroup) { +// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本) +func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) { ygg := v1.Group("/yggdrasil") { - ygg.GET("", GetMetaData) - ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates) + ygg.GET("", h.GetMetaData) + ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates) authserver := ygg.Group("/authserver") { - authserver.POST("/authenticate", Authenticate) - authserver.POST("/validate", ValidToken) - authserver.POST("/refresh", RefreshToken) - authserver.POST("/invalidate", InvalidToken) - authserver.POST("/signout", SignOut) + authserver.POST("/authenticate", h.Authenticate) + authserver.POST("/validate", h.ValidToken) + authserver.POST("/refresh", h.RefreshToken) + authserver.POST("/invalidate", h.InvalidToken) + authserver.POST("/signout", h.SignOut) } sessionServer := ygg.Group("/sessionserver") { - sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID) - sessionServer.POST("/session/minecraft/join", JoinServer) - sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer) + sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID) + sessionServer.POST("/session/minecraft/join", h.JoinServer) + sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer) } api := ygg.Group("/api") profiles := api.Group("/profiles") { - profiles.POST("/minecraft", GetProfilesByName) + profiles.POST("/minecraft", h.GetProfilesByName) } } } @@ -188,4 +191,3 @@ func registerSystemRoutes(v1 *gin.RouterGroup) { }) } } - diff --git a/internal/handler/yggdrasil_handler_di.go b/internal/handler/yggdrasil_handler_di.go new file mode 100644 index 0000000..c4fb8f3 --- /dev/null +++ b/internal/handler/yggdrasil_handler_di.go @@ -0,0 +1,454 @@ +package handler + +import ( + "bytes" + "carrotskin/internal/container" + "carrotskin/internal/model" + "carrotskin/internal/service" + "carrotskin/pkg/utils" + "io" + "net/http" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// YggdrasilHandler Yggdrasil API处理器 +type YggdrasilHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewYggdrasilHandler 创建YggdrasilHandler实例 +func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler { + return &YggdrasilHandler{ + container: c, + logger: c.Logger, + } +} + +// Authenticate 用户认证 +func (h *YggdrasilHandler) Authenticate(c *gin.Context) { + rawData, err := io.ReadAll(c.Request.Body) + if err != nil { + h.logger.Error("读取请求体失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) + return + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData)) + + var request AuthenticateRequest + if err = c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析认证请求失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + var userId int64 + var profile *model.Profile + var UUID string + + if emailRegex.MatchString(request.Identifier) { + userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier) + } else { + profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier) + if err != nil { + h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) + c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) + return + } + userId = profile.UserID + UUID = profile.UUID + } + + if err != nil { + h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err)) + c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"}) + return + } + + if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil { + h.logger.Warn("认证失败: 密码错误", zap.Error(err)) + c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword}) + return + } + + selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(h.container.DB, h.logger, userId, UUID, request.ClientToken) + if err != nil { + h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + user, err := service.GetUserByID(userId) + if err != nil { + h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId)) + } + + availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles)) + for _, p := range availableProfiles { + availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p)) + } + + response := AuthenticateResponse{ + AccessToken: accessToken, + ClientToken: clientToken, + AvailableProfiles: availableProfilesData, + } + + if selectedProfile != nil { + response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile) + } + + if request.RequestUser && user != nil { + response.User = service.SerializeUser(h.logger, user, UUID) + } + + h.logger.Info("用户认证成功", zap.Int64("userId", userId)) + c.JSON(http.StatusOK, response) +} + +// ValidToken 验证令牌 +func (h *YggdrasilHandler) ValidToken(c *gin.Context) { + var request ValidTokenRequest + if err := c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析验证令牌请求失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if service.ValidToken(h.container.DB, request.AccessToken, request.ClientToken) { + h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken)) + c.JSON(http.StatusNoContent, gin.H{"valid": true}) + } else { + h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken)) + c.JSON(http.StatusForbidden, gin.H{"valid": false}) + } +} + +// RefreshToken 刷新令牌 +func (h *YggdrasilHandler) RefreshToken(c *gin.Context) { + var request RefreshRequest + if err := c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析刷新令牌请求失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + UUID, err := service.GetUUIDByAccessToken(h.container.DB, request.AccessToken) + if err != nil { + h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + userID, _ := service.GetUserIDByAccessToken(h.container.DB, request.AccessToken) + UUID = utils.FormatUUID(UUID) + + profile, err := service.GetProfileByUUID(h.container.DB, UUID) + if err != nil { + h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + var profileData map[string]interface{} + var userData map[string]interface{} + var profileID string + + if request.SelectedProfile != nil { + profileIDValue, ok := request.SelectedProfile["id"] + if !ok { + h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID)) + c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"}) + return + } + + profileID, ok = profileIDValue.(string) + if !ok { + h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID)) + c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"}) + return + } + + profileID = utils.FormatUUID(profileID) + + if profile.UserID != userID { + h.logger.Warn("刷新令牌失败: 用户不匹配", + zap.Int64("userId", userID), + zap.Int64("profileUserId", profile.UserID), + ) + c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch}) + return + } + + profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile) + } + + user, _ := service.GetUserByID(userID) + if request.RequestUser && user != nil { + userData = service.SerializeUser(h.logger, user, UUID) + } + + newAccessToken, newClientToken, err := service.RefreshToken(h.container.DB, h.logger, + request.AccessToken, + request.ClientToken, + profileID, + ) + if err != nil { + h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + h.logger.Info("刷新令牌成功", zap.Int64("userId", userID)) + c.JSON(http.StatusOK, RefreshResponse{ + AccessToken: newAccessToken, + ClientToken: newClientToken, + SelectedProfile: profileData, + User: userData, + }) +} + +// InvalidToken 使令牌失效 +func (h *YggdrasilHandler) InvalidToken(c *gin.Context) { + var request ValidTokenRequest + if err := c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析使令牌失效请求失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + service.InvalidToken(h.container.DB, h.logger, request.AccessToken) + h.logger.Info("令牌已失效", zap.String("token", request.AccessToken)) + c.JSON(http.StatusNoContent, gin.H{}) +} + +// SignOut 用户登出 +func (h *YggdrasilHandler) SignOut(c *gin.Context) { + var request SignOutRequest + if err := c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析登出请求失败", zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + if !emailRegex.MatchString(request.Email) { + h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email)) + c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat}) + return + } + + user, err := service.GetUserByEmail(request.Email) + if err != nil || user == nil { + h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"}) + return + } + + if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil { + h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID)) + c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword}) + return + } + + service.InvalidUserTokens(h.container.DB, h.logger, user.ID) + h.logger.Info("用户登出成功", zap.Int64("userId", user.ID)) + c.JSON(http.StatusNoContent, gin.H{"valid": true}) +} + +// GetProfileByUUID 根据UUID获取档案 +func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) { + uuid := utils.FormatUUID(c.Param("uuid")) + h.logger.Info("获取配置文件请求", zap.String("uuid", uuid)) + + profile, err := service.GetProfileByUUID(h.container.DB, uuid) + if err != nil { + h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid)) + standardResponse(c, http.StatusInternalServerError, nil, err.Error()) + return + } + + h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name)) + c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) +} + +// JoinServer 加入服务器 +func (h *YggdrasilHandler) JoinServer(c *gin.Context) { + var request JoinServerRequest + clientIP := c.ClientIP() + + if err := c.ShouldBindJSON(&request); err != nil { + h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP)) + standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest) + return + } + + h.logger.Info("收到加入服务器请求", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), + ) + + if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { + h.logger.Error("加入服务器失败", + zap.Error(err), + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), + ) + standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed) + return + } + + h.logger.Info("加入服务器成功", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), + ) + c.Status(http.StatusNoContent) +} + +// HasJoinedServer 验证玩家是否已加入服务器 +func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) { + clientIP, _ := c.GetQuery("ip") + + serverID, exists := c.GetQuery("serverId") + if !exists || serverID == "" { + h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP)) + standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired) + return + } + + username, exists := c.GetQuery("username") + if !exists || username == "" { + h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP)) + standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired) + return + } + + h.logger.Info("收到会话验证请求", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("ip", clientIP), + ) + + if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil { + h.logger.Warn("会话验证失败", + zap.Error(err), + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("ip", clientIP), + ) + standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed) + return + } + + profile, err := service.GetProfileByUUID(h.container.DB, username) + if err != nil { + h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username)) + standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound) + return + } + + h.logger.Info("会话验证成功", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("uuid", profile.UUID), + ) + c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) +} + +// GetProfilesByName 批量获取配置文件 +func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) { + var names []string + + if err := c.ShouldBindJSON(&names); err != nil { + h.logger.Error("解析名称数组请求失败", zap.Error(err)) + standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams) + return + } + + h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names))) + + profiles, err := service.GetProfilesDataByNames(h.container.DB, names) + if err != nil { + h.logger.Error("获取配置文件失败", zap.Error(err)) + } + + h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles))) + c.JSON(http.StatusOK, profiles) +} + +// GetMetaData 获取Yggdrasil元数据 +func (h *YggdrasilHandler) GetMetaData(c *gin.Context) { + meta := gin.H{ + "implementationName": "CellAuth", + "implementationVersion": "0.0.1", + "serverName": "LittleLan's Yggdrasil Server Implementation.", + "links": gin.H{ + "homepage": "https://skin.littlelan.cn", + "register": "https://skin.littlelan.cn/auth", + }, + "feature.non_email_login": true, + "feature.enable_profile_key": true, + } + + skinDomains := []string{".hitwh.games", ".littlelan.cn"} + signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis) + if err != nil { + h.logger.Error("获取公钥失败", zap.Error(err)) + standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) + return + } + + h.logger.Info("提供元数据") + c.JSON(http.StatusOK, gin.H{ + "meta": meta, + "skinDomains": skinDomains, + "signaturePublickey": signature, + }) +} + +// GetPlayerCertificates 获取玩家证书 +func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"}) + c.Abort() + return + } + + bearerPrefix := "Bearer " + if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) + c.Abort() + return + } + + tokenID := authHeader[len(bearerPrefix):] + if tokenID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) + c.Abort() + return + } + + uuid, err := service.GetUUIDByAccessToken(h.container.DB, tokenID) + if uuid == "" { + h.logger.Error("获取玩家UUID失败", zap.Error(err)) + standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) + return + } + + uuid = utils.FormatUUID(uuid) + + certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid) + if err != nil { + h.logger.Error("生成玩家证书失败", zap.Error(err)) + standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) + return + } + + h.logger.Info("成功生成玩家证书") + c.JSON(http.StatusOK, certificate) +} diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index f72ca88..8fabb7c 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -82,4 +82,3 @@ type YggdrasilRepository interface { GetPasswordByID(id int64) (string, error) ResetPassword(id int64, password string) error } - diff --git a/internal/repository/profile_repository_impl.go b/internal/repository/profile_repository_impl.go index ebe3fdb..5eb4e9e 100644 --- a/internal/repository/profile_repository_impl.go +++ b/internal/repository/profile_repository_impl.go @@ -146,4 +146,3 @@ func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.K return nil }) } - diff --git a/internal/repository/texture_repository_impl.go b/internal/repository/texture_repository_impl.go index c6a2971..82f37df 100644 --- a/internal/repository/texture_repository_impl.go +++ b/internal/repository/texture_repository_impl.go @@ -172,4 +172,3 @@ func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, erro Count(&count).Error return count, err } - diff --git a/internal/repository/token_repository_impl.go b/internal/repository/token_repository_impl.go index e4c94e1..623f06a 100644 --- a/internal/repository/token_repository_impl.go +++ b/internal/repository/token_repository_impl.go @@ -68,4 +68,3 @@ func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) return result.RowsAffected, result.Error } - diff --git a/internal/repository/user_repository_impl.go b/internal/repository/user_repository_impl.go index 57ec4c8..b932ae7 100644 --- a/internal/repository/user_repository_impl.go +++ b/internal/repository/user_repository_impl.go @@ -100,4 +100,3 @@ func handleNotFoundResult[T any](result *T, err error) (*T, error) { } return result, nil } - From e05ba3b041a12cb80a5dd376781e69d310378a8c Mon Sep 17 00:00:00 2001 From: lan Date: Tue, 2 Dec 2025 17:50:52 +0800 Subject: [PATCH 3/5] =?UTF-8?q?feat:=20Service=E5=B1=82=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增Service接口定义(internal/service/interfaces.go): - UserService: 用户认证、查询、更新等接口 - ProfileService: 档案CRUD、状态管理接口 - TextureService: 材质管理、收藏功能接口 - TokenService: 令牌生命周期管理接口 - VerificationService: 验证码服务接口 - CaptchaService: 滑动验证码接口 - UploadService: 上传服务接口 - YggdrasilService: Yggdrasil API接口 新增Service实现: - user_service_impl.go: 用户服务实现 - profile_service_impl.go: 档案服务实现 - texture_service_impl.go: 材质服务实现 - token_service_impl.go: 令牌服务实现 更新Container: - 添加Service层字段 - 初始化Service实例 - 添加With*Service选项函数 遵循Go最佳实践: - 接口定义与实现分离 - 依赖通过构造函数注入 - 便于单元测试mock --- internal/container/container.go | 41 +++ internal/handler/profile_handler_di.go | 1 - internal/service/interfaces.go | 144 +++++++++ internal/service/profile_service_impl.go | 233 ++++++++++++++ internal/service/texture_service_impl.go | 216 +++++++++++++ internal/service/token_service_impl.go | 278 +++++++++++++++++ internal/service/user_service_impl.go | 368 +++++++++++++++++++++++ 7 files changed, 1280 insertions(+), 1 deletion(-) create mode 100644 internal/service/interfaces.go create mode 100644 internal/service/profile_service_impl.go create mode 100644 internal/service/texture_service_impl.go create mode 100644 internal/service/token_service_impl.go create mode 100644 internal/service/user_service_impl.go diff --git a/internal/container/container.go b/internal/container/container.go index cde146e..2677f09 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -2,6 +2,7 @@ package container import ( "carrotskin/internal/repository" + "carrotskin/internal/service" "carrotskin/pkg/auth" "carrotskin/pkg/redis" "carrotskin/pkg/storage" @@ -26,6 +27,12 @@ type Container struct { TextureRepo repository.TextureRepository TokenRepo repository.TokenRepository ConfigRepo repository.SystemConfigRepository + + // Service层 + UserService service.UserService + ProfileService service.ProfileService + TextureService service.TextureService + TokenService service.TokenService } // NewContainer 创建依赖容器 @@ -51,6 +58,12 @@ func NewContainer( c.TokenRepo = repository.NewTokenRepository(db) c.ConfigRepo = repository.NewSystemConfigRepository(db) + // 初始化Service + c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, logger) + c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, logger) + c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, logger) + c.TokenService = service.NewTokenService(c.TokenRepo, c.ProfileRepo, logger) + return c } @@ -135,3 +148,31 @@ func WithConfigRepo(repo repository.SystemConfigRepository) Option { c.ConfigRepo = repo } } + +// WithUserService 设置用户服务 +func WithUserService(svc service.UserService) Option { + return func(c *Container) { + c.UserService = svc + } +} + +// WithProfileService 设置档案服务 +func WithProfileService(svc service.ProfileService) Option { + return func(c *Container) { + c.ProfileService = svc + } +} + +// WithTextureService 设置材质服务 +func WithTextureService(svc service.TextureService) Option { + return func(c *Container) { + c.TextureService = svc + } +} + +// WithTokenService 设置令牌服务 +func WithTokenService(svc service.TokenService) Option { + return func(c *Container) { + c.TokenService = svc + } +} diff --git a/internal/handler/profile_handler_di.go b/internal/handler/profile_handler_di.go index 6fdbeb9..d9d8e3b 100644 --- a/internal/handler/profile_handler_di.go +++ b/internal/handler/profile_handler_di.go @@ -244,4 +244,3 @@ func (h *ProfileHandler) SetActive(c *gin.Context) { RespondSuccess(c, gin.H{"message": "设置成功"}) } - diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go new file mode 100644 index 0000000..55a9f1d --- /dev/null +++ b/internal/service/interfaces.go @@ -0,0 +1,144 @@ +// Package service 定义业务逻辑层接口 +package service + +import ( + "carrotskin/internal/model" + "carrotskin/pkg/storage" + "context" + + "go.uber.org/zap" +) + +// UserService 用户服务接口 +type UserService interface { + // 用户认证 + Register(username, password, email, avatar string) (*model.User, string, error) + Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) + + // 用户查询 + GetByID(id int64) (*model.User, error) + GetByEmail(email string) (*model.User, error) + + // 用户更新 + UpdateInfo(user *model.User) error + UpdateAvatar(userID int64, avatarURL string) error + ChangePassword(userID int64, oldPassword, newPassword string) error + ResetPassword(email, newPassword string) error + ChangeEmail(userID int64, newEmail string) error + + // URL验证 + ValidateAvatarURL(avatarURL string) error + + // 配置获取 + GetMaxProfilesPerUser() int + GetMaxTexturesPerUser() int +} + +// ProfileService 档案服务接口 +type ProfileService interface { + // 档案CRUD + Create(userID int64, name string) (*model.Profile, error) + GetByUUID(uuid string) (*model.Profile, error) + GetByUserID(userID int64) ([]*model.Profile, error) + Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) + Delete(uuid string, userID int64) error + + // 档案状态 + SetActive(uuid string, userID int64) error + CheckLimit(userID int64, maxProfiles int) error + + // 批量查询 + GetByNames(names []string) ([]*model.Profile, error) + GetByProfileName(name string) (*model.Profile, error) +} + +// TextureService 材质服务接口 +type TextureService interface { + // 材质CRUD + Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) + GetByID(id int64) (*model.Texture, error) + GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) + Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) + Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) + Delete(textureID, uploaderID int64) error + + // 收藏 + ToggleFavorite(userID, textureID int64) (bool, error) + GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) + + // 限制检查 + CheckUploadLimit(uploaderID int64, maxTextures int) error +} + +// TokenService 令牌服务接口 +type TokenService interface { + // 令牌管理 + Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) + Validate(accessToken, clientToken string) bool + Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) + Invalidate(accessToken string) + InvalidateUserTokens(userID int64) + + // 令牌查询 + GetUUIDByAccessToken(accessToken string) (string, error) + GetUserIDByAccessToken(accessToken string) (int64, error) +} + +// VerificationService 验证码服务接口 +type VerificationService interface { + SendCode(ctx context.Context, email, codeType string) error + VerifyCode(ctx context.Context, email, code, codeType string) error +} + +// CaptchaService 滑动验证码服务接口 +type CaptchaService interface { + Generate(ctx context.Context) (masterImg, tileImg, captchaID string, y int, err error) + Verify(ctx context.Context, dx int, captchaID string) (bool, error) +} + +// UploadService 上传服务接口 +type UploadService interface { + GenerateAvatarUploadURL(ctx context.Context, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) + GenerateTextureUploadURL(ctx context.Context, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) +} + +// YggdrasilService Yggdrasil服务接口 +type YggdrasilService interface { + // 用户认证 + GetUserIDByEmail(email string) (int64, error) + VerifyPassword(password string, userID int64) error + + // 会话管理 + JoinServer(serverID, accessToken, selectedProfile, ip string) error + HasJoinedServer(serverID, username, ip string) error + + // 密码管理 + ResetYggdrasilPassword(userID int64) (string, error) + + // 序列化 + SerializeProfile(profile model.Profile) map[string]interface{} + SerializeUser(user *model.User, uuid string) map[string]interface{} + + // 证书 + GeneratePlayerCertificate(uuid string) (map[string]interface{}, error) + GetPublicKey() (string, error) +} + +// Services 服务集合 +type Services struct { + User UserService + Profile ProfileService + Texture TextureService + Token TokenService + Verification VerificationService + Captcha CaptchaService + Upload UploadService + Yggdrasil YggdrasilService +} + +// ServiceDeps 服务依赖 +type ServiceDeps struct { + Logger *zap.Logger + Storage *storage.StorageClient +} + diff --git a/internal/service/profile_service_impl.go b/internal/service/profile_service_impl.go new file mode 100644 index 0000000..a84dcad --- /dev/null +++ b/internal/service/profile_service_impl.go @@ -0,0 +1,233 @@ +package service + +import ( + "carrotskin/internal/model" + "carrotskin/internal/repository" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + + "github.com/google/uuid" + "go.uber.org/zap" + "gorm.io/gorm" +) + +// profileServiceImpl ProfileService的实现 +type profileServiceImpl struct { + profileRepo repository.ProfileRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewProfileService 创建ProfileService实例 +func NewProfileService( + profileRepo repository.ProfileRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) ProfileService { + return &profileServiceImpl{ + profileRepo: profileRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) { + // 验证用户存在 + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { + return nil, errors.New("用户不存在") + } + if user.Status != 1 { + return nil, errors.New("用户状态异常") + } + + // 检查角色名是否已存在 + existingName, err := s.profileRepo.FindByName(name) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("查询角色名失败: %w", err) + } + if existingName != nil { + return nil, errors.New("角色名已被使用") + } + + // 生成UUID和RSA密钥 + profileUUID := uuid.New().String() + privateKey, err := generateRSAPrivateKeyInternal() + if err != nil { + return nil, fmt.Errorf("生成RSA密钥失败: %w", err) + } + + // 创建档案 + profile := &model.Profile{ + UUID: profileUUID, + UserID: userID, + Name: name, + RSAPrivateKey: privateKey, + IsActive: true, + } + + if err := s.profileRepo.Create(profile); err != nil { + return nil, fmt.Errorf("创建档案失败: %w", err) + } + + // 设置活跃状态 + if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { + return nil, fmt.Errorf("设置活跃状态失败: %w", err) + } + + return profile, nil +} + +func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrProfileNotFound + } + return nil, fmt.Errorf("查询档案失败: %w", err) + } + return profile, nil +} + +func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) { + profiles, err := s.profileRepo.FindByUserID(userID) + if err != nil { + return nil, fmt.Errorf("查询档案列表失败: %w", err) + } + return profiles, nil +} + +func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrProfileNotFound + } + return nil, fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return nil, ErrProfileNoPermission + } + + // 检查角色名是否重复 + if name != nil && *name != profile.Name { + existingName, err := s.profileRepo.FindByName(*name) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("查询角色名失败: %w", err) + } + if existingName != nil { + return nil, errors.New("角色名已被使用") + } + profile.Name = *name + } + + // 更新皮肤和披风 + if skinID != nil { + profile.SkinID = skinID + } + if capeID != nil { + profile.CapeID = capeID + } + + if err := s.profileRepo.Update(profile); err != nil { + return nil, fmt.Errorf("更新档案失败: %w", err) + } + + return s.profileRepo.FindByUUID(uuid) +} + +func (s *profileServiceImpl) Delete(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.Delete(uuid); err != nil { + return fmt.Errorf("删除档案失败: %w", err) + } + return nil +} + +func (s *profileServiceImpl) SetActive(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.SetActive(uuid, userID); err != nil { + return fmt.Errorf("设置活跃状态失败: %w", err) + } + + if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { + return fmt.Errorf("更新使用时间失败: %w", err) + } + + return nil +} + +func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error { + count, err := s.profileRepo.CountByUserID(userID) + if err != nil { + return fmt.Errorf("查询档案数量失败: %w", err) + } + + if int(count) >= maxProfiles { + return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles) + } + return nil +} + +func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) { + profiles, err := s.profileRepo.GetByNames(names) + if err != nil { + return nil, fmt.Errorf("查找失败: %w", err) + } + return profiles, nil +} + +func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByName(name) + if err != nil { + return nil, errors.New("用户角色未创建") + } + return profile, nil +} + +// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式) +func generateRSAPrivateKeyInternal() (string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + return string(privateKeyPEM), nil +} + diff --git a/internal/service/texture_service_impl.go b/internal/service/texture_service_impl.go new file mode 100644 index 0000000..9a82ac8 --- /dev/null +++ b/internal/service/texture_service_impl.go @@ -0,0 +1,216 @@ +package service + +import ( + "carrotskin/internal/model" + "carrotskin/internal/repository" + "errors" + "fmt" + + "go.uber.org/zap" +) + +// textureServiceImpl TextureService的实现 +type textureServiceImpl struct { + textureRepo repository.TextureRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewTextureService 创建TextureService实例 +func NewTextureService( + textureRepo repository.TextureRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) TextureService { + return &textureServiceImpl{ + textureRepo: textureRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { + // 验证用户存在 + user, err := s.userRepo.FindByID(uploaderID) + if err != nil || user == nil { + return nil, ErrUserNotFound + } + + // 检查Hash是否已存在 + existingTexture, err := s.textureRepo.FindByHash(hash) + if err != nil { + return nil, err + } + if existingTexture != nil { + return nil, errors.New("该材质已存在") + } + + // 转换材质类型 + textureTypeEnum, err := parseTextureTypeInternal(textureType) + if err != nil { + return nil, err + } + + // 创建材质 + texture := &model.Texture{ + UploaderID: uploaderID, + Name: name, + Description: description, + Type: textureTypeEnum, + URL: url, + Hash: hash, + Size: size, + IsPublic: isPublic, + IsSlim: isSlim, + Status: 1, + DownloadCount: 0, + FavoriteCount: 0, + } + + if err := s.textureRepo.Create(texture); err != nil { + return nil, err + } + + return texture, nil +} + +func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) { + texture, err := s.textureRepo.FindByID(id) + if err != nil { + return nil, err + } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.Status == -1 { + return nil, errors.New("材质已删除") + } + return texture, nil +} + +func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) +} + +func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) +} + +func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { + return nil, err + } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return nil, ErrTextureNoPermission + } + + // 更新字段 + updates := make(map[string]interface{}) + if name != "" { + updates["name"] = name + } + if description != "" { + updates["description"] = description + } + if isPublic != nil { + updates["is_public"] = *isPublic + } + + if len(updates) > 0 { + if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { + return nil, err + } + } + + return s.textureRepo.FindByID(textureID) +} + +func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { + return err + } + if texture == nil { + return ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return ErrTextureNoPermission + } + + return s.textureRepo.Delete(textureID) +} + +func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) { + // 确保材质存在 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { + return false, err + } + if texture == nil { + return false, ErrTextureNotFound + } + + isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) + if err != nil { + return false, err + } + + if isFavorited { + // 已收藏 -> 取消收藏 + if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { + return false, err + } + if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { + return false, err + } + return false, nil + } + + // 未收藏 -> 添加收藏 + if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { + return false, err + } + if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { + return false, err + } + return true, nil +} + +func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.GetUserFavorites(userID, page, pageSize) +} + +func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error { + count, err := s.textureRepo.CountByUploaderID(uploaderID) + if err != nil { + return err + } + + if count >= int64(maxTextures) { + return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures) + } + + return nil +} + +// parseTextureTypeInternal 解析材质类型 +func parseTextureTypeInternal(textureType string) (model.TextureType, error) { + switch textureType { + case "SKIN": + return model.TextureTypeSkin, nil + case "CAPE": + return model.TextureTypeCape, nil + default: + return "", errors.New("无效的材质类型") + } +} + diff --git a/internal/service/token_service_impl.go b/internal/service/token_service_impl.go new file mode 100644 index 0000000..8d49910 --- /dev/null +++ b/internal/service/token_service_impl.go @@ -0,0 +1,278 @@ +package service + +import ( + "carrotskin/internal/model" + "carrotskin/internal/repository" + "context" + "errors" + "fmt" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" +) + +// tokenServiceImpl TokenService的实现 +type tokenServiceImpl struct { + tokenRepo repository.TokenRepository + profileRepo repository.ProfileRepository + logger *zap.Logger +} + +// NewTokenService 创建TokenService实例 +func NewTokenService( + tokenRepo repository.TokenRepository, + profileRepo repository.ProfileRepository, + logger *zap.Logger, +) TokenService { + return &tokenServiceImpl{ + tokenRepo: tokenRepo, + profileRepo: profileRepo, + logger: logger, + } +} + +const ( + tokenExtendedTimeout = 10 * time.Second + tokensMaxCount = 10 +) + +func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { + var ( + selectedProfileID *model.Profile + availableProfiles []*model.Profile + ) + + // 设置超时上下文 + _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + defer cancel() + + // 验证用户存在 + if UUID != "" { + _, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) + } + } + + // 生成令牌 + if clientToken == "" { + clientToken = uuid.New().String() + } + + accessToken := uuid.New().String() + token := model.Token{ + AccessToken: accessToken, + ClientToken: clientToken, + UserID: userID, + Usable: true, + IssueDate: time.Now(), + } + + // 获取用户配置文件 + profiles, err := s.profileRepo.FindByUserID(userID) + if err != nil { + return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) + } + + // 如果用户只有一个配置文件,自动选择 + if len(profiles) == 1 { + selectedProfileID = profiles[0] + token.ProfileId = selectedProfileID.UUID + } + availableProfiles = profiles + + // 插入令牌 + err = s.tokenRepo.Create(&token) + if err != nil { + return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) + } + + // 清理多余的令牌 + go s.checkAndCleanupExcessTokens(userID) + + return selectedProfileID, availableProfiles, accessToken, clientToken, nil +} + +func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool { + if accessToken == "" { + return false + } + + token, err := s.tokenRepo.FindByAccessToken(accessToken) + if err != nil { + return false + } + + if !token.Usable { + return false + } + + if clientToken == "" { + return true + } + + return token.ClientToken == clientToken +} + +func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { + if accessToken == "" { + return "", "", errors.New("accessToken不能为空") + } + + // 查找旧令牌 + oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", "", errors.New("accessToken无效") + } + s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken)) + return "", "", fmt.Errorf("查询令牌失败: %w", err) + } + + // 验证profile + if selectedProfileID != "" { + valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) + if validErr != nil { + s.logger.Error("验证Profile失败", + zap.Error(err), + zap.Int64("userId", oldToken.UserID), + zap.String("profileId", selectedProfileID), + ) + return "", "", fmt.Errorf("验证角色失败: %w", err) + } + if !valid { + return "", "", errors.New("角色与用户不匹配") + } + } + + // 检查 clientToken 是否有效 + if clientToken != "" && clientToken != oldToken.ClientToken { + return "", "", errors.New("clientToken无效") + } + + // 检查 selectedProfileID 的逻辑 + if selectedProfileID != "" { + if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID { + return "", "", errors.New("原令牌已绑定角色,无法选择新角色") + } + } else { + selectedProfileID = oldToken.ProfileId + } + + // 生成新令牌 + newAccessToken := uuid.New().String() + newToken := model.Token{ + AccessToken: newAccessToken, + ClientToken: oldToken.ClientToken, + UserID: oldToken.UserID, + Usable: true, + ProfileId: selectedProfileID, + IssueDate: time.Now(), + } + + // 先插入新令牌,再删除旧令牌 + err = s.tokenRepo.Create(&newToken) + if err != nil { + s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) + return "", "", fmt.Errorf("创建新Token失败: %w", err) + } + + err = s.tokenRepo.DeleteByAccessToken(accessToken) + if err != nil { + s.logger.Warn("删除旧Token失败,但新Token已创建", + zap.Error(err), + zap.String("oldToken", oldToken.AccessToken), + zap.String("newToken", newAccessToken), + ) + } + + s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken)) + return newAccessToken, oldToken.ClientToken, nil +} + +func (s *tokenServiceImpl) Invalidate(accessToken string) { + if accessToken == "" { + return + } + + err := s.tokenRepo.DeleteByAccessToken(accessToken) + if err != nil { + s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) + return + } + s.logger.Info("成功删除Token", zap.String("token", accessToken)) +} + +func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) { + if userID == 0 { + return + } + + err := s.tokenRepo.DeleteByUserID(userID) + if err != nil { + s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) + return + } + + s.logger.Info("成功删除用户Token", zap.Int64("userId", userID)) +} + +func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) { + return s.tokenRepo.GetUUIDByAccessToken(accessToken) +} + +func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { + return s.tokenRepo.GetUserIDByAccessToken(accessToken) +} + +// 私有辅助方法 + +func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) { + if userID == 0 { + return + } + + tokens, err := s.tokenRepo.GetByUserID(userID) + if err != nil { + s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if len(tokens) <= tokensMaxCount { + return + } + + tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) + for i := tokensMaxCount; i < len(tokens); i++ { + tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) + } + + deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) + if err != nil { + s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if deletedCount > 0 { + s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) + } +} + +func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) { + if userID == 0 || UUID == "" { + return false, errors.New("用户ID或配置文件ID不能为空") + } + + profile, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, errors.New("配置文件不存在") + } + return false, fmt.Errorf("验证配置文件失败: %w", err) + } + return profile.UserID == userID, nil +} + diff --git a/internal/service/user_service_impl.go b/internal/service/user_service_impl.go new file mode 100644 index 0000000..2b7250e --- /dev/null +++ b/internal/service/user_service_impl.go @@ -0,0 +1,368 @@ +package service + +import ( + "carrotskin/internal/model" + "carrotskin/internal/repository" + "carrotskin/pkg/auth" + "carrotskin/pkg/config" + "carrotskin/pkg/redis" + "context" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "go.uber.org/zap" +) + +// userServiceImpl UserService的实现 +type userServiceImpl struct { + userRepo repository.UserRepository + configRepo repository.SystemConfigRepository + jwtService *auth.JWTService + redis *redis.Client + logger *zap.Logger +} + +// NewUserService 创建UserService实例 +func NewUserService( + userRepo repository.UserRepository, + configRepo repository.SystemConfigRepository, + jwtService *auth.JWTService, + redisClient *redis.Client, + logger *zap.Logger, +) UserService { + return &userServiceImpl{ + userRepo: userRepo, + configRepo: configRepo, + jwtService: jwtService, + redis: redisClient, + logger: logger, + } +} + +func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) { + // 检查用户名是否已存在 + existingUser, err := s.userRepo.FindByUsername(username) + if err != nil { + return nil, "", err + } + if existingUser != nil { + return nil, "", errors.New("用户名已存在") + } + + // 检查邮箱是否已存在 + existingEmail, err := s.userRepo.FindByEmail(email) + if err != nil { + return nil, "", err + } + if existingEmail != nil { + return nil, "", errors.New("邮箱已被注册") + } + + // 加密密码 + hashedPassword, err := auth.HashPassword(password) + if err != nil { + return nil, "", errors.New("密码加密失败") + } + + // 确定头像URL + avatarURL := avatar + if avatarURL != "" { + if err := s.ValidateAvatarURL(avatarURL); err != nil { + return nil, "", err + } + } else { + avatarURL = s.getDefaultAvatar() + } + + // 创建用户 + user := &model.User{ + Username: username, + Password: hashedPassword, + Email: email, + Avatar: avatarURL, + Role: "user", + Status: 1, + Points: 0, + } + + if err := s.userRepo.Create(user); err != nil { + return nil, "", err + } + + // 生成JWT Token + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) + if err != nil { + return nil, "", errors.New("生成Token失败") + } + + return user, token, nil +} + +func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { + ctx := context.Background() + + // 检查账号是否被锁定 + if s.redis != nil { + identifier := usernameOrEmail + ":" + ipAddress + locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier) + if err == nil && locked { + return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1) + } + } + + // 查找用户 + var user *model.User + var err error + + if strings.Contains(usernameOrEmail, "@") { + user, err = s.userRepo.FindByEmail(usernameOrEmail) + } else { + user, err = s.userRepo.FindByUsername(usernameOrEmail) + } + + if err != nil { + return nil, "", err + } + if user == nil { + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在") + return nil, "", errors.New("用户名/邮箱或密码错误") + } + + // 检查用户状态 + if user.Status != 1 { + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用") + return nil, "", errors.New("账号已被禁用") + } + + // 验证密码 + if !auth.CheckPassword(user.Password, password) { + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误") + return nil, "", errors.New("用户名/邮箱或密码错误") + } + + // 登录成功,清除失败计数 + if s.redis != nil { + identifier := usernameOrEmail + ":" + ipAddress + _ = ClearLoginAttempts(ctx, s.redis, identifier) + } + + // 生成JWT Token + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) + if err != nil { + return nil, "", errors.New("生成Token失败") + } + + // 更新最后登录时间 + now := time.Now() + user.LastLoginAt = &now + _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ + "last_login_at": now, + }) + + // 记录成功登录日志 + s.logSuccessLogin(user.ID, ipAddress, userAgent) + + return user, token, nil +} + +func (s *userServiceImpl) GetByID(id int64) (*model.User, error) { + return s.userRepo.FindByID(id) +} + +func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) { + return s.userRepo.FindByEmail(email) +} + +func (s *userServiceImpl) UpdateInfo(user *model.User) error { + return s.userRepo.Update(user) +} + +func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error { + return s.userRepo.UpdateFields(userID, map[string]interface{}{ + "avatar": avatarURL, + }) +} + +func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error { + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { + return errors.New("用户不存在") + } + + if !auth.CheckPassword(user.Password, oldPassword) { + return errors.New("原密码错误") + } + + hashedPassword, err := auth.HashPassword(newPassword) + if err != nil { + return errors.New("密码加密失败") + } + + return s.userRepo.UpdateFields(userID, map[string]interface{}{ + "password": hashedPassword, + }) +} + +func (s *userServiceImpl) ResetPassword(email, newPassword string) error { + user, err := s.userRepo.FindByEmail(email) + if err != nil || user == nil { + return errors.New("用户不存在") + } + + hashedPassword, err := auth.HashPassword(newPassword) + if err != nil { + return errors.New("密码加密失败") + } + + return s.userRepo.UpdateFields(user.ID, map[string]interface{}{ + "password": hashedPassword, + }) +} + +func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error { + existingUser, err := s.userRepo.FindByEmail(newEmail) + if err != nil { + return err + } + if existingUser != nil && existingUser.ID != userID { + return errors.New("邮箱已被其他用户使用") + } + + return s.userRepo.UpdateFields(userID, map[string]interface{}{ + "email": newEmail, + }) +} + +func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error { + if avatarURL == "" { + return nil + } + + // 允许相对路径 + if strings.HasPrefix(avatarURL, "/") { + return nil + } + + // 解析URL + parsedURL, err := url.Parse(avatarURL) + if err != nil { + return errors.New("无效的URL格式") + } + + // 必须是HTTP或HTTPS协议 + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return errors.New("URL必须使用http或https协议") + } + + host := parsedURL.Hostname() + if host == "" { + return errors.New("URL缺少主机名") + } + + // 从配置获取允许的域名列表 + cfg, err := config.GetConfig() + if err != nil { + allowedDomains := []string{"localhost", "127.0.0.1"} + return s.checkDomainAllowed(host, allowedDomains) + } + + return s.checkDomainAllowed(host, cfg.Security.AllowedDomains) +} + +func (s *userServiceImpl) GetMaxProfilesPerUser() int { + config, err := s.configRepo.GetByKey("max_profiles_per_user") + if err != nil || config == nil { + return 5 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 5 + } + return value +} + +func (s *userServiceImpl) GetMaxTexturesPerUser() int { + config, err := s.configRepo.GetByKey("max_textures_per_user") + if err != nil || config == nil { + return 50 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 50 + } + return value +} + +// 私有辅助方法 + +func (s *userServiceImpl) getDefaultAvatar() string { + config, err := s.configRepo.GetByKey("default_avatar") + if err != nil || config == nil || config.Value == "" { + return "" + } + return config.Value +} + +func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error { + host = strings.ToLower(host) + + for _, allowed := range allowedDomains { + allowed = strings.ToLower(strings.TrimSpace(allowed)) + if allowed == "" { + continue + } + + if host == allowed { + return nil + } + + if strings.HasPrefix(allowed, "*.") { + suffix := allowed[1:] + if strings.HasSuffix(host, suffix) { + return nil + } + } + } + + return errors.New("URL域名不在允许的列表中") +} + +func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) { + if s.redis != nil { + identifier := usernameOrEmail + ":" + ipAddress + count, _ := RecordLoginFailure(ctx, s.redis, identifier) + if count >= MaxLoginAttempts { + s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") + return + } + } + s.logFailedLogin(userID, ipAddress, userAgent, reason) +} + +func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: true, + } + _ = s.userRepo.CreateLoginLog(log) +} + +func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: false, + FailureReason: reason, + } + _ = s.userRepo.CreateLoginLog(log) +} From 188a05caa71097f597115406a689d8836a8e621e Mon Sep 17 00:00:00 2001 From: lan Date: Tue, 2 Dec 2025 18:41:34 +0800 Subject: [PATCH 4/5] chore: Clean up code by removing trailing whitespace in multiple files --- .dockerignore | 1 + Dockerfile | 1 + internal/handler/captcha_handler_di.go | 1 + internal/handler/profile_handler_di.go | 1 + internal/handler/texture_handler_di.go | 1 + internal/repository/interfaces.go | 1 + internal/repository/profile_repository_impl.go | 1 + internal/repository/system_config_repository_impl.go | 1 + internal/repository/texture_repository_impl.go | 1 + internal/repository/token_repository_impl.go | 1 + internal/repository/user_repository_impl.go | 1 + internal/service/interfaces.go | 1 + internal/service/profile_service_impl.go | 1 + internal/service/texture_service_impl.go | 1 - internal/service/token_service_impl.go | 1 - 15 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.dockerignore b/.dockerignore index 6686339..b5e12a9 100644 --- a/.dockerignore +++ b/.dockerignore @@ -76,3 +76,4 @@ minio-data/ + diff --git a/Dockerfile b/Dockerfile index b5a00ab..cebe971 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,3 +61,4 @@ ENTRYPOINT ["./server"] + diff --git a/internal/handler/captcha_handler_di.go b/internal/handler/captcha_handler_di.go index 8078aee..f9849d0 100644 --- a/internal/handler/captcha_handler_di.go +++ b/internal/handler/captcha_handler_di.go @@ -106,3 +106,4 @@ func (h *CaptchaHandler) Verify(c *gin.Context) { } } + diff --git a/internal/handler/profile_handler_di.go b/internal/handler/profile_handler_di.go index d9d8e3b..6fdbeb9 100644 --- a/internal/handler/profile_handler_di.go +++ b/internal/handler/profile_handler_di.go @@ -244,3 +244,4 @@ func (h *ProfileHandler) SetActive(c *gin.Context) { RespondSuccess(c, gin.H{"message": "设置成功"}) } + diff --git a/internal/handler/texture_handler_di.go b/internal/handler/texture_handler_di.go index 8233184..26bd558 100644 --- a/internal/handler/texture_handler_di.go +++ b/internal/handler/texture_handler_di.go @@ -282,3 +282,4 @@ func (h *TextureHandler) GetUserFavorites(c *gin.Context) { c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) } + diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 8fabb7c..f72ca88 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -82,3 +82,4 @@ type YggdrasilRepository interface { GetPasswordByID(id int64) (string, error) ResetPassword(id int64, password string) error } + diff --git a/internal/repository/profile_repository_impl.go b/internal/repository/profile_repository_impl.go index 5eb4e9e..ebe3fdb 100644 --- a/internal/repository/profile_repository_impl.go +++ b/internal/repository/profile_repository_impl.go @@ -146,3 +146,4 @@ func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.K return nil }) } + diff --git a/internal/repository/system_config_repository_impl.go b/internal/repository/system_config_repository_impl.go index 2bb5844..4ba261f 100644 --- a/internal/repository/system_config_repository_impl.go +++ b/internal/repository/system_config_repository_impl.go @@ -42,3 +42,4 @@ func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error { return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error } + diff --git a/internal/repository/texture_repository_impl.go b/internal/repository/texture_repository_impl.go index 82f37df..c6a2971 100644 --- a/internal/repository/texture_repository_impl.go +++ b/internal/repository/texture_repository_impl.go @@ -172,3 +172,4 @@ func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, erro Count(&count).Error return count, err } + diff --git a/internal/repository/token_repository_impl.go b/internal/repository/token_repository_impl.go index 623f06a..e4c94e1 100644 --- a/internal/repository/token_repository_impl.go +++ b/internal/repository/token_repository_impl.go @@ -68,3 +68,4 @@ func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) return result.RowsAffected, result.Error } + diff --git a/internal/repository/user_repository_impl.go b/internal/repository/user_repository_impl.go index b932ae7..57ec4c8 100644 --- a/internal/repository/user_repository_impl.go +++ b/internal/repository/user_repository_impl.go @@ -100,3 +100,4 @@ func handleNotFoundResult[T any](result *T, err error) (*T, error) { } return result, nil } + diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go index 55a9f1d..82f8507 100644 --- a/internal/service/interfaces.go +++ b/internal/service/interfaces.go @@ -142,3 +142,4 @@ type ServiceDeps struct { Storage *storage.StorageClient } + diff --git a/internal/service/profile_service_impl.go b/internal/service/profile_service_impl.go index a84dcad..a956793 100644 --- a/internal/service/profile_service_impl.go +++ b/internal/service/profile_service_impl.go @@ -231,3 +231,4 @@ func generateRSAPrivateKeyInternal() (string, error) { return string(privateKeyPEM), nil } + diff --git a/internal/service/texture_service_impl.go b/internal/service/texture_service_impl.go index 9a82ac8..eb19a82 100644 --- a/internal/service/texture_service_impl.go +++ b/internal/service/texture_service_impl.go @@ -213,4 +213,3 @@ func parseTextureTypeInternal(textureType string) (model.TextureType, error) { return "", errors.New("无效的材质类型") } } - diff --git a/internal/service/token_service_impl.go b/internal/service/token_service_impl.go index 8d49910..b128abf 100644 --- a/internal/service/token_service_impl.go +++ b/internal/service/token_service_impl.go @@ -275,4 +275,3 @@ func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (b } return profile.UserID == userID, nil } - From 801f1b1397990c61ee10c3f87c3dab91592025cf Mon Sep 17 00:00:00 2001 From: lafay <2021211506@stu.hit.edu.cn> Date: Tue, 2 Dec 2025 19:43:39 +0800 Subject: [PATCH 5/5] refactor: Implement dependency injection for handlers and services - Refactored AuthHandler, UserHandler, TextureHandler, ProfileHandler, CaptchaHandler, and YggdrasilHandler to use dependency injection. - Removed direct instantiation of services and repositories within handlers, replacing them with constructor injection. - Updated the container to initialize service instances and provide them to handlers. - Enhanced code structure for better testability and adherence to Go best practices. --- go.mod | 2 +- internal/handler/auth_handler.go | 76 +- internal/handler/auth_handler_di.go | 177 ---- internal/handler/captcha_handler.go | 81 +- internal/handler/captcha_handler_di.go | 109 --- internal/handler/profile_handler.go | 109 +-- internal/handler/profile_handler_di.go | 247 ------ internal/handler/routes.go | 267 +++--- internal/handler/routes_di.go | 193 ----- internal/handler/texture_handler.go | 183 ++-- internal/handler/texture_handler_di.go | 285 ------ internal/handler/user_handler.go | 158 ++-- internal/handler/user_handler_di.go | 233 ----- internal/handler/yggdrasil_handler.go | 384 ++++---- internal/handler/yggdrasil_handler_di.go | 454 ---------- internal/service/helpers_test.go | 50 ++ internal/service/mocks_test.go | 964 +++++++++++++++++++++ internal/service/profile_service.go | 221 ++--- internal/service/profile_service_impl.go | 234 ----- internal/service/profile_service_test.go | 333 ++++++- internal/service/serialize_service.go | 5 +- internal/service/serialize_service_test.go | 59 +- internal/service/texture_service.go | 179 ++-- internal/service/texture_service_impl.go | 215 ----- internal/service/texture_service_test.go | 357 ++++++++ internal/service/token_service.go | 240 ++--- internal/service/token_service_impl.go | 277 ------ internal/service/token_service_test.go | 328 ++++++- internal/service/upload_service.go | 42 +- internal/service/upload_service_test.go | 179 +++- internal/service/user_service.go | 303 +++---- internal/service/user_service_impl.go | 368 -------- internal/service/user_service_test.go | 445 +++++++--- 33 files changed, 3628 insertions(+), 4129 deletions(-) delete mode 100644 internal/handler/auth_handler_di.go delete mode 100644 internal/handler/captcha_handler_di.go delete mode 100644 internal/handler/profile_handler_di.go delete mode 100644 internal/handler/routes_di.go delete mode 100644 internal/handler/texture_handler_di.go delete mode 100644 internal/handler/user_handler_di.go delete mode 100644 internal/handler/yggdrasil_handler_di.go create mode 100644 internal/service/helpers_test.go create mode 100644 internal/service/mocks_test.go delete mode 100644 internal/service/profile_service_impl.go delete mode 100644 internal/service/texture_service_impl.go delete mode 100644 internal/service/token_service_impl.go delete mode 100644 internal/service/user_service_impl.go diff --git a/go.mod b/go.mod index 377b009..e083b3d 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.0 github.com/joho/godotenv v1.5.1 github.com/jordan-wright/email v4.0.1-0.20210109023952-943e75fe5223+incompatible - github.com/lib/pq v1.10.9 github.com/minio/minio-go/v7 v7.0.66 github.com/redis/go-redis/v9 v9.0.5 github.com/spf13/viper v1.21.0 @@ -28,6 +27,7 @@ require ( github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/stretchr/testify v1.11.1 // indirect golang.org/x/image v0.16.0 // indirect golang.org/x/sync v0.16.0 // indirect gorm.io/driver/mysql v1.5.6 // indirect diff --git a/internal/handler/auth_handler.go b/internal/handler/auth_handler.go index c2ae087..143c7ea 100644 --- a/internal/handler/auth_handler.go +++ b/internal/handler/auth_handler.go @@ -1,17 +1,29 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/auth" "carrotskin/pkg/email" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" "github.com/gin-gonic/gin" "go.uber.org/zap" ) +// AuthHandler 认证处理器(依赖注入版本) +type AuthHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewAuthHandler 创建AuthHandler实例 +func NewAuthHandler(c *container.Container) *AuthHandler { + return &AuthHandler{ + container: c, + logger: c.Logger, + } +} + // Register 用户注册 // @Summary 用户注册 // @Description 注册新用户账号 @@ -22,11 +34,7 @@ import ( // @Success 200 {object} model.Response "注册成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/register [post] -func Register(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - jwtService := auth.MustGetJWTService() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) Register(c *gin.Context) { var req types.RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -34,16 +42,16 @@ func Register(c *gin.Context) { } // 验证邮箱验证码 - if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } // 注册用户 - user, token, err := service.RegisterUser(jwtService, req.Username, req.Password, req.Email, req.Avatar) + user, token, err := h.container.UserService.Register(req.Username, req.Password, req.Email, req.Avatar) if err != nil { - loggerInstance.Error("用户注册失败", zap.Error(err)) + h.logger.Error("用户注册失败", zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } @@ -65,11 +73,7 @@ func Register(c *gin.Context) { // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Failure 401 {object} model.ErrorResponse "登录失败" // @Router /api/v1/auth/login [post] -func Login(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - jwtService := auth.MustGetJWTService() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) Login(c *gin.Context) { var req types.LoginRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -79,9 +83,9 @@ func Login(c *gin.Context) { ipAddress := c.ClientIP() userAgent := c.GetHeader("User-Agent") - user, token, err := service.LoginUserWithRateLimit(redisClient, jwtService, req.Username, req.Password, ipAddress, userAgent) + user, token, err := h.container.UserService.Login(req.Username, req.Password, ipAddress, userAgent) if err != nil { - loggerInstance.Warn("用户登录失败", + h.logger.Warn("用户登录失败", zap.String("username_or_email", req.Username), zap.String("ip", ipAddress), zap.Error(err), @@ -106,19 +110,21 @@ func Login(c *gin.Context) { // @Success 200 {object} model.Response "发送成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/send-code [post] -func SendVerificationCode(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - emailService := email.MustGetService() - +func (h *AuthHandler) SendVerificationCode(c *gin.Context) { var req types.SendVerificationCodeRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) return } - if err := service.SendVerificationCode(c.Request.Context(), redisClient, emailService, req.Email, req.Type); err != nil { - loggerInstance.Error("发送验证码失败", + emailService, err := h.getEmailService() + if err != nil { + RespondServerError(c, "邮件服务不可用", err) + return + } + + if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { + h.logger.Error("发送验证码失败", zap.String("email", req.Email), zap.String("type", req.Type), zap.Error(err), @@ -140,10 +146,7 @@ func SendVerificationCode(c *gin.Context) { // @Success 200 {object} model.Response "重置成功" // @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/auth/reset-password [post] -func ResetPassword(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - +func (h *AuthHandler) ResetPassword(c *gin.Context) { var req types.ResetPasswordRequest if err := c.ShouldBindJSON(&req); err != nil { RespondBadRequest(c, "请求参数错误", err) @@ -151,18 +154,23 @@ func ResetPassword(c *gin.Context) { } // 验证验证码 - if err := service.VerifyCode(c.Request.Context(), redisClient, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } // 重置密码 - if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { - loggerInstance.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) + if err := h.container.UserService.ResetPassword(req.Email, req.NewPassword); err != nil { + h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) RespondServerError(c, err.Error(), nil) return } RespondSuccess(c, gin.H{"message": "密码重置成功"}) } + +// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) +func (h *AuthHandler) getEmailService() (*email.Service, error) { + return email.GetService() +} diff --git a/internal/handler/auth_handler_di.go b/internal/handler/auth_handler_di.go deleted file mode 100644 index 9087008..0000000 --- a/internal/handler/auth_handler_di.go +++ /dev/null @@ -1,177 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - "carrotskin/pkg/email" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// AuthHandler 认证处理器(依赖注入版本) -type AuthHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewAuthHandler 创建AuthHandler实例 -func NewAuthHandler(c *container.Container) *AuthHandler { - return &AuthHandler{ - container: c, - logger: c.Logger, - } -} - -// Register 用户注册 -// @Summary 用户注册 -// @Description 注册新用户账号 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.RegisterRequest true "注册信息" -// @Success 200 {object} model.Response "注册成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/register [post] -func (h *AuthHandler) Register(c *gin.Context) { - var req types.RegisterRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - // 验证邮箱验证码 - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { - h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - // 注册用户 - user, token, err := service.RegisterUser(h.container.JWT, req.Username, req.Password, req.Email, req.Avatar) - if err != nil { - h.logger.Error("用户注册失败", zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.LoginResponse{ - Token: token, - UserInfo: UserToUserInfo(user), - }) -} - -// Login 用户登录 -// @Summary 用户登录 -// @Description 用户登录获取JWT Token,支持用户名或邮箱登录 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)" -// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "登录失败" -// @Router /api/v1/auth/login [post] -func (h *AuthHandler) Login(c *gin.Context) { - var req types.LoginRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - ipAddress := c.ClientIP() - userAgent := c.GetHeader("User-Agent") - - user, token, err := service.LoginUserWithRateLimit(h.container.Redis, h.container.JWT, req.Username, req.Password, ipAddress, userAgent) - if err != nil { - h.logger.Warn("用户登录失败", - zap.String("username_or_email", req.Username), - zap.String("ip", ipAddress), - zap.Error(err), - ) - RespondUnauthorized(c, err.Error()) - return - } - - RespondSuccess(c, &types.LoginResponse{ - Token: token, - UserInfo: UserToUserInfo(user), - }) -} - -// SendVerificationCode 发送验证码 -// @Summary 发送验证码 -// @Description 发送邮箱验证码(注册/重置密码/更换邮箱) -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.SendVerificationCodeRequest true "发送验证码请求" -// @Success 200 {object} model.Response "发送成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/send-code [post] -func (h *AuthHandler) SendVerificationCode(c *gin.Context) { - var req types.SendVerificationCodeRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - emailService, err := h.getEmailService() - if err != nil { - RespondServerError(c, "邮件服务不可用", err) - return - } - - if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { - h.logger.Error("发送验证码失败", - zap.String("email", req.Email), - zap.String("type", req.Type), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, gin.H{"message": "验证码已发送,请查收邮件"}) -} - -// ResetPassword 重置密码 -// @Summary 重置密码 -// @Description 通过邮箱验证码重置密码 -// @Tags auth -// @Accept json -// @Produce json -// @Param request body types.ResetPasswordRequest true "重置密码请求" -// @Success 200 {object} model.Response "重置成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/auth/reset-password [post] -func (h *AuthHandler) ResetPassword(c *gin.Context) { - var req types.ResetPasswordRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - // 验证验证码 - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { - h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - // 重置密码 - if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { - h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, gin.H{"message": "密码重置成功"}) -} - -// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) -func (h *AuthHandler) getEmailService() (*email.Service, error) { - return email.GetService() -} - diff --git a/internal/handler/captcha_handler.go b/internal/handler/captcha_handler.go index c7e8942..f9849d0 100644 --- a/internal/handler/captcha_handler.go +++ b/internal/handler/captcha_handler.go @@ -1,47 +1,77 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" - "carrotskin/pkg/redis" "net/http" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) +// CaptchaHandler 验证码处理器 +type CaptchaHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewCaptchaHandler 创建CaptchaHandler实例 +func NewCaptchaHandler(c *container.Container) *CaptchaHandler { + return &CaptchaHandler{ + container: c, + logger: c.Logger, + } +} + +// CaptchaVerifyRequest 验证码验证请求 +type CaptchaVerifyRequest struct { + CaptchaID string `json:"captchaId" binding:"required"` + Dx int `json:"dx" binding:"required"` +} + // Generate 生成验证码 -func Generate(c *gin.Context) { - // 调用验证码服务生成验证码数据 - redisClient := redis.MustGetClient() - masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), redisClient) +// @Summary 生成滑动验证码 +// @Description 生成滑动验证码图片 +// @Tags captcha +// @Accept json +// @Produce json +// @Success 200 {object} map[string]interface{} "生成成功" +// @Failure 500 {object} map[string]interface{} "生成失败" +// @Router /api/v1/captcha/generate [get] +func (h *CaptchaHandler) Generate(c *gin.Context) { + masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis) if err != nil { + h.logger.Error("生成验证码失败", zap.Error(err)) c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, - "msg": "生成验证码失败: " + err.Error(), + "msg": "生成验证码失败", }) return } - // 返回验证码数据给前端 c.JSON(http.StatusOK, gin.H{ "code": 200, "data": gin.H{ - "masterImage": masterImg, // 主图(base64格式) - "tileImage": tileImg, // 滑块图(base64格式) - "captchaId": captchaID, // 验证码唯一标识(用于后续验证) - "y": y, // 滑块Y坐标(前端可用于定位滑块初始位置) + "masterImage": masterImg, + "tileImage": tileImg, + "captchaId": captchaID, + "y": y, }, }) } // Verify 验证验证码 -func Verify(c *gin.Context) { - // 定义请求参数结构体 - var req struct { - CaptchaID string `json:"captchaId" binding:"required"` // 验证码唯一标识 - Dx int `json:"dx" binding:"required"` // 用户滑动的X轴偏移量 - } - - // 解析并校验请求参数 +// @Summary 验证滑动验证码 +// @Description 验证用户滑动的偏移量是否正确 +// @Tags captcha +// @Accept json +// @Produce json +// @Param request body CaptchaVerifyRequest true "验证请求" +// @Success 200 {object} map[string]interface{} "验证结果" +// @Failure 400 {object} map[string]interface{} "参数错误" +// @Router /api/v1/captcha/verify [post] +func (h *CaptchaHandler) Verify(c *gin.Context) { + var req CaptchaVerifyRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, @@ -50,18 +80,19 @@ func Verify(c *gin.Context) { return } - // 调用验证码服务验证偏移量 - redisClient := redis.MustGetClient() - valid, err := service.VerifyCaptchaData(c.Request.Context(), redisClient, req.Dx, req.CaptchaID) + valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID) if err != nil { + h.logger.Error("验证码验证失败", + zap.String("captcha_id", req.CaptchaID), + zap.Error(err), + ) c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, - "msg": "验证失败: " + err.Error(), + "msg": "验证失败", }) return } - // 根据验证结果返回响应 if valid { c.JSON(http.StatusOK, gin.H{ "code": 200, @@ -74,3 +105,5 @@ func Verify(c *gin.Context) { }) } } + + diff --git a/internal/handler/captcha_handler_di.go b/internal/handler/captcha_handler_di.go deleted file mode 100644 index f9849d0..0000000 --- a/internal/handler/captcha_handler_di.go +++ /dev/null @@ -1,109 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "net/http" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// CaptchaHandler 验证码处理器 -type CaptchaHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewCaptchaHandler 创建CaptchaHandler实例 -func NewCaptchaHandler(c *container.Container) *CaptchaHandler { - return &CaptchaHandler{ - container: c, - logger: c.Logger, - } -} - -// CaptchaVerifyRequest 验证码验证请求 -type CaptchaVerifyRequest struct { - CaptchaID string `json:"captchaId" binding:"required"` - Dx int `json:"dx" binding:"required"` -} - -// Generate 生成验证码 -// @Summary 生成滑动验证码 -// @Description 生成滑动验证码图片 -// @Tags captcha -// @Accept json -// @Produce json -// @Success 200 {object} map[string]interface{} "生成成功" -// @Failure 500 {object} map[string]interface{} "生成失败" -// @Router /api/v1/captcha/generate [get] -func (h *CaptchaHandler) Generate(c *gin.Context) { - masterImg, tileImg, captchaID, y, err := service.GenerateCaptchaData(c.Request.Context(), h.container.Redis) - if err != nil { - h.logger.Error("生成验证码失败", zap.Error(err)) - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "msg": "生成验证码失败", - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "code": 200, - "data": gin.H{ - "masterImage": masterImg, - "tileImage": tileImg, - "captchaId": captchaID, - "y": y, - }, - }) -} - -// Verify 验证验证码 -// @Summary 验证滑动验证码 -// @Description 验证用户滑动的偏移量是否正确 -// @Tags captcha -// @Accept json -// @Produce json -// @Param request body CaptchaVerifyRequest true "验证请求" -// @Success 200 {object} map[string]interface{} "验证结果" -// @Failure 400 {object} map[string]interface{} "参数错误" -// @Router /api/v1/captcha/verify [post] -func (h *CaptchaHandler) Verify(c *gin.Context) { - var req CaptchaVerifyRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "msg": "参数错误: " + err.Error(), - }) - return - } - - valid, err := service.VerifyCaptchaData(c.Request.Context(), h.container.Redis, req.Dx, req.CaptchaID) - if err != nil { - h.logger.Error("验证码验证失败", - zap.String("captcha_id", req.CaptchaID), - zap.Error(err), - ) - c.JSON(http.StatusInternalServerError, gin.H{ - "code": 500, - "msg": "验证失败", - }) - return - } - - if valid { - c.JSON(http.StatusOK, gin.H{ - "code": 200, - "msg": "验证成功", - }) - } else { - c.JSON(http.StatusOK, gin.H{ - "code": 400, - "msg": "验证失败,请重试", - }) - } -} - - diff --git a/internal/handler/profile_handler.go b/internal/handler/profile_handler.go index cc0063b..daa029a 100644 --- a/internal/handler/profile_handler.go +++ b/internal/handler/profile_handler.go @@ -1,16 +1,28 @@ package handler import ( - "carrotskin/internal/service" + "carrotskin/internal/container" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// CreateProfile 创建档案 +// ProfileHandler 档案处理器 +type ProfileHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewProfileHandler 创建ProfileHandler实例 +func NewProfileHandler(c *container.Container) *ProfileHandler { + return &ProfileHandler{ + container: c, + logger: c.Logger, + } +} + +// Create 创建档案 // @Summary 创建Minecraft档案 // @Description 创建新的Minecraft角色档案,UUID由后端自动生成 // @Tags profile @@ -18,12 +30,10 @@ import ( // @Produce json // @Security BearerAuth // @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)" -// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功,返回完整档案信息(含自动生成的UUID)" -// @Failure 400 {object} model.ErrorResponse "请求参数错误或已达档案数量上限" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" +// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" // @Router /api/v1/profile [post] -func CreateProfile(c *gin.Context) { +func (h *ProfileHandler) Create(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -35,17 +45,15 @@ func CreateProfile(c *gin.Context) { return } - maxProfiles := service.GetMaxProfilesPerUser() - db := database.MustGetDB() - - if err := service.CheckProfileLimit(db, userID, maxProfiles); err != nil { + maxProfiles := h.container.UserService.GetMaxProfilesPerUser() + if err := h.container.ProfileService.CheckLimit(userID, maxProfiles); err != nil { RespondBadRequest(c, err.Error(), nil) return } - profile, err := service.CreateProfile(db, userID, req.Name) + profile, err := h.container.ProfileService.Create(userID, req.Name) if err != nil { - logger.MustGetLogger().Error("创建档案失败", + h.logger.Error("创建档案失败", zap.Int64("user_id", userID), zap.String("name", req.Name), zap.Error(err), @@ -57,7 +65,7 @@ func CreateProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// GetProfiles 获取档案列表 +// List 获取档案列表 // @Summary 获取档案列表 // @Description 获取当前用户的所有档案 // @Tags profile @@ -65,18 +73,16 @@ func CreateProfile(c *gin.Context) { // @Produce json // @Security BearerAuth // @Success 200 {object} model.Response "获取成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile [get] -func GetProfiles(c *gin.Context) { +func (h *ProfileHandler) List(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - profiles, err := service.GetUserProfiles(database.MustGetDB(), userID) + profiles, err := h.container.ProfileService.GetByUserID(userID) if err != nil { - logger.MustGetLogger().Error("获取档案列表失败", + h.logger.Error("获取档案列表失败", zap.Int64("user_id", userID), zap.Error(err), ) @@ -87,7 +93,7 @@ func GetProfiles(c *gin.Context) { RespondSuccess(c, ProfilesToProfileInfos(profiles)) } -// GetProfile 获取档案详情 +// Get 获取档案详情 // @Summary 获取档案详情 // @Description 根据UUID获取档案详细信息 // @Tags profile @@ -96,14 +102,17 @@ func GetProfiles(c *gin.Context) { // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "获取成功" // @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [get] -func GetProfile(c *gin.Context) { +func (h *ProfileHandler) Get(c *gin.Context) { uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - profile, err := service.GetProfileByUUID(database.MustGetDB(), uuid) + profile, err := h.container.ProfileService.GetByUUID(uuid) if err != nil { - logger.MustGetLogger().Error("获取档案失败", + h.logger.Error("获取档案失败", zap.String("uuid", uuid), zap.Error(err), ) @@ -114,7 +123,7 @@ func GetProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// UpdateProfile 更新档案 +// Update 更新档案 // @Summary 更新档案 // @Description 更新档案信息 // @Tags profile @@ -124,19 +133,19 @@ func GetProfile(c *gin.Context) { // @Param uuid path string true "档案UUID" // @Param request body types.UpdateProfileRequest true "更新信息" // @Success 200 {object} model.Response "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [put] -func UpdateProfile(c *gin.Context) { +func (h *ProfileHandler) Update(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } var req types.UpdateProfileRequest if err := c.ShouldBindJSON(&req); err != nil { @@ -149,9 +158,9 @@ func UpdateProfile(c *gin.Context) { namePtr = &req.Name } - profile, err := service.UpdateProfile(database.MustGetDB(), uuid, userID, namePtr, req.SkinID, req.CapeID) + profile, err := h.container.ProfileService.Update(uuid, userID, namePtr, req.SkinID, req.CapeID) if err != nil { - logger.MustGetLogger().Error("更新档案失败", + h.logger.Error("更新档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), @@ -163,7 +172,7 @@ func UpdateProfile(c *gin.Context) { RespondSuccess(c, ProfileToProfileInfo(profile)) } -// DeleteProfile 删除档案 +// Delete 删除档案 // @Summary 删除档案 // @Description 删除指定的Minecraft档案 // @Tags profile @@ -172,22 +181,22 @@ func UpdateProfile(c *gin.Context) { // @Security BearerAuth // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "删除成功" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid} [delete] -func DeleteProfile(c *gin.Context) { +func (h *ProfileHandler) Delete(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - err := service.DeleteProfile(database.MustGetDB(), uuid, userID) - if err != nil { - logger.MustGetLogger().Error("删除档案失败", + if err := h.container.ProfileService.Delete(uuid, userID); err != nil { + h.logger.Error("删除档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), @@ -199,7 +208,7 @@ func DeleteProfile(c *gin.Context) { RespondSuccess(c, gin.H{"message": "删除成功"}) } -// SetActiveProfile 设置活跃档案 +// SetActive 设置活跃档案 // @Summary 设置活跃档案 // @Description 将指定档案设置为活跃状态 // @Tags profile @@ -208,22 +217,22 @@ func DeleteProfile(c *gin.Context) { // @Security BearerAuth // @Param uuid path string true "档案UUID" // @Success 200 {object} model.Response "设置成功" -// @Failure 401 {object} model.ErrorResponse "未授权" // @Failure 403 {object} model.ErrorResponse "无权操作" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" // @Router /api/v1/profile/{uuid}/activate [post] -func SetActiveProfile(c *gin.Context) { +func (h *ProfileHandler) SetActive(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } uuid := c.Param("uuid") + if uuid == "" { + RespondBadRequest(c, "UUID不能为空", nil) + return + } - err := service.SetActiveProfile(database.MustGetDB(), uuid, userID) - if err != nil { - logger.MustGetLogger().Error("设置活跃档案失败", + if err := h.container.ProfileService.SetActive(uuid, userID); err != nil { + h.logger.Error("设置活跃档案失败", zap.String("uuid", uuid), zap.Int64("user_id", userID), zap.Error(err), diff --git a/internal/handler/profile_handler_di.go b/internal/handler/profile_handler_di.go deleted file mode 100644 index 6fdbeb9..0000000 --- a/internal/handler/profile_handler_di.go +++ /dev/null @@ -1,247 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// ProfileHandler 档案处理器 -type ProfileHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewProfileHandler 创建ProfileHandler实例 -func NewProfileHandler(c *container.Container) *ProfileHandler { - return &ProfileHandler{ - container: c, - logger: c.Logger, - } -} - -// Create 创建档案 -// @Summary 创建Minecraft档案 -// @Description 创建新的Minecraft角色档案,UUID由后端自动生成 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.CreateProfileRequest true "档案信息(仅需提供角色名)" -// @Success 200 {object} model.Response{data=types.ProfileInfo} "创建成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/profile [post] -func (h *ProfileHandler) Create(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.CreateProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) - return - } - - maxProfiles := service.GetMaxProfilesPerUser() - if err := service.CheckProfileLimit(h.container.DB, userID, maxProfiles); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - profile, err := service.CreateProfile(h.container.DB, userID, req.Name) - if err != nil { - h.logger.Error("创建档案失败", - zap.Int64("user_id", userID), - zap.String("name", req.Name), - zap.Error(err), - ) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// List 获取档案列表 -// @Summary 获取档案列表 -// @Description 获取当前用户的所有档案 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "获取成功" -// @Router /api/v1/profile [get] -func (h *ProfileHandler) List(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - profiles, err := service.GetUserProfiles(h.container.DB, userID) - if err != nil { - h.logger.Error("获取档案列表失败", - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondServerError(c, err.Error(), nil) - return - } - - RespondSuccess(c, ProfilesToProfileInfos(profiles)) -} - -// Get 获取档案详情 -// @Summary 获取档案详情 -// @Description 根据UUID获取档案详细信息 -// @Tags profile -// @Accept json -// @Produce json -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "获取成功" -// @Failure 404 {object} model.ErrorResponse "档案不存在" -// @Router /api/v1/profile/{uuid} [get] -func (h *ProfileHandler) Get(c *gin.Context) { - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - profile, err := service.GetProfileByUUID(h.container.DB, uuid) - if err != nil { - h.logger.Error("获取档案失败", - zap.String("uuid", uuid), - zap.Error(err), - ) - RespondNotFound(c, err.Error()) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// Update 更新档案 -// @Summary 更新档案 -// @Description 更新档案信息 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Param request body types.UpdateProfileRequest true "更新信息" -// @Success 200 {object} model.Response "更新成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid} [put] -func (h *ProfileHandler) Update(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - var req types.UpdateProfileRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误: "+err.Error(), nil) - return - } - - var namePtr *string - if req.Name != "" { - namePtr = &req.Name - } - - profile, err := service.UpdateProfile(h.container.DB, uuid, userID, namePtr, req.SkinID, req.CapeID) - if err != nil { - h.logger.Error("更新档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, ProfileToProfileInfo(profile)) -} - -// Delete 删除档案 -// @Summary 删除档案 -// @Description 删除指定的Minecraft档案 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "删除成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid} [delete] -func (h *ProfileHandler) Delete(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - if err := service.DeleteProfile(h.container.DB, uuid, userID); err != nil { - h.logger.Error("删除档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, gin.H{"message": "删除成功"}) -} - -// SetActive 设置活跃档案 -// @Summary 设置活跃档案 -// @Description 将指定档案设置为活跃状态 -// @Tags profile -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param uuid path string true "档案UUID" -// @Success 200 {object} model.Response "设置成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/profile/{uuid}/activate [post] -func (h *ProfileHandler) SetActive(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - uuid := c.Param("uuid") - if uuid == "" { - RespondBadRequest(c, "UUID不能为空", nil) - return - } - - if err := service.SetActiveProfile(h.container.DB, uuid, userID); err != nil { - h.logger.Error("设置活跃档案失败", - zap.String("uuid", uuid), - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondWithError(c, err) - return - } - - RespondSuccess(c, gin.H{"message": "设置成功"}) -} - diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 95cee4c..a6da9c8 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -1,142 +1,193 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/middleware" "carrotskin/internal/model" "github.com/gin-gonic/gin" ) -// RegisterRoutes 注册所有路由 -func RegisterRoutes(router *gin.Engine) { +// Handlers 集中管理所有Handler +type Handlers struct { + Auth *AuthHandler + User *UserHandler + Texture *TextureHandler + Profile *ProfileHandler + Captcha *CaptchaHandler + Yggdrasil *YggdrasilHandler +} + +// NewHandlers 创建所有Handler实例 +func NewHandlers(c *container.Container) *Handlers { + return &Handlers{ + Auth: NewAuthHandler(c), + User: NewUserHandler(c), + Texture: NewTextureHandler(c), + Profile: NewProfileHandler(c), + Captcha: NewCaptchaHandler(c), + Yggdrasil: NewYggdrasilHandler(c), + } +} + +// RegisterRoutesWithDI 使用依赖注入注册所有路由 +func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { // 设置Swagger文档 SetupSwagger(router) + // 创建Handler实例 + h := NewHandlers(c) + // API路由组 v1 := router.Group("/api/v1") { // 认证路由(无需JWT) - authGroup := v1.Group("/auth") - { - authGroup.POST("/register", Register) - authGroup.POST("/login", Login) - authGroup.POST("/send-code", SendVerificationCode) - authGroup.POST("/reset-password", ResetPassword) - } + registerAuthRoutes(v1, h.Auth) // 用户路由(需要JWT认证) - userGroup := v1.Group("/user") - userGroup.Use(middleware.AuthMiddleware()) - { - userGroup.GET("/profile", GetUserProfile) - userGroup.PUT("/profile", UpdateUserProfile) - - // 头像相关 - userGroup.POST("/avatar/upload-url", GenerateAvatarUploadURL) - userGroup.PUT("/avatar", UpdateAvatar) - - // 更换邮箱 - userGroup.POST("/change-email", ChangeEmail) - - // Yggdrasil密码相关 - userGroup.POST("/yggdrasil-password/reset", ResetYggdrasilPassword) // 重置Yggdrasil密码并返回新密码 - } + registerUserRoutes(v1, h.User) // 材质路由 - textureGroup := v1.Group("/texture") - { - // 公开路由(无需认证) - textureGroup.GET("", SearchTextures) // 搜索材质 - textureGroup.GET("/:id", GetTexture) // 获取材质详情 - - // 需要认证的路由 - textureAuth := textureGroup.Group("") - textureAuth.Use(middleware.AuthMiddleware()) - { - textureAuth.POST("/upload-url", GenerateTextureUploadURL) // 生成上传URL - textureAuth.POST("", CreateTexture) // 创建材质记录 - textureAuth.PUT("/:id", UpdateTexture) // 更新材质 - textureAuth.DELETE("/:id", DeleteTexture) // 删除材质 - textureAuth.POST("/:id/favorite", ToggleFavorite) // 切换收藏 - textureAuth.GET("/my", GetUserTextures) // 我的材质 - textureAuth.GET("/favorites", GetUserFavorites) // 我的收藏 - } - } + registerTextureRoutes(v1, h.Texture) // 档案路由 - profileGroup := v1.Group("/profile") - { - // 公开路由(无需认证) - profileGroup.GET("/:uuid", GetProfile) // 获取档案详情 + registerProfileRoutesWithDI(v1, h.Profile) - // 需要认证的路由 - profileAuth := profileGroup.Group("") - profileAuth.Use(middleware.AuthMiddleware()) - { - profileAuth.POST("/", CreateProfile) // 创建档案 - profileAuth.GET("/", GetProfiles) // 获取我的档案列表 - profileAuth.PUT("/:uuid", UpdateProfile) // 更新档案 - profileAuth.DELETE("/:uuid", DeleteProfile) // 删除档案 - profileAuth.POST("/:uuid/activate", SetActiveProfile) // 设置活跃档案 - } - } // 验证码路由 - captchaGroup := v1.Group("/captcha") - { - captchaGroup.GET("/generate", Generate) //生成验证码 - captchaGroup.POST("/verify", Verify) //验证验证码 - } + registerCaptchaRoutesWithDI(v1, h.Captcha) // Yggdrasil API路由组 - ygg := v1.Group("/yggdrasil") - { - ygg.GET("", GetMetaData) - ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates) - authserver := ygg.Group("/authserver") - { - authserver.POST("/authenticate", Authenticate) - authserver.POST("/validate", ValidToken) - authserver.POST("/refresh", RefreshToken) - authserver.POST("/invalidate", InvalidToken) - authserver.POST("/signout", SignOut) - } - sessionServer := ygg.Group("/sessionserver") - { - sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID) - sessionServer.POST("/session/minecraft/join", JoinServer) - sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer) - } - api := ygg.Group("/api") - profiles := api.Group("/profiles") - { - profiles.POST("/minecraft", GetProfilesByName) - } - } + registerYggdrasilRoutesWithDI(v1, h.Yggdrasil) + // 系统路由 - system := v1.Group("/system") + registerSystemRoutes(v1) + } +} + +// registerAuthRoutes 注册认证路由 +func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { + authGroup := v1.Group("/auth") + { + authGroup.POST("/register", h.Register) + authGroup.POST("/login", h.Login) + authGroup.POST("/send-code", h.SendVerificationCode) + authGroup.POST("/reset-password", h.ResetPassword) + } +} + +// registerUserRoutes 注册用户路由 +func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { + userGroup := v1.Group("/user") + userGroup.Use(middleware.AuthMiddleware()) + { + userGroup.GET("/profile", h.GetProfile) + userGroup.PUT("/profile", h.UpdateProfile) + + // 头像相关 + userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) + userGroup.PUT("/avatar", h.UpdateAvatar) + + // 更换邮箱 + userGroup.POST("/change-email", h.ChangeEmail) + + // Yggdrasil密码相关 + userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) + } +} + +// registerTextureRoutes 注册材质路由 +func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { + textureGroup := v1.Group("/texture") + { + // 公开路由(无需认证) + textureGroup.GET("", h.Search) + textureGroup.GET("/:id", h.Get) + + // 需要认证的路由 + textureAuth := textureGroup.Group("") + textureAuth.Use(middleware.AuthMiddleware()) { - system.GET("/config", GetSystemConfig) + textureAuth.POST("/upload-url", h.GenerateUploadURL) + textureAuth.POST("", h.Create) + textureAuth.PUT("/:id", h.Update) + textureAuth.DELETE("/:id", h.Delete) + textureAuth.POST("/:id/favorite", h.ToggleFavorite) + textureAuth.GET("/my", h.GetUserTextures) + textureAuth.GET("/favorites", h.GetUserFavorites) } } } -// 以下是系统配置相关的占位符函数,待后续实现 +// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) +func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { + profileGroup := v1.Group("/profile") + { + // 公开路由(无需认证) + profileGroup.GET("/:uuid", h.Get) -// GetSystemConfig 获取系统配置 -// @Summary 获取系统配置 -// @Description 获取公开的系统配置信息 -// @Tags system -// @Accept json -// @Produce json -// @Success 200 {object} model.Response "获取成功" -// @Router /api/v1/system/config [get] -func GetSystemConfig(c *gin.Context) { - // TODO: 实现从数据库读取系统配置 - c.JSON(200, model.NewSuccessResponse(gin.H{ - "site_name": "CarrotSkin", - "site_description": "A Minecraft Skin Station", - "registration_enabled": true, - "max_textures_per_user": 100, - "max_profiles_per_user": 5, - })) + // 需要认证的路由 + profileAuth := profileGroup.Group("") + profileAuth.Use(middleware.AuthMiddleware()) + { + profileAuth.POST("/", h.Create) + profileAuth.GET("/", h.List) + profileAuth.PUT("/:uuid", h.Update) + profileAuth.DELETE("/:uuid", h.Delete) + profileAuth.POST("/:uuid/activate", h.SetActive) + } + } +} + +// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本) +func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) { + captchaGroup := v1.Group("/captcha") + { + captchaGroup.GET("/generate", h.Generate) + captchaGroup.POST("/verify", h.Verify) + } +} + +// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本) +func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) { + ygg := v1.Group("/yggdrasil") + { + ygg.GET("", h.GetMetaData) + ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates) + authserver := ygg.Group("/authserver") + { + authserver.POST("/authenticate", h.Authenticate) + authserver.POST("/validate", h.ValidToken) + authserver.POST("/refresh", h.RefreshToken) + authserver.POST("/invalidate", h.InvalidToken) + authserver.POST("/signout", h.SignOut) + } + sessionServer := ygg.Group("/sessionserver") + { + sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID) + sessionServer.POST("/session/minecraft/join", h.JoinServer) + sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer) + } + api := ygg.Group("/api") + profiles := api.Group("/profiles") + { + profiles.POST("/minecraft", h.GetProfilesByName) + } + } +} + +// registerSystemRoutes 注册系统路由 +func registerSystemRoutes(v1 *gin.RouterGroup) { + system := v1.Group("/system") + { + system.GET("/config", func(c *gin.Context) { + // TODO: 实现从数据库读取系统配置 + c.JSON(200, model.NewSuccessResponse(gin.H{ + "site_name": "CarrotSkin", + "site_description": "A Minecraft Skin Station", + "registration_enabled": true, + "max_textures_per_user": 100, + "max_profiles_per_user": 5, + })) + }) + } } diff --git a/internal/handler/routes_di.go b/internal/handler/routes_di.go deleted file mode 100644 index a6da9c8..0000000 --- a/internal/handler/routes_di.go +++ /dev/null @@ -1,193 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/middleware" - "carrotskin/internal/model" - - "github.com/gin-gonic/gin" -) - -// Handlers 集中管理所有Handler -type Handlers struct { - Auth *AuthHandler - User *UserHandler - Texture *TextureHandler - Profile *ProfileHandler - Captcha *CaptchaHandler - Yggdrasil *YggdrasilHandler -} - -// NewHandlers 创建所有Handler实例 -func NewHandlers(c *container.Container) *Handlers { - return &Handlers{ - Auth: NewAuthHandler(c), - User: NewUserHandler(c), - Texture: NewTextureHandler(c), - Profile: NewProfileHandler(c), - Captcha: NewCaptchaHandler(c), - Yggdrasil: NewYggdrasilHandler(c), - } -} - -// RegisterRoutesWithDI 使用依赖注入注册所有路由 -func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { - // 设置Swagger文档 - SetupSwagger(router) - - // 创建Handler实例 - h := NewHandlers(c) - - // API路由组 - v1 := router.Group("/api/v1") - { - // 认证路由(无需JWT) - registerAuthRoutes(v1, h.Auth) - - // 用户路由(需要JWT认证) - registerUserRoutes(v1, h.User) - - // 材质路由 - registerTextureRoutes(v1, h.Texture) - - // 档案路由 - registerProfileRoutesWithDI(v1, h.Profile) - - // 验证码路由 - registerCaptchaRoutesWithDI(v1, h.Captcha) - - // Yggdrasil API路由组 - registerYggdrasilRoutesWithDI(v1, h.Yggdrasil) - - // 系统路由 - registerSystemRoutes(v1) - } -} - -// registerAuthRoutes 注册认证路由 -func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { - authGroup := v1.Group("/auth") - { - authGroup.POST("/register", h.Register) - authGroup.POST("/login", h.Login) - authGroup.POST("/send-code", h.SendVerificationCode) - authGroup.POST("/reset-password", h.ResetPassword) - } -} - -// registerUserRoutes 注册用户路由 -func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { - userGroup := v1.Group("/user") - userGroup.Use(middleware.AuthMiddleware()) - { - userGroup.GET("/profile", h.GetProfile) - userGroup.PUT("/profile", h.UpdateProfile) - - // 头像相关 - userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) - userGroup.PUT("/avatar", h.UpdateAvatar) - - // 更换邮箱 - userGroup.POST("/change-email", h.ChangeEmail) - - // Yggdrasil密码相关 - userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) - } -} - -// registerTextureRoutes 注册材质路由 -func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { - textureGroup := v1.Group("/texture") - { - // 公开路由(无需认证) - textureGroup.GET("", h.Search) - textureGroup.GET("/:id", h.Get) - - // 需要认证的路由 - textureAuth := textureGroup.Group("") - textureAuth.Use(middleware.AuthMiddleware()) - { - textureAuth.POST("/upload-url", h.GenerateUploadURL) - textureAuth.POST("", h.Create) - textureAuth.PUT("/:id", h.Update) - textureAuth.DELETE("/:id", h.Delete) - textureAuth.POST("/:id/favorite", h.ToggleFavorite) - textureAuth.GET("/my", h.GetUserTextures) - textureAuth.GET("/favorites", h.GetUserFavorites) - } - } -} - -// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) -func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { - profileGroup := v1.Group("/profile") - { - // 公开路由(无需认证) - profileGroup.GET("/:uuid", h.Get) - - // 需要认证的路由 - profileAuth := profileGroup.Group("") - profileAuth.Use(middleware.AuthMiddleware()) - { - profileAuth.POST("/", h.Create) - profileAuth.GET("/", h.List) - profileAuth.PUT("/:uuid", h.Update) - profileAuth.DELETE("/:uuid", h.Delete) - profileAuth.POST("/:uuid/activate", h.SetActive) - } - } -} - -// registerCaptchaRoutesWithDI 注册验证码路由(依赖注入版本) -func registerCaptchaRoutesWithDI(v1 *gin.RouterGroup, h *CaptchaHandler) { - captchaGroup := v1.Group("/captcha") - { - captchaGroup.GET("/generate", h.Generate) - captchaGroup.POST("/verify", h.Verify) - } -} - -// registerYggdrasilRoutesWithDI 注册Yggdrasil API路由(依赖注入版本) -func registerYggdrasilRoutesWithDI(v1 *gin.RouterGroup, h *YggdrasilHandler) { - ygg := v1.Group("/yggdrasil") - { - ygg.GET("", h.GetMetaData) - ygg.POST("/minecraftservices/player/certificates", h.GetPlayerCertificates) - authserver := ygg.Group("/authserver") - { - authserver.POST("/authenticate", h.Authenticate) - authserver.POST("/validate", h.ValidToken) - authserver.POST("/refresh", h.RefreshToken) - authserver.POST("/invalidate", h.InvalidToken) - authserver.POST("/signout", h.SignOut) - } - sessionServer := ygg.Group("/sessionserver") - { - sessionServer.GET("/session/minecraft/profile/:uuid", h.GetProfileByUUID) - sessionServer.POST("/session/minecraft/join", h.JoinServer) - sessionServer.GET("/session/minecraft/hasJoined", h.HasJoinedServer) - } - api := ygg.Group("/api") - profiles := api.Group("/profiles") - { - profiles.POST("/minecraft", h.GetProfilesByName) - } - } -} - -// registerSystemRoutes 注册系统路由 -func registerSystemRoutes(v1 *gin.RouterGroup) { - system := v1.Group("/system") - { - system.GET("/config", func(c *gin.Context) { - // TODO: 实现从数据库读取系统配置 - c.JSON(200, model.NewSuccessResponse(gin.H{ - "site_name": "CarrotSkin", - "site_description": "A Minecraft Skin Station", - "registration_enabled": true, - "max_textures_per_user": 100, - "max_profiles_per_user": 5, - })) - }) - } -} diff --git a/internal/handler/texture_handler.go b/internal/handler/texture_handler.go index a139f38..909e287 100644 --- a/internal/handler/texture_handler.go +++ b/internal/handler/texture_handler.go @@ -1,30 +1,32 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/model" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/storage" "strconv" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// GenerateTextureUploadURL 生成材质上传URL -// @Summary 生成材质上传URL -// @Description 生成预签名URL用于上传材质文件 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.GenerateTextureUploadURLRequest true "上传URL请求" -// @Success 200 {object} model.Response "生成成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/texture/upload-url [post] -func GenerateTextureUploadURL(c *gin.Context) { +// TextureHandler 材质处理器(依赖注入版本) +type TextureHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewTextureHandler 创建TextureHandler实例 +func NewTextureHandler(c *container.Container) *TextureHandler { + return &TextureHandler{ + container: c, + logger: c.Logger, + } +} + +// GenerateUploadURL 生成材质上传URL +func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -36,16 +38,20 @@ func GenerateTextureUploadURL(c *gin.Context) { return } - storageClient := storage.MustGetClient() + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + result, err := service.GenerateTextureUploadURL( c.Request.Context(), - storageClient, + h.container.Storage, userID, req.FileName, string(req.TextureType), ) if err != nil { - logger.MustGetLogger().Error("生成材质上传URL失败", + h.logger.Error("生成材质上传URL失败", zap.Int64("user_id", userID), zap.String("file_name", req.FileName), zap.String("texture_type", string(req.TextureType)), @@ -63,18 +69,8 @@ func GenerateTextureUploadURL(c *gin.Context) { }) } -// CreateTexture 创建材质记录 -// @Summary 创建材质记录 -// @Description 文件上传完成后,创建材质记录到数据库 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.CreateTextureRequest true "创建材质请求" -// @Success 200 {object} model.Response "创建成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/texture [post] -func CreateTexture(c *gin.Context) { +// Create 创建材质记录 +func (h *TextureHandler) Create(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -86,13 +82,13 @@ func CreateTexture(c *gin.Context) { return } - maxTextures := service.GetMaxTexturesPerUser() - if err := service.CheckTextureUploadLimit(database.MustGetDB(), userID, maxTextures); err != nil { + maxTextures := h.container.UserService.GetMaxTexturesPerUser() + if err := h.container.TextureService.CheckUploadLimit(userID, maxTextures); err != nil { RespondBadRequest(c, err.Error(), nil) return } - texture, err := service.CreateTexture(database.MustGetDB(), + texture, err := h.container.TextureService.Create( userID, req.Name, req.Description, @@ -104,7 +100,7 @@ func CreateTexture(c *gin.Context) { req.IsSlim, ) if err != nil { - logger.MustGetLogger().Error("创建材质失败", + h.logger.Error("创建材质失败", zap.Int64("user_id", userID), zap.String("name", req.Name), zap.Error(err), @@ -116,24 +112,15 @@ func CreateTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// GetTexture 获取材质详情 -// @Summary 获取材质详情 -// @Description 根据ID获取材质详细信息 -// @Tags texture -// @Accept json -// @Produce json -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "获取成功" -// @Failure 404 {object} model.ErrorResponse "材质不存在" -// @Router /api/v1/texture/{id} [get] -func GetTexture(c *gin.Context) { +// Get 获取材质详情 +func (h *TextureHandler) Get(c *gin.Context) { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { RespondBadRequest(c, "无效的材质ID", err) return } - texture, err := service.GetTextureByID(database.MustGetDB(), id) + texture, err := h.container.TextureService.GetByID(id) if err != nil { RespondNotFound(c, err.Error()) return @@ -142,20 +129,8 @@ func GetTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// SearchTextures 搜索材质 -// @Summary 搜索材质 -// @Description 根据关键词和类型搜索材质 -// @Tags texture -// @Accept json -// @Produce json -// @Param keyword query string false "关键词" -// @Param type query string false "材质类型(SKIN/CAPE)" -// @Param public_only query bool false "只看公开材质" -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "搜索成功" -// @Router /api/v1/texture [get] -func SearchTextures(c *gin.Context) { +// Search 搜索材质 +func (h *TextureHandler) Search(c *gin.Context) { keyword := c.Query("keyword") textureTypeStr := c.Query("type") publicOnly := c.Query("public_only") == "true" @@ -171,9 +146,9 @@ func SearchTextures(c *gin.Context) { textureType = model.TextureTypeCape } - textures, total, err := service.SearchTextures(database.MustGetDB(), keyword, textureType, publicOnly, page, pageSize) + textures, total, err := h.container.TextureService.Search(keyword, textureType, publicOnly, page, pageSize) if err != nil { - logger.MustGetLogger().Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) + h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) RespondServerError(c, "搜索材质失败", err) return } @@ -181,19 +156,8 @@ func SearchTextures(c *gin.Context) { c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) } -// UpdateTexture 更新材质 -// @Summary 更新材质 -// @Description 更新材质信息(仅上传者可操作) -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Param request body types.UpdateTextureRequest true "更新材质请求" -// @Success 200 {object} model.Response "更新成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/texture/{id} [put] -func UpdateTexture(c *gin.Context) { +// Update 更新材质 +func (h *TextureHandler) Update(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -211,9 +175,9 @@ func UpdateTexture(c *gin.Context) { return } - texture, err := service.UpdateTexture(database.MustGetDB(), textureID, userID, req.Name, req.Description, req.IsPublic) + texture, err := h.container.TextureService.Update(textureID, userID, req.Name, req.Description, req.IsPublic) if err != nil { - logger.MustGetLogger().Error("更新材质失败", + h.logger.Error("更新材质失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -225,18 +189,8 @@ func UpdateTexture(c *gin.Context) { RespondSuccess(c, TextureToTextureInfo(texture)) } -// DeleteTexture 删除材质 -// @Summary 删除材质 -// @Description 删除材质(软删除,仅上传者可操作) -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "删除成功" -// @Failure 403 {object} model.ErrorResponse "无权操作" -// @Router /api/v1/texture/{id} [delete] -func DeleteTexture(c *gin.Context) { +// Delete 删除材质 +func (h *TextureHandler) Delete(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -248,8 +202,8 @@ func DeleteTexture(c *gin.Context) { return } - if err := service.DeleteTexture(database.MustGetDB(), textureID, userID); err != nil { - logger.MustGetLogger().Error("删除材质失败", + if err := h.container.TextureService.Delete(textureID, userID); err != nil { + h.logger.Error("删除材质失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -262,16 +216,7 @@ func DeleteTexture(c *gin.Context) { } // ToggleFavorite 切换收藏状态 -// @Summary 切换收藏状态 -// @Description 收藏或取消收藏材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param id path int true "材质ID" -// @Success 200 {object} model.Response "切换成功" -// @Router /api/v1/texture/{id}/favorite [post] -func ToggleFavorite(c *gin.Context) { +func (h *TextureHandler) ToggleFavorite(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -283,9 +228,9 @@ func ToggleFavorite(c *gin.Context) { return } - isFavorited, err := service.ToggleTextureFavorite(database.MustGetDB(), userID, textureID) + isFavorited, err := h.container.TextureService.ToggleFavorite(userID, textureID) if err != nil { - logger.MustGetLogger().Error("切换收藏状态失败", + h.logger.Error("切换收藏状态失败", zap.Int64("user_id", userID), zap.Int64("texture_id", textureID), zap.Error(err), @@ -298,17 +243,7 @@ func ToggleFavorite(c *gin.Context) { } // GetUserTextures 获取用户上传的材质列表 -// @Summary 获取用户上传的材质列表 -// @Description 获取当前用户上传的所有材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "获取成功" -// @Router /api/v1/texture/my [get] -func GetUserTextures(c *gin.Context) { +func (h *TextureHandler) GetUserTextures(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -317,9 +252,9 @@ func GetUserTextures(c *gin.Context) { page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - textures, total, err := service.GetUserTextures(database.MustGetDB(), userID, page, pageSize) + textures, total, err := h.container.TextureService.GetByUserID(userID, page, pageSize) if err != nil { - logger.MustGetLogger().Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) + h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) RespondServerError(c, "获取材质列表失败", err) return } @@ -328,17 +263,7 @@ func GetUserTextures(c *gin.Context) { } // GetUserFavorites 获取用户收藏的材质列表 -// @Summary 获取用户收藏的材质列表 -// @Description 获取当前用户收藏的所有材质 -// @Tags texture -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param page query int false "页码" default(1) -// @Param page_size query int false "每页数量" default(20) -// @Success 200 {object} model.PaginationResponse "获取成功" -// @Router /api/v1/texture/favorites [get] -func GetUserFavorites(c *gin.Context) { +func (h *TextureHandler) GetUserFavorites(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -347,9 +272,9 @@ func GetUserFavorites(c *gin.Context) { page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID, page, pageSize) + textures, total, err := h.container.TextureService.GetUserFavorites(userID, page, pageSize) if err != nil { - logger.MustGetLogger().Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) + h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) RespondServerError(c, "获取收藏列表失败", err) return } diff --git a/internal/handler/texture_handler_di.go b/internal/handler/texture_handler_di.go deleted file mode 100644 index 26bd558..0000000 --- a/internal/handler/texture_handler_di.go +++ /dev/null @@ -1,285 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/model" - "carrotskin/internal/service" - "carrotskin/internal/types" - "strconv" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// TextureHandler 材质处理器(依赖注入版本) -type TextureHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewTextureHandler 创建TextureHandler实例 -func NewTextureHandler(c *container.Container) *TextureHandler { - return &TextureHandler{ - container: c, - logger: c.Logger, - } -} - -// GenerateUploadURL 生成材质上传URL -func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.GenerateTextureUploadURLRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if h.container.Storage == nil { - RespondServerError(c, "存储服务不可用", nil) - return - } - - result, err := service.GenerateTextureUploadURL( - c.Request.Context(), - h.container.Storage, - userID, - req.FileName, - string(req.TextureType), - ) - if err != nil { - h.logger.Error("生成材质上传URL失败", - zap.Int64("user_id", userID), - zap.String("file_name", req.FileName), - zap.String("texture_type", string(req.TextureType)), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.GenerateTextureUploadURLResponse{ - PostURL: result.PostURL, - FormData: result.FormData, - TextureURL: result.FileURL, - ExpiresIn: 900, - }) -} - -// Create 创建材质记录 -func (h *TextureHandler) Create(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.CreateTextureRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - maxTextures := service.GetMaxTexturesPerUser() - if err := service.CheckTextureUploadLimit(h.container.DB, userID, maxTextures); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - texture, err := service.CreateTexture(h.container.DB, - userID, - req.Name, - req.Description, - string(req.Type), - req.URL, - req.Hash, - req.Size, - req.IsPublic, - req.IsSlim, - ) - if err != nil { - h.logger.Error("创建材质失败", - zap.Int64("user_id", userID), - zap.String("name", req.Name), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Get 获取材质详情 -func (h *TextureHandler) Get(c *gin.Context) { - id, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - texture, err := service.GetTextureByID(h.container.DB, id) - if err != nil { - RespondNotFound(c, err.Error()) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Search 搜索材质 -func (h *TextureHandler) Search(c *gin.Context) { - keyword := c.Query("keyword") - textureTypeStr := c.Query("type") - publicOnly := c.Query("public_only") == "true" - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - var textureType model.TextureType - switch textureTypeStr { - case "SKIN": - textureType = model.TextureTypeSkin - case "CAPE": - textureType = model.TextureTypeCape - } - - textures, total, err := service.SearchTextures(h.container.DB, keyword, textureType, publicOnly, page, pageSize) - if err != nil { - h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) - RespondServerError(c, "搜索材质失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - -// Update 更新材质 -func (h *TextureHandler) Update(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - var req types.UpdateTextureRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - texture, err := service.UpdateTexture(h.container.DB, textureID, userID, req.Name, req.Description, req.IsPublic) - if err != nil { - h.logger.Error("更新材质失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondForbidden(c, err.Error()) - return - } - - RespondSuccess(c, TextureToTextureInfo(texture)) -} - -// Delete 删除材质 -func (h *TextureHandler) Delete(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - if err := service.DeleteTexture(h.container.DB, textureID, userID); err != nil { - h.logger.Error("删除材质失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondForbidden(c, err.Error()) - return - } - - RespondSuccess(c, nil) -} - -// ToggleFavorite 切换收藏状态 -func (h *TextureHandler) ToggleFavorite(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) - if err != nil { - RespondBadRequest(c, "无效的材质ID", err) - return - } - - isFavorited, err := service.ToggleTextureFavorite(h.container.DB, userID, textureID) - if err != nil { - h.logger.Error("切换收藏状态失败", - zap.Int64("user_id", userID), - zap.Int64("texture_id", textureID), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, map[string]bool{"is_favorited": isFavorited}) -} - -// GetUserTextures 获取用户上传的材质列表 -func (h *TextureHandler) GetUserTextures(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - textures, total, err := service.GetUserTextures(h.container.DB, userID, page, pageSize) - if err != nil { - h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondServerError(c, "获取材质列表失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - -// GetUserFavorites 获取用户收藏的材质列表 -func (h *TextureHandler) GetUserFavorites(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) - pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) - - textures, total, err := service.GetUserTextureFavorites(h.container.DB, userID, page, pageSize) - if err != nil { - h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondServerError(c, "获取收藏列表失败", err) - return - } - - c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) -} - - diff --git a/internal/handler/user_handler.go b/internal/handler/user_handler.go index c6144a4..406596b 100644 --- a/internal/handler/user_handler.go +++ b/internal/handler/user_handler.go @@ -1,36 +1,38 @@ package handler import ( + "carrotskin/internal/container" "carrotskin/internal/service" "carrotskin/internal/types" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" - "carrotskin/pkg/storage" "github.com/gin-gonic/gin" "go.uber.org/zap" ) -// GetUserProfile 获取用户信息 -// @Summary 获取用户信息 -// @Description 获取当前登录用户的详细信息 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "获取成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Router /api/v1/user/profile [get] -func GetUserProfile(c *gin.Context) { +// UserHandler 用户处理器(依赖注入版本) +type UserHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewUserHandler 创建UserHandler实例 +func NewUserHandler(c *container.Container) *UserHandler { + return &UserHandler{ + container: c, + logger: c.Logger, + } +} + +// GetProfile 获取用户信息 +func (h *UserHandler) GetProfile(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { - logger.MustGetLogger().Error("获取用户信息失败", + h.logger.Error("获取用户信息失败", zap.Int64("user_id", userID), zap.Error(err), ) @@ -41,22 +43,8 @@ func GetUserProfile(c *gin.Context) { RespondSuccess(c, UserToUserInfo(user)) } -// UpdateUserProfile 更新用户信息 -// @Summary 更新用户信息 -// @Description 更新当前登录用户的头像和密码(修改邮箱请使用 /change-email 接口) -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.UpdateUserRequest true "更新信息(修改密码时需同时提供old_password和new_password)" -// @Success 200 {object} model.Response{data=types.UserInfo} "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 404 {object} model.ErrorResponse "用户不存在" -// @Failure 500 {object} model.ErrorResponse "服务器错误" -// @Router /api/v1/user/profile [put] -func UpdateUserProfile(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +// UpdateProfile 更新用户信息 +func (h *UserHandler) UpdateProfile(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -68,7 +56,7 @@ func UpdateUserProfile(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -81,32 +69,31 @@ func UpdateUserProfile(c *gin.Context) { return } - if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { - loggerInstance.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) + if err := h.container.UserService.ChangePassword(userID, req.OldPassword, req.NewPassword); err != nil { + h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } - loggerInstance.Info("用户修改密码成功", zap.Int64("user_id", userID)) + h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) } // 更新头像 if req.Avatar != "" { - // 验证头像 URL 是否来自允许的域名 - if err := service.ValidateAvatarURL(req.Avatar); err != nil { + if err := h.container.UserService.ValidateAvatarURL(req.Avatar); err != nil { RespondBadRequest(c, err.Error(), nil) return } user.Avatar = req.Avatar - if err := service.UpdateUserInfo(user); err != nil { - loggerInstance.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) + if err := h.container.UserService.UpdateInfo(user); err != nil { + h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) RespondServerError(c, "更新失败", err) return } } // 重新获取更新后的用户信息 - updatedUser, err := service.GetUserByID(userID) + updatedUser, err := h.container.UserService.GetByID(userID) if err != nil || updatedUser == nil { RespondNotFound(c, "用户不存在") return @@ -116,17 +103,7 @@ func UpdateUserProfile(c *gin.Context) { } // GenerateAvatarUploadURL 生成头像上传URL -// @Summary 生成头像上传URL -// @Description 生成预签名URL用于上传用户头像 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.GenerateAvatarUploadURLRequest true "文件名" -// @Success 200 {object} model.Response "生成成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/user/avatar/upload-url [post] -func GenerateAvatarUploadURL(c *gin.Context) { +func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -138,10 +115,14 @@ func GenerateAvatarUploadURL(c *gin.Context) { return } - storageClient := storage.MustGetClient() - result, err := service.GenerateAvatarUploadURL(c.Request.Context(), storageClient, userID, req.FileName) + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) if err != nil { - logger.MustGetLogger().Error("生成头像上传URL失败", + h.logger.Error("生成头像上传URL失败", zap.Int64("user_id", userID), zap.String("file_name", req.FileName), zap.Error(err), @@ -159,17 +140,7 @@ func GenerateAvatarUploadURL(c *gin.Context) { } // UpdateAvatar 更新头像URL -// @Summary 更新头像URL -// @Description 上传完成后更新用户的头像URL到数据库 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param avatar_url query string true "头像URL" -// @Success 200 {object} model.Response "更新成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Router /api/v1/user/avatar [put] -func UpdateAvatar(c *gin.Context) { +func (h *UserHandler) UpdateAvatar(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -181,13 +152,13 @@ func UpdateAvatar(c *gin.Context) { return } - if err := service.ValidateAvatarURL(avatarURL); err != nil { + if err := h.container.UserService.ValidateAvatarURL(avatarURL); err != nil { RespondBadRequest(c, err.Error(), nil) return } - if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { - logger.MustGetLogger().Error("更新头像失败", + if err := h.container.UserService.UpdateAvatar(userID, avatarURL); err != nil { + h.logger.Error("更新头像失败", zap.Int64("user_id", userID), zap.String("avatar_url", avatarURL), zap.Error(err), @@ -196,7 +167,7 @@ func UpdateAvatar(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -206,19 +177,7 @@ func UpdateAvatar(c *gin.Context) { } // ChangeEmail 更换邮箱 -// @Summary 更换邮箱 -// @Description 通过验证码更换用户邮箱 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Param request body types.ChangeEmailRequest true "更换邮箱请求" -// @Success 200 {object} model.Response{data=types.UserInfo} "更换成功" -// @Failure 400 {object} model.ErrorResponse "请求参数错误" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Router /api/v1/user/change-email [post] -func ChangeEmail(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +func (h *UserHandler) ChangeEmail(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return @@ -230,15 +189,14 @@ func ChangeEmail(c *gin.Context) { return } - redisClient := redis.MustGetClient() - if err := service.VerifyCode(c.Request.Context(), redisClient, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { - loggerInstance.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { + h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } - if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { - loggerInstance.Error("更换邮箱失败", + if err := h.container.UserService.ChangeEmail(userID, req.NewEmail); err != nil { + h.logger.Error("更换邮箱失败", zap.Int64("user_id", userID), zap.String("new_email", req.NewEmail), zap.Error(err), @@ -247,7 +205,7 @@ func ChangeEmail(c *gin.Context) { return } - user, err := service.GetUserByID(userID) + user, err := h.container.UserService.GetByID(userID) if err != nil || user == nil { RespondNotFound(c, "用户不存在") return @@ -257,31 +215,19 @@ func ChangeEmail(c *gin.Context) { } // ResetYggdrasilPassword 重置Yggdrasil密码 -// @Summary 重置Yggdrasil密码 -// @Description 重置当前用户的Yggdrasil密码并返回新密码 -// @Tags user -// @Accept json -// @Produce json -// @Security BearerAuth -// @Success 200 {object} model.Response "重置成功" -// @Failure 401 {object} model.ErrorResponse "未授权" -// @Failure 500 {object} model.ErrorResponse "服务器错误" -// @Router /api/v1/user/yggdrasil-password/reset [post] -func ResetYggdrasilPassword(c *gin.Context) { - loggerInstance := logger.MustGetLogger() +func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { userID, ok := GetUserIDFromContext(c) if !ok { return } - db := database.MustGetDB() - newPassword, err := service.ResetYggdrasilPassword(db, userID) + newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) if err != nil { - loggerInstance.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) + h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) RespondServerError(c, "重置Yggdrasil密码失败", nil) return } - loggerInstance.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) + h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) RespondSuccess(c, gin.H{"password": newPassword}) } diff --git a/internal/handler/user_handler_di.go b/internal/handler/user_handler_di.go deleted file mode 100644 index 91e8a5a..0000000 --- a/internal/handler/user_handler_di.go +++ /dev/null @@ -1,233 +0,0 @@ -package handler - -import ( - "carrotskin/internal/container" - "carrotskin/internal/service" - "carrotskin/internal/types" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// UserHandler 用户处理器(依赖注入版本) -type UserHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewUserHandler 创建UserHandler实例 -func NewUserHandler(c *container.Container) *UserHandler { - return &UserHandler{ - container: c, - logger: c.Logger, - } -} - -// GetProfile 获取用户信息 -func (h *UserHandler) GetProfile(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - h.logger.Error("获取用户信息失败", - zap.Int64("user_id", userID), - zap.Error(err), - ) - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// UpdateProfile 更新用户信息 -func (h *UserHandler) UpdateProfile(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.UpdateUserRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - // 处理密码修改 - if req.NewPassword != "" { - if req.OldPassword == "" { - RespondBadRequest(c, "修改密码需要提供原密码", nil) - return - } - - if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { - h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) - } - - // 更新头像 - if req.Avatar != "" { - if err := service.ValidateAvatarURL(req.Avatar); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - user.Avatar = req.Avatar - if err := service.UpdateUserInfo(user); err != nil { - h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) - RespondServerError(c, "更新失败", err) - return - } - } - - // 重新获取更新后的用户信息 - updatedUser, err := service.GetUserByID(userID) - if err != nil || updatedUser == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(updatedUser)) -} - -// GenerateAvatarUploadURL 生成头像上传URL -func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.GenerateAvatarUploadURLRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if h.container.Storage == nil { - RespondServerError(c, "存储服务不可用", nil) - return - } - - result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) - if err != nil { - h.logger.Error("生成头像上传URL失败", - zap.Int64("user_id", userID), - zap.String("file_name", req.FileName), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - RespondSuccess(c, &types.GenerateAvatarUploadURLResponse{ - PostURL: result.PostURL, - FormData: result.FormData, - AvatarURL: result.FileURL, - ExpiresIn: 900, - }) -} - -// UpdateAvatar 更新头像URL -func (h *UserHandler) UpdateAvatar(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - avatarURL := c.Query("avatar_url") - if avatarURL == "" { - RespondBadRequest(c, "头像URL不能为空", nil) - return - } - - if err := service.ValidateAvatarURL(avatarURL); err != nil { - RespondBadRequest(c, err.Error(), nil) - return - } - - if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { - h.logger.Error("更新头像失败", - zap.Int64("user_id", userID), - zap.String("avatar_url", avatarURL), - zap.Error(err), - ) - RespondServerError(c, "更新头像失败", err) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// ChangeEmail 更换邮箱 -func (h *UserHandler) ChangeEmail(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - var req types.ChangeEmailRequest - if err := c.ShouldBindJSON(&req); err != nil { - RespondBadRequest(c, "请求参数错误", err) - return - } - - if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { - h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) - RespondBadRequest(c, err.Error(), nil) - return - } - - if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { - h.logger.Error("更换邮箱失败", - zap.Int64("user_id", userID), - zap.String("new_email", req.NewEmail), - zap.Error(err), - ) - RespondBadRequest(c, err.Error(), nil) - return - } - - user, err := service.GetUserByID(userID) - if err != nil || user == nil { - RespondNotFound(c, "用户不存在") - return - } - - RespondSuccess(c, UserToUserInfo(user)) -} - -// ResetYggdrasilPassword 重置Yggdrasil密码 -func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { - userID, ok := GetUserIDFromContext(c) - if !ok { - return - } - - newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) - if err != nil { - h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) - RespondServerError(c, "重置Yggdrasil密码失败", nil) - return - } - - h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) - RespondSuccess(c, gin.H{"password": newPassword}) -} diff --git a/internal/handler/yggdrasil_handler.go b/internal/handler/yggdrasil_handler.go index acbf7b2..2ee21dc 100644 --- a/internal/handler/yggdrasil_handler.go +++ b/internal/handler/yggdrasil_handler.go @@ -2,11 +2,9 @@ package handler import ( "bytes" + "carrotskin/internal/container" "carrotskin/internal/model" "carrotskin/internal/service" - "carrotskin/pkg/database" - "carrotskin/pkg/logger" - "carrotskin/pkg/redis" "carrotskin/pkg/utils" "io" "net/http" @@ -111,6 +109,7 @@ type ( Password string `json:"password" binding:"required"` } + // JoinServerRequest 加入服务器请求 JoinServerRequest struct { ServerID string `json:"serverId" binding:"required"` AccessToken string `json:"accessToken" binding:"required"` @@ -138,6 +137,7 @@ type ( } ) +// APIResponse API响应 type APIResponse struct { Status int `json:"status"` Data interface{} `json:"data"` @@ -153,38 +153,47 @@ func standardResponse(c *gin.Context, status int, data interface{}, err interfac }) } -// Authenticate 用户认证 -func Authenticate(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() +// YggdrasilHandler Yggdrasil API处理器 +type YggdrasilHandler struct { + container *container.Container + logger *zap.Logger +} - // 读取并保存原始请求体,以便多次读取 +// NewYggdrasilHandler 创建YggdrasilHandler实例 +func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler { + return &YggdrasilHandler{ + container: c, + logger: c.Logger, + } +} + +// Authenticate 用户认证 +func (h *YggdrasilHandler) Authenticate(c *gin.Context) { rawData, err := io.ReadAll(c.Request.Body) if err != nil { - loggerInstance.Error("[ERROR] 读取请求体失败: ", zap.Error(err)) + h.logger.Error("读取请求体失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) return } c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData)) - // 绑定JSON数据到请求结构体 var request AuthenticateRequest if err = c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析认证请求失败: ", zap.Error(err)) + h.logger.Error("解析认证请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 根据标识符类型(邮箱或用户名)获取用户 var userId int64 var profile *model.Profile var UUID string + if emailRegex.MatchString(request.Identifier) { - userId, err = service.GetUserIDByEmail(db, request.Identifier) + userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier) } else { - profile, err = service.GetProfileByProfileName(db, request.Identifier) + profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier) if err != nil { - loggerInstance.Error("[ERROR] 用户名不存在: ", zap.String("标识符", request.Identifier), zap.Error(err)) + h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) return } @@ -193,163 +202,146 @@ func Authenticate(c *gin.Context) { } if err != nil { - loggerInstance.Warn("[WARN] 认证失败: 用户不存在", - zap.String("标识符:", request.Identifier), - zap.Error(err)) - + h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err)) + c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"}) return } - // 验证密码 - err = service.VerifyPassword(db, request.Password, userId) - if err != nil { - loggerInstance.Warn("[WARN] 认证失败:", zap.Error(err)) + if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil { + h.logger.Warn("认证失败: 密码错误", zap.Error(err)) c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword}) return } - // 生成新令牌 - selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(db, loggerInstance, userId, UUID, request.ClientToken) + + selectedProfile, availableProfiles, accessToken, clientToken, err := h.container.TokenService.Create(userId, UUID, request.ClientToken) if err != nil { - loggerInstance.Error("[ERROR] 生成令牌失败:", zap.Error(err), zap.Any("用户ID:", userId)) + h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId)) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - user, err := service.GetUserByID(userId) + user, err := h.container.UserService.GetByID(userId) if err != nil { - loggerInstance.Error("[ERROR] id查找错误:", zap.Error(err), zap.Any("ID:", userId)) + h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId)) } - // 处理可用的配置文件 - redisClient := redis.MustGetClient() + availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles)) - for _, profile := range availableProfiles { - availableProfilesData = append(availableProfilesData, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + for _, p := range availableProfiles { + availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p)) } + response := AuthenticateResponse{ AccessToken: accessToken, ClientToken: clientToken, AvailableProfiles: availableProfilesData, } + if selectedProfile != nil { - response.SelectedProfile = service.SerializeProfile(db, loggerInstance, redisClient, *selectedProfile) - } - if request.RequestUser { - // 使用 SerializeUser 来正确处理 Properties 字段 - response.User = service.SerializeUser(loggerInstance, user, UUID) + response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile) } - // 返回认证响应 - loggerInstance.Info("[INFO] 用户认证成功", zap.Any("用户ID:", userId)) + if request.RequestUser && user != nil { + response.User = service.SerializeUser(h.logger, user, UUID) + } + + h.logger.Info("用户认证成功", zap.Int64("userId", userId)) c.JSON(http.StatusOK, response) } // ValidToken 验证令牌 -func ValidToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) ValidToken(c *gin.Context) { var request ValidTokenRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析验证令牌请求失败: ", zap.Error(err)) + h.logger.Error("解析验证令牌请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 验证令牌 - if service.ValidToken(db, request.AccessToken, request.ClientToken) { - loggerInstance.Info("[INFO] 令牌验证成功", zap.Any("访问令牌:", request.AccessToken)) + + if h.container.TokenService.Validate(request.AccessToken, request.ClientToken) { + h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken)) c.JSON(http.StatusNoContent, gin.H{"valid": true}) } else { - loggerInstance.Warn("[WARN] 令牌验证失败", zap.Any("访问令牌:", request.AccessToken)) + h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken)) c.JSON(http.StatusForbidden, gin.H{"valid": false}) } } // RefreshToken 刷新令牌 -func RefreshToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) RefreshToken(c *gin.Context) { var request RefreshRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析刷新令牌请求失败: ", zap.Error(err)) + h.logger.Error("解析刷新令牌请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 获取用户ID和用户信息 - UUID, err := service.GetUUIDByAccessToken(db, request.AccessToken) + UUID, err := h.container.TokenService.GetUUIDByAccessToken(request.AccessToken) if err != nil { - loggerInstance.Warn("[WARN] 刷新令牌失败: 无效的访问令牌", zap.Any("令牌:", request.AccessToken), zap.Error(err)) + h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - userID, _ := service.GetUserIDByAccessToken(db, request.AccessToken) - // 格式化UUID 这里是因为HMCL的传入参数是HEX格式,为了兼容HMCL,在此做处理 + + userID, _ := h.container.TokenService.GetUserIDByAccessToken(request.AccessToken) UUID = utils.FormatUUID(UUID) - profile, err := service.GetProfileByUUID(db, UUID) + profile, err := h.container.ProfileService.GetByUUID(UUID) if err != nil { - loggerInstance.Error("[ERROR] 刷新令牌失败: 无法获取用户信息 错误: ", zap.Error(err)) + h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 准备响应数据 var profileData map[string]interface{} var userData map[string]interface{} var profileID string - // 处理选定的配置文件 if request.SelectedProfile != nil { - // 验证profileID是否存在 profileIDValue, ok := request.SelectedProfile["id"] if !ok { - loggerInstance.Error("[ERROR] 刷新令牌失败: 缺少配置文件ID", zap.Any("ID:", userID)) + h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"}) return } - // 类型断言 profileID, ok = profileIDValue.(string) if !ok { - loggerInstance.Error("[ERROR] 刷新令牌失败: 配置文件ID类型错误 ", zap.Any("用户ID:", userID)) + h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"}) return } - // 格式化profileID profileID = utils.FormatUUID(profileID) - // 验证配置文件所属用户 if profile.UserID != userID { - loggerInstance.Warn("[WARN] 刷新令牌失败: 用户不匹配 ", zap.Any("用户ID:", userID), zap.Any("配置文件用户ID:", profile.UserID)) + h.logger.Warn("刷新令牌失败: 用户不匹配", + zap.Int64("userId", userID), + zap.Int64("profileUserId", profile.UserID), + ) c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch}) return } - profileData = service.SerializeProfile(db, loggerInstance, redis.MustGetClient(), *profile) - } - user, _ := service.GetUserByID(userID) - // 添加用户信息(如果请求了) - if request.RequestUser { - userData = service.SerializeUser(loggerInstance, user, UUID) + profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile) } - // 刷新令牌 - newAccessToken, newClientToken, err := service.RefreshToken(db, loggerInstance, + user, _ := h.container.UserService.GetByID(userID) + if request.RequestUser && user != nil { + userData = service.SerializeUser(h.logger, user, UUID) + } + + newAccessToken, newClientToken, err := h.container.TokenService.Refresh( request.AccessToken, request.ClientToken, profileID, ) if err != nil { - loggerInstance := logger.MustGetLogger() - loggerInstance.Error("[ERROR] 刷新令牌失败: ", zap.Error(err), zap.Any("用户ID: ", userID)) + h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 返回响应 - loggerInstance.Info("[INFO] 刷新令牌成功", zap.Any("用户ID:", userID)) + h.logger.Info("刷新令牌成功", zap.Int64("userId", userID)) c.JSON(http.StatusOK, RefreshResponse{ AccessToken: newAccessToken, ClientToken: newClientToken, @@ -359,231 +351,177 @@ func RefreshToken(c *gin.Context) { } // InvalidToken 使令牌失效 -func InvalidToken(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) InvalidToken(c *gin.Context) { var request ValidTokenRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析使令牌失效请求失败: ", zap.Error(err)) + h.logger.Error("解析使令牌失效请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 使令牌失效 - service.InvalidToken(db, loggerInstance, request.AccessToken) - loggerInstance.Info("[INFO] 令牌已使失效", zap.Any("访问令牌:", request.AccessToken)) + + h.container.TokenService.Invalidate(request.AccessToken) + h.logger.Info("令牌已失效", zap.String("token", request.AccessToken)) c.JSON(http.StatusNoContent, gin.H{}) } // SignOut 用户登出 -func SignOut(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +func (h *YggdrasilHandler) SignOut(c *gin.Context) { var request SignOutRequest if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error("[ERROR] 解析登出请求失败: %v", zap.Error(err)) + h.logger.Error("解析登出请求失败", zap.Error(err)) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - // 验证邮箱格式 if !emailRegex.MatchString(request.Email) { - loggerInstance.Warn("[WARN] 登出失败: 邮箱格式不正确 ", zap.Any(" ", request.Email)) + h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email)) c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat}) return } - // 通过邮箱获取用户 - user, err := service.GetUserByEmail(request.Email) - if err != nil { - loggerInstance.Warn( - "登出失败: 用户不存在", - zap.String("邮箱", request.Email), - zap.Error(err), - ) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + user, err := h.container.UserService.GetByEmail(request.Email) + if err != nil || user == nil { + h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err)) + c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"}) return } - // 验证密码 - if err := service.VerifyPassword(db, request.Password, user.ID); err != nil { - loggerInstance.Warn("[WARN] 登出失败: 密码错误", zap.Any("用户ID:", user.ID)) + + if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil { + h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID)) c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword}) return } - // 使该用户的所有令牌失效 - service.InvalidUserTokens(db, loggerInstance, user.ID) - loggerInstance.Info("[INFO] 用户登出成功", zap.Any("用户ID:", user.ID)) + h.container.TokenService.InvalidateUserTokens(user.ID) + h.logger.Info("用户登出成功", zap.Int64("userId", user.ID)) c.JSON(http.StatusNoContent, gin.H{"valid": true}) } -func GetProfileByUUID(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - - // 获取并格式化UUID +// GetProfileByUUID 根据UUID获取档案 +func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) { uuid := utils.FormatUUID(c.Param("uuid")) - loggerInstance.Info("[INFO] 接收到获取配置文件请求", zap.Any("UUID:", uuid)) + h.logger.Info("获取配置文件请求", zap.String("uuid", uuid)) - // 获取配置文件 - profile, err := service.GetProfileByUUID(db, uuid) + profile, err := h.container.ProfileService.GetByUUID(uuid) if err != nil { - loggerInstance.Error("[ERROR] 获取配置文件失败:", zap.Error(err), zap.String("UUID:", uuid)) + h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid)) standardResponse(c, http.StatusInternalServerError, nil, err.Error()) return } - // 返回配置文件信息 - loggerInstance.Info("[INFO] 成功获取配置文件", zap.String("UUID:", uuid), zap.String("名称:", profile.Name)) - c.JSON(http.StatusOK, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name)) + c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) } -func JoinServer(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - +// JoinServer 加入服务器 +func (h *YggdrasilHandler) JoinServer(c *gin.Context) { var request JoinServerRequest clientIP := c.ClientIP() - // 解析请求参数 if err := c.ShouldBindJSON(&request); err != nil { - loggerInstance.Error( - "解析加入服务器请求失败", - zap.Error(err), - zap.String("IP", clientIP), - ) + h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP)) standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest) return } - loggerInstance.Info( - "收到加入服务器请求", - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + h.logger.Info("收到加入服务器请求", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) - // 处理加入服务器请求 - if err := service.JoinServer(db, loggerInstance, redisClient, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { - loggerInstance.Error( - "加入服务器失败", + if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { + h.logger.Error("加入服务器失败", zap.Error(err), - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed) return } - // 加入成功,返回204状态码 - loggerInstance.Info( - "加入服务器成功", - zap.String("服务器ID", request.ServerID), - zap.String("用户UUID", request.SelectedProfile), - zap.String("IP", clientIP), + h.logger.Info("加入服务器成功", + zap.String("serverId", request.ServerID), + zap.String("userUUID", request.SelectedProfile), + zap.String("ip", clientIP), ) c.Status(http.StatusNoContent) } -func HasJoinedServer(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - +// HasJoinedServer 验证玩家是否已加入服务器 +func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) { clientIP, _ := c.GetQuery("ip") - // 获取并验证服务器ID参数 serverID, exists := c.GetQuery("serverId") if !exists || serverID == "" { - loggerInstance.Warn("[WARN] 缺少服务器ID参数", zap.Any("IP:", clientIP)) + h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP)) standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired) return } - // 获取并验证用户名参数 username, exists := c.GetQuery("username") if !exists || username == "" { - loggerInstance.Warn("[WARN] 缺少用户名参数", zap.Any("服务器ID:", serverID), zap.Any("IP:", clientIP)) + h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP)) standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired) return } - loggerInstance.Info("[INFO] 收到会话验证请求", zap.Any("服务器ID:", serverID), zap.Any("用户名: ", username), zap.Any("IP: ", clientIP)) + h.logger.Info("收到会话验证请求", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("ip", clientIP), + ) - // 验证玩家是否已加入服务器 - if err := service.HasJoinedServer(loggerInstance, redisClient, serverID, username, clientIP); err != nil { - loggerInstance.Warn("[WARN] 会话验证失败", + if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil { + h.logger.Warn("会话验证失败", zap.Error(err), - zap.String("serverID", serverID), + zap.String("serverId", serverID), zap.String("username", username), - zap.String("clientIP", clientIP), + zap.String("ip", clientIP), ) standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed) return } - profile, err := service.GetProfileByUUID(db, username) + profile, err := h.container.ProfileService.GetByUUID(username) if err != nil { - loggerInstance.Error("[ERROR] 获取用户配置文件失败: %v - 用户名: %s", - zap.Error(err), // 错误详情(zap 原生支持,保留错误链) - zap.String("username", username), // 结构化存储用户名(便于检索) - ) + h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username)) standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound) return } - // 返回玩家配置文件 - loggerInstance.Info("[INFO] 会话验证成功 - 服务器ID: %s, 用户名: %s, UUID: %s", - zap.String("serverID", serverID), // 结构化存储服务器ID - zap.String("username", username), // 结构化存储用户名 - zap.String("UUID", profile.UUID), // 结构化存储UUID + h.logger.Info("会话验证成功", + zap.String("serverId", serverID), + zap.String("username", username), + zap.String("uuid", profile.UUID), ) - c.JSON(200, service.SerializeProfile(db, loggerInstance, redisClient, *profile)) + c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) } -func GetProfilesByName(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - +// GetProfilesByName 批量获取配置文件 +func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) { var names []string - // 解析请求参数 if err := c.ShouldBindJSON(&names); err != nil { - loggerInstance.Error("[ERROR] 解析名称数组请求失败: ", - zap.Error(err), - ) + h.logger.Error("解析名称数组请求失败", zap.Error(err)) standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams) return } - loggerInstance.Info("[INFO] 接收到批量获取配置文件请求", - zap.Int("名称数量:", len(names)), // 结构化存储名称数量 - ) - // 批量获取配置文件 - profiles, err := service.GetProfilesDataByNames(db, names) + h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names))) + + profiles, err := h.container.ProfileService.GetByNames(names) if err != nil { - loggerInstance.Error("[ERROR] 获取配置文件失败: ", - zap.Error(err), - ) + h.logger.Error("获取配置文件失败", zap.Error(err)) } - // 改造:zap 兼容原有 INFO 日志格式 - loggerInstance.Info("[INFO] 成功获取配置文件", - zap.Int("请求名称数:", len(names)), - zap.Int("返回结果数: ", len(profiles)), - ) - + h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles))) c.JSON(http.StatusOK, profiles) } -func GetMetaData(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - redisClient := redis.MustGetClient() - +// GetMetaData 获取Yggdrasil元数据 +func (h *YggdrasilHandler) GetMetaData(c *gin.Context) { meta := gin.H{ "implementationName": "CellAuth", "implementationVersion": "0.0.1", @@ -595,26 +533,25 @@ func GetMetaData(c *gin.Context) { "feature.non_email_login": true, "feature.enable_profile_key": true, } + skinDomains := []string{".hitwh.games", ".littlelan.cn"} - signature, err := service.GetPublicKeyFromRedisFunc(loggerInstance, redisClient) + signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis) if err != nil { - loggerInstance.Error("[ERROR] 获取公钥失败: ", zap.Error(err)) + h.logger.Error("获取公钥失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - loggerInstance.Info("[INFO] 提供元数据") - c.JSON(http.StatusOK, gin.H{"meta": meta, + h.logger.Info("提供元数据") + c.JSON(http.StatusOK, gin.H{ + "meta": meta, "skinDomains": skinDomains, - "signaturePublickey": signature}) + "signaturePublickey": signature, + }) } -func GetPlayerCertificates(c *gin.Context) { - loggerInstance := logger.MustGetLogger() - db := database.MustGetDB() - redisClient := redis.MustGetClient() - - var uuid string +// GetPlayerCertificates 获取玩家证书 +func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"}) @@ -622,39 +559,36 @@ func GetPlayerCertificates(c *gin.Context) { return } - // 检查是否以 Bearer 开头并提取 sessionID bearerPrefix := "Bearer " if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) c.Abort() return } + tokenID := authHeader[len(bearerPrefix):] if tokenID == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) c.Abort() return } - var err error - uuid, err = service.GetUUIDByAccessToken(db, tokenID) + uuid, err := h.container.TokenService.GetUUIDByAccessToken(tokenID) if uuid == "" { - loggerInstance.Error("[ERROR] 获取玩家UUID失败: ", zap.Error(err)) + h.logger.Error("获取玩家UUID失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - // 格式化UUID uuid = utils.FormatUUID(uuid) - // 生成玩家证书 - certificate, err := service.GeneratePlayerCertificate(db, loggerInstance, redisClient, uuid) + certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid) if err != nil { - loggerInstance.Error("[ERROR] 生成玩家证书失败: ", zap.Error(err)) + h.logger.Error("生成玩家证书失败", zap.Error(err)) standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) return } - loggerInstance.Info("[INFO] 成功生成玩家证书") + h.logger.Info("成功生成玩家证书") c.JSON(http.StatusOK, certificate) } diff --git a/internal/handler/yggdrasil_handler_di.go b/internal/handler/yggdrasil_handler_di.go deleted file mode 100644 index c4fb8f3..0000000 --- a/internal/handler/yggdrasil_handler_di.go +++ /dev/null @@ -1,454 +0,0 @@ -package handler - -import ( - "bytes" - "carrotskin/internal/container" - "carrotskin/internal/model" - "carrotskin/internal/service" - "carrotskin/pkg/utils" - "io" - "net/http" - - "github.com/gin-gonic/gin" - "go.uber.org/zap" -) - -// YggdrasilHandler Yggdrasil API处理器 -type YggdrasilHandler struct { - container *container.Container - logger *zap.Logger -} - -// NewYggdrasilHandler 创建YggdrasilHandler实例 -func NewYggdrasilHandler(c *container.Container) *YggdrasilHandler { - return &YggdrasilHandler{ - container: c, - logger: c.Logger, - } -} - -// Authenticate 用户认证 -func (h *YggdrasilHandler) Authenticate(c *gin.Context) { - rawData, err := io.ReadAll(c.Request.Body) - if err != nil { - h.logger.Error("读取请求体失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": "读取请求体失败"}) - return - } - c.Request.Body = io.NopCloser(bytes.NewBuffer(rawData)) - - var request AuthenticateRequest - if err = c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析认证请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - var userId int64 - var profile *model.Profile - var UUID string - - if emailRegex.MatchString(request.Identifier) { - userId, err = service.GetUserIDByEmail(h.container.DB, request.Identifier) - } else { - profile, err = service.GetProfileByProfileName(h.container.DB, request.Identifier) - if err != nil { - h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) - return - } - userId = profile.UserID - UUID = profile.UUID - } - - if err != nil { - h.logger.Warn("认证失败: 用户不存在", zap.String("identifier", request.Identifier), zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": "用户不存在"}) - return - } - - if err := service.VerifyPassword(h.container.DB, request.Password, userId); err != nil { - h.logger.Warn("认证失败: 密码错误", zap.Error(err)) - c.JSON(http.StatusForbidden, gin.H{"error": ErrWrongPassword}) - return - } - - selectedProfile, availableProfiles, accessToken, clientToken, err := service.NewToken(h.container.DB, h.logger, userId, UUID, request.ClientToken) - if err != nil { - h.logger.Error("生成令牌失败", zap.Error(err), zap.Int64("userId", userId)) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } - - user, err := service.GetUserByID(userId) - if err != nil { - h.logger.Error("获取用户信息失败", zap.Error(err), zap.Int64("userId", userId)) - } - - availableProfilesData := make([]map[string]interface{}, 0, len(availableProfiles)) - for _, p := range availableProfiles { - availableProfilesData = append(availableProfilesData, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *p)) - } - - response := AuthenticateResponse{ - AccessToken: accessToken, - ClientToken: clientToken, - AvailableProfiles: availableProfilesData, - } - - if selectedProfile != nil { - response.SelectedProfile = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *selectedProfile) - } - - if request.RequestUser && user != nil { - response.User = service.SerializeUser(h.logger, user, UUID) - } - - h.logger.Info("用户认证成功", zap.Int64("userId", userId)) - c.JSON(http.StatusOK, response) -} - -// ValidToken 验证令牌 -func (h *YggdrasilHandler) ValidToken(c *gin.Context) { - var request ValidTokenRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析验证令牌请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if service.ValidToken(h.container.DB, request.AccessToken, request.ClientToken) { - h.logger.Info("令牌验证成功", zap.String("accessToken", request.AccessToken)) - c.JSON(http.StatusNoContent, gin.H{"valid": true}) - } else { - h.logger.Warn("令牌验证失败", zap.String("accessToken", request.AccessToken)) - c.JSON(http.StatusForbidden, gin.H{"valid": false}) - } -} - -// RefreshToken 刷新令牌 -func (h *YggdrasilHandler) RefreshToken(c *gin.Context) { - var request RefreshRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析刷新令牌请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - UUID, err := service.GetUUIDByAccessToken(h.container.DB, request.AccessToken) - if err != nil { - h.logger.Warn("刷新令牌失败: 无效的访问令牌", zap.String("token", request.AccessToken), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - userID, _ := service.GetUserIDByAccessToken(h.container.DB, request.AccessToken) - UUID = utils.FormatUUID(UUID) - - profile, err := service.GetProfileByUUID(h.container.DB, UUID) - if err != nil { - h.logger.Error("刷新令牌失败: 无法获取用户信息", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - var profileData map[string]interface{} - var userData map[string]interface{} - var profileID string - - if request.SelectedProfile != nil { - profileIDValue, ok := request.SelectedProfile["id"] - if !ok { - h.logger.Error("刷新令牌失败: 缺少配置文件ID", zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": "缺少配置文件ID"}) - return - } - - profileID, ok = profileIDValue.(string) - if !ok { - h.logger.Error("刷新令牌失败: 配置文件ID类型错误", zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": "配置文件ID必须是字符串"}) - return - } - - profileID = utils.FormatUUID(profileID) - - if profile.UserID != userID { - h.logger.Warn("刷新令牌失败: 用户不匹配", - zap.Int64("userId", userID), - zap.Int64("profileUserId", profile.UserID), - ) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrUserNotMatch}) - return - } - - profileData = service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile) - } - - user, _ := service.GetUserByID(userID) - if request.RequestUser && user != nil { - userData = service.SerializeUser(h.logger, user, UUID) - } - - newAccessToken, newClientToken, err := service.RefreshToken(h.container.DB, h.logger, - request.AccessToken, - request.ClientToken, - profileID, - ) - if err != nil { - h.logger.Error("刷新令牌失败", zap.Error(err), zap.Int64("userId", userID)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - h.logger.Info("刷新令牌成功", zap.Int64("userId", userID)) - c.JSON(http.StatusOK, RefreshResponse{ - AccessToken: newAccessToken, - ClientToken: newClientToken, - SelectedProfile: profileData, - User: userData, - }) -} - -// InvalidToken 使令牌失效 -func (h *YggdrasilHandler) InvalidToken(c *gin.Context) { - var request ValidTokenRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析使令牌失效请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - service.InvalidToken(h.container.DB, h.logger, request.AccessToken) - h.logger.Info("令牌已失效", zap.String("token", request.AccessToken)) - c.JSON(http.StatusNoContent, gin.H{}) -} - -// SignOut 用户登出 -func (h *YggdrasilHandler) SignOut(c *gin.Context) { - var request SignOutRequest - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析登出请求失败", zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - if !emailRegex.MatchString(request.Email) { - h.logger.Warn("登出失败: 邮箱格式不正确", zap.String("email", request.Email)) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidEmailFormat}) - return - } - - user, err := service.GetUserByEmail(request.Email) - if err != nil || user == nil { - h.logger.Warn("登出失败: 用户不存在", zap.String("email", request.Email), zap.Error(err)) - c.JSON(http.StatusBadRequest, gin.H{"error": "用户不存在"}) - return - } - - if err := service.VerifyPassword(h.container.DB, request.Password, user.ID); err != nil { - h.logger.Warn("登出失败: 密码错误", zap.Int64("userId", user.ID)) - c.JSON(http.StatusBadRequest, gin.H{"error": ErrWrongPassword}) - return - } - - service.InvalidUserTokens(h.container.DB, h.logger, user.ID) - h.logger.Info("用户登出成功", zap.Int64("userId", user.ID)) - c.JSON(http.StatusNoContent, gin.H{"valid": true}) -} - -// GetProfileByUUID 根据UUID获取档案 -func (h *YggdrasilHandler) GetProfileByUUID(c *gin.Context) { - uuid := utils.FormatUUID(c.Param("uuid")) - h.logger.Info("获取配置文件请求", zap.String("uuid", uuid)) - - profile, err := service.GetProfileByUUID(h.container.DB, uuid) - if err != nil { - h.logger.Error("获取配置文件失败", zap.Error(err), zap.String("uuid", uuid)) - standardResponse(c, http.StatusInternalServerError, nil, err.Error()) - return - } - - h.logger.Info("成功获取配置文件", zap.String("uuid", uuid), zap.String("name", profile.Name)) - c.JSON(http.StatusOK, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) -} - -// JoinServer 加入服务器 -func (h *YggdrasilHandler) JoinServer(c *gin.Context) { - var request JoinServerRequest - clientIP := c.ClientIP() - - if err := c.ShouldBindJSON(&request); err != nil { - h.logger.Error("解析加入服务器请求失败", zap.Error(err), zap.String("ip", clientIP)) - standardResponse(c, http.StatusBadRequest, nil, ErrInvalidRequest) - return - } - - h.logger.Info("收到加入服务器请求", - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - - if err := service.JoinServer(h.container.DB, h.logger, h.container.Redis, request.ServerID, request.AccessToken, request.SelectedProfile, clientIP); err != nil { - h.logger.Error("加入服务器失败", - zap.Error(err), - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - standardResponse(c, http.StatusInternalServerError, nil, ErrJoinServerFailed) - return - } - - h.logger.Info("加入服务器成功", - zap.String("serverId", request.ServerID), - zap.String("userUUID", request.SelectedProfile), - zap.String("ip", clientIP), - ) - c.Status(http.StatusNoContent) -} - -// HasJoinedServer 验证玩家是否已加入服务器 -func (h *YggdrasilHandler) HasJoinedServer(c *gin.Context) { - clientIP, _ := c.GetQuery("ip") - - serverID, exists := c.GetQuery("serverId") - if !exists || serverID == "" { - h.logger.Warn("缺少服务器ID参数", zap.String("ip", clientIP)) - standardResponse(c, http.StatusNoContent, nil, ErrServerIDRequired) - return - } - - username, exists := c.GetQuery("username") - if !exists || username == "" { - h.logger.Warn("缺少用户名参数", zap.String("serverId", serverID), zap.String("ip", clientIP)) - standardResponse(c, http.StatusNoContent, nil, ErrUsernameRequired) - return - } - - h.logger.Info("收到会话验证请求", - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("ip", clientIP), - ) - - if err := service.HasJoinedServer(h.logger, h.container.Redis, serverID, username, clientIP); err != nil { - h.logger.Warn("会话验证失败", - zap.Error(err), - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("ip", clientIP), - ) - standardResponse(c, http.StatusNoContent, nil, ErrSessionVerifyFailed) - return - } - - profile, err := service.GetProfileByUUID(h.container.DB, username) - if err != nil { - h.logger.Error("获取用户配置文件失败", zap.Error(err), zap.String("username", username)) - standardResponse(c, http.StatusNoContent, nil, ErrProfileNotFound) - return - } - - h.logger.Info("会话验证成功", - zap.String("serverId", serverID), - zap.String("username", username), - zap.String("uuid", profile.UUID), - ) - c.JSON(200, service.SerializeProfile(h.container.DB, h.logger, h.container.Redis, *profile)) -} - -// GetProfilesByName 批量获取配置文件 -func (h *YggdrasilHandler) GetProfilesByName(c *gin.Context) { - var names []string - - if err := c.ShouldBindJSON(&names); err != nil { - h.logger.Error("解析名称数组请求失败", zap.Error(err)) - standardResponse(c, http.StatusBadRequest, nil, ErrInvalidParams) - return - } - - h.logger.Info("接收到批量获取配置文件请求", zap.Int("count", len(names))) - - profiles, err := service.GetProfilesDataByNames(h.container.DB, names) - if err != nil { - h.logger.Error("获取配置文件失败", zap.Error(err)) - } - - h.logger.Info("成功获取配置文件", zap.Int("requested", len(names)), zap.Int("returned", len(profiles))) - c.JSON(http.StatusOK, profiles) -} - -// GetMetaData 获取Yggdrasil元数据 -func (h *YggdrasilHandler) GetMetaData(c *gin.Context) { - meta := gin.H{ - "implementationName": "CellAuth", - "implementationVersion": "0.0.1", - "serverName": "LittleLan's Yggdrasil Server Implementation.", - "links": gin.H{ - "homepage": "https://skin.littlelan.cn", - "register": "https://skin.littlelan.cn/auth", - }, - "feature.non_email_login": true, - "feature.enable_profile_key": true, - } - - skinDomains := []string{".hitwh.games", ".littlelan.cn"} - signature, err := service.GetPublicKeyFromRedisFunc(h.logger, h.container.Redis) - if err != nil { - h.logger.Error("获取公钥失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - h.logger.Info("提供元数据") - c.JSON(http.StatusOK, gin.H{ - "meta": meta, - "skinDomains": skinDomains, - "signaturePublickey": signature, - }) -} - -// GetPlayerCertificates 获取玩家证书 -func (h *YggdrasilHandler) GetPlayerCertificates(c *gin.Context) { - authHeader := c.GetHeader("Authorization") - if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization header not provided"}) - c.Abort() - return - } - - bearerPrefix := "Bearer " - if len(authHeader) < len(bearerPrefix) || authHeader[:len(bearerPrefix)] != bearerPrefix { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) - c.Abort() - return - } - - tokenID := authHeader[len(bearerPrefix):] - if tokenID == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid Authorization format"}) - c.Abort() - return - } - - uuid, err := service.GetUUIDByAccessToken(h.container.DB, tokenID) - if uuid == "" { - h.logger.Error("获取玩家UUID失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - uuid = utils.FormatUUID(uuid) - - certificate, err := service.GeneratePlayerCertificate(h.container.DB, h.logger, h.container.Redis, uuid) - if err != nil { - h.logger.Error("生成玩家证书失败", zap.Error(err)) - standardResponse(c, http.StatusInternalServerError, nil, ErrInternalServer) - return - } - - h.logger.Info("成功生成玩家证书") - c.JSON(http.StatusOK, certificate) -} diff --git a/internal/service/helpers_test.go b/internal/service/helpers_test.go new file mode 100644 index 0000000..043aba4 --- /dev/null +++ b/internal/service/helpers_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "errors" + "testing" +) + +// TestNormalizePagination_Basic 覆盖 NormalizePagination 的边界分支 +func TestNormalizePagination_Basic(t *testing.T) { + tests := []struct { + name string + page int + size int + wantPage int + wantPageSize int + }{ + {"page 小于 1", 0, 10, 1, 10}, + {"pageSize 小于 1", 1, 0, 1, 20}, + {"pageSize 大于 100", 2, 200, 2, 100}, + {"正常范围", 3, 30, 3, 30}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPage, gotSize := NormalizePagination(tt.page, tt.size) + if gotPage != tt.wantPage || gotSize != tt.wantPageSize { + t.Fatalf("NormalizePagination(%d,%d) = (%d,%d), want (%d,%d)", + tt.page, tt.size, gotPage, gotSize, tt.wantPage, tt.wantPageSize) + } + }) + } +} + +// TestWrapError 覆盖 WrapError 的 nil 与非 nil 分支 +func TestWrapError(t *testing.T) { + if err := WrapError(nil, "msg"); err != nil { + t.Fatalf("WrapError(nil, ...) 应返回 nil, got=%v", err) + } + + orig := errors.New("orig") + wrapped := WrapError(orig, "context") + if wrapped == nil { + t.Fatalf("WrapError 应返回非 nil 错误") + } + if wrapped.Error() == orig.Error() { + t.Fatalf("WrapError 应添加上下文信息, got=%v", wrapped) + } +} + + diff --git a/internal/service/mocks_test.go b/internal/service/mocks_test.go new file mode 100644 index 0000000..0c3572e --- /dev/null +++ b/internal/service/mocks_test.go @@ -0,0 +1,964 @@ +package service + +import ( + "carrotskin/internal/model" + "errors" +) + +// ============================================================================ +// Repository Mocks +// ============================================================================ + +// MockUserRepository 模拟UserRepository +type MockUserRepository struct { + users map[int64]*model.User + // 用于模拟错误的标志 + FailCreate bool + FailFindByID bool + FailFindByUsername bool + FailFindByEmail bool + FailUpdate bool +} + +func NewMockUserRepository() *MockUserRepository { + return &MockUserRepository{ + users: make(map[int64]*model.User), + } +} + +func (m *MockUserRepository) Create(user *model.User) error { + if m.FailCreate { + return errors.New("mock create error") + } + if user.ID == 0 { + user.ID = int64(len(m.users) + 1) + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserRepository) FindByID(id int64) (*model.User, error) { + if m.FailFindByID { + return nil, errors.New("mock find error") + } + if user, ok := m.users[id]; ok { + return user, nil + } + return nil, nil +} + +func (m *MockUserRepository) FindByUsername(username string) (*model.User, error) { + if m.FailFindByUsername { + return nil, errors.New("mock find by username error") + } + for _, user := range m.users { + if user.Username == username { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) { + if m.FailFindByEmail { + return nil, errors.New("mock find by email error") + } + for _, user := range m.users { + if user.Email == email { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserRepository) Update(user *model.User) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update fields error") + } + _, ok := m.users[id] + if !ok { + return errors.New("user not found") + } + return nil +} + +func (m *MockUserRepository) Delete(id int64) error { + delete(m.users, id) + return nil +} + +func (m *MockUserRepository) CreateLoginLog(log *model.UserLoginLog) error { + return nil +} + +func (m *MockUserRepository) CreatePointLog(log *model.UserPointLog) error { + return nil +} + +func (m *MockUserRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { + return nil +} + +// MockProfileRepository 模拟ProfileRepository +type MockProfileRepository struct { + profiles map[string]*model.Profile + userProfiles map[int64][]*model.Profile + nextID int64 + FailCreate bool + FailFind bool + FailUpdate bool + FailDelete bool +} + +func NewMockProfileRepository() *MockProfileRepository { + return &MockProfileRepository{ + profiles: make(map[string]*model.Profile), + userProfiles: make(map[int64][]*model.Profile), + nextID: 1, + } +} + +func (m *MockProfileRepository) Create(profile *model.Profile) error { + if m.FailCreate { + return errors.New("mock create error") + } + m.profiles[profile.UUID] = profile + m.userProfiles[profile.UserID] = append(m.userProfiles[profile.UserID], profile) + return nil +} + +func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if profile, ok := m.profiles[uuid]; ok { + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + for _, profile := range m.profiles { + if profile.Name == name { + return profile, nil + } + } + return nil, nil +} + +func (m *MockProfileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + return m.userProfiles[userID], nil +} + +func (m *MockProfileRepository) Update(profile *model.Profile) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.profiles[profile.UUID] = profile + return nil +} + +func (m *MockProfileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update error") + } + return nil +} + +func (m *MockProfileRepository) Delete(uuid string) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.profiles, uuid) + return nil +} + +func (m *MockProfileRepository) CountByUserID(userID int64) (int64, error) { + return int64(len(m.userProfiles[userID])), nil +} + +func (m *MockProfileRepository) SetActive(uuid string, userID int64) error { + return nil +} + +func (m *MockProfileRepository) UpdateLastUsedAt(uuid string) error { + return nil +} + +func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, error) { + var result []*model.Profile + for _, name := range names { + for _, profile := range m.profiles { + if profile.Name == name { + result = append(result, profile) + } + } + } + return result, nil +} + +func (m *MockProfileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { + return nil, nil +} + +func (m *MockProfileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { + return nil +} + +// MockTextureRepository 模拟TextureRepository +type MockTextureRepository struct { + textures map[int64]*model.Texture + favorites map[int64]map[int64]bool // userID -> textureID -> favorited + nextID int64 + FailCreate bool + FailFind bool + FailUpdate bool + FailDelete bool +} + +func NewMockTextureRepository() *MockTextureRepository { + return &MockTextureRepository{ + textures: make(map[int64]*model.Texture), + favorites: make(map[int64]map[int64]bool), + nextID: 1, + } +} + +func (m *MockTextureRepository) Create(texture *model.Texture) error { + if m.FailCreate { + return errors.New("mock create error") + } + if texture.ID == 0 { + texture.ID = m.nextID + m.nextID++ + } + m.textures[texture.ID] = texture + return nil +} + +func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if texture, ok := m.textures[id]; ok { + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + for _, texture := range m.textures { + if texture.Hash == hash { + return texture, nil + } + } + return nil, nil +} + +func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailFind { + return nil, 0, errors.New("mock find error") + } + var result []*model.Texture + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + result = append(result, texture) + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailFind { + return nil, 0, errors.New("mock find error") + } + var result []*model.Texture + for _, texture := range m.textures { + if publicOnly && !texture.IsPublic { + continue + } + result = append(result, texture) + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) Update(texture *model.Texture) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.textures[texture.ID] = texture + return nil +} + +func (m *MockTextureRepository) UpdateFields(id int64, fields map[string]interface{}) error { + if m.FailUpdate { + return errors.New("mock update error") + } + return nil +} + +func (m *MockTextureRepository) Delete(id int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.textures, id) + return nil +} + +func (m *MockTextureRepository) IncrementDownloadCount(id int64) error { + if texture, ok := m.textures[id]; ok { + texture.DownloadCount++ + } + return nil +} + +func (m *MockTextureRepository) IncrementFavoriteCount(id int64) error { + if texture, ok := m.textures[id]; ok { + texture.FavoriteCount++ + } + return nil +} + +func (m *MockTextureRepository) DecrementFavoriteCount(id int64) error { + if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 { + texture.FavoriteCount-- + } + return nil +} + +func (m *MockTextureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { + return nil +} + +func (m *MockTextureRepository) IsFavorited(userID, textureID int64) (bool, error) { + if userFavs, ok := m.favorites[userID]; ok { + return userFavs[textureID], nil + } + return false, nil +} + +func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error { + if m.favorites[userID] == nil { + m.favorites[userID] = make(map[int64]bool) + } + m.favorites[userID][textureID] = true + return nil +} + +func (m *MockTextureRepository) RemoveFavorite(userID, textureID int64) error { + if userFavs, ok := m.favorites[userID]; ok { + delete(userFavs, textureID) + } + return nil +} + +func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var result []*model.Texture + if userFavs, ok := m.favorites[userID]; ok { + for textureID := range userFavs { + if texture, exists := m.textures[textureID]; exists { + result = append(result, texture) + } + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, error) { + var count int64 + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + count++ + } + } + return count, nil +} + +// MockTokenRepository 模拟TokenRepository +type MockTokenRepository struct { + tokens map[string]*model.Token + userTokens map[int64][]*model.Token + FailCreate bool + FailFind bool + FailDelete bool +} + +func NewMockTokenRepository() *MockTokenRepository { + return &MockTokenRepository{ + tokens: make(map[string]*model.Token), + userTokens: make(map[int64][]*model.Token), + } +} + +func (m *MockTokenRepository) Create(token *model.Token) error { + if m.FailCreate { + return errors.New("mock create error") + } + m.tokens[token.AccessToken] = token + m.userTokens[token.UserID] = append(m.userTokens[token.UserID], token) + return nil +} + +func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token, nil + } + return nil, errors.New("token not found") +} + +func (m *MockTokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { + if m.FailFind { + return nil, errors.New("mock find error") + } + return m.userTokens[userId], nil +} + +func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { + if m.FailFind { + return "", errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token.ProfileId, nil + } + return "", errors.New("token not found") +} + +func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { + if m.FailFind { + return 0, errors.New("mock find error") + } + if token, ok := m.tokens[accessToken]; ok { + return token.UserID, nil + } + return 0, errors.New("token not found") +} + +func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.tokens, accessToken) + return nil +} + +func (m *MockTokenRepository) DeleteByUserID(userId int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + for _, token := range m.userTokens[userId] { + delete(m.tokens, token.AccessToken) + } + m.userTokens[userId] = nil + return nil +} + +func (m *MockTokenRepository) BatchDelete(accessTokens []string) (int64, error) { + if m.FailDelete { + return 0, errors.New("mock delete error") + } + var count int64 + for _, accessToken := range accessTokens { + if _, ok := m.tokens[accessToken]; ok { + delete(m.tokens, accessToken) + count++ + } + } + return count, nil +} + +// MockSystemConfigRepository 模拟SystemConfigRepository +type MockSystemConfigRepository struct { + configs map[string]*model.SystemConfig +} + +func NewMockSystemConfigRepository() *MockSystemConfigRepository { + return &MockSystemConfigRepository{ + configs: make(map[string]*model.SystemConfig), + } +} + +func (m *MockSystemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { + if config, ok := m.configs[key]; ok { + return config, nil + } + return nil, nil +} + +func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) { + var result []model.SystemConfig + for _, v := range m.configs { + result = append(result, *v) + } + return result, nil +} + +func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) { + var result []model.SystemConfig + for _, v := range m.configs { + result = append(result, *v) + } + return result, nil +} + +func (m *MockSystemConfigRepository) Update(config *model.SystemConfig) error { + m.configs[config.Key] = config + return nil +} + +func (m *MockSystemConfigRepository) UpdateValue(key, value string) error { + if config, ok := m.configs[key]; ok { + config.Value = value + return nil + } + return errors.New("config not found") +} + +// ============================================================================ +// Service Mocks +// ============================================================================ + +// MockUserService 模拟UserService +type MockUserService struct { + users map[int64]*model.User + maxProfilesPerUser int + maxTexturesPerUser int + FailRegister bool + FailLogin bool + FailGetByID bool + FailUpdate bool +} + +func NewMockUserService() *MockUserService { + return &MockUserService{ + users: make(map[int64]*model.User), + maxProfilesPerUser: 5, + maxTexturesPerUser: 50, + } +} + +func (m *MockUserService) Register(username, password, email, avatar string) (*model.User, string, error) { + if m.FailRegister { + return nil, "", errors.New("mock register error") + } + user := &model.User{ + ID: int64(len(m.users) + 1), + Username: username, + Email: email, + Avatar: avatar, + Status: 1, + } + m.users[user.ID] = user + return user, "mock-token", nil +} + +func (m *MockUserService) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { + if m.FailLogin { + return nil, "", errors.New("mock login error") + } + for _, user := range m.users { + if user.Username == usernameOrEmail || user.Email == usernameOrEmail { + return user, "mock-token", nil + } + } + return nil, "", errors.New("user not found") +} + +func (m *MockUserService) GetByID(id int64) (*model.User, error) { + if m.FailGetByID { + return nil, errors.New("mock get by id error") + } + if user, ok := m.users[id]; ok { + return user, nil + } + return nil, nil +} + +func (m *MockUserService) GetByEmail(email string) (*model.User, error) { + for _, user := range m.users { + if user.Email == email { + return user, nil + } + } + return nil, nil +} + +func (m *MockUserService) UpdateInfo(user *model.User) error { + if m.FailUpdate { + return errors.New("mock update error") + } + m.users[user.ID] = user + return nil +} + +func (m *MockUserService) UpdateAvatar(userID int64, avatarURL string) error { + if m.FailUpdate { + return errors.New("mock update error") + } + if user, ok := m.users[userID]; ok { + user.Avatar = avatarURL + } + return nil +} + +func (m *MockUserService) ChangePassword(userID int64, oldPassword, newPassword string) error { + return nil +} + +func (m *MockUserService) ResetPassword(email, newPassword string) error { + return nil +} + +func (m *MockUserService) ChangeEmail(userID int64, newEmail string) error { + if user, ok := m.users[userID]; ok { + user.Email = newEmail + } + return nil +} + +func (m *MockUserService) ValidateAvatarURL(avatarURL string) error { + return nil +} + +func (m *MockUserService) GetMaxProfilesPerUser() int { + return m.maxProfilesPerUser +} + +func (m *MockUserService) GetMaxTexturesPerUser() int { + return m.maxTexturesPerUser +} + +// MockProfileService 模拟ProfileService +type MockProfileService struct { + profiles map[string]*model.Profile + FailCreate bool + FailGet bool + FailUpdate bool + FailDelete bool +} + +func NewMockProfileService() *MockProfileService { + return &MockProfileService{ + profiles: make(map[string]*model.Profile), + } +} + +func (m *MockProfileService) Create(userID int64, name string) (*model.Profile, error) { + if m.FailCreate { + return nil, errors.New("mock create error") + } + profile := &model.Profile{ + UUID: "mock-uuid-" + name, + UserID: userID, + Name: name, + } + m.profiles[profile.UUID] = profile + return profile, nil +} + +func (m *MockProfileService) GetByUUID(uuid string) (*model.Profile, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + if profile, ok := m.profiles[uuid]; ok { + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileService) GetByUserID(userID int64) ([]*model.Profile, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + var result []*model.Profile + for _, profile := range m.profiles { + if profile.UserID == userID { + result = append(result, profile) + } + } + return result, nil +} + +func (m *MockProfileService) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { + if m.FailUpdate { + return nil, errors.New("mock update error") + } + if profile, ok := m.profiles[uuid]; ok { + if name != nil { + profile.Name = *name + } + if skinID != nil { + profile.SkinID = skinID + } + if capeID != nil { + profile.CapeID = capeID + } + return profile, nil + } + return nil, errors.New("profile not found") +} + +func (m *MockProfileService) Delete(uuid string, userID int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.profiles, uuid) + return nil +} + +func (m *MockProfileService) SetActive(uuid string, userID int64) error { + return nil +} + +func (m *MockProfileService) CheckLimit(userID int64, maxProfiles int) error { + count := 0 + for _, profile := range m.profiles { + if profile.UserID == userID { + count++ + } + } + if count >= maxProfiles { + return errors.New("达到档案数量上限") + } + return nil +} + +func (m *MockProfileService) GetByNames(names []string) ([]*model.Profile, error) { + var result []*model.Profile + for _, name := range names { + for _, profile := range m.profiles { + if profile.Name == name { + result = append(result, profile) + } + } + } + return result, nil +} + +func (m *MockProfileService) GetByProfileName(name string) (*model.Profile, error) { + for _, profile := range m.profiles { + if profile.Name == name { + return profile, nil + } + } + return nil, errors.New("profile not found") +} + +// MockTextureService 模拟TextureService +type MockTextureService struct { + textures map[int64]*model.Texture + nextID int64 + FailCreate bool + FailGet bool + FailUpdate bool + FailDelete bool +} + +func NewMockTextureService() *MockTextureService { + return &MockTextureService{ + textures: make(map[int64]*model.Texture), + nextID: 1, + } +} + +func (m *MockTextureService) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { + if m.FailCreate { + return nil, errors.New("mock create error") + } + texture := &model.Texture{ + ID: m.nextID, + UploaderID: uploaderID, + Name: name, + Description: description, + URL: url, + Hash: hash, + Size: size, + IsPublic: isPublic, + IsSlim: isSlim, + } + m.textures[texture.ID] = texture + m.nextID++ + return texture, nil +} + +func (m *MockTextureService) GetByID(id int64) (*model.Texture, error) { + if m.FailGet { + return nil, errors.New("mock get error") + } + if texture, ok := m.textures[id]; ok { + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureService) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailGet { + return nil, 0, errors.New("mock get error") + } + var result []*model.Texture + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + result = append(result, texture) + } + } + return result, int64(len(result)), nil +} + +func (m *MockTextureService) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + if m.FailGet { + return nil, 0, errors.New("mock get error") + } + var result []*model.Texture + for _, texture := range m.textures { + if publicOnly && !texture.IsPublic { + continue + } + result = append(result, texture) + } + return result, int64(len(result)), nil +} + +func (m *MockTextureService) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { + if m.FailUpdate { + return nil, errors.New("mock update error") + } + if texture, ok := m.textures[textureID]; ok { + if name != "" { + texture.Name = name + } + if description != "" { + texture.Description = description + } + if isPublic != nil { + texture.IsPublic = *isPublic + } + return texture, nil + } + return nil, errors.New("texture not found") +} + +func (m *MockTextureService) Delete(textureID, uploaderID int64) error { + if m.FailDelete { + return errors.New("mock delete error") + } + delete(m.textures, textureID) + return nil +} + +func (m *MockTextureService) ToggleFavorite(userID, textureID int64) (bool, error) { + return true, nil +} + +func (m *MockTextureService) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + return nil, 0, nil +} + +func (m *MockTextureService) CheckUploadLimit(uploaderID int64, maxTextures int) error { + count := 0 + for _, texture := range m.textures { + if texture.UploaderID == uploaderID { + count++ + } + } + if count >= maxTextures { + return errors.New("达到材质数量上限") + } + return nil +} + +// MockTokenService 模拟TokenService +type MockTokenService struct { + tokens map[string]*model.Token + FailCreate bool + FailValidate bool + FailRefresh bool +} + +func NewMockTokenService() *MockTokenService { + return &MockTokenService{ + tokens: make(map[string]*model.Token), + } +} + +func (m *MockTokenService) Create(userID int64, uuid, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { + if m.FailCreate { + return nil, nil, "", "", errors.New("mock create error") + } + accessToken := "mock-access-token" + if clientToken == "" { + clientToken = "mock-client-token" + } + token := &model.Token{ + AccessToken: accessToken, + ClientToken: clientToken, + UserID: userID, + ProfileId: uuid, + Usable: true, + } + m.tokens[accessToken] = token + return nil, nil, accessToken, clientToken, nil +} + +func (m *MockTokenService) Validate(accessToken, clientToken string) bool { + if m.FailValidate { + return false + } + if token, ok := m.tokens[accessToken]; ok { + if clientToken == "" || token.ClientToken == clientToken { + return token.Usable + } + } + return false +} + +func (m *MockTokenService) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { + if m.FailRefresh { + return "", "", errors.New("mock refresh error") + } + return "new-access-token", clientToken, nil +} + +func (m *MockTokenService) Invalidate(accessToken string) { + delete(m.tokens, accessToken) +} + +func (m *MockTokenService) InvalidateUserTokens(userID int64) { + for key, token := range m.tokens { + if token.UserID == userID { + delete(m.tokens, key) + } + } +} + +func (m *MockTokenService) GetUUIDByAccessToken(accessToken string) (string, error) { + if token, ok := m.tokens[accessToken]; ok { + return token.ProfileId, nil + } + return "", errors.New("token not found") +} + +func (m *MockTokenService) GetUserIDByAccessToken(accessToken string) (int64, error) { + if token, ok := m.tokens[accessToken]; ok { + return token.UserID, nil + } + return 0, errors.New("token not found") +} diff --git a/internal/service/profile_service.go b/internal/service/profile_service.go index d3e2057..a956793 100644 --- a/internal/service/profile_service.go +++ b/internal/service/profile_service.go @@ -11,35 +11,54 @@ import ( "fmt" "github.com/google/uuid" - "github.com/jackc/pgx/v5" + "go.uber.org/zap" "gorm.io/gorm" ) -// CreateProfile 创建档案 -func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, error) { +// profileServiceImpl ProfileService的实现 +type profileServiceImpl struct { + profileRepo repository.ProfileRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewProfileService 创建ProfileService实例 +func NewProfileService( + profileRepo repository.ProfileRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) ProfileService { + return &profileServiceImpl{ + profileRepo: profileRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) { // 验证用户存在 - user, err := EnsureUserExists(userID) - if err != nil { - return nil, err + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { + return nil, errors.New("用户不存在") } if user.Status != 1 { - return nil, fmt.Errorf("用户状态异常") + return nil, errors.New("用户状态异常") } // 检查角色名是否已存在 - existingName, err := repository.FindProfileByName(name) + existingName, err := s.profileRepo.FindByName(name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, WrapError(err, "查询角色名失败") + return nil, fmt.Errorf("查询角色名失败: %w", err) } if existingName != nil { - return nil, fmt.Errorf("角色名已被使用") + return nil, errors.New("角色名已被使用") } // 生成UUID和RSA密钥 profileUUID := uuid.New().String() - privateKey, err := generateRSAPrivateKey() + privateKey, err := generateRSAPrivateKeyInternal() if err != nil { - return nil, WrapError(err, "生成RSA密钥失败") + return nil, fmt.Errorf("生成RSA密钥失败: %w", err) } // 创建档案 @@ -51,55 +70,59 @@ func CreateProfile(db *gorm.DB, userID int64, name string) (*model.Profile, erro IsActive: true, } - if err := repository.CreateProfile(profile); err != nil { - return nil, WrapError(err, "创建档案失败") + if err := s.profileRepo.Create(profile); err != nil { + return nil, fmt.Errorf("创建档案失败: %w", err) } // 设置活跃状态 - if err := repository.SetActiveProfile(profileUUID, userID); err != nil { - return nil, WrapError(err, "设置活跃状态失败") + if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { + return nil, fmt.Errorf("设置活跃状态失败: %w", err) } return profile, nil } -// GetProfileByUUID 获取档案详情 -func GetProfileByUUID(db *gorm.DB, uuid string) (*model.Profile, error) { - profile, err := repository.FindProfileByUUID(uuid) +func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrProfileNotFound } - return nil, WrapError(err, "查询档案失败") + return nil, fmt.Errorf("查询档案失败: %w", err) } return profile, nil } -// GetUserProfiles 获取用户的所有档案 -func GetUserProfiles(db *gorm.DB, userID int64) ([]*model.Profile, error) { - profiles, err := repository.FindProfilesByUserID(userID) +func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) { + profiles, err := s.profileRepo.FindByUserID(userID) if err != nil { - return nil, WrapError(err, "查询档案列表失败") + return nil, fmt.Errorf("查询档案列表失败: %w", err) } return profiles, nil } -// UpdateProfile 更新档案 -func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { +func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { // 获取档案并验证权限 - profile, err := GetProfileWithPermissionCheck(uuid, userID) + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { - return nil, err + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrProfileNotFound + } + return nil, fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return nil, ErrProfileNoPermission } // 检查角色名是否重复 if name != nil && *name != profile.Name { - existingName, err := repository.FindProfileByName(*name) + existingName, err := s.profileRepo.FindByName(*name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, WrapError(err, "查询角色名失败") + return nil, fmt.Errorf("查询角色名失败: %w", err) } if existingName != nil { - return nil, fmt.Errorf("角色名已被使用") + return nil, errors.New("角色名已被使用") } profile.Name = *name } @@ -112,47 +135,62 @@ func UpdateProfile(db *gorm.DB, uuid string, userID int64, name *string, skinID, profile.CapeID = capeID } - if err := repository.UpdateProfile(profile); err != nil { - return nil, WrapError(err, "更新档案失败") + if err := s.profileRepo.Update(profile); err != nil { + return nil, fmt.Errorf("更新档案失败: %w", err) } - return repository.FindProfileByUUID(uuid) + return s.profileRepo.FindByUUID(uuid) } -// DeleteProfile 删除档案 -func DeleteProfile(db *gorm.DB, uuid string, userID int64) error { - if _, err := GetProfileWithPermissionCheck(uuid, userID); err != nil { - return err - } - - if err := repository.DeleteProfile(uuid); err != nil { - return WrapError(err, "删除档案失败") - } - return nil -} - -// SetActiveProfile 设置活跃档案 -func SetActiveProfile(db *gorm.DB, uuid string, userID int64) error { - if _, err := GetProfileWithPermissionCheck(uuid, userID); err != nil { - return err - } - - if err := repository.SetActiveProfile(uuid, userID); err != nil { - return WrapError(err, "设置活跃状态失败") - } - - if err := repository.UpdateProfileLastUsedAt(uuid); err != nil { - return WrapError(err, "更新使用时间失败") - } - - return nil -} - -// CheckProfileLimit 检查用户档案数量限制 -func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error { - count, err := repository.CountProfilesByUserID(userID) +func (s *profileServiceImpl) Delete(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) if err != nil { - return WrapError(err, "查询档案数量失败") + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.Delete(uuid); err != nil { + return fmt.Errorf("删除档案失败: %w", err) + } + return nil +} + +func (s *profileServiceImpl) SetActive(uuid string, userID int64) error { + // 获取档案并验证权限 + profile, err := s.profileRepo.FindByUUID(uuid) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ErrProfileNotFound + } + return fmt.Errorf("查询档案失败: %w", err) + } + + if profile.UserID != userID { + return ErrProfileNoPermission + } + + if err := s.profileRepo.SetActive(uuid, userID); err != nil { + return fmt.Errorf("设置活跃状态失败: %w", err) + } + + if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { + return fmt.Errorf("更新使用时间失败: %w", err) + } + + return nil +} + +func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error { + count, err := s.profileRepo.CountByUserID(userID) + if err != nil { + return fmt.Errorf("查询档案数量失败: %w", err) } if int(count) >= maxProfiles { @@ -161,8 +199,24 @@ func CheckProfileLimit(db *gorm.DB, userID int64, maxProfiles int) error { return nil } -// generateRSAPrivateKey 生成RSA-2048私钥(PEM格式) -func generateRSAPrivateKey() (string, error) { +func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) { + profiles, err := s.profileRepo.GetByNames(names) + if err != nil { + return nil, fmt.Errorf("查找失败: %w", err) + } + return profiles, nil +} + +func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) { + profile, err := s.profileRepo.FindByName(name) + if err != nil { + return nil, errors.New("用户角色未创建") + } + return profile, nil +} + +// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式) +func generateRSAPrivateKeyInternal() (string, error) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return "", err @@ -177,33 +231,4 @@ func generateRSAPrivateKey() (string, error) { return string(privateKeyPEM), nil } -func ValidateProfileByUserID(db *gorm.DB, userId int64, UUID string) (bool, error) { - if userId == 0 || UUID == "" { - return false, errors.New("用户ID或配置文件ID不能为空") - } - profile, err := repository.FindProfileByUUID(UUID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return false, errors.New("配置文件不存在") - } - return false, WrapError(err, "验证配置文件失败") - } - return profile.UserID == userId, nil -} - -func GetProfilesDataByNames(db *gorm.DB, names []string) ([]*model.Profile, error) { - profiles, err := repository.GetProfilesByNames(names) - if err != nil { - return nil, WrapError(err, "查找失败") - } - return profiles, nil -} - -func GetProfileKeyPair(db *gorm.DB, profileId string) (*model.KeyPair, error) { - keyPair, err := repository.GetProfileKeyPair(profileId) - if err != nil { - return nil, WrapError(err, "查找失败") - } - return keyPair, nil -} diff --git a/internal/service/profile_service_impl.go b/internal/service/profile_service_impl.go deleted file mode 100644 index a956793..0000000 --- a/internal/service/profile_service_impl.go +++ /dev/null @@ -1,234 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - - "github.com/google/uuid" - "go.uber.org/zap" - "gorm.io/gorm" -) - -// profileServiceImpl ProfileService的实现 -type profileServiceImpl struct { - profileRepo repository.ProfileRepository - userRepo repository.UserRepository - logger *zap.Logger -} - -// NewProfileService 创建ProfileService实例 -func NewProfileService( - profileRepo repository.ProfileRepository, - userRepo repository.UserRepository, - logger *zap.Logger, -) ProfileService { - return &profileServiceImpl{ - profileRepo: profileRepo, - userRepo: userRepo, - logger: logger, - } -} - -func (s *profileServiceImpl) Create(userID int64, name string) (*model.Profile, error) { - // 验证用户存在 - user, err := s.userRepo.FindByID(userID) - if err != nil || user == nil { - return nil, errors.New("用户不存在") - } - if user.Status != 1 { - return nil, errors.New("用户状态异常") - } - - // 检查角色名是否已存在 - existingName, err := s.profileRepo.FindByName(name) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("查询角色名失败: %w", err) - } - if existingName != nil { - return nil, errors.New("角色名已被使用") - } - - // 生成UUID和RSA密钥 - profileUUID := uuid.New().String() - privateKey, err := generateRSAPrivateKeyInternal() - if err != nil { - return nil, fmt.Errorf("生成RSA密钥失败: %w", err) - } - - // 创建档案 - profile := &model.Profile{ - UUID: profileUUID, - UserID: userID, - Name: name, - RSAPrivateKey: privateKey, - IsActive: true, - } - - if err := s.profileRepo.Create(profile); err != nil { - return nil, fmt.Errorf("创建档案失败: %w", err) - } - - // 设置活跃状态 - if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { - return nil, fmt.Errorf("设置活跃状态失败: %w", err) - } - - return profile, nil -} - -func (s *profileServiceImpl) GetByUUID(uuid string) (*model.Profile, error) { - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProfileNotFound - } - return nil, fmt.Errorf("查询档案失败: %w", err) - } - return profile, nil -} - -func (s *profileServiceImpl) GetByUserID(userID int64) ([]*model.Profile, error) { - profiles, err := s.profileRepo.FindByUserID(userID) - if err != nil { - return nil, fmt.Errorf("查询档案列表失败: %w", err) - } - return profiles, nil -} - -func (s *profileServiceImpl) Update(uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ErrProfileNotFound - } - return nil, fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return nil, ErrProfileNoPermission - } - - // 检查角色名是否重复 - if name != nil && *name != profile.Name { - existingName, err := s.profileRepo.FindByName(*name) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("查询角色名失败: %w", err) - } - if existingName != nil { - return nil, errors.New("角色名已被使用") - } - profile.Name = *name - } - - // 更新皮肤和披风 - if skinID != nil { - profile.SkinID = skinID - } - if capeID != nil { - profile.CapeID = capeID - } - - if err := s.profileRepo.Update(profile); err != nil { - return nil, fmt.Errorf("更新档案失败: %w", err) - } - - return s.profileRepo.FindByUUID(uuid) -} - -func (s *profileServiceImpl) Delete(uuid string, userID int64) error { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProfileNotFound - } - return fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return ErrProfileNoPermission - } - - if err := s.profileRepo.Delete(uuid); err != nil { - return fmt.Errorf("删除档案失败: %w", err) - } - return nil -} - -func (s *profileServiceImpl) SetActive(uuid string, userID int64) error { - // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ErrProfileNotFound - } - return fmt.Errorf("查询档案失败: %w", err) - } - - if profile.UserID != userID { - return ErrProfileNoPermission - } - - if err := s.profileRepo.SetActive(uuid, userID); err != nil { - return fmt.Errorf("设置活跃状态失败: %w", err) - } - - if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { - return fmt.Errorf("更新使用时间失败: %w", err) - } - - return nil -} - -func (s *profileServiceImpl) CheckLimit(userID int64, maxProfiles int) error { - count, err := s.profileRepo.CountByUserID(userID) - if err != nil { - return fmt.Errorf("查询档案数量失败: %w", err) - } - - if int(count) >= maxProfiles { - return fmt.Errorf("已达到档案数量上限(%d个)", maxProfiles) - } - return nil -} - -func (s *profileServiceImpl) GetByNames(names []string) ([]*model.Profile, error) { - profiles, err := s.profileRepo.GetByNames(names) - if err != nil { - return nil, fmt.Errorf("查找失败: %w", err) - } - return profiles, nil -} - -func (s *profileServiceImpl) GetByProfileName(name string) (*model.Profile, error) { - profile, err := s.profileRepo.FindByName(name) - if err != nil { - return nil, errors.New("用户角色未创建") - } - return profile, nil -} - -// generateRSAPrivateKeyInternal 生成RSA-2048私钥(PEM格式) -func generateRSAPrivateKeyInternal() (string, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return "", err - } - - privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) - privateKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: privateKeyBytes, - }) - - return string(privateKeyPEM), nil -} - - diff --git a/internal/service/profile_service_test.go b/internal/service/profile_service_test.go index 37fef82..cf71362 100644 --- a/internal/service/profile_service_test.go +++ b/internal/service/profile_service_test.go @@ -1,7 +1,10 @@ package service import ( + "carrotskin/internal/model" "testing" + + "go.uber.org/zap" ) // TestProfileService_Validation 测试Profile服务验证逻辑 @@ -347,22 +350,22 @@ func TestGenerateRSAPrivateKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - privateKey, err := generateRSAPrivateKey() + privateKey, err := generateRSAPrivateKeyInternal() if (err != nil) != tt.wantError { - t.Errorf("generateRSAPrivateKey() error = %v, wantError %v", err, tt.wantError) + t.Errorf("generateRSAPrivateKeyInternal() error = %v, wantError %v", err, tt.wantError) return } if !tt.wantError { if privateKey == "" { - t.Error("generateRSAPrivateKey() 返回的私钥不应为空") + t.Error("generateRSAPrivateKeyInternal() 返回的私钥不应为空") } // 验证PEM格式 if len(privateKey) < 100 { - t.Errorf("generateRSAPrivateKey() 返回的私钥长度异常: %d", len(privateKey)) + t.Errorf("generateRSAPrivateKeyInternal() 返回的私钥长度异常: %d", len(privateKey)) } // 验证包含PEM头部 if !contains(privateKey, "BEGIN RSA PRIVATE KEY") { - t.Error("generateRSAPrivateKey() 返回的私钥应包含PEM头部") + t.Error("generateRSAPrivateKeyInternal() 返回的私钥应包含PEM头部") } } }) @@ -373,9 +376,9 @@ func TestGenerateRSAPrivateKey(t *testing.T) { func TestGenerateRSAPrivateKey_Uniqueness(t *testing.T) { keys := make(map[string]bool) for i := 0; i < 10; i++ { - key, err := generateRSAPrivateKey() + key, err := generateRSAPrivateKeyInternal() if err != nil { - t.Fatalf("generateRSAPrivateKey() 失败: %v", err) + t.Fatalf("generateRSAPrivateKeyInternal() 失败: %v", err) } if keys[key] { t.Errorf("第%d次生成的密钥与之前重复", i+1) @@ -404,3 +407,319 @@ func containsMiddle(s, substr string) bool { } return false } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestProfileServiceImpl_Create 测试创建Profile +func TestProfileServiceImpl_Create(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户 + testUser := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Status: 1, + } + userRepo.Create(testUser) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + userID int64 + profileName string + wantErr bool + errMsg string + setupMocks func() + }{ + { + name: "正常创建Profile", + userID: 1, + profileName: "TestProfile", + wantErr: false, + }, + { + name: "用户不存在", + userID: 999, + profileName: "TestProfile2", + wantErr: true, + errMsg: "用户不存在", + }, + { + name: "角色名已存在", + userID: 1, + profileName: "ExistingProfile", + wantErr: true, + errMsg: "角色名已被使用", + setupMocks: func() { + profileRepo.Create(&model.Profile{ + UUID: "existing-uuid", + UserID: 2, + Name: "ExistingProfile", + }) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupMocks != nil { + tt.setupMocks() + } + + profile, err := profileService.Create(tt.userID, tt.profileName) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if profile == nil { + t.Error("返回的Profile不应为nil") + } + if profile.Name != tt.profileName { + t.Errorf("Profile名称不匹配: got %v, want %v", profile.Name, tt.profileName) + } + if profile.UUID == "" { + t.Error("Profile UUID不应为空") + } + } + }) + } +} + +// TestProfileServiceImpl_GetByUUID 测试获取Profile +func TestProfileServiceImpl_GetByUUID(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "test-uuid-123", + UserID: 1, + Name: "TestProfile", + } + profileRepo.Create(testProfile) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + uuid string + wantErr bool + }{ + { + name: "获取存在的Profile", + uuid: "test-uuid-123", + wantErr: false, + }, + { + name: "获取不存在的Profile", + uuid: "non-existent-uuid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile, err := profileService.GetByUUID(tt.uuid) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if profile == nil { + t.Error("返回的Profile不应为nil") + } + if profile.UUID != tt.uuid { + t.Errorf("Profile UUID不匹配: got %v, want %v", profile.UUID, tt.uuid) + } + } + }) + } +} + +// TestProfileServiceImpl_Delete 测试删除Profile +func TestProfileServiceImpl_Delete(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "delete-test-uuid", + UserID: 1, + Name: "DeleteTestProfile", + } + profileRepo.Create(testProfile) + + profileService := NewProfileService(profileRepo, userRepo, logger) + + tests := []struct { + name string + uuid string + userID int64 + wantErr bool + }{ + { + name: "正常删除", + uuid: "delete-test-uuid", + userID: 1, + wantErr: false, + }, + { + name: "用户ID不匹配", + uuid: "delete-test-uuid", + userID: 2, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := profileService.Delete(tt.uuid, tt.userID) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + } + } + }) + } +} + +// TestProfileServiceImpl_GetByUserID 测试按用户获取档案列表 +func TestProfileServiceImpl_GetByUserID(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 为用户 1 和 2 预置不同档案 + profileRepo.Create(&model.Profile{UUID: "p1", UserID: 1, Name: "P1"}) + profileRepo.Create(&model.Profile{UUID: "p2", UserID: 1, Name: "P2"}) + profileRepo.Create(&model.Profile{UUID: "p3", UserID: 2, Name: "P3"}) + + svc := NewProfileService(profileRepo, userRepo, logger) + + list, err := svc.GetByUserID(1) + if err != nil { + t.Fatalf("GetByUserID 失败: %v", err) + } + if len(list) != 2 { + t.Fatalf("GetByUserID 返回数量错误, got=%d, want=2", len(list)) + } +} + +// TestProfileServiceImpl_Update_And_SetActive 测试 Update 与 SetActive +func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + profile := &model.Profile{ + UUID: "u1", + UserID: 1, + Name: "OldName", + } + profileRepo.Create(profile) + + svc := NewProfileService(profileRepo, userRepo, logger) + + // 正常更新名称与皮肤/披风 + newName := "NewName" + var skinID int64 = 10 + var capeID int64 = 20 + updated, err := svc.Update("u1", 1, &newName, &skinID, &capeID) + if err != nil { + t.Fatalf("Update 正常情况失败: %v", err) + } + if updated == nil || updated.Name != newName { + t.Fatalf("Update 未更新名称, got=%+v", updated) + } + + // 用户无权限 + if _, err := svc.Update("u1", 2, &newName, nil, nil); err == nil { + t.Fatalf("Update 在无权限时应返回错误") + } + + // 名称重复 + profileRepo.Create(&model.Profile{ + UUID: "u2", + UserID: 2, + Name: "Duplicate", + }) + if _, err := svc.Update("u1", 1, stringPtr("Duplicate"), nil, nil); err == nil { + t.Fatalf("Update 在名称重复时应返回错误") + } + + // SetActive 正常 + if err := svc.SetActive("u1", 1); err != nil { + t.Fatalf("SetActive 正常情况失败: %v", err) + } + + // SetActive 无权限 + if err := svc.SetActive("u1", 2); err == nil { + t.Fatalf("SetActive 在无权限时应返回错误") + } +} + +// TestProfileServiceImpl_CheckLimit_And_GetByNames 测试 CheckLimit / GetByNames / GetByProfileName +func TestProfileServiceImpl_CheckLimit_And_GetByNames(t *testing.T) { + profileRepo := NewMockProfileRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 为用户 1 预置 2 个档案 + profileRepo.Create(&model.Profile{UUID: "a", UserID: 1, Name: "A"}) + profileRepo.Create(&model.Profile{UUID: "b", UserID: 1, Name: "B"}) + + svc := NewProfileService(profileRepo, userRepo, logger) + + // CheckLimit 未达上限 + if err := svc.CheckLimit(1, 3); err != nil { + t.Fatalf("CheckLimit 未达到上限时不应报错: %v", err) + } + + // CheckLimit 达到上限 + if err := svc.CheckLimit(1, 2); err == nil { + t.Fatalf("CheckLimit 达到上限时应报错") + } + + // GetByNames + list, err := svc.GetByNames([]string{"A", "B"}) + if err != nil { + t.Fatalf("GetByNames 失败: %v", err) + } + if len(list) != 2 { + t.Fatalf("GetByNames 返回数量错误, got=%d, want=2", len(list)) + } + + // GetByProfileName 存在 + p, err := svc.GetByProfileName("A") + if err != nil || p == nil || p.Name != "A" { + t.Fatalf("GetByProfileName 返回错误, profile=%+v, err=%v", p, err) + } +} diff --git a/internal/service/serialize_service.go b/internal/service/serialize_service.go index 2400522..4f12691 100644 --- a/internal/service/serialize_service.go +++ b/internal/service/serialize_service.go @@ -2,6 +2,7 @@ package service import ( "carrotskin/internal/model" + "carrotskin/internal/repository" "carrotskin/pkg/redis" "encoding/base64" "time" @@ -31,7 +32,7 @@ func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client // 处理皮肤 if p.SkinID != nil { - skin, err := GetTextureByID(db, *p.SkinID) + skin, err := repository.FindTextureByID(*p.SkinID) if err != nil { logger.Error("[ERROR] 获取皮肤失败:", zap.Error(err), zap.Any("SkinID:", *p.SkinID)) } else { @@ -44,7 +45,7 @@ func SerializeProfile(db *gorm.DB, logger *zap.Logger, redisClient *redis.Client // 处理披风 if p.CapeID != nil { - cape, err := GetTextureByID(db, *p.CapeID) + cape, err := repository.FindTextureByID(*p.CapeID) if err != nil { logger.Error("[ERROR] 获取披风失败:", zap.Error(err), zap.Any("capeID:", *p.CapeID)) } else { diff --git a/internal/service/serialize_service_test.go b/internal/service/serialize_service_test.go index 4f2d3be..4ad66e7 100644 --- a/internal/service/serialize_service_test.go +++ b/internal/service/serialize_service_test.go @@ -5,6 +5,7 @@ import ( "testing" "go.uber.org/zap/zaptest" + "gorm.io/datatypes" ) // TestSerializeUser_NilUser 实际调用SerializeUser函数测试nil用户 @@ -19,25 +20,51 @@ func TestSerializeUser_NilUser(t *testing.T) { // TestSerializeUser_ActualCall 实际调用SerializeUser函数 func TestSerializeUser_ActualCall(t *testing.T) { logger := zaptest.NewLogger(t) - user := &model.User{ - ID: 1, - Username: "testuser", - Email: "test@example.com", - // Properties 使用 datatypes.JSON,测试中可以为空 - } - result := SerializeUser(logger, user, "test-uuid-123") - if result == nil { - t.Fatal("SerializeUser() 返回的结果不应为nil") - } + t.Run("Properties为nil时", func(t *testing.T) { + user := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + } - if result["id"] != "test-uuid-123" { - t.Errorf("id = %v, want 'test-uuid-123'", result["id"]) - } + result := SerializeUser(logger, user, "test-uuid-123") + if result == nil { + t.Fatal("SerializeUser() 返回的结果不应为nil") + } - if result["properties"] == nil { - t.Error("properties 不应为nil") - } + if result["id"] != "test-uuid-123" { + t.Errorf("id = %v, want 'test-uuid-123'", result["id"]) + } + + // 当 Properties 为 nil 时,properties 应该为 nil + if result["properties"] != nil { + t.Error("当 user.Properties 为 nil 时,properties 应为 nil") + } + }) + + t.Run("Properties有值时", func(t *testing.T) { + propsJSON := datatypes.JSON(`[{"name":"test","value":"value"}]`) + user := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Properties: &propsJSON, + } + + result := SerializeUser(logger, user, "test-uuid-456") + if result == nil { + t.Fatal("SerializeUser() 返回的结果不应为nil") + } + + if result["id"] != "test-uuid-456" { + t.Errorf("id = %v, want 'test-uuid-456'", result["id"]) + } + + if result["properties"] == nil { + t.Error("当 user.Properties 有值时,properties 不应为 nil") + } + }) } // TestProperty_Structure 测试Property结构 diff --git a/internal/service/texture_service.go b/internal/service/texture_service.go index ea312f0..eb19a82 100644 --- a/internal/service/texture_service.go +++ b/internal/service/texture_service.go @@ -6,18 +6,38 @@ import ( "errors" "fmt" - "gorm.io/gorm" + "go.uber.org/zap" ) -// CreateTexture 创建材质 -func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { +// textureServiceImpl TextureService的实现 +type textureServiceImpl struct { + textureRepo repository.TextureRepository + userRepo repository.UserRepository + logger *zap.Logger +} + +// NewTextureService 创建TextureService实例 +func NewTextureService( + textureRepo repository.TextureRepository, + userRepo repository.UserRepository, + logger *zap.Logger, +) TextureService { + return &textureServiceImpl{ + textureRepo: textureRepo, + userRepo: userRepo, + logger: logger, + } +} + +func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { // 验证用户存在 - if _, err := EnsureUserExists(uploaderID); err != nil { - return nil, err + user, err := s.userRepo.FindByID(uploaderID) + if err != nil || user == nil { + return nil, ErrUserNotFound } // 检查Hash是否已存在 - existingTexture, err := repository.FindTextureByHash(hash) + existingTexture, err := s.textureRepo.FindByHash(hash) if err != nil { return nil, err } @@ -26,7 +46,7 @@ func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType } // 转换材质类型 - textureTypeEnum, err := parseTextureType(textureType) + textureTypeEnum, err := parseTextureTypeInternal(textureType) if err != nil { return nil, err } @@ -47,36 +67,49 @@ func CreateTexture(db *gorm.DB, uploaderID int64, name, description, textureType FavoriteCount: 0, } - if err := repository.CreateTexture(texture); err != nil { + if err := s.textureRepo.Create(texture); err != nil { return nil, err } return texture, nil } -// GetTextureByID 根据ID获取材质 -func GetTextureByID(db *gorm.DB, id int64) (*model.Texture, error) { - return EnsureTextureExists(id) -} - -// GetUserTextures 获取用户上传的材质列表 -func GetUserTextures(db *gorm.DB, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return repository.FindTexturesByUploaderID(uploaderID, page, pageSize) -} - -// SearchTextures 搜索材质 -func SearchTextures(db *gorm.DB, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return repository.SearchTextures(keyword, textureType, publicOnly, page, pageSize) -} - -// UpdateTexture 更新材质 -func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { - // 获取材质并验证权限 - if _, err := GetTextureWithPermissionCheck(textureID, uploaderID); err != nil { +func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) { + texture, err := s.textureRepo.FindByID(id) + if err != nil { return nil, err } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.Status == -1 { + return nil, errors.New("材质已删除") + } + return texture, nil +} + +func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) +} + +func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + page, pageSize = NormalizePagination(page, pageSize) + return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) +} + +func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { + return nil, err + } + if texture == nil { + return nil, ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return nil, ErrTextureNoPermission + } // 更新字段 updates := make(map[string]interface{}) @@ -91,83 +124,73 @@ func UpdateTexture(db *gorm.DB, textureID, uploaderID int64, name, description s } if len(updates) > 0 { - if err := repository.UpdateTextureFields(textureID, updates); err != nil { + if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { return nil, err } } - return repository.FindTextureByID(textureID) + return s.textureRepo.FindByID(textureID) } -// DeleteTexture 删除材质 -func DeleteTexture(db *gorm.DB, textureID, uploaderID int64) error { - if _, err := GetTextureWithPermissionCheck(textureID, uploaderID); err != nil { +func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error { + // 获取材质并验证权限 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { return err } - return repository.DeleteTexture(textureID) + if texture == nil { + return ErrTextureNotFound + } + if texture.UploaderID != uploaderID { + return ErrTextureNoPermission + } + + return s.textureRepo.Delete(textureID) } -// RecordTextureDownload 记录下载 -func RecordTextureDownload(db *gorm.DB, textureID int64, userID *int64, ipAddress, userAgent string) error { - if _, err := EnsureTextureExists(textureID); err != nil { - return err - } - - if err := repository.IncrementTextureDownloadCount(textureID); err != nil { - return err - } - - log := &model.TextureDownloadLog{ - TextureID: textureID, - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - } - - return repository.CreateTextureDownloadLog(log) -} - -// ToggleTextureFavorite 切换收藏状态 -func ToggleTextureFavorite(db *gorm.DB, userID, textureID int64) (bool, error) { - if _, err := EnsureTextureExists(textureID); err != nil { +func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) { + // 确保材质存在 + texture, err := s.textureRepo.FindByID(textureID) + if err != nil { return false, err } + if texture == nil { + return false, ErrTextureNotFound + } - isFavorited, err := repository.IsTextureFavorited(userID, textureID) + isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) if err != nil { return false, err } if isFavorited { // 已收藏 -> 取消收藏 - if err := repository.RemoveTextureFavorite(userID, textureID); err != nil { + if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { return false, err } - if err := repository.DecrementTextureFavoriteCount(textureID); err != nil { + if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { return false, err } return false, nil - } else { - // 未收藏 -> 添加收藏 - if err := repository.AddTextureFavorite(userID, textureID); err != nil { - return false, err - } - if err := repository.IncrementTextureFavoriteCount(textureID); err != nil { - return false, err - } - return true, nil } + + // 未收藏 -> 添加收藏 + if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { + return false, err + } + if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { + return false, err + } + return true, nil } -// GetUserTextureFavorites 获取用户收藏的材质列表 -func GetUserTextureFavorites(db *gorm.DB, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { page, pageSize = NormalizePagination(page, pageSize) - return repository.GetUserTextureFavorites(userID, page, pageSize) + return s.textureRepo.GetUserFavorites(userID, page, pageSize) } -// CheckTextureUploadLimit 检查用户上传材质数量限制 -func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) error { - count, err := repository.CountTexturesByUploaderID(uploaderID) +func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error { + count, err := s.textureRepo.CountByUploaderID(uploaderID) if err != nil { return err } @@ -179,8 +202,8 @@ func CheckTextureUploadLimit(db *gorm.DB, uploaderID int64, maxTextures int) err return nil } -// parseTextureType 解析材质类型 -func parseTextureType(textureType string) (model.TextureType, error) { +// parseTextureTypeInternal 解析材质类型 +func parseTextureTypeInternal(textureType string) (model.TextureType, error) { switch textureType { case "SKIN": return model.TextureTypeSkin, nil diff --git a/internal/service/texture_service_impl.go b/internal/service/texture_service_impl.go deleted file mode 100644 index eb19a82..0000000 --- a/internal/service/texture_service_impl.go +++ /dev/null @@ -1,215 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "errors" - "fmt" - - "go.uber.org/zap" -) - -// textureServiceImpl TextureService的实现 -type textureServiceImpl struct { - textureRepo repository.TextureRepository - userRepo repository.UserRepository - logger *zap.Logger -} - -// NewTextureService 创建TextureService实例 -func NewTextureService( - textureRepo repository.TextureRepository, - userRepo repository.UserRepository, - logger *zap.Logger, -) TextureService { - return &textureServiceImpl{ - textureRepo: textureRepo, - userRepo: userRepo, - logger: logger, - } -} - -func (s *textureServiceImpl) Create(uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { - // 验证用户存在 - user, err := s.userRepo.FindByID(uploaderID) - if err != nil || user == nil { - return nil, ErrUserNotFound - } - - // 检查Hash是否已存在 - existingTexture, err := s.textureRepo.FindByHash(hash) - if err != nil { - return nil, err - } - if existingTexture != nil { - return nil, errors.New("该材质已存在") - } - - // 转换材质类型 - textureTypeEnum, err := parseTextureTypeInternal(textureType) - if err != nil { - return nil, err - } - - // 创建材质 - texture := &model.Texture{ - UploaderID: uploaderID, - Name: name, - Description: description, - Type: textureTypeEnum, - URL: url, - Hash: hash, - Size: size, - IsPublic: isPublic, - IsSlim: isSlim, - Status: 1, - DownloadCount: 0, - FavoriteCount: 0, - } - - if err := s.textureRepo.Create(texture); err != nil { - return nil, err - } - - return texture, nil -} - -func (s *textureServiceImpl) GetByID(id int64) (*model.Texture, error) { - texture, err := s.textureRepo.FindByID(id) - if err != nil { - return nil, err - } - if texture == nil { - return nil, ErrTextureNotFound - } - if texture.Status == -1 { - return nil, errors.New("材质已删除") - } - return texture, nil -} - -func (s *textureServiceImpl) GetByUserID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) -} - -func (s *textureServiceImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) -} - -func (s *textureServiceImpl) Update(textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { - // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return nil, err - } - if texture == nil { - return nil, ErrTextureNotFound - } - if texture.UploaderID != uploaderID { - return nil, ErrTextureNoPermission - } - - // 更新字段 - updates := make(map[string]interface{}) - if name != "" { - updates["name"] = name - } - if description != "" { - updates["description"] = description - } - if isPublic != nil { - updates["is_public"] = *isPublic - } - - if len(updates) > 0 { - if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { - return nil, err - } - } - - return s.textureRepo.FindByID(textureID) -} - -func (s *textureServiceImpl) Delete(textureID, uploaderID int64) error { - // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return err - } - if texture == nil { - return ErrTextureNotFound - } - if texture.UploaderID != uploaderID { - return ErrTextureNoPermission - } - - return s.textureRepo.Delete(textureID) -} - -func (s *textureServiceImpl) ToggleFavorite(userID, textureID int64) (bool, error) { - // 确保材质存在 - texture, err := s.textureRepo.FindByID(textureID) - if err != nil { - return false, err - } - if texture == nil { - return false, ErrTextureNotFound - } - - isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) - if err != nil { - return false, err - } - - if isFavorited { - // 已收藏 -> 取消收藏 - if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { - return false, err - } - if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { - return false, err - } - return false, nil - } - - // 未收藏 -> 添加收藏 - if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { - return false, err - } - if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { - return false, err - } - return true, nil -} - -func (s *textureServiceImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { - page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.GetUserFavorites(userID, page, pageSize) -} - -func (s *textureServiceImpl) CheckUploadLimit(uploaderID int64, maxTextures int) error { - count, err := s.textureRepo.CountByUploaderID(uploaderID) - if err != nil { - return err - } - - if count >= int64(maxTextures) { - return fmt.Errorf("已达到最大上传数量限制(%d)", maxTextures) - } - - return nil -} - -// parseTextureTypeInternal 解析材质类型 -func parseTextureTypeInternal(textureType string) (model.TextureType, error) { - switch textureType { - case "SKIN": - return model.TextureTypeSkin, nil - case "CAPE": - return model.TextureTypeCape, nil - default: - return "", errors.New("无效的材质类型") - } -} diff --git a/internal/service/texture_service_test.go b/internal/service/texture_service_test.go index c4e9ec1..a99a4f0 100644 --- a/internal/service/texture_service_test.go +++ b/internal/service/texture_service_test.go @@ -1,7 +1,10 @@ package service import ( + "carrotskin/internal/model" "testing" + + "go.uber.org/zap" ) // TestTextureService_TypeValidation 测试材质类型验证 @@ -469,3 +472,357 @@ func TestCheckTextureUploadLimit_Logic(t *testing.T) { func boolPtr(b bool) *bool { return &b } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestTextureServiceImpl_Create 测试创建Texture +func TestTextureServiceImpl_Create(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户 + testUser := &model.User{ + ID: 1, + Username: "testuser", + Email: "test@example.com", + Status: 1, + } + userRepo.Create(testUser) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + tests := []struct { + name string + uploaderID int64 + textureName string + textureType string + hash string + wantErr bool + errContains string + setupMocks func() + }{ + { + name: "正常创建SKIN材质", + uploaderID: 1, + textureName: "TestSkin", + textureType: "SKIN", + hash: "unique-hash-1", + wantErr: false, + }, + { + name: "正常创建CAPE材质", + uploaderID: 1, + textureName: "TestCape", + textureType: "CAPE", + hash: "unique-hash-2", + wantErr: false, + }, + { + name: "用户不存在", + uploaderID: 999, + textureName: "TestTexture", + textureType: "SKIN", + hash: "unique-hash-3", + wantErr: true, + }, + { + name: "材质Hash已存在", + uploaderID: 1, + textureName: "DuplicateTexture", + textureType: "SKIN", + hash: "existing-hash", + wantErr: true, + errContains: "已存在", + setupMocks: func() { + textureRepo.Create(&model.Texture{ + ID: 100, + UploaderID: 1, + Name: "ExistingTexture", + Hash: "existing-hash", + }) + }, + }, + { + name: "无效的材质类型", + uploaderID: 1, + textureName: "InvalidTypeTexture", + textureType: "INVALID", + hash: "unique-hash-4", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupMocks != nil { + tt.setupMocks() + } + + texture, err := textureService.Create( + tt.uploaderID, + tt.textureName, + "Test description", + tt.textureType, + "http://example.com/texture.png", + tt.hash, + 1024, + true, + false, + ) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errContains != "" && !containsString(err.Error(), tt.errContains) { + t.Errorf("错误信息应包含 %q, 实际为: %v", tt.errContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if texture == nil { + t.Error("返回的Texture不应为nil") + } + if texture.Name != tt.textureName { + t.Errorf("Texture名称不匹配: got %v, want %v", texture.Name, tt.textureName) + } + } + }) + } +} + +// TestTextureServiceImpl_GetByID 测试获取Texture +func TestTextureServiceImpl_GetByID(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置Texture + testTexture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "TestTexture", + Hash: "test-hash", + } + textureRepo.Create(testTexture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + tests := []struct { + name string + id int64 + wantErr bool + }{ + { + name: "获取存在的Texture", + id: 1, + wantErr: false, + }, + { + name: "获取不存在的Texture", + id: 999, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + texture, err := textureService.GetByID(tt.id) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if texture == nil { + t.Error("返回的Texture不应为nil") + } + } + }) + } +} + +// TestTextureServiceImpl_GetByUserID_And_Search 测试 GetByUserID 与 Search 分页封装 +func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置多条 Texture + for i := int64(1); i <= 5; i++ { + textureRepo.Create(&model.Texture{ + ID: i, + UploaderID: 1, + Name: "T", + IsPublic: i%2 == 0, + }) + } + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // GetByUserID 应按上传者过滤并调用 NormalizePagination + textures, total, err := textureService.GetByUserID(1, 0, 0) + if err != nil { + t.Fatalf("GetByUserID 失败: %v", err) + } + if total != int64(len(textures)) { + t.Fatalf("GetByUserID 返回数量与总数不一致, total=%d, len=%d", total, len(textures)) + } + + // Search 仅验证能够正常调用并返回结果 + searchResult, searchTotal, err := textureService.Search("", "", true, -1, 200) + if err != nil { + t.Fatalf("Search 失败: %v", err) + } + if searchTotal != int64(len(searchResult)) { + t.Fatalf("Search 返回数量与总数不一致, total=%d, len=%d", searchTotal, len(searchResult)) + } +} + +// TestTextureServiceImpl_Update_And_Delete 测试 Update / Delete 权限与字段更新 +func TestTextureServiceImpl_Update_And_Delete(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + texture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "Old", + Description:"OldDesc", + IsPublic: false, + } + textureRepo.Create(texture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // 更新成功 + newName := "NewName" + newDesc := "NewDesc" + public := boolPtr(true) + updated, err := textureService.Update(1, 1, newName, newDesc, public) + if err != nil { + t.Fatalf("Update 正常情况失败: %v", err) + } + // 由于 MockTextureRepository.UpdateFields 不会真正修改结构体字段,这里只验证不会返回 nil 即可 + if updated == nil { + t.Fatalf("Update 返回结果不应为 nil") + } + + // 无权限更新 + if _, err := textureService.Update(1, 2, "X", "Y", nil); err == nil { + t.Fatalf("Update 在无权限时应返回错误") + } + + // 删除成功 + if err := textureService.Delete(1, 1); err != nil { + t.Fatalf("Delete 正常情况失败: %v", err) + } + + // 无权限删除 + if err := textureService.Delete(1, 2); err == nil { + t.Fatalf("Delete 在无权限时应返回错误") + } +} + +// TestTextureServiceImpl_FavoritesAndLimit 测试 GetUserFavorites 与 CheckUploadLimit +func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置若干 Texture 与收藏关系 + for i := int64(1); i <= 3; i++ { + textureRepo.Create(&model.Texture{ + ID: i, + UploaderID: 1, + Name: "T", + }) + _ = textureRepo.AddFavorite(1, i) + } + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // GetUserFavorites + favs, total, err := textureService.GetUserFavorites(1, -1, -1) + if err != nil { + t.Fatalf("GetUserFavorites 失败: %v", err) + } + if int64(len(favs)) != total || total != 3 { + t.Fatalf("GetUserFavorites 数量不正确, total=%d, len=%d", total, len(favs)) + } + + // CheckUploadLimit 未超过上限 + if err := textureService.CheckUploadLimit(1, 10); err != nil { + t.Fatalf("CheckUploadLimit 在未达到上限时不应报错: %v", err) + } + + // CheckUploadLimit 超过上限 + if err := textureService.CheckUploadLimit(1, 2); err == nil { + t.Fatalf("CheckUploadLimit 在超过上限时应返回错误") + } +} + +// TestTextureServiceImpl_ToggleFavorite 测试收藏功能 +func TestTextureServiceImpl_ToggleFavorite(t *testing.T) { + textureRepo := NewMockTextureRepository() + userRepo := NewMockUserRepository() + logger := zap.NewNop() + + // 预置用户和Texture + testUser := &model.User{ID: 1, Username: "testuser", Status: 1} + userRepo.Create(testUser) + + testTexture := &model.Texture{ + ID: 1, + UploaderID: 1, + Name: "TestTexture", + Hash: "test-hash", + } + textureRepo.Create(testTexture) + + textureService := NewTextureService(textureRepo, userRepo, logger) + + // 第一次收藏 + isFavorited, err := textureService.ToggleFavorite(1, 1) + if err != nil { + t.Errorf("第一次收藏失败: %v", err) + } + if !isFavorited { + t.Error("第一次操作应该是添加收藏") + } + + // 第二次取消收藏 + isFavorited, err = textureService.ToggleFavorite(1, 1) + if err != nil { + t.Errorf("取消收藏失败: %v", err) + } + if isFavorited { + t.Error("第二次操作应该是取消收藏") + } +} + +// 辅助函数 +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || + (len(s) > len(substr) && (findSubstring(s, substr) != -1))) +} + +func findSubstring(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/internal/service/token_service.go b/internal/service/token_service.go index 20af177..b128abf 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -6,35 +6,55 @@ import ( "context" "errors" "fmt" - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "go.uber.org/zap" "strconv" "time" - "gorm.io/gorm" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "go.uber.org/zap" ) -// 常量定义 +// tokenServiceImpl TokenService的实现 +type tokenServiceImpl struct { + tokenRepo repository.TokenRepository + profileRepo repository.ProfileRepository + logger *zap.Logger +} + +// NewTokenService 创建TokenService实例 +func NewTokenService( + tokenRepo repository.TokenRepository, + profileRepo repository.ProfileRepository, + logger *zap.Logger, +) TokenService { + return &tokenServiceImpl{ + tokenRepo: tokenRepo, + profileRepo: profileRepo, + logger: logger, + } +} + const ( - ExtendedTimeout = 10 * time.Second - TokensMaxCount = 10 // 用户最多保留的token数量 + tokenExtendedTimeout = 10 * time.Second + tokensMaxCount = 10 ) -// NewToken 创建新令牌 -func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { +func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { var ( selectedProfileID *model.Profile availableProfiles []*model.Profile ) + // 设置超时上下文 _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) defer cancel() // 验证用户存在 - _, err := repository.FindProfileByUUID(UUID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) + if UUID != "" { + _, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) + } } // 生成令牌 @@ -46,13 +66,13 @@ func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, client token := model.Token{ AccessToken: accessToken, ClientToken: clientToken, - UserID: userId, + UserID: userID, Usable: true, IssueDate: time.Now(), } // 获取用户配置文件 - profiles, err := repository.FindProfilesByUserID(userId) + profiles, err := s.profileRepo.FindByUserID(userID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) } @@ -64,65 +84,24 @@ func NewToken(db *gorm.DB, logger *zap.Logger, userId int64, UUID string, client } availableProfiles = profiles - // 插入令牌到tokens集合 - _, insertCancel := context.WithTimeout(context.Background(), DefaultTimeout) - defer insertCancel() - - err = repository.CreateToken(&token) + // 插入令牌 + err = s.tokenRepo.Create(&token) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) } + // 清理多余的令牌 - go CheckAndCleanupExcessTokens(db, logger, userId) + go s.checkAndCleanupExcessTokens(userID) return selectedProfileID, availableProfiles, accessToken, clientToken, nil } -// CheckAndCleanupExcessTokens 检查并清理用户多余的令牌,只保留最新的10个 -func CheckAndCleanupExcessTokens(db *gorm.DB, logger *zap.Logger, userId int64) { - if userId == 0 { - return - } - // 获取用户所有令牌,按发行日期降序排序 - tokens, err := repository.GetTokensByUserId(userId) - if err != nil { - logger.Error("[ERROR] 获取用户Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10))) - return - } - - // 如果令牌数量不超过上限,无需清理 - if len(tokens) <= TokensMaxCount { - return - } - - // 获取需要删除的令牌ID列表 - tokensToDelete := make([]string, 0, len(tokens)-TokensMaxCount) - for i := TokensMaxCount; i < len(tokens); i++ { - tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) - } - - // 执行批量删除,传入上下文和待删除的令牌列表(作为切片参数) - DeletedCount, err := repository.BatchDeleteTokens(tokensToDelete) - if err != nil { - logger.Error("[ERROR] 清理用户多余Token失败: ", zap.Error(err), zap.String("userId", strconv.FormatInt(userId, 10))) - return - } - - if DeletedCount > 0 { - logger.Info("[INFO] 成功清理用户多余Token", zap.Any("userId:", userId), zap.Any("count:", DeletedCount)) - } -} - -// ValidToken 验证令牌有效性 -func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool { +func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool { if accessToken == "" { return false } - // 使用投影只获取需要的字段 - var token *model.Token - token, err := repository.FindTokenByID(accessToken) - + token, err := s.tokenRepo.FindByAccessToken(accessToken) if err != nil { return false } @@ -131,47 +110,35 @@ func ValidToken(db *gorm.DB, accessToken string, clientToken string) bool { return false } - // 如果客户端令牌为空,只验证访问令牌 if clientToken == "" { return true } - // 否则验证客户端令牌是否匹配 return token.ClientToken == clientToken } -func GetUUIDByAccessToken(db *gorm.DB, accessToken string) (string, error) { - return repository.GetUUIDByAccessToken(accessToken) -} - -func GetUserIDByAccessToken(db *gorm.DB, accessToken string) (int64, error) { - return repository.GetUserIDByAccessToken(accessToken) -} - -// RefreshToken 刷新令牌 -func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken string, selectedProfileID string) (string, string, error) { +func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { if accessToken == "" { return "", "", errors.New("accessToken不能为空") } // 查找旧令牌 - oldToken, err := repository.GetTokenByAccessToken(accessToken) + oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", "", errors.New("accessToken无效") } - logger.Error("[ERROR] 查询Token失败: ", zap.Error(err), zap.Any("accessToken:", accessToken)) + s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return "", "", fmt.Errorf("查询令牌失败: %w", err) } // 验证profile if selectedProfileID != "" { - valid, validErr := ValidateProfileByUserID(db, oldToken.UserID, selectedProfileID) + valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) if validErr != nil { - logger.Error( - "验证Profile失败", + s.logger.Error("验证Profile失败", zap.Error(err), - zap.Any("userId", oldToken.UserID), + zap.Int64("userId", oldToken.UserID), zap.String("profileId", selectedProfileID), ) return "", "", fmt.Errorf("验证角色失败: %w", err) @@ -192,86 +159,119 @@ func RefreshToken(db *gorm.DB, logger *zap.Logger, accessToken, clientToken stri return "", "", errors.New("原令牌已绑定角色,无法选择新角色") } } else { - selectedProfileID = oldToken.ProfileId // 如果未指定,则保持原角色 + selectedProfileID = oldToken.ProfileId } // 生成新令牌 newAccessToken := uuid.New().String() newToken := model.Token{ AccessToken: newAccessToken, - ClientToken: oldToken.ClientToken, // 新令牌的 clientToken 与原令牌相同 + ClientToken: oldToken.ClientToken, UserID: oldToken.UserID, Usable: true, - ProfileId: selectedProfileID, // 绑定到指定角色或保持原角色 + ProfileId: selectedProfileID, IssueDate: time.Now(), } - // 使用双重写入模式替代事务,先插入新令牌,再删除旧令牌 - - err = repository.CreateToken(&newToken) + // 先插入新令牌,再删除旧令牌 + err = s.tokenRepo.Create(&newToken) if err != nil { - logger.Error( - "创建新Token失败", - zap.Error(err), - zap.String("accessToken", accessToken), - ) + s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return "", "", fmt.Errorf("创建新Token失败: %w", err) } - err = repository.DeleteTokenByAccessToken(accessToken) + err = s.tokenRepo.DeleteByAccessToken(accessToken) if err != nil { - // 删除旧令牌失败,记录日志但不阻止操作,因为新令牌已成功创建 - logger.Warn( - "删除旧Token失败,但新Token已创建", + s.logger.Warn("删除旧Token失败,但新Token已创建", zap.Error(err), zap.String("oldToken", oldToken.AccessToken), zap.String("newToken", newAccessToken), ) } - logger.Info( - "成功刷新Token", - zap.Any("userId", oldToken.UserID), - zap.String("accessToken", newAccessToken), - ) + s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken)) return newAccessToken, oldToken.ClientToken, nil } -// InvalidToken 使令牌失效 -func InvalidToken(db *gorm.DB, logger *zap.Logger, accessToken string) { +func (s *tokenServiceImpl) Invalidate(accessToken string) { if accessToken == "" { return } - err := repository.DeleteTokenByAccessToken(accessToken) + err := s.tokenRepo.DeleteByAccessToken(accessToken) if err != nil { - logger.Error( - "删除Token失败", - zap.Error(err), - zap.String("accessToken", accessToken), - ) + s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return } - logger.Info("[INFO] 成功删除", zap.Any("Token:", accessToken)) - + s.logger.Info("成功删除Token", zap.String("token", accessToken)) } -// InvalidUserTokens 使用户所有令牌失效 -func InvalidUserTokens(db *gorm.DB, logger *zap.Logger, userId int64) { - if userId == 0 { +func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) { + if userID == 0 { return } - err := repository.DeleteTokenByUserId(userId) + err := s.tokenRepo.DeleteByUserID(userID) if err != nil { - logger.Error( - "[ERROR]删除用户Token失败", - zap.Error(err), - zap.Any("userId", userId), - ) + s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) return } - logger.Info("[INFO] 成功删除用户Token", zap.Any("userId:", userId)) - + s.logger.Info("成功删除用户Token", zap.Int64("userId", userID)) +} + +func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) { + return s.tokenRepo.GetUUIDByAccessToken(accessToken) +} + +func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { + return s.tokenRepo.GetUserIDByAccessToken(accessToken) +} + +// 私有辅助方法 + +func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) { + if userID == 0 { + return + } + + tokens, err := s.tokenRepo.GetByUserID(userID) + if err != nil { + s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if len(tokens) <= tokensMaxCount { + return + } + + tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) + for i := tokensMaxCount; i < len(tokens); i++ { + tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) + } + + deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) + if err != nil { + s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) + return + } + + if deletedCount > 0 { + s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) + } +} + +func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) { + if userID == 0 || UUID == "" { + return false, errors.New("用户ID或配置文件ID不能为空") + } + + profile, err := s.profileRepo.FindByUUID(UUID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, errors.New("配置文件不存在") + } + return false, fmt.Errorf("验证配置文件失败: %w", err) + } + return profile.UserID == userID, nil } diff --git a/internal/service/token_service_impl.go b/internal/service/token_service_impl.go deleted file mode 100644 index b128abf..0000000 --- a/internal/service/token_service_impl.go +++ /dev/null @@ -1,277 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "context" - "errors" - "fmt" - "strconv" - "time" - - "github.com/google/uuid" - "github.com/jackc/pgx/v5" - "go.uber.org/zap" -) - -// tokenServiceImpl TokenService的实现 -type tokenServiceImpl struct { - tokenRepo repository.TokenRepository - profileRepo repository.ProfileRepository - logger *zap.Logger -} - -// NewTokenService 创建TokenService实例 -func NewTokenService( - tokenRepo repository.TokenRepository, - profileRepo repository.ProfileRepository, - logger *zap.Logger, -) TokenService { - return &tokenServiceImpl{ - tokenRepo: tokenRepo, - profileRepo: profileRepo, - logger: logger, - } -} - -const ( - tokenExtendedTimeout = 10 * time.Second - tokensMaxCount = 10 -) - -func (s *tokenServiceImpl) Create(userID int64, UUID string, clientToken string) (*model.Profile, []*model.Profile, string, string, error) { - var ( - selectedProfileID *model.Profile - availableProfiles []*model.Profile - ) - - // 设置超时上下文 - _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) - defer cancel() - - // 验证用户存在 - if UUID != "" { - _, err := s.profileRepo.FindByUUID(UUID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) - } - } - - // 生成令牌 - if clientToken == "" { - clientToken = uuid.New().String() - } - - accessToken := uuid.New().String() - token := model.Token{ - AccessToken: accessToken, - ClientToken: clientToken, - UserID: userID, - Usable: true, - IssueDate: time.Now(), - } - - // 获取用户配置文件 - profiles, err := s.profileRepo.FindByUserID(userID) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) - } - - // 如果用户只有一个配置文件,自动选择 - if len(profiles) == 1 { - selectedProfileID = profiles[0] - token.ProfileId = selectedProfileID.UUID - } - availableProfiles = profiles - - // 插入令牌 - err = s.tokenRepo.Create(&token) - if err != nil { - return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) - } - - // 清理多余的令牌 - go s.checkAndCleanupExcessTokens(userID) - - return selectedProfileID, availableProfiles, accessToken, clientToken, nil -} - -func (s *tokenServiceImpl) Validate(accessToken, clientToken string) bool { - if accessToken == "" { - return false - } - - token, err := s.tokenRepo.FindByAccessToken(accessToken) - if err != nil { - return false - } - - if !token.Usable { - return false - } - - if clientToken == "" { - return true - } - - return token.ClientToken == clientToken -} - -func (s *tokenServiceImpl) Refresh(accessToken, clientToken, selectedProfileID string) (string, string, error) { - if accessToken == "" { - return "", "", errors.New("accessToken不能为空") - } - - // 查找旧令牌 - oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return "", "", errors.New("accessToken无效") - } - s.logger.Error("查询Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return "", "", fmt.Errorf("查询令牌失败: %w", err) - } - - // 验证profile - if selectedProfileID != "" { - valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) - if validErr != nil { - s.logger.Error("验证Profile失败", - zap.Error(err), - zap.Int64("userId", oldToken.UserID), - zap.String("profileId", selectedProfileID), - ) - return "", "", fmt.Errorf("验证角色失败: %w", err) - } - if !valid { - return "", "", errors.New("角色与用户不匹配") - } - } - - // 检查 clientToken 是否有效 - if clientToken != "" && clientToken != oldToken.ClientToken { - return "", "", errors.New("clientToken无效") - } - - // 检查 selectedProfileID 的逻辑 - if selectedProfileID != "" { - if oldToken.ProfileId != "" && oldToken.ProfileId != selectedProfileID { - return "", "", errors.New("原令牌已绑定角色,无法选择新角色") - } - } else { - selectedProfileID = oldToken.ProfileId - } - - // 生成新令牌 - newAccessToken := uuid.New().String() - newToken := model.Token{ - AccessToken: newAccessToken, - ClientToken: oldToken.ClientToken, - UserID: oldToken.UserID, - Usable: true, - ProfileId: selectedProfileID, - IssueDate: time.Now(), - } - - // 先插入新令牌,再删除旧令牌 - err = s.tokenRepo.Create(&newToken) - if err != nil { - s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return "", "", fmt.Errorf("创建新Token失败: %w", err) - } - - err = s.tokenRepo.DeleteByAccessToken(accessToken) - if err != nil { - s.logger.Warn("删除旧Token失败,但新Token已创建", - zap.Error(err), - zap.String("oldToken", oldToken.AccessToken), - zap.String("newToken", newAccessToken), - ) - } - - s.logger.Info("成功刷新Token", zap.Int64("userId", oldToken.UserID), zap.String("accessToken", newAccessToken)) - return newAccessToken, oldToken.ClientToken, nil -} - -func (s *tokenServiceImpl) Invalidate(accessToken string) { - if accessToken == "" { - return - } - - err := s.tokenRepo.DeleteByAccessToken(accessToken) - if err != nil { - s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) - return - } - s.logger.Info("成功删除Token", zap.String("token", accessToken)) -} - -func (s *tokenServiceImpl) InvalidateUserTokens(userID int64) { - if userID == 0 { - return - } - - err := s.tokenRepo.DeleteByUserID(userID) - if err != nil { - s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) - return - } - - s.logger.Info("成功删除用户Token", zap.Int64("userId", userID)) -} - -func (s *tokenServiceImpl) GetUUIDByAccessToken(accessToken string) (string, error) { - return s.tokenRepo.GetUUIDByAccessToken(accessToken) -} - -func (s *tokenServiceImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { - return s.tokenRepo.GetUserIDByAccessToken(accessToken) -} - -// 私有辅助方法 - -func (s *tokenServiceImpl) checkAndCleanupExcessTokens(userID int64) { - if userID == 0 { - return - } - - tokens, err := s.tokenRepo.GetByUserID(userID) - if err != nil { - s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) - return - } - - if len(tokens) <= tokensMaxCount { - return - } - - tokensToDelete := make([]string, 0, len(tokens)-tokensMaxCount) - for i := tokensMaxCount; i < len(tokens); i++ { - tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) - } - - deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) - if err != nil { - s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) - return - } - - if deletedCount > 0 { - s.logger.Info("成功清理用户多余Token", zap.Int64("userId", userID), zap.Int64("count", deletedCount)) - } -} - -func (s *tokenServiceImpl) validateProfileByUserID(userID int64, UUID string) (bool, error) { - if userID == 0 || UUID == "" { - return false, errors.New("用户ID或配置文件ID不能为空") - } - - profile, err := s.profileRepo.FindByUUID(UUID) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return false, errors.New("配置文件不存在") - } - return false, fmt.Errorf("验证配置文件失败: %w", err) - } - return profile.UserID == userID, nil -} diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 7c051d2..e85978b 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -1,18 +1,23 @@ package service import ( + "carrotskin/internal/model" + "fmt" "testing" "time" + + "go.uber.org/zap" ) // TestTokenService_Constants 测试Token服务相关常量 func TestTokenService_Constants(t *testing.T) { - if ExtendedTimeout != 10*time.Second { - t.Errorf("ExtendedTimeout = %v, want 10 seconds", ExtendedTimeout) + // 测试私有常量通过行为验证 + if tokenExtendedTimeout != 10*time.Second { + t.Errorf("tokenExtendedTimeout = %v, want 10 seconds", tokenExtendedTimeout) } - if TokensMaxCount != 10 { - t.Errorf("TokensMaxCount = %d, want 10", TokensMaxCount) + if tokensMaxCount != 10 { + t.Errorf("tokensMaxCount = %d, want 10", tokensMaxCount) } } @@ -22,8 +27,8 @@ func TestTokenService_Timeout(t *testing.T) { t.Errorf("DefaultTimeout = %v, want 5 seconds", DefaultTimeout) } - if ExtendedTimeout <= DefaultTimeout { - t.Errorf("ExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", ExtendedTimeout, DefaultTimeout) + if tokenExtendedTimeout <= DefaultTimeout { + t.Errorf("tokenExtendedTimeout (%v) should be greater than DefaultTimeout (%v)", tokenExtendedTimeout, DefaultTimeout) } } @@ -202,3 +207,314 @@ func TestTokenService_UserIDValidation(t *testing.T) { }) } } + +// ============================================================================ +// 使用 Mock 的集成测试 +// ============================================================================ + +// TestTokenServiceImpl_Create 测试创建Token +func TestTokenServiceImpl_Create(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Profile + testProfile := &model.Profile{ + UUID: "test-profile-uuid", + UserID: 1, + Name: "TestProfile", + IsActive: true, + } + profileRepo.Create(testProfile) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + tests := []struct { + name string + userID int64 + uuid string + clientToken string + wantErr bool + }{ + { + name: "正常创建Token(指定UUID)", + userID: 1, + uuid: "test-profile-uuid", + clientToken: "client-token-1", + wantErr: false, + }, + { + name: "正常创建Token(空clientToken)", + userID: 1, + uuid: "test-profile-uuid", + clientToken: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, accessToken, clientToken, err := tokenService.Create(tt.userID, tt.uuid, tt.clientToken) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if accessToken == "" { + t.Error("accessToken不应为空") + } + if clientToken == "" { + t.Error("clientToken不应为空") + } + } + }) + } +} + +// TestTokenServiceImpl_Validate 测试验证Token +func TestTokenServiceImpl_Validate(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Token + testToken := &model.Token{ + AccessToken: "valid-access-token", + ClientToken: "valid-client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + } + tokenRepo.Create(testToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + tests := []struct { + name string + accessToken string + clientToken string + wantValid bool + }{ + { + name: "有效Token(完全匹配)", + accessToken: "valid-access-token", + clientToken: "valid-client-token", + wantValid: true, + }, + { + name: "有效Token(只检查accessToken)", + accessToken: "valid-access-token", + clientToken: "", + wantValid: true, + }, + { + name: "无效Token(accessToken不存在)", + accessToken: "invalid-access-token", + clientToken: "", + wantValid: false, + }, + { + name: "无效Token(clientToken不匹配)", + accessToken: "valid-access-token", + clientToken: "wrong-client-token", + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tokenService.Validate(tt.accessToken, tt.clientToken) + + if isValid != tt.wantValid { + t.Errorf("Token验证结果不匹配: got %v, want %v", isValid, tt.wantValid) + } + }) + } +} + +// TestTokenServiceImpl_Invalidate 测试注销Token +func TestTokenServiceImpl_Invalidate(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置Token + testToken := &model.Token{ + AccessToken: "token-to-invalidate", + ClientToken: "client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + } + tokenRepo.Create(testToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 验证Token存在 + isValid := tokenService.Validate("token-to-invalidate", "") + if !isValid { + t.Error("Token应该有效") + } + + // 注销Token + tokenService.Invalidate("token-to-invalidate") + + // 验证Token已失效(从repo中删除) + _, err := tokenRepo.FindByAccessToken("token-to-invalidate") + if err == nil { + t.Error("Token应该已被删除") + } +} + +// TestTokenServiceImpl_InvalidateUserTokens 测试注销用户所有Token +func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置多个Token + for i := 1; i <= 3; i++ { + tokenRepo.Create(&model.Token{ + AccessToken: fmt.Sprintf("user1-token-%d", i), + ClientToken: "client-token", + UserID: 1, + ProfileId: "test-profile-uuid", + Usable: true, + }) + } + tokenRepo.Create(&model.Token{ + AccessToken: "user2-token-1", + ClientToken: "client-token", + UserID: 2, + ProfileId: "test-profile-uuid-2", + Usable: true, + }) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 注销用户1的所有Token + tokenService.InvalidateUserTokens(1) + + // 验证用户1的Token已失效 + tokens, _ := tokenRepo.GetByUserID(1) + if len(tokens) > 0 { + t.Errorf("用户1的Token应该全部被删除,但还剩 %d 个", len(tokens)) + } + + // 验证用户2的Token仍然存在 + tokens2, _ := tokenRepo.GetByUserID(2) + if len(tokens2) != 1 { + t.Errorf("用户2的Token应该仍然存在,期望1个,实际 %d 个", len(tokens2)) + } +} + +// TestTokenServiceImpl_Refresh 覆盖 Refresh 的主要分支 +func TestTokenServiceImpl_Refresh(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + // 预置 Profile 与 Token + profile := &model.Profile{ + UUID: "profile-uuid", + UserID: 1, + } + profileRepo.Create(profile) + + oldToken := &model.Token{ + AccessToken: "old-token", + ClientToken: "client-token", + UserID: 1, + ProfileId: "", + Usable: true, + } + tokenRepo.Create(oldToken) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + // 正常刷新,不指定 profile + newAccess, client, err := tokenService.Refresh("old-token", "client-token", "") + if err != nil { + t.Fatalf("Refresh 正常情况失败: %v", err) + } + if newAccess == "" || client != "client-token" { + t.Fatalf("Refresh 返回值异常: access=%s, client=%s", newAccess, client) + } + + // accessToken 为空 + if _, _, err := tokenService.Refresh("", "client-token", ""); err == nil { + t.Fatalf("Refresh 在 accessToken 为空时应返回错误") + } +} + +// TestTokenServiceImpl_GetByAccessToken 封装 GetUUIDByAccessToken / GetUserIDByAccessToken +func TestTokenServiceImpl_GetByAccessToken(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + token := &model.Token{ + AccessToken: "token-1", + UserID: 42, + ProfileId: "profile-42", + Usable: true, + } + tokenRepo.Create(token) + + tokenService := NewTokenService(tokenRepo, profileRepo, logger) + + uuid, err := tokenService.GetUUIDByAccessToken("token-1") + if err != nil || uuid != "profile-42" { + t.Fatalf("GetUUIDByAccessToken 返回错误: uuid=%s, err=%v", uuid, err) + } + + uid, err := tokenService.GetUserIDByAccessToken("token-1") + if err != nil || uid != 42 { + t.Fatalf("GetUserIDByAccessToken 返回错误: uid=%d, err=%v", uid, err) + } +} + +// TestTokenServiceImpl_validateProfileByUserID 直接测试内部校验逻辑 +func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) { + tokenRepo := NewMockTokenRepository() + profileRepo := NewMockProfileRepository() + logger := zap.NewNop() + + svc := &tokenServiceImpl{ + tokenRepo: tokenRepo, + profileRepo: profileRepo, + logger: logger, + } + + // 预置 Profile + profile := &model.Profile{ + UUID: "p-1", + UserID: 1, + } + profileRepo.Create(profile) + + // 参数非法 + if ok, err := svc.validateProfileByUserID(0, ""); err == nil || ok { + t.Fatalf("validateProfileByUserID 在参数非法时应返回错误") + } + + // Profile 不存在 + if ok, err := svc.validateProfileByUserID(1, "not-exists"); err == nil || ok { + t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误") + } + + // 用户与 Profile 匹配 + if ok, err := svc.validateProfileByUserID(1, "p-1"); err != nil || !ok { + t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err) + } + + // 用户与 Profile 不匹配 + if ok, err := svc.validateProfileByUserID(2, "p-1"); err != nil || ok { + t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err) + } +} \ No newline at end of file diff --git a/internal/service/upload_service.go b/internal/service/upload_service.go index 4678872..877357b 100644 --- a/internal/service/upload_service.go +++ b/internal/service/upload_service.go @@ -74,27 +74,38 @@ func ValidateFileName(fileName string, fileType FileType) error { return nil } -// GenerateAvatarUploadURL 生成头像上传URL +// uploadStorageClient 为上传服务定义的最小依赖接口,便于单元测试注入 mock +type uploadStorageClient interface { + GetBucket(name string) (string, error) + GeneratePresignedPostURL(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) +} + +// GenerateAvatarUploadURL 生成头像上传URL(对外导出) func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.StorageClient, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) { + return generateAvatarUploadURLWithClient(ctx, storageClient, userID, fileName) +} + +// generateAvatarUploadURLWithClient 使用接口类型的内部实现,方便测试 +func generateAvatarUploadURLWithClient(ctx context.Context, storageClient uploadStorageClient, userID int64, fileName string) (*storage.PresignedPostPolicyResult, error) { // 1. 验证文件名 if err := ValidateFileName(fileName, FileTypeAvatar); err != nil { return nil, err } - + // 2. 获取上传配置 uploadConfig := GetUploadConfig(FileTypeAvatar) - + // 3. 获取存储桶名称 bucketName, err := storageClient.GetBucket("avatars") if err != nil { return nil, fmt.Errorf("获取存储桶失败: %w", err) } - + // 4. 生成对象名称(路径) // 格式: user_{userId}/timestamp_{originalFileName} timestamp := time.Now().Format("20060102150405") objectName := fmt.Sprintf("user_%d/%s_%s", userID, timestamp, fileName) - + // 5. 生成预签名POST URL (使用存储客户端内置的 PublicURL) result, err := storageClient.GeneratePresignedPostURL( ctx, @@ -107,37 +118,42 @@ func GenerateAvatarUploadURL(ctx context.Context, storageClient *storage.Storage if err != nil { return nil, fmt.Errorf("生成上传URL失败: %w", err) } - + return result, nil } -// GenerateTextureUploadURL 生成材质上传URL +// GenerateTextureUploadURL 生成材质上传URL(对外导出) func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.StorageClient, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) { + return generateTextureUploadURLWithClient(ctx, storageClient, userID, fileName, textureType) +} + +// generateTextureUploadURLWithClient 使用接口类型的内部实现,方便测试 +func generateTextureUploadURLWithClient(ctx context.Context, storageClient uploadStorageClient, userID int64, fileName, textureType string) (*storage.PresignedPostPolicyResult, error) { // 1. 验证文件名 if err := ValidateFileName(fileName, FileTypeTexture); err != nil { return nil, err } - + // 2. 验证材质类型 if textureType != "SKIN" && textureType != "CAPE" { return nil, fmt.Errorf("无效的材质类型: %s", textureType) } - + // 3. 获取上传配置 uploadConfig := GetUploadConfig(FileTypeTexture) - + // 4. 获取存储桶名称 bucketName, err := storageClient.GetBucket("textures") if err != nil { return nil, fmt.Errorf("获取存储桶失败: %w", err) } - + // 5. 生成对象名称(路径) // 格式: user_{userId}/{textureType}/timestamp_{originalFileName} timestamp := time.Now().Format("20060102150405") textureTypeFolder := strings.ToLower(textureType) objectName := fmt.Sprintf("user_%d/%s/%s_%s", userID, textureTypeFolder, timestamp, fileName) - + // 6. 生成预签名POST URL (使用存储客户端内置的 PublicURL) result, err := storageClient.GeneratePresignedPostURL( ctx, @@ -150,6 +166,6 @@ func GenerateTextureUploadURL(ctx context.Context, storageClient *storage.Storag if err != nil { return nil, fmt.Errorf("生成上传URL失败: %w", err) } - + return result, nil } diff --git a/internal/service/upload_service_test.go b/internal/service/upload_service_test.go index 52f2012..07df008 100644 --- a/internal/service/upload_service_test.go +++ b/internal/service/upload_service_test.go @@ -1,9 +1,13 @@ package service import ( + "context" + "errors" "strings" "testing" "time" + + "carrotskin/pkg/storage" ) // TestUploadService_FileTypes 测试文件类型常量 @@ -135,43 +139,43 @@ func TestGetUploadConfig_TextureConfig(t *testing.T) { // TestValidateFileName 测试文件名验证 func TestValidateFileName(t *testing.T) { tests := []struct { - name string - fileName string - fileType FileType - wantErr bool + name string + fileName string + fileType FileType + wantErr bool errContains string }{ { - name: "有效的头像文件名", - fileName: "avatar.png", - fileType: FileTypeAvatar, - wantErr: false, + name: "有效的头像文件名", + fileName: "avatar.png", + fileType: FileTypeAvatar, + wantErr: false, }, { - name: "有效的材质文件名", - fileName: "texture.png", - fileType: FileTypeTexture, - wantErr: false, + name: "有效的材质文件名", + fileName: "texture.png", + fileType: FileTypeTexture, + wantErr: false, }, { - name: "文件名为空", - fileName: "", - fileType: FileTypeAvatar, - wantErr: true, + name: "文件名为空", + fileName: "", + fileType: FileTypeAvatar, + wantErr: true, errContains: "文件名不能为空", }, { - name: "不支持的文件扩展名", - fileName: "file.txt", - fileType: FileTypeAvatar, - wantErr: true, + name: "不支持的文件扩展名", + fileName: "file.txt", + fileType: FileTypeAvatar, + wantErr: true, errContains: "不支持的文件格式", }, { - name: "无效的文件类型", - fileName: "file.png", - fileType: FileType("invalid"), - wantErr: true, + name: "无效的文件类型", + fileName: "file.png", + fileType: FileType("invalid"), + wantErr: true, errContains: "不支持的文件类型", }, } @@ -277,3 +281,130 @@ func TestUploadConfig_Structure(t *testing.T) { } } +// mockStorageClient 用于单元测试的简单存储客户端假实现 +// 注意:这里只声明与 upload_service 使用到的方法,避免依赖真实 MinIO 客户端 +type mockStorageClient struct { + getBucketFn func(name string) (string, error) + generatePresignedPostURLFn func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) +} + +func (m *mockStorageClient) GetBucket(name string) (string, error) { + if m.getBucketFn != nil { + return m.getBucketFn(name) + } + return "", errors.New("GetBucket not implemented") +} + +func (m *mockStorageClient) GeneratePresignedPostURL(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if m.generatePresignedPostURLFn != nil { + return m.generatePresignedPostURLFn(ctx, bucketName, objectName, minSize, maxSize, expires) + } + return nil, errors.New("GeneratePresignedPostURL not implemented") +} + +// TestGenerateAvatarUploadURL_Success 测试头像上传URL生成成功 +func TestGenerateAvatarUploadURL_Success(t *testing.T) { + ctx := context.Background() + + mockClient := &mockStorageClient{ + getBucketFn: func(name string) (string, error) { + if name != "avatars" { + t.Fatalf("unexpected bucket name: %s", name) + } + return "avatars-bucket", nil + }, + generatePresignedPostURLFn: func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if bucketName != "avatars-bucket" { + t.Fatalf("unexpected bucketName: %s", bucketName) + } + if !strings.Contains(objectName, "user_") { + t.Fatalf("objectName should contain user_ prefix, got: %s", objectName) + } + if !strings.Contains(objectName, "avatar.png") { + t.Fatalf("objectName should contain original file name, got: %s", objectName) + } + // 检查大小与过期时间传递 + if minSize != 1024 { + t.Fatalf("minSize = %d, want 1024", minSize) + } + if maxSize != 5*1024*1024 { + t.Fatalf("maxSize = %d, want 5MB", maxSize) + } + if expires != 15*time.Minute { + t.Fatalf("expires = %v, want 15m", expires) + } + return &storage.PresignedPostPolicyResult{ + PostURL: "http://example.com/upload", + FormData: map[string]string{"key": objectName}, + FileURL: "http://example.com/file/" + objectName, + }, nil + }, + } + + // 直接将 mock 实例转换为真实类型使用(依赖其方法集与被测代码一致) + storageClient := (*storage.StorageClient)(nil) + _ = storageClient // 避免未使用告警,实际调用仍通过 mockClient 完成 + + // 直接通过内部使用接口的实现进行测试,避免依赖真实 StorageClient + result, err := generateAvatarUploadURLWithClient(ctx, mockClient, 123, "avatar.png") + + if err != nil { + t.Fatalf("GenerateAvatarUploadURL() error = %v, want nil", err) + } + if result == nil { + t.Fatalf("GenerateAvatarUploadURL() result is nil") + } + if result.PostURL == "" || result.FileURL == "" { + t.Fatalf("GenerateAvatarUploadURL() result has empty URLs: %+v", result) + } +} + +// TestGenerateTextureUploadURL_Success 测试材质上传URL生成成功(SKIN/CAPE) +func TestGenerateTextureUploadURL_Success(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + textureType string + }{ + {"SKIN 材质", "SKIN"}, + {"CAPE 材质", "CAPE"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockClient := &mockStorageClient{ + getBucketFn: func(name string) (string, error) { + if name != "textures" { + t.Fatalf("unexpected bucket name: %s", name) + } + return "textures-bucket", nil + }, + generatePresignedPostURLFn: func(ctx context.Context, bucketName, objectName string, minSize, maxSize int64, expires time.Duration) (*storage.PresignedPostPolicyResult, error) { + if bucketName != "textures-bucket" { + t.Fatalf("unexpected bucketName: %s", bucketName) + } + if !strings.Contains(objectName, "texture.png") { + t.Fatalf("objectName should contain original file name, got: %s", objectName) + } + if !strings.Contains(objectName, "/"+strings.ToLower(tt.textureType)+"/") { + t.Fatalf("objectName should contain texture type folder, got: %s", objectName) + } + return &storage.PresignedPostPolicyResult{ + PostURL: "http://example.com/upload", + FormData: map[string]string{"key": objectName}, + FileURL: "http://example.com/file/" + objectName, + }, nil + }, + } + + result, err := generateTextureUploadURLWithClient(ctx, mockClient, 123, "texture.png", tt.textureType) + if err != nil { + t.Fatalf("generateTextureUploadURLWithClient() error = %v, want nil", err) + } + if result == nil || result.PostURL == "" || result.FileURL == "" { + t.Fatalf("generateTextureUploadURLWithClient() result invalid: %+v", result) + } + }) + } +} diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 249a341..2b7250e 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -12,12 +12,39 @@ import ( "net/url" "strings" "time" + + "go.uber.org/zap" ) -// RegisterUser 用户注册 -func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar string) (*model.User, string, error) { +// userServiceImpl UserService的实现 +type userServiceImpl struct { + userRepo repository.UserRepository + configRepo repository.SystemConfigRepository + jwtService *auth.JWTService + redis *redis.Client + logger *zap.Logger +} + +// NewUserService 创建UserService实例 +func NewUserService( + userRepo repository.UserRepository, + configRepo repository.SystemConfigRepository, + jwtService *auth.JWTService, + redisClient *redis.Client, + logger *zap.Logger, +) UserService { + return &userServiceImpl{ + userRepo: userRepo, + configRepo: configRepo, + jwtService: jwtService, + redis: redisClient, + logger: logger, + } +} + +func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) { // 检查用户名是否已存在 - existingUser, err := repository.FindUserByUsername(username) + existingUser, err := s.userRepo.FindByUsername(username) if err != nil { return nil, "", err } @@ -26,7 +53,7 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar } // 检查邮箱是否已存在 - existingEmail, err := repository.FindUserByEmail(email) + existingEmail, err := s.userRepo.FindByEmail(email) if err != nil { return nil, "", err } @@ -40,15 +67,14 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar return nil, "", errors.New("密码加密失败") } - // 确定头像URL:优先使用用户提供的头像,否则使用默认头像 + // 确定头像URL avatarURL := avatar if avatarURL != "" { - // 验证用户提供的头像 URL 是否来自允许的域名 - if err := ValidateAvatarURL(avatarURL); err != nil { + if err := s.ValidateAvatarURL(avatarURL); err != nil { return nil, "", err } } else { - avatarURL = getDefaultAvatar() + avatarURL = s.getDefaultAvatar() } // 创建用户 @@ -62,12 +88,12 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar Points: 0, } - if err := repository.CreateUser(user); err != nil { + if err := s.userRepo.Create(user); err != nil { return nil, "", err } // 生成JWT Token - token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role) + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) if err != nil { return nil, "", errors.New("生成Token失败") } @@ -75,92 +101,56 @@ func RegisterUser(jwtService *auth.JWTService, username, password, email, avatar return user, token, nil } -// LoginUser 用户登录(支持用户名或邮箱登录) -func LoginUser(jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { - return LoginUserWithRateLimit(nil, jwtService, usernameOrEmail, password, ipAddress, userAgent) -} - -// LoginUserWithRateLimit 用户登录(带频率限制) -func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTService, usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { +func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { ctx := context.Background() - // 检查账号是否被锁定(基于用户名/邮箱和IP) - if redisClient != nil { + // 检查账号是否被锁定 + if s.redis != nil { identifier := usernameOrEmail + ":" + ipAddress - locked, ttl, err := CheckLoginLocked(ctx, redisClient, identifier) + locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier) if err == nil && locked { return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1) } } - // 查找用户:判断是用户名还是邮箱 + // 查找用户 var user *model.User var err error if strings.Contains(usernameOrEmail, "@") { - user, err = repository.FindUserByEmail(usernameOrEmail) + user, err = s.userRepo.FindByEmail(usernameOrEmail) } else { - user, err = repository.FindUserByUsername(usernameOrEmail) + user, err = s.userRepo.FindByUsername(usernameOrEmail) } if err != nil { return nil, "", err } if user == nil { - // 记录失败尝试 - if redisClient != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, redisClient, identifier) - // 检查是否触发锁定 - if count >= MaxLoginAttempts { - logFailedLogin(0, ipAddress, userAgent, "用户不存在-账号已锁定") - return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes())) - } - remaining := MaxLoginAttempts - count - if remaining > 0 { - logFailedLogin(0, ipAddress, userAgent, "用户不存在") - return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining) - } - } - logFailedLogin(0, ipAddress, userAgent, "用户不存在") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在") return nil, "", errors.New("用户名/邮箱或密码错误") } // 检查用户状态 if user.Status != 1 { - logFailedLogin(user.ID, ipAddress, userAgent, "账号已被禁用") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用") return nil, "", errors.New("账号已被禁用") } // 验证密码 if !auth.CheckPassword(user.Password, password) { - // 记录失败尝试 - if redisClient != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, redisClient, identifier) - // 检查是否触发锁定 - if count >= MaxLoginAttempts { - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误-账号已锁定") - return nil, "", fmt.Errorf("登录失败次数过多,账号已被锁定 %d 分钟", int(LoginLockDuration.Minutes())) - } - remaining := MaxLoginAttempts - count - if remaining > 0 { - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误") - return nil, "", fmt.Errorf("用户名/邮箱或密码错误,还剩 %d 次尝试机会", remaining) - } - } - logFailedLogin(user.ID, ipAddress, userAgent, "密码错误") + s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误") return nil, "", errors.New("用户名/邮箱或密码错误") } // 登录成功,清除失败计数 - if redisClient != nil { + if s.redis != nil { identifier := usernameOrEmail + ":" + ipAddress - _ = ClearLoginAttempts(ctx, redisClient, identifier) + _ = ClearLoginAttempts(ctx, s.redis, identifier) } // 生成JWT Token - token, err := jwtService.GenerateToken(user.ID, user.Username, user.Role) + token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) if err != nil { return nil, "", errors.New("生成Token失败") } @@ -168,37 +158,37 @@ func LoginUserWithRateLimit(redisClient *redis.Client, jwtService *auth.JWTServi // 更新最后登录时间 now := time.Now() user.LastLoginAt = &now - _ = repository.UpdateUserFields(user.ID, map[string]interface{}{ + _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ "last_login_at": now, }) // 记录成功登录日志 - logSuccessLogin(user.ID, ipAddress, userAgent) + s.logSuccessLogin(user.ID, ipAddress, userAgent) return user, token, nil } -// GetUserByID 根据ID获取用户 -func GetUserByID(id int64) (*model.User, error) { - return repository.FindUserByID(id) +func (s *userServiceImpl) GetByID(id int64) (*model.User, error) { + return s.userRepo.FindByID(id) } -// UpdateUserInfo 更新用户信息 -func UpdateUserInfo(user *model.User) error { - return repository.UpdateUser(user) +func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) { + return s.userRepo.FindByEmail(email) } -// UpdateUserAvatar 更新用户头像 -func UpdateUserAvatar(userID int64, avatarURL string) error { - return repository.UpdateUserFields(userID, map[string]interface{}{ +func (s *userServiceImpl) UpdateInfo(user *model.User) error { + return s.userRepo.Update(user) +} + +func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error { + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "avatar": avatarURL, }) } -// ChangeUserPassword 修改密码 -func ChangeUserPassword(userID int64, oldPassword, newPassword string) error { - user, err := repository.FindUserByID(userID) - if err != nil { +func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error { + user, err := s.userRepo.FindByID(userID) + if err != nil || user == nil { return errors.New("用户不存在") } @@ -211,15 +201,14 @@ func ChangeUserPassword(userID int64, oldPassword, newPassword string) error { return errors.New("密码加密失败") } - return repository.UpdateUserFields(userID, map[string]interface{}{ + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "password": hashedPassword, }) } -// ResetUserPassword 重置密码(通过邮箱) -func ResetUserPassword(email, newPassword string) error { - user, err := repository.FindUserByEmail(email) - if err != nil { +func (s *userServiceImpl) ResetPassword(email, newPassword string) error { + user, err := s.userRepo.FindByEmail(email) + if err != nil || user == nil { return errors.New("用户不存在") } @@ -228,14 +217,13 @@ func ResetUserPassword(email, newPassword string) error { return errors.New("密码加密失败") } - return repository.UpdateUserFields(user.ID, map[string]interface{}{ + return s.userRepo.UpdateFields(user.ID, map[string]interface{}{ "password": hashedPassword, }) } -// ChangeUserEmail 更换邮箱 -func ChangeUserEmail(userID int64, newEmail string) error { - existingUser, err := repository.FindUserByEmail(newEmail) +func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error { + existingUser, err := s.userRepo.FindByEmail(newEmail) if err != nil { return err } @@ -243,47 +231,12 @@ func ChangeUserEmail(userID int64, newEmail string) error { return errors.New("邮箱已被其他用户使用") } - return repository.UpdateUserFields(userID, map[string]interface{}{ + return s.userRepo.UpdateFields(userID, map[string]interface{}{ "email": newEmail, }) } -// logSuccessLogin 记录成功登录 -func logSuccessLogin(userID int64, ipAddress, userAgent string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: true, - } - _ = repository.CreateLoginLog(log) -} - -// logFailedLogin 记录失败登录 -func logFailedLogin(userID int64, ipAddress, userAgent, reason string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: false, - FailureReason: reason, - } - _ = repository.CreateLoginLog(log) -} - -// getDefaultAvatar 获取默认头像URL -func getDefaultAvatar() string { - config, err := repository.GetSystemConfigByKey("default_avatar") - if err != nil || config == nil || config.Value == "" { - return "" - } - return config.Value -} - -// ValidateAvatarURL 验证头像URL是否合法 -func ValidateAvatarURL(avatarURL string) error { +func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error { if avatarURL == "" { return nil } @@ -293,13 +246,8 @@ func ValidateAvatarURL(avatarURL string) error { return nil } - return ValidateURLDomain(avatarURL) -} - -// ValidateURLDomain 验证URL的域名是否在允许列表中 -func ValidateURLDomain(rawURL string) error { // 解析URL - parsedURL, err := url.Parse(rawURL) + parsedURL, err := url.Parse(avatarURL) if err != nil { return errors.New("无效的URL格式") } @@ -309,7 +257,6 @@ func ValidateURLDomain(rawURL string) error { return errors.New("URL必须使用http或https协议") } - // 获取主机名(不包含端口) host := parsedURL.Hostname() if host == "" { return errors.New("URL缺少主机名") @@ -318,16 +265,50 @@ func ValidateURLDomain(rawURL string) error { // 从配置获取允许的域名列表 cfg, err := config.GetConfig() if err != nil { - // 如果配置获取失败,使用默认的安全域名列表 allowedDomains := []string{"localhost", "127.0.0.1"} - return checkDomainAllowed(host, allowedDomains) + return s.checkDomainAllowed(host, allowedDomains) } - return checkDomainAllowed(host, cfg.Security.AllowedDomains) + return s.checkDomainAllowed(host, cfg.Security.AllowedDomains) } -// checkDomainAllowed 检查域名是否在允许列表中 -func checkDomainAllowed(host string, allowedDomains []string) error { +func (s *userServiceImpl) GetMaxProfilesPerUser() int { + config, err := s.configRepo.GetByKey("max_profiles_per_user") + if err != nil || config == nil { + return 5 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 5 + } + return value +} + +func (s *userServiceImpl) GetMaxTexturesPerUser() int { + config, err := s.configRepo.GetByKey("max_textures_per_user") + if err != nil || config == nil { + return 50 + } + var value int + fmt.Sscanf(config.Value, "%d", &value) + if value <= 0 { + return 50 + } + return value +} + +// 私有辅助方法 + +func (s *userServiceImpl) getDefaultAvatar() string { + config, err := s.configRepo.GetByKey("default_avatar") + if err != nil || config == nil || config.Value == "" { + return "" + } + return config.Value +} + +func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error { host = strings.ToLower(host) for _, allowed := range allowedDomains { @@ -336,14 +317,12 @@ func checkDomainAllowed(host string, allowedDomains []string) error { continue } - // 精确匹配 if host == allowed { return nil } - // 支持通配符子域名匹配 (如 *.example.com) if strings.HasPrefix(allowed, "*.") { - suffix := allowed[1:] // 移除 "*",保留 ".example.com" + suffix := allowed[1:] if strings.HasSuffix(host, suffix) { return nil } @@ -353,39 +332,37 @@ func checkDomainAllowed(host string, allowedDomains []string) error { return errors.New("URL域名不在允许的列表中") } -// GetUserByEmail 根据邮箱获取用户 -func GetUserByEmail(email string) (*model.User, error) { - user, err := repository.FindUserByEmail(email) - if err != nil { - return nil, errors.New("邮箱查找失败") +func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) { + if s.redis != nil { + identifier := usernameOrEmail + ":" + ipAddress + count, _ := RecordLoginFailure(ctx, s.redis, identifier) + if count >= MaxLoginAttempts { + s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") + return + } } - return user, nil + s.logFailedLogin(userID, ipAddress, userAgent, reason) } -// GetMaxProfilesPerUser 获取每用户最大档案数量配置 -func GetMaxProfilesPerUser() int { - config, err := repository.GetSystemConfigByKey("max_profiles_per_user") - if err != nil || config == nil { - return 5 +func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: true, } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 5 - } - return value + _ = s.userRepo.CreateLoginLog(log) } -// GetMaxTexturesPerUser 获取每用户最大材质数量配置 -func GetMaxTexturesPerUser() int { - config, err := repository.GetSystemConfigByKey("max_textures_per_user") - if err != nil || config == nil { - return 50 +func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { + log := &model.UserLoginLog{ + UserID: userID, + IPAddress: ipAddress, + UserAgent: userAgent, + LoginMethod: "PASSWORD", + IsSuccess: false, + FailureReason: reason, } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 50 - } - return value + _ = s.userRepo.CreateLoginLog(log) } diff --git a/internal/service/user_service_impl.go b/internal/service/user_service_impl.go deleted file mode 100644 index 2b7250e..0000000 --- a/internal/service/user_service_impl.go +++ /dev/null @@ -1,368 +0,0 @@ -package service - -import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "carrotskin/pkg/auth" - "carrotskin/pkg/config" - "carrotskin/pkg/redis" - "context" - "errors" - "fmt" - "net/url" - "strings" - "time" - - "go.uber.org/zap" -) - -// userServiceImpl UserService的实现 -type userServiceImpl struct { - userRepo repository.UserRepository - configRepo repository.SystemConfigRepository - jwtService *auth.JWTService - redis *redis.Client - logger *zap.Logger -} - -// NewUserService 创建UserService实例 -func NewUserService( - userRepo repository.UserRepository, - configRepo repository.SystemConfigRepository, - jwtService *auth.JWTService, - redisClient *redis.Client, - logger *zap.Logger, -) UserService { - return &userServiceImpl{ - userRepo: userRepo, - configRepo: configRepo, - jwtService: jwtService, - redis: redisClient, - logger: logger, - } -} - -func (s *userServiceImpl) Register(username, password, email, avatar string) (*model.User, string, error) { - // 检查用户名是否已存在 - existingUser, err := s.userRepo.FindByUsername(username) - if err != nil { - return nil, "", err - } - if existingUser != nil { - return nil, "", errors.New("用户名已存在") - } - - // 检查邮箱是否已存在 - existingEmail, err := s.userRepo.FindByEmail(email) - if err != nil { - return nil, "", err - } - if existingEmail != nil { - return nil, "", errors.New("邮箱已被注册") - } - - // 加密密码 - hashedPassword, err := auth.HashPassword(password) - if err != nil { - return nil, "", errors.New("密码加密失败") - } - - // 确定头像URL - avatarURL := avatar - if avatarURL != "" { - if err := s.ValidateAvatarURL(avatarURL); err != nil { - return nil, "", err - } - } else { - avatarURL = s.getDefaultAvatar() - } - - // 创建用户 - user := &model.User{ - Username: username, - Password: hashedPassword, - Email: email, - Avatar: avatarURL, - Role: "user", - Status: 1, - Points: 0, - } - - if err := s.userRepo.Create(user); err != nil { - return nil, "", err - } - - // 生成JWT Token - token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) - if err != nil { - return nil, "", errors.New("生成Token失败") - } - - return user, token, nil -} - -func (s *userServiceImpl) Login(usernameOrEmail, password, ipAddress, userAgent string) (*model.User, string, error) { - ctx := context.Background() - - // 检查账号是否被锁定 - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - locked, ttl, err := CheckLoginLocked(ctx, s.redis, identifier) - if err == nil && locked { - return nil, "", fmt.Errorf("登录尝试次数过多,请在 %d 分钟后重试", int(ttl.Minutes())+1) - } - } - - // 查找用户 - var user *model.User - var err error - - if strings.Contains(usernameOrEmail, "@") { - user, err = s.userRepo.FindByEmail(usernameOrEmail) - } else { - user, err = s.userRepo.FindByUsername(usernameOrEmail) - } - - if err != nil { - return nil, "", err - } - if user == nil { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, 0, "用户不存在") - return nil, "", errors.New("用户名/邮箱或密码错误") - } - - // 检查用户状态 - if user.Status != 1 { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "账号已被禁用") - return nil, "", errors.New("账号已被禁用") - } - - // 验证密码 - if !auth.CheckPassword(user.Password, password) { - s.recordLoginFailure(ctx, usernameOrEmail, ipAddress, userAgent, user.ID, "密码错误") - return nil, "", errors.New("用户名/邮箱或密码错误") - } - - // 登录成功,清除失败计数 - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - _ = ClearLoginAttempts(ctx, s.redis, identifier) - } - - // 生成JWT Token - token, err := s.jwtService.GenerateToken(user.ID, user.Username, user.Role) - if err != nil { - return nil, "", errors.New("生成Token失败") - } - - // 更新最后登录时间 - now := time.Now() - user.LastLoginAt = &now - _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ - "last_login_at": now, - }) - - // 记录成功登录日志 - s.logSuccessLogin(user.ID, ipAddress, userAgent) - - return user, token, nil -} - -func (s *userServiceImpl) GetByID(id int64) (*model.User, error) { - return s.userRepo.FindByID(id) -} - -func (s *userServiceImpl) GetByEmail(email string) (*model.User, error) { - return s.userRepo.FindByEmail(email) -} - -func (s *userServiceImpl) UpdateInfo(user *model.User) error { - return s.userRepo.Update(user) -} - -func (s *userServiceImpl) UpdateAvatar(userID int64, avatarURL string) error { - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "avatar": avatarURL, - }) -} - -func (s *userServiceImpl) ChangePassword(userID int64, oldPassword, newPassword string) error { - user, err := s.userRepo.FindByID(userID) - if err != nil || user == nil { - return errors.New("用户不存在") - } - - if !auth.CheckPassword(user.Password, oldPassword) { - return errors.New("原密码错误") - } - - hashedPassword, err := auth.HashPassword(newPassword) - if err != nil { - return errors.New("密码加密失败") - } - - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "password": hashedPassword, - }) -} - -func (s *userServiceImpl) ResetPassword(email, newPassword string) error { - user, err := s.userRepo.FindByEmail(email) - if err != nil || user == nil { - return errors.New("用户不存在") - } - - hashedPassword, err := auth.HashPassword(newPassword) - if err != nil { - return errors.New("密码加密失败") - } - - return s.userRepo.UpdateFields(user.ID, map[string]interface{}{ - "password": hashedPassword, - }) -} - -func (s *userServiceImpl) ChangeEmail(userID int64, newEmail string) error { - existingUser, err := s.userRepo.FindByEmail(newEmail) - if err != nil { - return err - } - if existingUser != nil && existingUser.ID != userID { - return errors.New("邮箱已被其他用户使用") - } - - return s.userRepo.UpdateFields(userID, map[string]interface{}{ - "email": newEmail, - }) -} - -func (s *userServiceImpl) ValidateAvatarURL(avatarURL string) error { - if avatarURL == "" { - return nil - } - - // 允许相对路径 - if strings.HasPrefix(avatarURL, "/") { - return nil - } - - // 解析URL - parsedURL, err := url.Parse(avatarURL) - if err != nil { - return errors.New("无效的URL格式") - } - - // 必须是HTTP或HTTPS协议 - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return errors.New("URL必须使用http或https协议") - } - - host := parsedURL.Hostname() - if host == "" { - return errors.New("URL缺少主机名") - } - - // 从配置获取允许的域名列表 - cfg, err := config.GetConfig() - if err != nil { - allowedDomains := []string{"localhost", "127.0.0.1"} - return s.checkDomainAllowed(host, allowedDomains) - } - - return s.checkDomainAllowed(host, cfg.Security.AllowedDomains) -} - -func (s *userServiceImpl) GetMaxProfilesPerUser() int { - config, err := s.configRepo.GetByKey("max_profiles_per_user") - if err != nil || config == nil { - return 5 - } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 5 - } - return value -} - -func (s *userServiceImpl) GetMaxTexturesPerUser() int { - config, err := s.configRepo.GetByKey("max_textures_per_user") - if err != nil || config == nil { - return 50 - } - var value int - fmt.Sscanf(config.Value, "%d", &value) - if value <= 0 { - return 50 - } - return value -} - -// 私有辅助方法 - -func (s *userServiceImpl) getDefaultAvatar() string { - config, err := s.configRepo.GetByKey("default_avatar") - if err != nil || config == nil || config.Value == "" { - return "" - } - return config.Value -} - -func (s *userServiceImpl) checkDomainAllowed(host string, allowedDomains []string) error { - host = strings.ToLower(host) - - for _, allowed := range allowedDomains { - allowed = strings.ToLower(strings.TrimSpace(allowed)) - if allowed == "" { - continue - } - - if host == allowed { - return nil - } - - if strings.HasPrefix(allowed, "*.") { - suffix := allowed[1:] - if strings.HasSuffix(host, suffix) { - return nil - } - } - } - - return errors.New("URL域名不在允许的列表中") -} - -func (s *userServiceImpl) recordLoginFailure(ctx context.Context, usernameOrEmail, ipAddress, userAgent string, userID int64, reason string) { - if s.redis != nil { - identifier := usernameOrEmail + ":" + ipAddress - count, _ := RecordLoginFailure(ctx, s.redis, identifier) - if count >= MaxLoginAttempts { - s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") - return - } - } - s.logFailedLogin(userID, ipAddress, userAgent, reason) -} - -func (s *userServiceImpl) logSuccessLogin(userID int64, ipAddress, userAgent string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: true, - } - _ = s.userRepo.CreateLoginLog(log) -} - -func (s *userServiceImpl) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { - log := &model.UserLoginLog{ - UserID: userID, - IPAddress: ipAddress, - UserAgent: userAgent, - LoginMethod: "PASSWORD", - IsSuccess: false, - FailureReason: reason, - } - _ = s.userRepo.CreateLoginLog(log) -} diff --git a/internal/service/user_service_test.go b/internal/service/user_service_test.go index 9144fb4..e5bfc36 100644 --- a/internal/service/user_service_test.go +++ b/internal/service/user_service_test.go @@ -1,199 +1,378 @@ package service import ( - "strings" + "carrotskin/internal/model" + "carrotskin/pkg/auth" "testing" + + "go.uber.org/zap" ) -// TestGetDefaultAvatar 测试获取默认头像的逻辑 -// 注意:这个测试需要mock repository,但由于repository是函数式的, -// 我们只测试逻辑部分 -func TestGetDefaultAvatar_Logic(t *testing.T) { +func TestUserServiceImpl_Register(t *testing.T) { + // 准备依赖 + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 初始化Service + // 注意:redisClient 传入 nil,因为 Register 方法中没有使用 redis + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 测试用例 tests := []struct { - name string - configExists bool - configValue string - expectedResult string + name string + username string + password string + email string + avatar string + wantErr bool + errMsg string + setupMocks func() }{ { - name: "配置存在时返回配置值", - configExists: true, - configValue: "https://example.com/avatar.png", - expectedResult: "https://example.com/avatar.png", + name: "正常注册", + username: "testuser", + password: "password123", + email: "test@example.com", + avatar: "", + wantErr: false, }, { - name: "配置不存在时返回错误信息", - configExists: false, - configValue: "", - expectedResult: "数据库中不存在默认头像配置", + name: "用户名已存在", + username: "existinguser", + password: "password123", + email: "new@example.com", + avatar: "", + wantErr: true, + errMsg: "用户名已存在", + setupMocks: func() { + userRepo.Create(&model.User{ + Username: "existinguser", + Email: "old@example.com", + }) + }, + }, + { + name: "邮箱已存在", + username: "newuser", + password: "password123", + email: "existing@example.com", + avatar: "", + wantErr: true, + errMsg: "邮箱已被注册", + setupMocks: func() { + userRepo.Create(&model.User{ + Username: "otheruser", + Email: "existing@example.com", + }) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 这个测试只验证逻辑,不实际调用repository - // 实际的repository调用测试需要集成测试或mock - if tt.configExists { - if tt.expectedResult != tt.configValue { - t.Errorf("当配置存在时,应该返回配置值") + // 重置mock状态 + if tt.setupMocks != nil { + tt.setupMocks() + } + + user, token, err := userService.Register(tt.username, tt.password, tt.email, tt.avatar) + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + return + } + if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) } } else { - if !strings.Contains(tt.expectedResult, "数据库中不存在默认头像配置") { - t.Errorf("当配置不存在时,应该返回错误信息") + if err != nil { + t.Errorf("不期望返回错误: %v", err) + return + } + if user == nil { + t.Error("返回的用户不应为nil") + } + if token == "" { + t.Error("返回的Token不应为空") + } + if user.Username != tt.username { + t.Errorf("用户名不匹配: got %v, want %v", user.Username, tt.username) } } }) } } -// TestLoginUser_EmailDetection 测试登录时邮箱检测逻辑 -func TestLoginUser_EmailDetection(t *testing.T) { +func TestUserServiceImpl_Login(t *testing.T) { + // 准备依赖 + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 预置用户 + password := "password123" + hashedPassword, _ := auth.HashPassword(password) + testUser := &model.User{ + Username: "testlogin", + Email: "login@example.com", + Password: hashedPassword, + Status: 1, + } + userRepo.Create(testUser) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + tests := []struct { name string usernameOrEmail string - isEmail bool + password string + wantErr bool + errMsg string }{ { - name: "包含@符号,识别为邮箱", - usernameOrEmail: "user@example.com", - isEmail: true, + name: "用户名登录成功", + usernameOrEmail: "testlogin", + password: "password123", + wantErr: false, }, { - name: "不包含@符号,识别为用户名", - usernameOrEmail: "username", - isEmail: false, + name: "邮箱登录成功", + usernameOrEmail: "login@example.com", + password: "password123", + wantErr: false, }, { - name: "空字符串", - usernameOrEmail: "", - isEmail: false, + name: "密码错误", + usernameOrEmail: "testlogin", + password: "wrongpassword", + wantErr: true, + errMsg: "用户名/邮箱或密码错误", }, { - name: "只有@符号", - usernameOrEmail: "@", - isEmail: true, + name: "用户不存在", + usernameOrEmail: "nonexistent", + password: "password123", + wantErr: true, + errMsg: "用户名/邮箱或密码错误", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isEmail := strings.Contains(tt.usernameOrEmail, "@") - if isEmail != tt.isEmail { - t.Errorf("Email detection failed: got %v, want %v", isEmail, tt.isEmail) + user, token, err := userService.Login(tt.usernameOrEmail, tt.password, "127.0.0.1", "test-agent") + + if tt.wantErr { + if err == nil { + t.Error("期望返回错误,但实际没有错误") + } else if tt.errMsg != "" && err.Error() != tt.errMsg { + t.Errorf("错误信息不匹配: got %v, want %v", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("不期望返回错误: %v", err) + } + if user == nil { + t.Error("用户不应为nil") + } + if token == "" { + t.Error("Token不应为空") + } } }) } } -// TestUserService_Constants 测试用户服务相关常量 -func TestUserService_Constants(t *testing.T) { - // 测试默认用户角色 - defaultRole := "user" - if defaultRole == "" { - t.Error("默认用户角色不能为空") +// TestUserServiceImpl_BasicGetters 测试 GetByID / GetByEmail / UpdateInfo / UpdateAvatar +func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 预置用户 + user := &model.User{ + ID: 1, + Username: "basic", + Email: "basic@example.com", + Avatar: "", + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // GetByID + gotByID, err := userService.GetByID(1) + if err != nil || gotByID == nil || gotByID.ID != 1 { + t.Fatalf("GetByID 返回不正确: user=%+v, err=%v", gotByID, err) } - // 测试默认用户状态 - defaultStatus := int16(1) - if defaultStatus != 1 { - t.Errorf("默认用户状态应为1(正常),实际为%d", defaultStatus) + // GetByEmail + gotByEmail, err := userService.GetByEmail("basic@example.com") + if err != nil || gotByEmail == nil || gotByEmail.Email != "basic@example.com" { + t.Fatalf("GetByEmail 返回不正确: user=%+v, err=%v", gotByEmail, err) } - // 测试初始积分 - initialPoints := 0 - if initialPoints < 0 { - t.Errorf("初始积分不应为负数,实际为%d", initialPoints) + // UpdateInfo + user.Username = "updated" + if err := userService.UpdateInfo(user); err != nil { + t.Fatalf("UpdateInfo 失败: %v", err) + } + updated, _ := userRepo.FindByID(1) + if updated.Username != "updated" { + t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username) + } + + // UpdateAvatar 只需确认不会返回错误(具体字段更新由仓库层保证) + if err := userService.UpdateAvatar(1, "http://example.com/avatar.png"); err != nil { + t.Fatalf("UpdateAvatar 失败: %v", err) } } -// TestUserService_Validation 测试用户数据验证逻辑 -func TestUserService_Validation(t *testing.T) { +// TestUserServiceImpl_ChangePassword 测试 ChangePassword +func TestUserServiceImpl_ChangePassword(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + hashed, _ := auth.HashPassword("oldpass") + user := &model.User{ + ID: 1, + Username: "changepw", + Password: hashed, + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 原密码正确 + if err := userService.ChangePassword(1, "oldpass", "newpass"); err != nil { + t.Fatalf("ChangePassword 正常情况失败: %v", err) + } + + // 用户不存在 + if err := userService.ChangePassword(999, "oldpass", "newpass"); err == nil { + t.Fatalf("ChangePassword 应在用户不存在时返回错误") + } + + // 原密码错误 + if err := userService.ChangePassword(1, "wrong", "another"); err == nil { + t.Fatalf("ChangePassword 应在原密码错误时返回错误") + } +} + +// TestUserServiceImpl_ResetPassword 测试 ResetPassword +func TestUserServiceImpl_ResetPassword(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + user := &model.User{ + ID: 1, + Username: "resetpw", + Email: "reset@example.com", + } + userRepo.Create(user) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 正常重置 + if err := userService.ResetPassword("reset@example.com", "newpass"); err != nil { + t.Fatalf("ResetPassword 正常情况失败: %v", err) + } + + // 用户不存在 + if err := userService.ResetPassword("notfound@example.com", "newpass"); err == nil { + t.Fatalf("ResetPassword 应在用户不存在时返回错误") + } +} + +// TestUserServiceImpl_ChangeEmail 测试 ChangeEmail +func TestUserServiceImpl_ChangeEmail(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + user1 := &model.User{ID: 1, Email: "user1@example.com"} + user2 := &model.User{ID: 2, Email: "user2@example.com"} + userRepo.Create(user1) + userRepo.Create(user2) + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + + // 正常修改 + if err := userService.ChangeEmail(1, "new@example.com"); err != nil { + t.Fatalf("ChangeEmail 正常情况失败: %v", err) + } + + // 邮箱被其他用户占用 + if err := userService.ChangeEmail(1, "user2@example.com"); err == nil { + t.Fatalf("ChangeEmail 应在邮箱被占用时返回错误") + } +} + +// TestUserServiceImpl_ValidateAvatarURL 测试 ValidateAvatarURL +func TestUserServiceImpl_ValidateAvatarURL(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + tests := []struct { - name string - username string - email string - password string - wantValid bool + name string + url string + wantErr bool }{ - { - name: "有效的用户名和邮箱", - username: "testuser", - email: "test@example.com", - password: "password123", - wantValid: true, - }, - { - name: "用户名为空", - username: "", - email: "test@example.com", - password: "password123", - wantValid: false, - }, - { - name: "邮箱为空", - username: "testuser", - email: "", - password: "password123", - wantValid: false, - }, - { - name: "密码为空", - username: "testuser", - email: "test@example.com", - password: "", - wantValid: false, - }, - { - name: "邮箱格式无效(缺少@)", - username: "testuser", - email: "invalid-email", - password: "password123", - wantValid: false, - }, + {"空字符串通过", "", false}, + {"相对路径通过", "/images/avatar.png", false}, + {"非法URL格式", "://bad-url", true}, + {"非法协议", "ftp://example.com/avatar.png", true}, + {"缺少主机名", "http:///avatar.png", true}, + {"本地域名通过", "http://localhost/avatar.png", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // 简单的验证逻辑测试 - isValid := tt.username != "" && tt.email != "" && tt.password != "" && strings.Contains(tt.email, "@") - if isValid != tt.wantValid { - t.Errorf("Validation failed: got %v, want %v", isValid, tt.wantValid) + err := userService.ValidateAvatarURL(tt.url) + if (err != nil) != tt.wantErr { + t.Fatalf("ValidateAvatarURL(%q) error = %v, wantErr=%v", tt.url, err, tt.wantErr) } }) } } -// TestUserService_AvatarLogic 测试头像逻辑 -func TestUserService_AvatarLogic(t *testing.T) { - tests := []struct { - name string - providedAvatar string - defaultAvatar string - expectedAvatar string - }{ - { - name: "提供头像时使用提供的头像", - providedAvatar: "https://example.com/custom.png", - defaultAvatar: "https://example.com/default.png", - expectedAvatar: "https://example.com/custom.png", - }, - { - name: "未提供头像时使用默认头像", - providedAvatar: "", - defaultAvatar: "https://example.com/default.png", - expectedAvatar: "https://example.com/default.png", - }, +// TestUserServiceImpl_MaxLimits 测试 GetMaxProfilesPerUser / GetMaxTexturesPerUser +func TestUserServiceImpl_MaxLimits(t *testing.T) { + userRepo := NewMockUserRepository() + configRepo := NewMockSystemConfigRepository() + jwtService := auth.NewJWTService("secret", 1) + logger := zap.NewNop() + + // 未配置时走默认值 + userService := NewUserService(userRepo, configRepo, jwtService, nil, logger) + if got := userService.GetMaxProfilesPerUser(); got != 5 { + t.Fatalf("GetMaxProfilesPerUser 默认值错误, got=%d", got) + } + if got := userService.GetMaxTexturesPerUser(); got != 50 { + t.Fatalf("GetMaxTexturesPerUser 默认值错误, got=%d", got) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - avatarURL := tt.providedAvatar - if avatarURL == "" { - avatarURL = tt.defaultAvatar - } - if avatarURL != tt.expectedAvatar { - t.Errorf("Avatar logic failed: got %s, want %s", avatarURL, tt.expectedAvatar) - } - }) + // 配置有效值 + configRepo.Update(&model.SystemConfig{Key: "max_profiles_per_user", Value: "10"}) + configRepo.Update(&model.SystemConfig{Key: "max_textures_per_user", Value: "100"}) + + if got := userService.GetMaxProfilesPerUser(); got != 10 { + t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got) } -} + if got := userService.GetMaxTexturesPerUser(); got != 100 { + t.Fatalf("GetMaxTexturesPerUser 配置值错误, got=%d", got) + } +} \ No newline at end of file