diff --git a/.dockerignore b/.dockerignore index e375c38..6686339 100644 --- a/.dockerignore +++ b/.dockerignore @@ -74,3 +74,5 @@ local/ dev/ minio-data/ + + diff --git a/.gitea/workflows/docker.yml b/.gitea/workflows/docker.yml deleted file mode 100644 index 4595bfb..0000000 --- a/.gitea/workflows/docker.yml +++ /dev/null @@ -1,84 +0,0 @@ -name: Build and Push Docker Image - -on: - push: - branches: - - main - - master - - dev - tags: - - 'v*' - workflow_dispatch: - -env: - REGISTRY: code.littlelan.cn - IMAGE_NAME: carrotskin/backend - -jobs: - build-and-push: - runs-on: ubuntu-latest - container: - image: quay.io/buildah/stable:latest - options: --privileged - - steps: - - name: Install dependencies - run: | - dnf install -y git nodejs - - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Login to registry - run: | - buildah login \ - -u "${{ secrets.REGISTRY_USERNAME }}" \ - -p "${{ secrets.REGISTRY_PASSWORD }}" \ - ${{ env.REGISTRY }} - echo "Registry 登录成功" - - - name: Build image - run: | - buildah bud \ - --format docker \ - --layers \ - -t ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - -f Dockerfile \ - . - echo "镜像构建完成" - - - name: Tag and push image - run: | - SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7) - REF_NAME="${{ github.ref_name }}" - REF="${{ github.ref }}" - - # 推送分支/标签名 - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME} - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${REF_NAME} - echo "✓ 推送: ${REF_NAME}" - - # 推送 SHA 标签 - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA} - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${SHORT_SHA} - echo "✓ 推送: sha-${SHORT_SHA}" - - # main/master 推送 latest - if [ "$REF" = "refs/heads/main" ] || [ "$REF" = "refs/heads/master" ]; then - buildah tag ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:build \ - ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest - buildah push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:latest - echo "✓ 推送: latest" - fi - - - name: Build summary - run: | - echo "==============================" - echo "✅ 镜像构建完成!" - echo "仓库: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}" - echo "分支: ${{ github.ref_name }}" - echo "==============================" diff --git a/Dockerfile b/Dockerfile index 118c7a4..b5a00ab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -59,3 +59,5 @@ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ # 启动应用 ENTRYPOINT ["./server"] + + diff --git a/cmd/server/main_di_example.go.example b/cmd/server/main_di_example.go.example new file mode 100644 index 0000000..d9168ef --- /dev/null +++ b/cmd/server/main_di_example.go.example @@ -0,0 +1,146 @@ +// +build ignore +// 此文件是依赖注入版本的main.go示例 +// 可以参考此文件改造原有的main.go + +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + _ "carrotskin/docs" // Swagger文档 + "carrotskin/internal/container" + "carrotskin/internal/handler" + "carrotskin/internal/middleware" + "carrotskin/pkg/auth" + "carrotskin/pkg/config" + "carrotskin/pkg/database" + "carrotskin/pkg/email" + "carrotskin/pkg/logger" + "carrotskin/pkg/redis" + "carrotskin/pkg/storage" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func main() { + // 初始化配置 + if err := config.Init(); err != nil { + log.Fatalf("配置加载失败: %v", err) + } + cfg := config.MustGetConfig() + + // 初始化日志 + if err := logger.Init(cfg.Log); err != nil { + log.Fatalf("日志初始化失败: %v", err) + } + loggerInstance := logger.MustGetLogger() + defer loggerInstance.Sync() + + // 初始化数据库 + if err := database.Init(cfg.Database, loggerInstance); err != nil { + loggerInstance.Fatal("数据库初始化失败", zap.Error(err)) + } + defer database.Close() + + // 执行数据库迁移 + if err := database.AutoMigrate(loggerInstance); err != nil { + loggerInstance.Fatal("数据库迁移失败", zap.Error(err)) + } + + // 初始化种子数据 + if err := database.Seed(loggerInstance); err != nil { + loggerInstance.Fatal("种子数据初始化失败", zap.Error(err)) + } + + // 初始化JWT服务 + if err := auth.Init(cfg.JWT); err != nil { + loggerInstance.Fatal("JWT服务初始化失败", zap.Error(err)) + } + + // 初始化Redis + if err := redis.Init(cfg.Redis, loggerInstance); err != nil { + loggerInstance.Fatal("Redis连接失败", zap.Error(err)) + } + defer redis.MustGetClient().Close() + + // 初始化对象存储 + var storageClient *storage.StorageClient + if err := storage.Init(cfg.RustFS); err != nil { + loggerInstance.Warn("对象存储连接失败,某些功能可能不可用", zap.Error(err)) + } else { + storageClient = storage.MustGetClient() + loggerInstance.Info("对象存储连接成功") + } + + // 初始化邮件服务 + if err := email.Init(cfg.Email, loggerInstance); err != nil { + loggerInstance.Fatal("邮件服务初始化失败", zap.Error(err)) + } + + // ============ 依赖注入改动部分 ============ + // 创建依赖注入容器 + c := container.NewContainer( + database.MustGetDB(), + redis.MustGetClient(), + loggerInstance, + auth.MustGetJWTService(), + storageClient, + ) + + // 设置Gin模式 + if cfg.Server.Mode == "production" { + gin.SetMode(gin.ReleaseMode) + } + + // 创建路由 + router := gin.New() + + // 添加中间件 + router.Use(middleware.Logger(loggerInstance)) + router.Use(middleware.Recovery(loggerInstance)) + router.Use(middleware.CORS()) + + // 使用依赖注入方式注册路由 + handler.RegisterRoutesWithDI(router, c) + // ============ 依赖注入改动结束 ============ + + // 创建HTTP服务器 + srv := &http.Server{ + Addr: cfg.Server.Port, + Handler: router, + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + } + + // 启动服务器 + go func() { + loggerInstance.Info("服务器启动", zap.String("port", cfg.Server.Port)) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + loggerInstance.Fatal("服务器启动失败", zap.Error(err)) + } + }() + + // 等待中断信号优雅关闭 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + loggerInstance.Info("正在关闭服务器...") + + // 设置关闭超时 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + loggerInstance.Fatal("服务器强制关闭", zap.Error(err)) + } + + loggerInstance.Info("服务器已关闭") +} + diff --git a/internal/container/container.go b/internal/container/container.go new file mode 100644 index 0000000..230e68f --- /dev/null +++ b/internal/container/container.go @@ -0,0 +1,138 @@ +package container + +import ( + "carrotskin/internal/repository" + "carrotskin/pkg/auth" + "carrotskin/pkg/redis" + "carrotskin/pkg/storage" + + "go.uber.org/zap" + "gorm.io/gorm" +) + +// Container 依赖注入容器 +// 集中管理所有依赖,便于测试和维护 +type Container struct { + // 基础设施依赖 + DB *gorm.DB + Redis *redis.Client + Logger *zap.Logger + JWT *auth.JWTService + Storage *storage.StorageClient + + // Repository层 + UserRepo repository.UserRepository + ProfileRepo repository.ProfileRepository + TextureRepo repository.TextureRepository + TokenRepo repository.TokenRepository + ConfigRepo repository.SystemConfigRepository +} + +// NewContainer 创建依赖容器 +func NewContainer( + db *gorm.DB, + redisClient *redis.Client, + logger *zap.Logger, + jwtService *auth.JWTService, + storageClient *storage.StorageClient, +) *Container { + c := &Container{ + DB: db, + Redis: redisClient, + Logger: logger, + JWT: jwtService, + Storage: storageClient, + } + + // 初始化Repository + c.UserRepo = repository.NewUserRepository(db) + c.ProfileRepo = repository.NewProfileRepository(db) + c.TextureRepo = repository.NewTextureRepository(db) + c.TokenRepo = repository.NewTokenRepository(db) + c.ConfigRepo = repository.NewSystemConfigRepository(db) + + return c +} + +// NewTestContainer 创建测试用容器(可注入mock依赖) +func NewTestContainer(opts ...Option) *Container { + c := &Container{} + for _, opt := range opts { + opt(c) + } + return c +} + +// Option 容器配置选项 +type Option func(*Container) + +// WithDB 设置数据库连接 +func WithDB(db *gorm.DB) Option { + return func(c *Container) { + c.DB = db + } +} + +// WithRedis 设置Redis客户端 +func WithRedis(redis *redis.Client) Option { + return func(c *Container) { + c.Redis = redis + } +} + +// WithLogger 设置日志 +func WithLogger(logger *zap.Logger) Option { + return func(c *Container) { + c.Logger = logger + } +} + +// WithJWT 设置JWT服务 +func WithJWT(jwt *auth.JWTService) Option { + return func(c *Container) { + c.JWT = jwt + } +} + +// WithStorage 设置存储客户端 +func WithStorage(storage *storage.StorageClient) Option { + return func(c *Container) { + c.Storage = storage + } +} + +// WithUserRepo 设置用户仓储 +func WithUserRepo(repo repository.UserRepository) Option { + return func(c *Container) { + c.UserRepo = repo + } +} + +// WithProfileRepo 设置档案仓储 +func WithProfileRepo(repo repository.ProfileRepository) Option { + return func(c *Container) { + c.ProfileRepo = repo + } +} + +// WithTextureRepo 设置材质仓储 +func WithTextureRepo(repo repository.TextureRepository) Option { + return func(c *Container) { + c.TextureRepo = repo + } +} + +// WithTokenRepo 设置令牌仓储 +func WithTokenRepo(repo repository.TokenRepository) Option { + return func(c *Container) { + c.TokenRepo = repo + } +} + +// WithConfigRepo 设置系统配置仓储 +func WithConfigRepo(repo repository.SystemConfigRepository) Option { + return func(c *Container) { + c.ConfigRepo = repo + } +} + diff --git a/internal/handler/auth_handler_di.go b/internal/handler/auth_handler_di.go new file mode 100644 index 0000000..9087008 --- /dev/null +++ b/internal/handler/auth_handler_di.go @@ -0,0 +1,177 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "carrotskin/internal/types" + "carrotskin/pkg/email" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// AuthHandler 认证处理器(依赖注入版本) +type AuthHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewAuthHandler 创建AuthHandler实例 +func NewAuthHandler(c *container.Container) *AuthHandler { + return &AuthHandler{ + container: c, + logger: c.Logger, + } +} + +// Register 用户注册 +// @Summary 用户注册 +// @Description 注册新用户账号 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.RegisterRequest true "注册信息" +// @Success 200 {object} model.Response "注册成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/register [post] +func (h *AuthHandler) Register(c *gin.Context) { + var req types.RegisterRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + // 验证邮箱验证码 + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + // 注册用户 + user, token, err := service.RegisterUser(h.container.JWT, req.Username, req.Password, req.Email, req.Avatar) + if err != nil { + h.logger.Error("用户注册失败", zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.LoginResponse{ + Token: token, + UserInfo: UserToUserInfo(user), + }) +} + +// Login 用户登录 +// @Summary 用户登录 +// @Description 用户登录获取JWT Token,支持用户名或邮箱登录 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.LoginRequest true "登录信息(username字段支持用户名或邮箱)" +// @Success 200 {object} model.Response{data=types.LoginResponse} "登录成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Failure 401 {object} model.ErrorResponse "登录失败" +// @Router /api/v1/auth/login [post] +func (h *AuthHandler) Login(c *gin.Context) { + var req types.LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + ipAddress := c.ClientIP() + userAgent := c.GetHeader("User-Agent") + + user, token, err := service.LoginUserWithRateLimit(h.container.Redis, h.container.JWT, req.Username, req.Password, ipAddress, userAgent) + if err != nil { + h.logger.Warn("用户登录失败", + zap.String("username_or_email", req.Username), + zap.String("ip", ipAddress), + zap.Error(err), + ) + RespondUnauthorized(c, err.Error()) + return + } + + RespondSuccess(c, &types.LoginResponse{ + Token: token, + UserInfo: UserToUserInfo(user), + }) +} + +// SendVerificationCode 发送验证码 +// @Summary 发送验证码 +// @Description 发送邮箱验证码(注册/重置密码/更换邮箱) +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.SendVerificationCodeRequest true "发送验证码请求" +// @Success 200 {object} model.Response "发送成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/send-code [post] +func (h *AuthHandler) SendVerificationCode(c *gin.Context) { + var req types.SendVerificationCodeRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + emailService, err := h.getEmailService() + if err != nil { + RespondServerError(c, "邮件服务不可用", err) + return + } + + if err := service.SendVerificationCode(c.Request.Context(), h.container.Redis, emailService, req.Email, req.Type); err != nil { + h.logger.Error("发送验证码失败", + zap.String("email", req.Email), + zap.String("type", req.Type), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, gin.H{"message": "验证码已发送,请查收邮件"}) +} + +// ResetPassword 重置密码 +// @Summary 重置密码 +// @Description 通过邮箱验证码重置密码 +// @Tags auth +// @Accept json +// @Produce json +// @Param request body types.ResetPasswordRequest true "重置密码请求" +// @Success 200 {object} model.Response "重置成功" +// @Failure 400 {object} model.ErrorResponse "请求参数错误" +// @Router /api/v1/auth/reset-password [post] +func (h *AuthHandler) ResetPassword(c *gin.Context) { + var req types.ResetPasswordRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + // 验证验证码 + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.Email, req.VerificationCode, service.VerificationTypeResetPassword); err != nil { + h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + // 重置密码 + if err := service.ResetUserPassword(req.Email, req.NewPassword); err != nil { + h.logger.Error("重置密码失败", zap.String("email", req.Email), zap.Error(err)) + RespondServerError(c, err.Error(), nil) + return + } + + RespondSuccess(c, gin.H{"message": "密码重置成功"}) +} + +// getEmailService 获取邮件服务(暂时使用全局方式,后续可改为依赖注入) +func (h *AuthHandler) getEmailService() (*email.Service, error) { + return email.GetService() +} + diff --git a/internal/handler/helpers.go b/internal/handler/helpers.go index 3e4489a..390b162 100644 --- a/internal/handler/helpers.go +++ b/internal/handler/helpers.go @@ -4,14 +4,24 @@ import ( "carrotskin/internal/model" "carrotskin/internal/types" "net/http" + "strconv" "github.com/gin-gonic/gin" ) +// parseIntWithDefault 将字符串解析为整数,解析失败返回默认值 +func parseIntWithDefault(s string, defaultVal int) int { + val, err := strconv.Atoi(s) + if err != nil { + return defaultVal + } + return val +} + // GetUserIDFromContext 从上下文获取用户ID,如果不存在返回未授权响应 // 返回值: userID, ok (如果ok为false,已经发送了错误响应) func GetUserIDFromContext(c *gin.Context) (int64, bool) { - userID, exists := c.Get("user_id") + userIDValue, exists := c.Get("user_id") if !exists { c.JSON(http.StatusUnauthorized, model.NewErrorResponse( model.CodeUnauthorized, @@ -20,7 +30,19 @@ func GetUserIDFromContext(c *gin.Context) (int64, bool) { )) return 0, false } - return userID.(int64), true + + // 安全的类型断言 + userID, ok := userIDValue.(int64) + if !ok { + c.JSON(http.StatusInternalServerError, model.NewErrorResponse( + model.CodeServerError, + "用户ID类型错误", + nil, + )) + return 0, false + } + + return userID, true } // UserToUserInfo 将 User 模型转换为 UserInfo 响应 @@ -157,4 +179,3 @@ func RespondWithError(c *gin.Context, err error) { RespondServerError(c, msg, nil) } } - diff --git a/internal/handler/routes_di.go b/internal/handler/routes_di.go new file mode 100644 index 0000000..d022cf6 --- /dev/null +++ b/internal/handler/routes_di.go @@ -0,0 +1,191 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/middleware" + "carrotskin/internal/model" + + "github.com/gin-gonic/gin" +) + +// Handlers 集中管理所有Handler +type Handlers struct { + Auth *AuthHandler + User *UserHandler + Texture *TextureHandler + // Profile *ProfileHandler // 后续添加 + // Captcha *CaptchaHandler // 后续添加 + // Yggdrasil *YggdrasilHandler // 后续添加 +} + +// NewHandlers 创建所有Handler实例 +func NewHandlers(c *container.Container) *Handlers { + return &Handlers{ + Auth: NewAuthHandler(c), + User: NewUserHandler(c), + Texture: NewTextureHandler(c), + } +} + +// RegisterRoutesWithDI 使用依赖注入注册所有路由 +func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { + // 设置Swagger文档 + SetupSwagger(router) + + // 创建Handler实例 + h := NewHandlers(c) + + // API路由组 + v1 := router.Group("/api/v1") + { + // 认证路由(无需JWT) + registerAuthRoutes(v1, h.Auth) + + // 用户路由(需要JWT认证) + registerUserRoutes(v1, h.User) + + // 材质路由 + registerTextureRoutes(v1, h.Texture) + + // 档案路由(暂时保持原有方式) + registerProfileRoutes(v1) + + // 验证码路由(暂时保持原有方式) + registerCaptchaRoutes(v1) + + // Yggdrasil API路由组(暂时保持原有方式) + registerYggdrasilRoutes(v1) + + // 系统路由 + registerSystemRoutes(v1) + } +} + +// registerAuthRoutes 注册认证路由 +func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { + authGroup := v1.Group("/auth") + { + authGroup.POST("/register", h.Register) + authGroup.POST("/login", h.Login) + authGroup.POST("/send-code", h.SendVerificationCode) + authGroup.POST("/reset-password", h.ResetPassword) + } +} + +// registerUserRoutes 注册用户路由 +func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { + userGroup := v1.Group("/user") + userGroup.Use(middleware.AuthMiddleware()) + { + userGroup.GET("/profile", h.GetProfile) + userGroup.PUT("/profile", h.UpdateProfile) + + // 头像相关 + userGroup.POST("/avatar/upload-url", h.GenerateAvatarUploadURL) + userGroup.PUT("/avatar", h.UpdateAvatar) + + // 更换邮箱 + userGroup.POST("/change-email", h.ChangeEmail) + + // Yggdrasil密码相关 + userGroup.POST("/yggdrasil-password/reset", h.ResetYggdrasilPassword) + } +} + +// registerTextureRoutes 注册材质路由 +func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { + textureGroup := v1.Group("/texture") + { + // 公开路由(无需认证) + textureGroup.GET("", h.Search) + textureGroup.GET("/:id", h.Get) + + // 需要认证的路由 + textureAuth := textureGroup.Group("") + textureAuth.Use(middleware.AuthMiddleware()) + { + textureAuth.POST("/upload-url", h.GenerateUploadURL) + textureAuth.POST("", h.Create) + textureAuth.PUT("/:id", h.Update) + textureAuth.DELETE("/:id", h.Delete) + textureAuth.POST("/:id/favorite", h.ToggleFavorite) + textureAuth.GET("/my", h.GetUserTextures) + textureAuth.GET("/favorites", h.GetUserFavorites) + } + } +} + +// registerProfileRoutes 注册档案路由(保持原有方式,后续改造) +func registerProfileRoutes(v1 *gin.RouterGroup) { + profileGroup := v1.Group("/profile") + { + // 公开路由(无需认证) + profileGroup.GET("/:uuid", GetProfile) + + // 需要认证的路由 + profileAuth := profileGroup.Group("") + profileAuth.Use(middleware.AuthMiddleware()) + { + profileAuth.POST("/", CreateProfile) + profileAuth.GET("/", GetProfiles) + profileAuth.PUT("/:uuid", UpdateProfile) + profileAuth.DELETE("/:uuid", DeleteProfile) + profileAuth.POST("/:uuid/activate", SetActiveProfile) + } + } +} + +// registerCaptchaRoutes 注册验证码路由(保持原有方式) +func registerCaptchaRoutes(v1 *gin.RouterGroup) { + captchaGroup := v1.Group("/captcha") + { + captchaGroup.GET("/generate", Generate) + captchaGroup.POST("/verify", Verify) + } +} + +// registerYggdrasilRoutes 注册Yggdrasil API路由(保持原有方式) +func registerYggdrasilRoutes(v1 *gin.RouterGroup) { + ygg := v1.Group("/yggdrasil") + { + ygg.GET("", GetMetaData) + ygg.POST("/minecraftservices/player/certificates", GetPlayerCertificates) + authserver := ygg.Group("/authserver") + { + authserver.POST("/authenticate", Authenticate) + authserver.POST("/validate", ValidToken) + authserver.POST("/refresh", RefreshToken) + authserver.POST("/invalidate", InvalidToken) + authserver.POST("/signout", SignOut) + } + sessionServer := ygg.Group("/sessionserver") + { + sessionServer.GET("/session/minecraft/profile/:uuid", GetProfileByUUID) + sessionServer.POST("/session/minecraft/join", JoinServer) + sessionServer.GET("/session/minecraft/hasJoined", HasJoinedServer) + } + api := ygg.Group("/api") + profiles := api.Group("/profiles") + { + profiles.POST("/minecraft", GetProfilesByName) + } + } +} + +// registerSystemRoutes 注册系统路由 +func registerSystemRoutes(v1 *gin.RouterGroup) { + system := v1.Group("/system") + { + system.GET("/config", func(c *gin.Context) { + // TODO: 实现从数据库读取系统配置 + c.JSON(200, model.NewSuccessResponse(gin.H{ + "site_name": "CarrotSkin", + "site_description": "A Minecraft Skin Station", + "registration_enabled": true, + "max_textures_per_user": 100, + "max_profiles_per_user": 5, + })) + }) + } +} + diff --git a/internal/handler/texture_handler.go b/internal/handler/texture_handler.go index c7a5184..a139f38 100644 --- a/internal/handler/texture_handler.go +++ b/internal/handler/texture_handler.go @@ -160,8 +160,8 @@ func SearchTextures(c *gin.Context) { textureTypeStr := c.Query("type") publicOnly := c.Query("public_only") == "true" - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) var textureType model.TextureType switch textureTypeStr { @@ -314,8 +314,8 @@ func GetUserTextures(c *gin.Context) { return } - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) textures, total, err := service.GetUserTextures(database.MustGetDB(), userID, page, pageSize) if err != nil { @@ -344,8 +344,8 @@ func GetUserFavorites(c *gin.Context) { return } - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) textures, total, err := service.GetUserTextureFavorites(database.MustGetDB(), userID, page, pageSize) if err != nil { diff --git a/internal/handler/texture_handler_di.go b/internal/handler/texture_handler_di.go new file mode 100644 index 0000000..8233184 --- /dev/null +++ b/internal/handler/texture_handler_di.go @@ -0,0 +1,284 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/model" + "carrotskin/internal/service" + "carrotskin/internal/types" + "strconv" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// TextureHandler 材质处理器(依赖注入版本) +type TextureHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewTextureHandler 创建TextureHandler实例 +func NewTextureHandler(c *container.Container) *TextureHandler { + return &TextureHandler{ + container: c, + logger: c.Logger, + } +} + +// GenerateUploadURL 生成材质上传URL +func (h *TextureHandler) GenerateUploadURL(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.GenerateTextureUploadURLRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateTextureUploadURL( + c.Request.Context(), + h.container.Storage, + userID, + req.FileName, + string(req.TextureType), + ) + if err != nil { + h.logger.Error("生成材质上传URL失败", + zap.Int64("user_id", userID), + zap.String("file_name", req.FileName), + zap.String("texture_type", string(req.TextureType)), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.GenerateTextureUploadURLResponse{ + PostURL: result.PostURL, + FormData: result.FormData, + TextureURL: result.FileURL, + ExpiresIn: 900, + }) +} + +// Create 创建材质记录 +func (h *TextureHandler) Create(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.CreateTextureRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + maxTextures := service.GetMaxTexturesPerUser() + if err := service.CheckTextureUploadLimit(h.container.DB, userID, maxTextures); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + + texture, err := service.CreateTexture(h.container.DB, + userID, + req.Name, + req.Description, + string(req.Type), + req.URL, + req.Hash, + req.Size, + req.IsPublic, + req.IsSlim, + ) + if err != nil { + h.logger.Error("创建材质失败", + zap.Int64("user_id", userID), + zap.String("name", req.Name), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Get 获取材质详情 +func (h *TextureHandler) Get(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + texture, err := service.GetTextureByID(h.container.DB, id) + if err != nil { + RespondNotFound(c, err.Error()) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Search 搜索材质 +func (h *TextureHandler) Search(c *gin.Context) { + keyword := c.Query("keyword") + textureTypeStr := c.Query("type") + publicOnly := c.Query("public_only") == "true" + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + var textureType model.TextureType + switch textureTypeStr { + case "SKIN": + textureType = model.TextureTypeSkin + case "CAPE": + textureType = model.TextureTypeCape + } + + textures, total, err := service.SearchTextures(h.container.DB, keyword, textureType, publicOnly, page, pageSize) + if err != nil { + h.logger.Error("搜索材质失败", zap.String("keyword", keyword), zap.Error(err)) + RespondServerError(c, "搜索材质失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + +// Update 更新材质 +func (h *TextureHandler) Update(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + var req types.UpdateTextureRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + texture, err := service.UpdateTexture(h.container.DB, textureID, userID, req.Name, req.Description, req.IsPublic) + if err != nil { + h.logger.Error("更新材质失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondForbidden(c, err.Error()) + return + } + + RespondSuccess(c, TextureToTextureInfo(texture)) +} + +// Delete 删除材质 +func (h *TextureHandler) Delete(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + if err := service.DeleteTexture(h.container.DB, textureID, userID); err != nil { + h.logger.Error("删除材质失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondForbidden(c, err.Error()) + return + } + + RespondSuccess(c, nil) +} + +// ToggleFavorite 切换收藏状态 +func (h *TextureHandler) ToggleFavorite(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + textureID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + RespondBadRequest(c, "无效的材质ID", err) + return + } + + isFavorited, err := service.ToggleTextureFavorite(h.container.DB, userID, textureID) + if err != nil { + h.logger.Error("切换收藏状态失败", + zap.Int64("user_id", userID), + zap.Int64("texture_id", textureID), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, map[string]bool{"is_favorited": isFavorited}) +} + +// GetUserTextures 获取用户上传的材质列表 +func (h *TextureHandler) GetUserTextures(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + textures, total, err := service.GetUserTextures(h.container.DB, userID, page, pageSize) + if err != nil { + h.logger.Error("获取用户材质列表失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondServerError(c, "获取材质列表失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + +// GetUserFavorites 获取用户收藏的材质列表 +func (h *TextureHandler) GetUserFavorites(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + page := parseIntWithDefault(c.DefaultQuery("page", "1"), 1) + pageSize := parseIntWithDefault(c.DefaultQuery("page_size", "20"), 20) + + textures, total, err := service.GetUserTextureFavorites(h.container.DB, userID, page, pageSize) + if err != nil { + h.logger.Error("获取用户收藏列表失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondServerError(c, "获取收藏列表失败", err) + return + } + + c.JSON(200, model.NewPaginationResponse(TexturesToTextureInfos(textures), total, page, pageSize)) +} + diff --git a/internal/handler/user_handler_di.go b/internal/handler/user_handler_di.go new file mode 100644 index 0000000..91e8a5a --- /dev/null +++ b/internal/handler/user_handler_di.go @@ -0,0 +1,233 @@ +package handler + +import ( + "carrotskin/internal/container" + "carrotskin/internal/service" + "carrotskin/internal/types" + + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +// UserHandler 用户处理器(依赖注入版本) +type UserHandler struct { + container *container.Container + logger *zap.Logger +} + +// NewUserHandler 创建UserHandler实例 +func NewUserHandler(c *container.Container) *UserHandler { + return &UserHandler{ + container: c, + logger: c.Logger, + } +} + +// GetProfile 获取用户信息 +func (h *UserHandler) GetProfile(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + h.logger.Error("获取用户信息失败", + zap.Int64("user_id", userID), + zap.Error(err), + ) + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// UpdateProfile 更新用户信息 +func (h *UserHandler) UpdateProfile(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.UpdateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + // 处理密码修改 + if req.NewPassword != "" { + if req.OldPassword == "" { + RespondBadRequest(c, "修改密码需要提供原密码", nil) + return + } + + if err := service.ChangeUserPassword(userID, req.OldPassword, req.NewPassword); err != nil { + h.logger.Error("修改密码失败", zap.Int64("user_id", userID), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + h.logger.Info("用户修改密码成功", zap.Int64("user_id", userID)) + } + + // 更新头像 + if req.Avatar != "" { + if err := service.ValidateAvatarURL(req.Avatar); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + user.Avatar = req.Avatar + if err := service.UpdateUserInfo(user); err != nil { + h.logger.Error("更新用户信息失败", zap.Int64("user_id", user.ID), zap.Error(err)) + RespondServerError(c, "更新失败", err) + return + } + } + + // 重新获取更新后的用户信息 + updatedUser, err := service.GetUserByID(userID) + if err != nil || updatedUser == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(updatedUser)) +} + +// GenerateAvatarUploadURL 生成头像上传URL +func (h *UserHandler) GenerateAvatarUploadURL(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.GenerateAvatarUploadURLRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if h.container.Storage == nil { + RespondServerError(c, "存储服务不可用", nil) + return + } + + result, err := service.GenerateAvatarUploadURL(c.Request.Context(), h.container.Storage, userID, req.FileName) + if err != nil { + h.logger.Error("生成头像上传URL失败", + zap.Int64("user_id", userID), + zap.String("file_name", req.FileName), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + RespondSuccess(c, &types.GenerateAvatarUploadURLResponse{ + PostURL: result.PostURL, + FormData: result.FormData, + AvatarURL: result.FileURL, + ExpiresIn: 900, + }) +} + +// UpdateAvatar 更新头像URL +func (h *UserHandler) UpdateAvatar(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + avatarURL := c.Query("avatar_url") + if avatarURL == "" { + RespondBadRequest(c, "头像URL不能为空", nil) + return + } + + if err := service.ValidateAvatarURL(avatarURL); err != nil { + RespondBadRequest(c, err.Error(), nil) + return + } + + if err := service.UpdateUserAvatar(userID, avatarURL); err != nil { + h.logger.Error("更新头像失败", + zap.Int64("user_id", userID), + zap.String("avatar_url", avatarURL), + zap.Error(err), + ) + RespondServerError(c, "更新头像失败", err) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// ChangeEmail 更换邮箱 +func (h *UserHandler) ChangeEmail(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + var req types.ChangeEmailRequest + if err := c.ShouldBindJSON(&req); err != nil { + RespondBadRequest(c, "请求参数错误", err) + return + } + + if err := service.VerifyCode(c.Request.Context(), h.container.Redis, req.NewEmail, req.VerificationCode, service.VerificationTypeChangeEmail); err != nil { + h.logger.Warn("验证码验证失败", zap.String("new_email", req.NewEmail), zap.Error(err)) + RespondBadRequest(c, err.Error(), nil) + return + } + + if err := service.ChangeUserEmail(userID, req.NewEmail); err != nil { + h.logger.Error("更换邮箱失败", + zap.Int64("user_id", userID), + zap.String("new_email", req.NewEmail), + zap.Error(err), + ) + RespondBadRequest(c, err.Error(), nil) + return + } + + user, err := service.GetUserByID(userID) + if err != nil || user == nil { + RespondNotFound(c, "用户不存在") + return + } + + RespondSuccess(c, UserToUserInfo(user)) +} + +// ResetYggdrasilPassword 重置Yggdrasil密码 +func (h *UserHandler) ResetYggdrasilPassword(c *gin.Context) { + userID, ok := GetUserIDFromContext(c) + if !ok { + return + } + + newPassword, err := service.ResetYggdrasilPassword(h.container.DB, userID) + if err != nil { + h.logger.Error("重置Yggdrasil密码失败", zap.Error(err), zap.Int64("userId", userID)) + RespondServerError(c, "重置Yggdrasil密码失败", nil) + return + } + + h.logger.Info("Yggdrasil密码重置成功", zap.Int64("userId", userID)) + RespondSuccess(c, gin.H{"password": newPassword}) +} diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index aaf1847..a806368 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -1,16 +1,48 @@ package middleware import ( + "carrotskin/pkg/config" + "github.com/gin-gonic/gin" ) // CORS 跨域中间件 func CORS() gin.HandlerFunc { + // 获取配置,如果配置未初始化则使用默认值 + var allowedOrigins []string + if cfg, err := config.GetConfig(); err == nil { + allowedOrigins = cfg.Security.AllowedOrigins + } else { + // 默认允许所有来源(向后兼容) + allowedOrigins = []string{"*"} + } + return gin.HandlerFunc(func(c *gin.Context) { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Credentials", "true") + origin := c.GetHeader("Origin") + + // 检查是否允许该来源 + allowOrigin := "*" + if len(allowedOrigins) > 0 && allowedOrigins[0] != "*" { + allowOrigin = "" + for _, allowed := range allowedOrigins { + if allowed == origin || allowed == "*" { + allowOrigin = origin + break + } + } + } + + if allowOrigin != "" { + c.Header("Access-Control-Allow-Origin", allowOrigin) + // 只有在非通配符模式下才允许credentials + if allowOrigin != "*" { + c.Header("Access-Control-Allow-Credentials", "true") + } + } + c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") + c.Header("Access-Control-Max-Age", "86400") // 缓存预检请求结果24小时 if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index a07f6c7..833e76d 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -24,10 +24,11 @@ func TestCORS_Headers(t *testing.T) { router.ServeHTTP(w, req) // 验证CORS响应头 + // 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范, + // 不应该设置 Access-Control-Allow-Credentials 为 "true" expectedHeaders := map[string]string{ - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Credentials": "true", - "Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE", } for header, expectedValue := range expectedHeaders { @@ -37,6 +38,11 @@ func TestCORS_Headers(t *testing.T) { } } + // 验证在通配符模式下不设置Credentials(这是正确的安全行为) + if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { + t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials) + } + // 验证Access-Control-Allow-Headers包含必要字段 allowHeaders := w.Header().Get("Access-Control-Allow-Headers") if allowHeaders == "" { @@ -117,6 +123,30 @@ func TestCORS_AllowHeaders(t *testing.T) { } } +// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为 +func TestCORS_WithSpecificOrigin(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 注意:此测试验证的是在配置了具体allowed origins时的行为 + // 在没有配置初始化的情况下,默认使用通配符模式 + router := gin.New() + router.Use(CORS()) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "http://example.com") + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + // 默认配置下使用通配符,所以不应该设置credentials + if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { + t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials) + } +} + // 辅助函数:检查字符串是否包含子字符串(简单实现) func contains(s, substr string) bool { if len(substr) == 0 { diff --git a/internal/middleware/recovery.go b/internal/middleware/recovery.go index 8277182..3293f42 100644 --- a/internal/middleware/recovery.go +++ b/internal/middleware/recovery.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "runtime/debug" @@ -11,16 +12,26 @@ import ( // Recovery 恢复中间件 func Recovery(logger *zap.Logger) gin.HandlerFunc { return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { - if err, ok := recovered.(string); ok { - logger.Error("服务器恐慌", - zap.String("error", err), - zap.String("path", c.Request.URL.Path), - zap.String("method", c.Request.Method), - zap.String("ip", c.ClientIP()), - zap.String("stack", string(debug.Stack())), - ) + // 将任意类型的panic转换为字符串 + var errMsg string + switch v := recovered.(type) { + case string: + errMsg = v + case error: + errMsg = v.Error() + default: + errMsg = fmt.Sprintf("%v", v) } + logger.Error("服务器恐慌", + zap.String("error", errMsg), + zap.String("path", c.Request.URL.Path), + zap.String("method", c.Request.Method), + zap.String("ip", c.ClientIP()), + zap.String("user_agent", c.GetHeader("User-Agent")), + zap.String("stack", string(debug.Stack())), + ) + c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, "message": "服务器内部错误", diff --git a/internal/model/response.go b/internal/model/response.go index e76dac0..26865ed 100644 --- a/internal/model/response.go +++ b/internal/model/response.go @@ -1,10 +1,12 @@ package model +import "os" + // Response 通用API响应结构 type Response struct { - Code int `json:"code"` // 业务状态码 - Message string `json:"message"` // 响应消息 - Data interface{} `json:"data,omitempty"` // 响应数据 + Code int `json:"code"` // 业务状态码 + Message string `json:"message"` // 响应消息 + Data interface{} `json:"data,omitempty"` // 响应数据 } // PaginationResponse 分页响应结构 @@ -12,9 +14,9 @@ type PaginationResponse struct { Code int `json:"code"` Message string `json:"message"` Data interface{} `json:"data"` - Total int64 `json:"total"` // 总记录数 - Page int `json:"page"` // 当前页码 - PerPage int `json:"per_page"` // 每页数量 + Total int64 `json:"total"` // 总记录数 + Page int `json:"page"` // 当前页码 + PerPage int `json:"per_page"` // 每页数量 } // ErrorResponse 错误响应 @@ -26,14 +28,14 @@ type ErrorResponse struct { // 常用状态码 const ( - CodeSuccess = 200 // 成功 - CodeCreated = 201 // 创建成功 - CodeBadRequest = 400 // 请求参数错误 - CodeUnauthorized = 401 // 未授权 - CodeForbidden = 403 // 禁止访问 - CodeNotFound = 404 // 资源不存在 - CodeConflict = 409 // 资源冲突 - CodeServerError = 500 // 服务器错误 + CodeSuccess = 200 // 成功 + CodeCreated = 201 // 创建成功 + CodeBadRequest = 400 // 请求参数错误 + CodeUnauthorized = 401 // 未授权 + CodeForbidden = 403 // 禁止访问 + CodeNotFound = 404 // 资源不存在 + CodeConflict = 409 // 资源冲突 + CodeServerError = 500 // 服务器错误 ) // 常用响应消息 @@ -61,17 +63,26 @@ func NewSuccessResponse(data interface{}) *Response { } // NewErrorResponse 创建错误响应 +// 注意:err参数仅在开发环境下显示,生产环境不应暴露详细错误信息 func NewErrorResponse(code int, message string, err error) *ErrorResponse { resp := &ErrorResponse{ Code: code, Message: message, } - if err != nil { + // 仅在非生产环境下返回详细错误信息 + // 可以通过环境变量 ENVIRONMENT 控制 + if err != nil && !isProductionEnvironment() { resp.Error = err.Error() } return resp } +// isProductionEnvironment 检查是否为生产环境 +func isProductionEnvironment() bool { + env := os.Getenv("ENVIRONMENT") + return env == "production" || env == "prod" +} + // NewPaginationResponse 创建分页响应 func NewPaginationResponse(data interface{}, total int64, page, perPage int) *PaginationResponse { return &PaginationResponse{ diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go new file mode 100644 index 0000000..f72ca88 --- /dev/null +++ b/internal/repository/interfaces.go @@ -0,0 +1,85 @@ +package repository + +import ( + "carrotskin/internal/model" +) + +// UserRepository 用户仓储接口 +type UserRepository interface { + Create(user *model.User) error + FindByID(id int64) (*model.User, error) + FindByUsername(username string) (*model.User, error) + FindByEmail(email string) (*model.User, error) + Update(user *model.User) error + UpdateFields(id int64, fields map[string]interface{}) error + Delete(id int64) error + CreateLoginLog(log *model.UserLoginLog) error + CreatePointLog(log *model.UserPointLog) error + UpdatePoints(userID int64, amount int, changeType, reason string) error +} + +// ProfileRepository 档案仓储接口 +type ProfileRepository interface { + Create(profile *model.Profile) error + FindByUUID(uuid string) (*model.Profile, error) + FindByName(name string) (*model.Profile, error) + FindByUserID(userID int64) ([]*model.Profile, error) + Update(profile *model.Profile) error + UpdateFields(uuid string, updates map[string]interface{}) error + Delete(uuid string) error + CountByUserID(userID int64) (int64, error) + SetActive(uuid string, userID int64) error + UpdateLastUsedAt(uuid string) error + GetByNames(names []string) ([]*model.Profile, error) + GetKeyPair(profileId string) (*model.KeyPair, error) + UpdateKeyPair(profileId string, keyPair *model.KeyPair) error +} + +// TextureRepository 材质仓储接口 +type TextureRepository interface { + Create(texture *model.Texture) error + FindByID(id int64) (*model.Texture, error) + FindByHash(hash string) (*model.Texture, error) + FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) + Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) + Update(texture *model.Texture) error + UpdateFields(id int64, fields map[string]interface{}) error + Delete(id int64) error + IncrementDownloadCount(id int64) error + IncrementFavoriteCount(id int64) error + DecrementFavoriteCount(id int64) error + CreateDownloadLog(log *model.TextureDownloadLog) error + IsFavorited(userID, textureID int64) (bool, error) + AddFavorite(userID, textureID int64) error + RemoveFavorite(userID, textureID int64) error + GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) + CountByUploaderID(uploaderID int64) (int64, error) +} + +// TokenRepository 令牌仓储接口 +type TokenRepository interface { + Create(token *model.Token) error + FindByAccessToken(accessToken string) (*model.Token, error) + GetByUserID(userId int64) ([]*model.Token, error) + GetUUIDByAccessToken(accessToken string) (string, error) + GetUserIDByAccessToken(accessToken string) (int64, error) + DeleteByAccessToken(accessToken string) error + DeleteByUserID(userId int64) error + BatchDelete(accessTokens []string) (int64, error) +} + +// SystemConfigRepository 系统配置仓储接口 +type SystemConfigRepository interface { + GetByKey(key string) (*model.SystemConfig, error) + GetPublic() ([]model.SystemConfig, error) + GetAll() ([]model.SystemConfig, error) + Update(config *model.SystemConfig) error + UpdateValue(key, value string) error +} + +// YggdrasilRepository Yggdrasil仓储接口 +type YggdrasilRepository interface { + GetPasswordByID(id int64) (string, error) + ResetPassword(id int64, password string) error +} + diff --git a/internal/repository/profile_repository_impl.go b/internal/repository/profile_repository_impl.go new file mode 100644 index 0000000..ebe3fdb --- /dev/null +++ b/internal/repository/profile_repository_impl.go @@ -0,0 +1,149 @@ +package repository + +import ( + "carrotskin/internal/model" + "context" + "errors" + "fmt" + + "gorm.io/gorm" +) + +// profileRepositoryImpl ProfileRepository的实现 +type profileRepositoryImpl struct { + db *gorm.DB +} + +// NewProfileRepository 创建ProfileRepository实例 +func NewProfileRepository(db *gorm.DB) ProfileRepository { + return &profileRepositoryImpl{db: db} +} + +func (r *profileRepositoryImpl) Create(profile *model.Profile) error { + return r.db.Create(profile).Error +} + +func (r *profileRepositoryImpl) FindByUUID(uuid string) (*model.Profile, error) { + var profile model.Profile + err := r.db.Where("uuid = ?", uuid). + Preload("Skin"). + Preload("Cape"). + First(&profile).Error + if err != nil { + return nil, err + } + return &profile, nil +} + +func (r *profileRepositoryImpl) FindByName(name string) (*model.Profile, error) { + var profile model.Profile + err := r.db.Where("name = ?", name).First(&profile).Error + if err != nil { + return nil, err + } + return &profile, nil +} + +func (r *profileRepositoryImpl) FindByUserID(userID int64) ([]*model.Profile, error) { + var profiles []*model.Profile + err := r.db.Where("user_id = ?", userID). + Preload("Skin"). + Preload("Cape"). + Order("created_at DESC"). + Find(&profiles).Error + return profiles, err +} + +func (r *profileRepositoryImpl) Update(profile *model.Profile) error { + return r.db.Save(profile).Error +} + +func (r *profileRepositoryImpl) UpdateFields(uuid string, updates map[string]interface{}) error { + return r.db.Model(&model.Profile{}). + Where("uuid = ?", uuid). + Updates(updates).Error +} + +func (r *profileRepositoryImpl) Delete(uuid string) error { + return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error +} + +func (r *profileRepositoryImpl) CountByUserID(userID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.Profile{}). + Where("user_id = ?", userID). + Count(&count).Error + return count, err +} + +func (r *profileRepositoryImpl) SetActive(uuid string, userID int64) error { + return r.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&model.Profile{}). + Where("user_id = ?", userID). + Update("is_active", false).Error; err != nil { + return err + } + + return tx.Model(&model.Profile{}). + Where("uuid = ? AND user_id = ?", uuid, userID). + Update("is_active", true).Error + }) +} + +func (r *profileRepositoryImpl) UpdateLastUsedAt(uuid string) error { + return r.db.Model(&model.Profile{}). + Where("uuid = ?", uuid). + Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error +} + +func (r *profileRepositoryImpl) GetByNames(names []string) ([]*model.Profile, error) { + var profiles []*model.Profile + err := r.db.Where("name in (?)", names).Find(&profiles).Error + return profiles, err +} + +func (r *profileRepositoryImpl) GetKeyPair(profileId string) (*model.KeyPair, error) { + if profileId == "" { + return nil, errors.New("参数不能为空") + } + + var profile model.Profile + result := r.db.WithContext(context.Background()). + Select("key_pair"). + Where("id = ?", profileId). + First(&profile) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, errors.New("key pair未找到") + } + return nil, fmt.Errorf("获取key pair失败: %w", result.Error) + } + + return &model.KeyPair{}, nil +} + +func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { + if profileId == "" { + return errors.New("profileId 不能为空") + } + if keyPair == nil { + return errors.New("keyPair 不能为 nil") + } + + return r.db.Transaction(func(tx *gorm.DB) error { + result := tx.WithContext(context.Background()). + Table("profiles"). + Where("id = ?", profileId). + UpdateColumns(map[string]interface{}{ + "private_key": keyPair.PrivateKey, + "public_key": keyPair.PublicKey, + }) + + if result.Error != nil { + return fmt.Errorf("更新 keyPair 失败: %w", result.Error) + } + return nil + }) +} + diff --git a/internal/repository/system_config_repository_impl.go b/internal/repository/system_config_repository_impl.go new file mode 100644 index 0000000..2bb5844 --- /dev/null +++ b/internal/repository/system_config_repository_impl.go @@ -0,0 +1,44 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// systemConfigRepositoryImpl SystemConfigRepository的实现 +type systemConfigRepositoryImpl struct { + db *gorm.DB +} + +// NewSystemConfigRepository 创建SystemConfigRepository实例 +func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository { + return &systemConfigRepositoryImpl{db: db} +} + +func (r *systemConfigRepositoryImpl) GetByKey(key string) (*model.SystemConfig, error) { + var config model.SystemConfig + err := r.db.Where("key = ?", key).First(&config).Error + return handleNotFoundResult(&config, err) +} + +func (r *systemConfigRepositoryImpl) GetPublic() ([]model.SystemConfig, error) { + var configs []model.SystemConfig + err := r.db.Where("is_public = ?", true).Find(&configs).Error + return configs, err +} + +func (r *systemConfigRepositoryImpl) GetAll() ([]model.SystemConfig, error) { + var configs []model.SystemConfig + err := r.db.Find(&configs).Error + return configs, err +} + +func (r *systemConfigRepositoryImpl) Update(config *model.SystemConfig) error { + return r.db.Save(config).Error +} + +func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error { + return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error +} + diff --git a/internal/repository/texture_repository_impl.go b/internal/repository/texture_repository_impl.go new file mode 100644 index 0000000..c6a2971 --- /dev/null +++ b/internal/repository/texture_repository_impl.go @@ -0,0 +1,175 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// textureRepositoryImpl TextureRepository的实现 +type textureRepositoryImpl struct { + db *gorm.DB +} + +// NewTextureRepository 创建TextureRepository实例 +func NewTextureRepository(db *gorm.DB) TextureRepository { + return &textureRepositoryImpl{db: db} +} + +func (r *textureRepositoryImpl) Create(texture *model.Texture) error { + return r.db.Create(texture).Error +} + +func (r *textureRepositoryImpl) FindByID(id int64) (*model.Texture, error) { + var texture model.Texture + err := r.db.Preload("Uploader").First(&texture, id).Error + return handleNotFoundResult(&texture, err) +} + +func (r *textureRepositoryImpl) FindByHash(hash string) (*model.Texture, error) { + var texture model.Texture + err := r.db.Where("hash = ?", hash).First(&texture).Error + return handleNotFoundResult(&texture, err) +} + +func (r *textureRepositoryImpl) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + query := r.db.Model(&model.Texture{}).Where("status = 1") + + if publicOnly { + query = query.Where("is_public = ?", true) + } + if textureType != "" { + query = query.Where("type = ?", textureType) + } + if keyword != "" { + query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%") + } + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) Update(texture *model.Texture) error { + return r.db.Save(texture).Error +} + +func (r *textureRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error +} + +func (r *textureRepositoryImpl) Delete(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error +} + +func (r *textureRepositoryImpl) IncrementDownloadCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error +} + +func (r *textureRepositoryImpl) IncrementFavoriteCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error +} + +func (r *textureRepositoryImpl) DecrementFavoriteCount(id int64) error { + return r.db.Model(&model.Texture{}).Where("id = ?", id). + UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error +} + +func (r *textureRepositoryImpl) CreateDownloadLog(log *model.TextureDownloadLog) error { + return r.db.Create(log).Error +} + +func (r *textureRepositoryImpl) IsFavorited(userID, textureID int64) (bool, error) { + var count int64 + err := r.db.Model(&model.UserTextureFavorite{}). + Where("user_id = ? AND texture_id = ?", userID, textureID). + Count(&count).Error + return count > 0, err +} + +func (r *textureRepositoryImpl) AddFavorite(userID, textureID int64) error { + favorite := &model.UserTextureFavorite{ + UserID: userID, + TextureID: textureID, + } + return r.db.Create(favorite).Error +} + +func (r *textureRepositoryImpl) RemoveFavorite(userID, textureID int64) error { + return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID). + Delete(&model.UserTextureFavorite{}).Error +} + +func (r *textureRepositoryImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { + var textures []*model.Texture + var total int64 + + subQuery := r.db.Model(&model.UserTextureFavorite{}). + Select("texture_id"). + Where("user_id = ?", userID) + + query := r.db.Model(&model.Texture{}). + Where("id IN (?) AND status = 1", subQuery) + + if err := query.Count(&total).Error; err != nil { + return nil, 0, err + } + + err := query.Scopes(Paginate(page, pageSize)). + Preload("Uploader"). + Order("created_at DESC"). + Find(&textures).Error + + if err != nil { + return nil, 0, err + } + + return textures, total, nil +} + +func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, error) { + var count int64 + err := r.db.Model(&model.Texture{}). + Where("uploader_id = ? AND status != -1", uploaderID). + Count(&count).Error + return count, err +} + diff --git a/internal/repository/token_repository_impl.go b/internal/repository/token_repository_impl.go new file mode 100644 index 0000000..e4c94e1 --- /dev/null +++ b/internal/repository/token_repository_impl.go @@ -0,0 +1,71 @@ +package repository + +import ( + "carrotskin/internal/model" + + "gorm.io/gorm" +) + +// tokenRepositoryImpl TokenRepository的实现 +type tokenRepositoryImpl struct { + db *gorm.DB +} + +// NewTokenRepository 创建TokenRepository实例 +func NewTokenRepository(db *gorm.DB) TokenRepository { + return &tokenRepositoryImpl{db: db} +} + +func (r *tokenRepositoryImpl) Create(token *model.Token) error { + return r.db.Create(token).Error +} + +func (r *tokenRepositoryImpl) FindByAccessToken(accessToken string) (*model.Token, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +func (r *tokenRepositoryImpl) GetByUserID(userId int64) ([]*model.Token, error) { + var tokens []*model.Token + err := r.db.Where("user_id = ?", userId).Find(&tokens).Error + return tokens, err +} + +func (r *tokenRepositoryImpl) GetUUIDByAccessToken(accessToken string) (string, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return "", err + } + return token.ProfileId, nil +} + +func (r *tokenRepositoryImpl) GetUserIDByAccessToken(accessToken string) (int64, error) { + var token model.Token + err := r.db.Where("access_token = ?", accessToken).First(&token).Error + if err != nil { + return 0, err + } + return token.UserID, nil +} + +func (r *tokenRepositoryImpl) DeleteByAccessToken(accessToken string) error { + return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error +} + +func (r *tokenRepositoryImpl) DeleteByUserID(userId int64) error { + return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error +} + +func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) { + if len(accessTokens) == 0 { + return 0, nil + } + result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) + return result.RowsAffected, result.Error +} + diff --git a/internal/repository/user_repository_impl.go b/internal/repository/user_repository_impl.go new file mode 100644 index 0000000..57ec4c8 --- /dev/null +++ b/internal/repository/user_repository_impl.go @@ -0,0 +1,103 @@ +package repository + +import ( + "carrotskin/internal/model" + "errors" + + "gorm.io/gorm" +) + +// userRepositoryImpl UserRepository的实现 +type userRepositoryImpl struct { + db *gorm.DB +} + +// NewUserRepository 创建UserRepository实例 +func NewUserRepository(db *gorm.DB) UserRepository { + return &userRepositoryImpl{db: db} +} + +func (r *userRepositoryImpl) Create(user *model.User) error { + return r.db.Create(user).Error +} + +func (r *userRepositoryImpl) FindByID(id int64) (*model.User, error) { + var user model.User + err := r.db.Where("id = ? AND status != -1", id).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) FindByUsername(username string) (*model.User, error) { + var user model.User + err := r.db.Where("username = ? AND status != -1", username).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) FindByEmail(email string) (*model.User, error) { + var user model.User + err := r.db.Where("email = ? AND status != -1", email).First(&user).Error + return handleNotFoundResult(&user, err) +} + +func (r *userRepositoryImpl) Update(user *model.User) error { + return r.db.Save(user).Error +} + +func (r *userRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error { + return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error +} + +func (r *userRepositoryImpl) Delete(id int64) error { + return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error +} + +func (r *userRepositoryImpl) CreateLoginLog(log *model.UserLoginLog) error { + return r.db.Create(log).Error +} + +func (r *userRepositoryImpl) CreatePointLog(log *model.UserPointLog) error { + return r.db.Create(log).Error +} + +func (r *userRepositoryImpl) UpdatePoints(userID int64, amount int, changeType, reason string) error { + return r.db.Transaction(func(tx *gorm.DB) error { + var user model.User + if err := tx.Where("id = ?", userID).First(&user).Error; err != nil { + return err + } + + balanceBefore := user.Points + balanceAfter := balanceBefore + amount + + if balanceAfter < 0 { + return errors.New("积分不足") + } + + if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil { + return err + } + + log := &model.UserPointLog{ + UserID: userID, + ChangeType: changeType, + Amount: amount, + BalanceBefore: balanceBefore, + BalanceAfter: balanceAfter, + Reason: reason, + } + + return tx.Create(log).Error + }) +} + +// handleNotFoundResult 处理记录未找到的情况 +func handleNotFoundResult[T any](result *T, err error) (*T, error) { + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, err + } + return result, nil +} + diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 4a98ca7..249a341 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -4,10 +4,12 @@ import ( "carrotskin/internal/model" "carrotskin/internal/repository" "carrotskin/pkg/auth" + "carrotskin/pkg/config" "carrotskin/pkg/redis" "context" "errors" "fmt" + "net/url" "strings" "time" ) @@ -286,24 +288,69 @@ func ValidateAvatarURL(avatarURL string) error { return nil } - // 允许的域名列表 - allowedDomains := []string{ - "rustfs.example.com", - "localhost", - "127.0.0.1", - } - - for _, domain := range allowedDomains { - if strings.Contains(avatarURL, domain) { - return nil - } - } - + // 允许相对路径 if strings.HasPrefix(avatarURL, "/") { return nil } - return errors.New("头像URL不在允许的域名列表中") + return ValidateURLDomain(avatarURL) +} + +// ValidateURLDomain 验证URL的域名是否在允许列表中 +func ValidateURLDomain(rawURL string) error { + // 解析URL + parsedURL, err := url.Parse(rawURL) + if err != nil { + return errors.New("无效的URL格式") + } + + // 必须是HTTP或HTTPS协议 + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return errors.New("URL必须使用http或https协议") + } + + // 获取主机名(不包含端口) + host := parsedURL.Hostname() + if host == "" { + return errors.New("URL缺少主机名") + } + + // 从配置获取允许的域名列表 + cfg, err := config.GetConfig() + if err != nil { + // 如果配置获取失败,使用默认的安全域名列表 + allowedDomains := []string{"localhost", "127.0.0.1"} + return checkDomainAllowed(host, allowedDomains) + } + + return checkDomainAllowed(host, cfg.Security.AllowedDomains) +} + +// checkDomainAllowed 检查域名是否在允许列表中 +func checkDomainAllowed(host string, allowedDomains []string) error { + host = strings.ToLower(host) + + for _, allowed := range allowedDomains { + allowed = strings.ToLower(strings.TrimSpace(allowed)) + if allowed == "" { + continue + } + + // 精确匹配 + if host == allowed { + return nil + } + + // 支持通配符子域名匹配 (如 *.example.com) + if strings.HasPrefix(allowed, "*.") { + suffix := allowed[1:] // 移除 "*",保留 ".example.com" + if strings.HasSuffix(host, suffix) { + return nil + } + } + } + + return errors.New("URL域名不在允许的列表中") } // GetUserByEmail 根据邮箱获取用户 diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 275ee86..b509c70 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -55,6 +55,10 @@ func (j *JWTService) GenerateToken(userID int64, username, role string) (string, // ValidateToken 验证JWT Token func (j *JWTService) ValidateToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + // 验证签名算法,防止algorithm confusion攻击 + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("不支持的签名算法") + } return []byte(j.secretKey), nil }) diff --git a/pkg/config/config.go b/pkg/config/config.go index 919e4c3..83c5188 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strconv" + "strings" "time" "github.com/joho/godotenv" @@ -22,6 +23,7 @@ type Config struct { Log LogConfig `mapstructure:"log"` Upload UploadConfig `mapstructure:"upload"` Email EmailConfig `mapstructure:"email"` + Security SecurityConfig `mapstructure:"security"` } // ServerConfig 服务器配置 @@ -107,6 +109,12 @@ type EmailConfig struct { FromName string `mapstructure:"from_name"` } +// SecurityConfig 安全配置 +type SecurityConfig struct { + AllowedOrigins []string `mapstructure:"allowed_origins"` // 允许的CORS来源 + AllowedDomains []string `mapstructure:"allowed_domains"` // 允许的头像/材质URL域名 +} + // Load 加载配置 - 完全从环境变量加载,不依赖YAML文件 func Load() (*Config, error) { // 加载.env文件(如果存在) @@ -160,7 +168,7 @@ func setDefaults() { // RustFS默认配置 viper.SetDefault("rustfs.endpoint", "127.0.0.1:9000") - viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL + viper.SetDefault("rustfs.public_url", "") // 为空时使用 endpoint 构建 URL viper.SetDefault("rustfs.use_ssl", false) // JWT默认配置 @@ -188,6 +196,10 @@ func setDefaults() { // 邮件默认配置 viper.SetDefault("email.enabled", false) viper.SetDefault("email.smtp_port", 587) + + // 安全默认配置 + viper.SetDefault("security.allowed_origins", []string{"*"}) + viper.SetDefault("security.allowed_domains", []string{"localhost", "127.0.0.1"}) } // setupEnvMappings 设置环境变量映射 @@ -310,6 +322,15 @@ func overrideFromEnv(config *Config) { if env := os.Getenv("ENVIRONMENT"); env != "" { config.Environment = env } + + // 处理安全配置 + if allowedOrigins := os.Getenv("SECURITY_ALLOWED_ORIGINS"); allowedOrigins != "" { + config.Security.AllowedOrigins = strings.Split(allowedOrigins, ",") + } + + if allowedDomains := os.Getenv("SECURITY_ALLOWED_DOMAINS"); allowedDomains != "" { + config.Security.AllowedDomains = strings.Split(allowedDomains, ",") + } } // IsTestEnvironment 判断是否为测试环境 diff --git a/pkg/config/manager.go b/pkg/config/manager.go index 5c2d631..1ded256 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -62,6 +62,3 @@ func MustGetRustFSConfig() *RustFSConfig { return cfg } - - -