diff --git a/Dockerfile b/Dockerfile index 512bf9d..6dd5d6e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -65,3 +65,5 @@ ENTRYPOINT ["./server"] + + diff --git a/internal/container/container.go b/internal/container/container.go index 4dfce6c..70edc4d 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -44,6 +44,7 @@ type Container struct { UploadService service.UploadService SecurityService service.SecurityService CaptchaService service.CaptchaService + SignatureService *service.SignatureService } // NewContainer 创建依赖容器 @@ -80,26 +81,27 @@ func NewContainer( c.ConfigRepo = repository.NewSystemConfigRepository(db) c.YggdrasilRepo = repository.NewYggdrasilRepository(db) - // 初始化SignatureService(用于获取Yggdrasil私钥) - signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger) - - // 获取Yggdrasil私钥并创建JWT服务 - _, privateKey, err := signatureService.GetOrCreateYggdrasilKeyPair() - if err != nil { - logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err)) - } - yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin") + // 初始化SignatureService(作为依赖注入,避免在容器中创建并立即调用) + // 将SignatureService添加到容器中,供其他服务使用 + c.SignatureService = service.NewSignatureService(c.ProfileRepo, redisClient, logger) // 初始化Service(注入缓存管理器) c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger) c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger) c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger) - - // 使用JWT版本的TokenService + + // 获取Yggdrasil私钥并创建JWT服务(TokenService需要) + // 注意:这里仍然需要预先初始化,因为TokenService在创建时需要YggdrasilJWT + // 但SignatureService已经作为依赖注入,降低了耦合度 + _, privateKey, err := c.SignatureService.GetOrCreateYggdrasilKeyPair() + if err != nil { + logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err)) + } + yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin") c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger) // 使用组合服务(内部包含认证、会话、序列化、证书服务) - c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger) + c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger) // 初始化其他服务 c.SecurityService = service.NewSecurityService(redisClient) diff --git a/internal/handler/customskin_handler.go b/internal/handler/customskin_handler.go index ed2123a..87cebd4 100644 --- a/internal/handler/customskin_handler.go +++ b/internal/handler/customskin_handler.go @@ -219,7 +219,7 @@ func (h *CustomSkinHandler) GetTexture(c *gin.Context) { // 增加下载计数(异步) go func() { - _ = h.container.TextureRepo.IncrementDownloadCount(texture.ID) + _ = h.container.TextureRepo.IncrementDownloadCount(ctx, texture.ID) }() // 流式传输文件内容 diff --git a/internal/handler/helpers.go b/internal/handler/helpers.go index 390b162..202a50b 100644 --- a/internal/handler/helpers.go +++ b/internal/handler/helpers.go @@ -1,6 +1,7 @@ package handler import ( + "carrotskin/internal/errors" "carrotskin/internal/model" "carrotskin/internal/types" "net/http" @@ -165,17 +166,46 @@ func RespondSuccess(c *gin.Context, data interface{}) { c.JSON(http.StatusOK, model.NewSuccessResponse(data)) } -// RespondWithError 根据错误消息自动选择状态码 +// RespondWithError 根据错误类型自动选择状态码 func RespondWithError(c *gin.Context, err error) { - msg := err.Error() - switch msg { - case "档案不存在", "用户不存在", "材质不存在": - RespondNotFound(c, msg) - case "无权操作此档案", "无权操作此材质": - RespondForbidden(c, msg) - case "未授权": - RespondUnauthorized(c, msg) - default: - RespondServerError(c, msg, nil) + if err == nil { + return } + + // 使用errors.Is检查预定义错误 + if errors.Is(err, errors.ErrUserNotFound) || + errors.Is(err, errors.ErrProfileNotFound) || + errors.Is(err, errors.ErrTextureNotFound) || + errors.Is(err, errors.ErrNotFound) { + RespondNotFound(c, err.Error()) + return + } + + if errors.Is(err, errors.ErrProfileNoPermission) || + errors.Is(err, errors.ErrTextureNoPermission) || + errors.Is(err, errors.ErrForbidden) { + RespondForbidden(c, err.Error()) + return + } + + if errors.Is(err, errors.ErrUnauthorized) || + errors.Is(err, errors.ErrInvalidToken) || + errors.Is(err, errors.ErrTokenExpired) { + RespondUnauthorized(c, err.Error()) + return + } + + // 检查AppError类型 + var appErr *errors.AppError + if errors.As(err, &appErr) { + c.JSON(appErr.Code, model.NewErrorResponse( + appErr.Code, + appErr.Message, + appErr.Err, + )) + return + } + + // 默认返回500错误 + RespondServerError(c, err.Error(), err) } diff --git a/internal/handler/routes.go b/internal/handler/routes.go index 2b5ef72..4d62899 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -4,6 +4,7 @@ import ( "carrotskin/internal/container" "carrotskin/internal/middleware" "carrotskin/internal/model" + "carrotskin/pkg/auth" "github.com/gin-gonic/gin" ) @@ -47,13 +48,13 @@ func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) { registerAuthRoutes(v1, h.Auth) // 用户路由(需要JWT认证) - registerUserRoutes(v1, h.User) + registerUserRoutes(v1, h.User, c.JWT) // 材质路由 - registerTextureRoutes(v1, h.Texture) + registerTextureRoutes(v1, h.Texture, c.JWT) // 档案路由 - registerProfileRoutesWithDI(v1, h.Profile) + registerProfileRoutesWithDI(v1, h.Profile, c.JWT) // 验证码路由 registerCaptchaRoutesWithDI(v1, h.Captcha) @@ -81,9 +82,9 @@ func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) { } // registerUserRoutes 注册用户路由 -func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { +func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler, jwtService *auth.JWTService) { userGroup := v1.Group("/user") - userGroup.Use(middleware.AuthMiddleware()) + userGroup.Use(middleware.AuthMiddleware(jwtService)) { userGroup.GET("/profile", h.GetProfile) userGroup.PUT("/profile", h.UpdateProfile) @@ -101,7 +102,7 @@ func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { } // registerTextureRoutes 注册材质路由 -func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { +func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler, jwtService *auth.JWTService) { textureGroup := v1.Group("/texture") { // 公开路由(无需认证) @@ -110,7 +111,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { // 需要认证的路由 textureAuth := textureGroup.Group("") - textureAuth.Use(middleware.AuthMiddleware()) + textureAuth.Use(middleware.AuthMiddleware(jwtService)) { textureAuth.POST("/upload-url", h.GenerateUploadURL) textureAuth.POST("", h.Create) @@ -124,7 +125,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { } // registerProfileRoutesWithDI 注册档案路由(依赖注入版本) -func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { +func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler, jwtService *auth.JWTService) { profileGroup := v1.Group("/profile") { // 公开路由(无需认证) @@ -132,7 +133,7 @@ func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { // 需要认证的路由 profileAuth := profileGroup.Group("") - profileAuth.Use(middleware.AuthMiddleware()) + profileAuth.Use(middleware.AuthMiddleware(jwtService)) { profileAuth.POST("/", h.Create) profileAuth.GET("/", h.List) diff --git a/internal/handler/swagger.go b/internal/handler/swagger.go index 11ed54f..419a148 100644 --- a/internal/handler/swagger.go +++ b/internal/handler/swagger.go @@ -1,15 +1,95 @@ package handler import ( + "context" + "errors" "net/http" + "time" + + "carrotskin/pkg/database" + "carrotskin/pkg/redis" "github.com/gin-gonic/gin" ) -// HealthCheck 健康检查 +// HealthCheck 健康检查,检查依赖服务状态 func HealthCheck(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "status": "ok", - "message": "CarrotSkin API is running", + ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second) + defer cancel() + + checks := make(map[string]string) + status := "ok" + + // 检查数据库 + if err := checkDatabase(ctx); err != nil { + checks["database"] = "unhealthy: " + err.Error() + status = "degraded" + } else { + checks["database"] = "healthy" + } + + // 检查Redis + if err := checkRedis(ctx); err != nil { + checks["redis"] = "unhealthy: " + err.Error() + status = "degraded" + } else { + checks["redis"] = "healthy" + } + + // 根据状态返回相应的HTTP状态码 + httpStatus := http.StatusOK + if status == "degraded" { + httpStatus = http.StatusServiceUnavailable + } + + c.JSON(httpStatus, gin.H{ + "status": status, + "message": "CarrotSkin API health check", + "checks": checks, + "timestamp": time.Now().Unix(), }) } + +// checkDatabase 检查数据库连接 +func checkDatabase(ctx context.Context) error { + db, err := database.GetDB() + if err != nil { + return err + } + + sqlDB, err := db.DB() + if err != nil { + return err + } + + // 使用Ping检查连接 + if err := sqlDB.PingContext(ctx); err != nil { + return err + } + + // 执行简单查询验证 + var result int + if err := db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error; err != nil { + return err + } + + return nil +} + +// checkRedis 检查Redis连接 +func checkRedis(ctx context.Context) error { + client, err := redis.GetClient() + if err != nil { + return err + } + if client == nil { + return errors.New("Redis客户端未初始化") + } + + // 使用Ping检查连接 + if err := client.Ping(ctx).Err(); err != nil { + return err + } + + return nil +} diff --git a/internal/handler/yggdrasil_handler.go b/internal/handler/yggdrasil_handler.go index f873f0a..fc63566 100644 --- a/internal/handler/yggdrasil_handler.go +++ b/internal/handler/yggdrasil_handler.go @@ -190,7 +190,7 @@ func (h *YggdrasilHandler) Authenticate(c *gin.Context) { if emailRegex.MatchString(request.Identifier) { userId, err = h.container.YggdrasilService.GetUserIDByEmail(c.Request.Context(), request.Identifier) } else { - profile, err = h.container.ProfileRepo.FindByName(request.Identifier) + profile, err = h.container.ProfileRepo.FindByName(c.Request.Context(), request.Identifier) if err != nil { h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 9187b7c..fd48b9d 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "carrotskin/internal/model" "net/http" "strings" @@ -9,17 +10,16 @@ import ( "github.com/gin-gonic/gin" ) -// AuthMiddleware JWT认证中间件 -func AuthMiddleware() gin.HandlerFunc { +// AuthMiddleware JWT认证中间件(注入JWT服务版本) +func AuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc { return gin.HandlerFunc(func(c *gin.Context) { - jwtService := auth.MustGetJWTService() - authHeader := c.GetHeader("Authorization") if authHeader == "" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "缺少Authorization头", - }) + c.JSON(http.StatusUnauthorized, model.NewErrorResponse( + model.CodeUnauthorized, + "缺少Authorization头", + nil, + )) c.Abort() return } @@ -27,10 +27,11 @@ func AuthMiddleware() gin.HandlerFunc { // Bearer token格式 tokenParts := strings.SplitN(authHeader, " ", 2) if len(tokenParts) != 2 || tokenParts[0] != "Bearer" { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "无效的Authorization头格式", - }) + c.JSON(http.StatusUnauthorized, model.NewErrorResponse( + model.CodeUnauthorized, + "无效的Authorization头格式", + nil, + )) c.Abort() return } @@ -38,10 +39,11 @@ func AuthMiddleware() gin.HandlerFunc { token := tokenParts[1] claims, err := jwtService.ValidateToken(token) if err != nil { - c.JSON(http.StatusUnauthorized, gin.H{ - "code": 401, - "message": "无效的token", - }) + c.JSON(http.StatusUnauthorized, model.NewErrorResponse( + model.CodeUnauthorized, + "无效的token", + err, + )) c.Abort() return } @@ -55,11 +57,9 @@ func AuthMiddleware() gin.HandlerFunc { }) } -// OptionalAuthMiddleware 可选的JWT认证中间件 -func OptionalAuthMiddleware() gin.HandlerFunc { +// OptionalAuthMiddleware 可选的JWT认证中间件(注入JWT服务版本) +func OptionalAuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc { return gin.HandlerFunc(func(c *gin.Context) { - jwtService := auth.MustGetJWTService() - authHeader := c.GetHeader("Authorization") if authHeader != "" { tokenParts := strings.SplitN(authHeader, " ", 2) diff --git a/internal/model/client.go b/internal/model/client.go index b1b461a..a71dc2d 100644 --- a/internal/model/client.go +++ b/internal/model/client.go @@ -22,3 +22,5 @@ func (Client) TableName() string { return "clients" } + + diff --git a/internal/repository/client_repository.go b/internal/repository/client_repository.go index 199d735..20a6435 100644 --- a/internal/repository/client_repository.go +++ b/internal/repository/client_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "gorm.io/gorm" ) @@ -16,48 +17,48 @@ func NewClientRepository(db *gorm.DB) ClientRepository { return &clientRepository{db: db} } -func (r *clientRepository) Create(client *model.Client) error { - return r.db.Create(client).Error +func (r *clientRepository) Create(ctx context.Context, client *model.Client) error { + return r.db.WithContext(ctx).Create(client).Error } -func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) { +func (r *clientRepository) FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error) { var client model.Client - err := r.db.Where("client_token = ?", clientToken).First(&client).Error + err := r.db.WithContext(ctx).Where("client_token = ?", clientToken).First(&client).Error if err != nil { return nil, err } return &client, nil } -func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) { +func (r *clientRepository) FindByUUID(ctx context.Context, uuid string) (*model.Client, error) { var client model.Client - err := r.db.Where("uuid = ?", uuid).First(&client).Error + err := r.db.WithContext(ctx).Where("uuid = ?", uuid).First(&client).Error if err != nil { return nil, err } return &client, nil } -func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) { +func (r *clientRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error) { var clients []*model.Client - err := r.db.Where("user_id = ?", userID).Find(&clients).Error + err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&clients).Error return clients, err } -func (r *clientRepository) Update(client *model.Client) error { - return r.db.Save(client).Error +func (r *clientRepository) Update(ctx context.Context, client *model.Client) error { + return r.db.WithContext(ctx).Save(client).Error } -func (r *clientRepository) IncrementVersion(clientUUID string) error { - return r.db.Model(&model.Client{}). +func (r *clientRepository) IncrementVersion(ctx context.Context, clientUUID string) error { + return r.db.WithContext(ctx).Model(&model.Client{}). Where("uuid = ?", clientUUID). Update("version", gorm.Expr("version + 1")).Error } -func (r *clientRepository) DeleteByClientToken(clientToken string) error { - return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error +func (r *clientRepository) DeleteByClientToken(ctx context.Context, clientToken string) error { + return r.db.WithContext(ctx).Where("client_token = ?", clientToken).Delete(&model.Client{}).Error } -func (r *clientRepository) DeleteByUserID(userID int64) error { - return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error +func (r *clientRepository) DeleteByUserID(ctx context.Context, userID int64) error { + return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&model.Client{}).Error } diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 64d1e23..f2ec4a8 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -2,95 +2,105 @@ package repository import ( "carrotskin/internal/model" + "context" ) // UserRepository 用户仓储接口 type UserRepository interface { - Create(user *model.User) error - FindByID(id int64) (*model.User, error) - FindByUsername(username string) (*model.User, error) - FindByEmail(email string) (*model.User, error) - Update(user *model.User) error - UpdateFields(id int64, fields map[string]interface{}) error - Delete(id int64) error - CreateLoginLog(log *model.UserLoginLog) error - CreatePointLog(log *model.UserPointLog) error - UpdatePoints(userID int64, amount int, changeType, reason string) error + Create(ctx context.Context, user *model.User) error + FindByID(ctx context.Context, id int64) (*model.User, error) + FindByUsername(ctx context.Context, username string) (*model.User, error) + FindByEmail(ctx context.Context, email string) (*model.User, error) + FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) // 批量查询 + Update(ctx context.Context, user *model.User) error + UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error + BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新 + Delete(ctx context.Context, id int64) error + BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除 + CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error + CreatePointLog(ctx context.Context, log *model.UserPointLog) error + UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error } // ProfileRepository 档案仓储接口 type ProfileRepository interface { - Create(profile *model.Profile) error - FindByUUID(uuid string) (*model.Profile, error) - FindByName(name string) (*model.Profile, error) - FindByUserID(userID int64) ([]*model.Profile, error) - Update(profile *model.Profile) error - UpdateFields(uuid string, updates map[string]interface{}) error - Delete(uuid string) error - CountByUserID(userID int64) (int64, error) - SetActive(uuid string, userID int64) error - UpdateLastUsedAt(uuid string) error - GetByNames(names []string) ([]*model.Profile, error) - GetKeyPair(profileId string) (*model.KeyPair, error) - UpdateKeyPair(profileId string, keyPair *model.KeyPair) error + Create(ctx context.Context, profile *model.Profile) error + FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) + FindByName(ctx context.Context, name string) (*model.Profile, error) + FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) + FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) // 批量查询 + Update(ctx context.Context, profile *model.Profile) error + UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error + BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) // 批量更新 + Delete(ctx context.Context, uuid string) error + BatchDelete(ctx context.Context, uuids []string) (int64, error) // 批量删除 + CountByUserID(ctx context.Context, userID int64) (int64, error) + SetActive(ctx context.Context, uuid string, userID int64) error + UpdateLastUsedAt(ctx context.Context, uuid string) error + GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) + GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) + UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error } // TextureRepository 材质仓储接口 type TextureRepository interface { - Create(texture *model.Texture) error - FindByID(id int64) (*model.Texture, error) - FindByHash(hash string) (*model.Texture, error) - FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) - Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) - Update(texture *model.Texture) error - UpdateFields(id int64, fields map[string]interface{}) error - Delete(id int64) error - IncrementDownloadCount(id int64) error - IncrementFavoriteCount(id int64) error - DecrementFavoriteCount(id int64) error - CreateDownloadLog(log *model.TextureDownloadLog) error - IsFavorited(userID, textureID int64) (bool, error) - AddFavorite(userID, textureID int64) error - RemoveFavorite(userID, textureID int64) error - GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) - CountByUploaderID(uploaderID int64) (int64, error) + Create(ctx context.Context, texture *model.Texture) error + FindByID(ctx context.Context, id int64) (*model.Texture, error) + FindByHash(ctx context.Context, hash string) (*model.Texture, error) + FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) // 批量查询 + FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) + Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) + Update(ctx context.Context, texture *model.Texture) error + UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error + BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新 + Delete(ctx context.Context, id int64) error + BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除 + IncrementDownloadCount(ctx context.Context, id int64) error + IncrementFavoriteCount(ctx context.Context, id int64) error + DecrementFavoriteCount(ctx context.Context, id int64) error + CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error + IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) + AddFavorite(ctx context.Context, userID, textureID int64) error + RemoveFavorite(ctx context.Context, userID, textureID int64) error + GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) + CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) } // TokenRepository 令牌仓储接口 type TokenRepository interface { - Create(token *model.Token) error - FindByAccessToken(accessToken string) (*model.Token, error) - GetByUserID(userId int64) ([]*model.Token, error) - GetUUIDByAccessToken(accessToken string) (string, error) - GetUserIDByAccessToken(accessToken string) (int64, error) - DeleteByAccessToken(accessToken string) error - DeleteByUserID(userId int64) error - BatchDelete(accessTokens []string) (int64, error) + Create(ctx context.Context, token *model.Token) error + FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) + GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) + GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) + GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) + DeleteByAccessToken(ctx context.Context, accessToken string) error + DeleteByUserID(ctx context.Context, userId int64) error + BatchDelete(ctx context.Context, accessTokens []string) (int64, error) } // SystemConfigRepository 系统配置仓储接口 type SystemConfigRepository interface { - GetByKey(key string) (*model.SystemConfig, error) - GetPublic() ([]model.SystemConfig, error) - GetAll() ([]model.SystemConfig, error) - Update(config *model.SystemConfig) error - UpdateValue(key, value string) error + GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) + GetPublic(ctx context.Context) ([]model.SystemConfig, error) + GetAll(ctx context.Context) ([]model.SystemConfig, error) + Update(ctx context.Context, config *model.SystemConfig) error + UpdateValue(ctx context.Context, key, value string) error } // YggdrasilRepository Yggdrasil仓储接口 type YggdrasilRepository interface { - GetPasswordByID(id int64) (string, error) - ResetPassword(id int64, password string) error + GetPasswordByID(ctx context.Context, id int64) (string, error) + ResetPassword(ctx context.Context, id int64, password string) error } // ClientRepository Client仓储接口 type ClientRepository interface { - Create(client *model.Client) error - FindByClientToken(clientToken string) (*model.Client, error) - FindByUUID(uuid string) (*model.Client, error) - FindByUserID(userID int64) ([]*model.Client, error) - Update(client *model.Client) error - IncrementVersion(clientUUID string) error - DeleteByClientToken(clientToken string) error - DeleteByUserID(userID int64) error + Create(ctx context.Context, client *model.Client) error + FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error) + FindByUUID(ctx context.Context, uuid string) (*model.Client, error) + FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error) + Update(ctx context.Context, client *model.Client) error + IncrementVersion(ctx context.Context, clientUUID string) error + DeleteByClientToken(ctx context.Context, clientToken string) error + DeleteByUserID(ctx context.Context, userID int64) error } diff --git a/internal/repository/profile_repository.go b/internal/repository/profile_repository.go index fd1558c..6f9cded 100644 --- a/internal/repository/profile_repository.go +++ b/internal/repository/profile_repository.go @@ -19,13 +19,13 @@ func NewProfileRepository(db *gorm.DB) ProfileRepository { return &profileRepository{db: db} } -func (r *profileRepository) Create(profile *model.Profile) error { - return r.db.Create(profile).Error +func (r *profileRepository) Create(ctx context.Context, profile *model.Profile) error { + return r.db.WithContext(ctx).Create(profile).Error } -func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) { +func (r *profileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) { var profile model.Profile - err := r.db.Where("uuid = ?", uuid). + err := r.db.WithContext(ctx).Where("uuid = ?", uuid). Preload("Skin"). Preload("Cape"). First(&profile).Error @@ -35,10 +35,10 @@ func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) { return &profile, nil } -func (r *profileRepository) FindByName(name string) (*model.Profile, error) { +func (r *profileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) { var profile model.Profile // 使用 LOWER 函数进行不区分大小写的查询,并预加载 Skin 和 Cape - err := r.db.Where("LOWER(name) = LOWER(?)", name). + err := r.db.WithContext(ctx).Where("LOWER(name) = LOWER(?)", name). Preload("Skin"). Preload("Cape"). First(&profile).Error @@ -48,9 +48,9 @@ func (r *profileRepository) FindByName(name string) (*model.Profile, error) { return &profile, nil } -func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { +func (r *profileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) { var profiles []*model.Profile - err := r.db.Where("user_id = ?", userID). + err := r.db.WithContext(ctx).Where("user_id = ?", userID). Preload("Skin"). Preload("Cape"). Order("created_at DESC"). @@ -58,30 +58,59 @@ func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error) return profiles, err } -func (r *profileRepository) Update(profile *model.Profile) error { - return r.db.Save(profile).Error +func (r *profileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) { + if len(uuids) == 0 { + return []*model.Profile{}, nil + } + var profiles []*model.Profile + // 使用 IN 查询优化批量查询,并预加载关联 + err := r.db.WithContext(ctx).Where("uuid IN ?", uuids). + Preload("Skin"). + Preload("Cape"). + Find(&profiles).Error + return profiles, err } -func (r *profileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { - return r.db.Model(&model.Profile{}). +func (r *profileRepository) Update(ctx context.Context, profile *model.Profile) error { + return r.db.WithContext(ctx).Save(profile).Error +} + +func (r *profileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error { + return r.db.WithContext(ctx).Model(&model.Profile{}). Where("uuid = ?", uuid). Updates(updates).Error } -func (r *profileRepository) Delete(uuid string) error { - return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error +func (r *profileRepository) Delete(ctx context.Context, uuid string) error { + return r.db.WithContext(ctx).Where("uuid = ?", uuid).Delete(&model.Profile{}).Error } -func (r *profileRepository) CountByUserID(userID int64) (int64, error) { +func (r *profileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) { + if len(uuids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Model(&model.Profile{}).Where("uuid IN ?", uuids).Updates(updates) + return result.RowsAffected, result.Error +} + +func (r *profileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) { + if len(uuids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Where("uuid IN ?", uuids).Delete(&model.Profile{}) + return result.RowsAffected, result.Error +} + +func (r *profileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { var count int64 - err := r.db.Model(&model.Profile{}). + err := r.db.WithContext(ctx).Model(&model.Profile{}). Where("user_id = ?", userID). Count(&count).Error return count, err } -func (r *profileRepository) SetActive(uuid string, userID int64) error { - return r.db.Transaction(func(tx *gorm.DB) error { +func (r *profileRepository) SetActive(ctx context.Context, uuid string, userID int64) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Model(&model.Profile{}). Where("user_id = ?", userID). Update("is_active", false).Error; err != nil { @@ -94,28 +123,28 @@ func (r *profileRepository) SetActive(uuid string, userID int64) error { }) } -func (r *profileRepository) UpdateLastUsedAt(uuid string) error { - return r.db.Model(&model.Profile{}). +func (r *profileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error { + return r.db.WithContext(ctx).Model(&model.Profile{}). Where("uuid = ?", uuid). Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error } -func (r *profileRepository) GetByNames(names []string) ([]*model.Profile, error) { +func (r *profileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) { var profiles []*model.Profile - err := r.db.Where("name in (?)", names). + err := r.db.WithContext(ctx).Where("name in (?)", names). Preload("Skin"). Preload("Cape"). Find(&profiles).Error return profiles, err } -func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { +func (r *profileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) { if profileId == "" { return nil, errors.New("参数不能为空") } var profile model.Profile - result := r.db.WithContext(context.Background()). + result := r.db.WithContext(ctx). Select("key_pair"). Where("id = ?", profileId). First(&profile) @@ -130,7 +159,7 @@ func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) return &model.KeyPair{}, nil } -func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { +func (r *profileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error { if profileId == "" { return errors.New("profileId 不能为空") } @@ -138,9 +167,8 @@ func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPa return errors.New("keyPair 不能为 nil") } - return r.db.Transaction(func(tx *gorm.DB) error { - result := tx.WithContext(context.Background()). - Table("profiles"). + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + result := tx.Table("profiles"). Where("id = ?", profileId). UpdateColumns(map[string]interface{}{ "private_key": keyPair.PrivateKey, diff --git a/internal/repository/system_config_repository.go b/internal/repository/system_config_repository.go index 174ad45..41b2cc1 100644 --- a/internal/repository/system_config_repository.go +++ b/internal/repository/system_config_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "gorm.io/gorm" ) @@ -16,28 +17,28 @@ func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository { return &systemConfigRepository{db: db} } -func (r *systemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { +func (r *systemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) { var config model.SystemConfig - err := r.db.Where("key = ?", key).First(&config).Error + err := r.db.WithContext(ctx).Where("key = ?", key).First(&config).Error return handleNotFoundResult(&config, err) } -func (r *systemConfigRepository) GetPublic() ([]model.SystemConfig, error) { +func (r *systemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) { var configs []model.SystemConfig - err := r.db.Where("is_public = ?", true).Find(&configs).Error + err := r.db.WithContext(ctx).Where("is_public = ?", true).Find(&configs).Error return configs, err } -func (r *systemConfigRepository) GetAll() ([]model.SystemConfig, error) { +func (r *systemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) { var configs []model.SystemConfig - err := r.db.Find(&configs).Error + err := r.db.WithContext(ctx).Find(&configs).Error return configs, err } -func (r *systemConfigRepository) Update(config *model.SystemConfig) error { - return r.db.Save(config).Error +func (r *systemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error { + return r.db.WithContext(ctx).Save(config).Error } -func (r *systemConfigRepository) UpdateValue(key, value string) error { - return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error +func (r *systemConfigRepository) UpdateValue(ctx context.Context, key, value string) error { + return r.db.WithContext(ctx).Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error } diff --git a/internal/repository/texture_repository.go b/internal/repository/texture_repository.go index 5c6dc43..a2b9827 100644 --- a/internal/repository/texture_repository.go +++ b/internal/repository/texture_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "gorm.io/gorm" ) @@ -16,27 +17,39 @@ func NewTextureRepository(db *gorm.DB) TextureRepository { return &textureRepository{db: db} } -func (r *textureRepository) Create(texture *model.Texture) error { - return r.db.Create(texture).Error +func (r *textureRepository) Create(ctx context.Context, texture *model.Texture) error { + return r.db.WithContext(ctx).Create(texture).Error } -func (r *textureRepository) FindByID(id int64) (*model.Texture, error) { +func (r *textureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) { var texture model.Texture - err := r.db.Preload("Uploader").First(&texture, id).Error + err := r.db.WithContext(ctx).Preload("Uploader").First(&texture, id).Error return handleNotFoundResult(&texture, err) } -func (r *textureRepository) FindByHash(hash string) (*model.Texture, error) { +func (r *textureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) { var texture model.Texture - err := r.db.Where("hash = ?", hash).First(&texture).Error + err := r.db.WithContext(ctx).Where("hash = ?", hash).First(&texture).Error return handleNotFoundResult(&texture, err) } -func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (r *textureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) { + if len(ids) == 0 { + return []*model.Texture{}, nil + } + var textures []*model.Texture + // 使用 IN 查询优化批量查询,并预加载关联 + err := r.db.WithContext(ctx).Where("id IN ?", ids). + Preload("Uploader"). + Find(&textures).Error + return textures, err +} + +func (r *textureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { var textures []*model.Texture var total int64 - query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID) + query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID) if err := query.Count(&total).Error; err != nil { return nil, 0, err @@ -54,11 +67,11 @@ func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize in return textures, total, nil } -func (r *textureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { +func (r *textureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { var textures []*model.Texture var total int64 - query := r.db.Model(&model.Texture{}).Where("status = 1") + query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("status = 1") if publicOnly { query = query.Where("is_public = ?", true) @@ -86,67 +99,86 @@ func (r *textureRepository) Search(keyword string, textureType model.TextureType return textures, total, nil } -func (r *textureRepository) Update(texture *model.Texture) error { - return r.db.Save(texture).Error +func (r *textureRepository) Update(ctx context.Context, texture *model.Texture) error { + return r.db.WithContext(ctx).Save(texture).Error } -func (r *textureRepository) UpdateFields(id int64, fields map[string]interface{}) error { - return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error +func (r *textureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error { + return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error } -func (r *textureRepository) Delete(id int64) error { - return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error +func (r *textureRepository) Delete(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error } -func (r *textureRepository) IncrementDownloadCount(id int64) error { - return r.db.Model(&model.Texture{}).Where("id = ?", id). +func (r *textureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Updates(fields) + return result.RowsAffected, result.Error +} + +func (r *textureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Update("status", -1) + return result.RowsAffected, result.Error +} + +func (r *textureRepository) IncrementDownloadCount(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id). UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error } -func (r *textureRepository) IncrementFavoriteCount(id int64) error { - return r.db.Model(&model.Texture{}).Where("id = ?", id). +func (r *textureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id). UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error } -func (r *textureRepository) DecrementFavoriteCount(id int64) error { - return r.db.Model(&model.Texture{}).Where("id = ?", id). +func (r *textureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id). UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error } -func (r *textureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { - return r.db.Create(log).Error +func (r *textureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error { + return r.db.WithContext(ctx).Create(log).Error } -func (r *textureRepository) IsFavorited(userID, textureID int64) (bool, error) { +func (r *textureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) { var count int64 - err := r.db.Model(&model.UserTextureFavorite{}). + // 使用 Select("1") 优化,只查询是否存在,不需要查询所有字段 + err := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}). + Select("1"). Where("user_id = ? AND texture_id = ?", userID, textureID). + Limit(1). Count(&count).Error return count > 0, err } -func (r *textureRepository) AddFavorite(userID, textureID int64) error { +func (r *textureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error { favorite := &model.UserTextureFavorite{ UserID: userID, TextureID: textureID, } - return r.db.Create(favorite).Error + return r.db.WithContext(ctx).Create(favorite).Error } -func (r *textureRepository) RemoveFavorite(userID, textureID int64) error { - return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID). +func (r *textureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error { + return r.db.WithContext(ctx).Where("user_id = ? AND texture_id = ?", userID, textureID). Delete(&model.UserTextureFavorite{}).Error } -func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (r *textureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { var textures []*model.Texture var total int64 - subQuery := r.db.Model(&model.UserTextureFavorite{}). + subQuery := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}). Select("texture_id"). Where("user_id = ?", userID) - query := r.db.Model(&model.Texture{}). + query := r.db.WithContext(ctx).Model(&model.Texture{}). Where("id IN (?) AND status = 1", subQuery) if err := query.Count(&total).Error; err != nil { @@ -165,9 +197,9 @@ func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) ( return textures, total, nil } -func (r *textureRepository) CountByUploaderID(uploaderID int64) (int64, error) { +func (r *textureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) { var count int64 - err := r.db.Model(&model.Texture{}). + err := r.db.WithContext(ctx).Model(&model.Texture{}). Where("uploader_id = ? AND status != -1", uploaderID). Count(&count).Error return count, err diff --git a/internal/repository/token_repository.go b/internal/repository/token_repository.go index ecf2cca..ebc7968 100644 --- a/internal/repository/token_repository.go +++ b/internal/repository/token_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "gorm.io/gorm" ) @@ -16,55 +17,55 @@ func NewTokenRepository(db *gorm.DB) TokenRepository { return &tokenRepository{db: db} } -func (r *tokenRepository) Create(token *model.Token) error { - return r.db.Create(token).Error +func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error { + return r.db.WithContext(ctx).Create(token).Error } -func (r *tokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { +func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) { var token model.Token - err := r.db.Where("access_token = ?", accessToken).First(&token).Error + err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error if err != nil { return nil, err } return &token, nil } -func (r *tokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { +func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) { var tokens []*model.Token - err := r.db.Where("user_id = ?", userId).Find(&tokens).Error + err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error return tokens, err } -func (r *tokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { +func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { var token model.Token - err := r.db.Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error + err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error if err != nil { return "", err } return token.ProfileId, nil } -func (r *tokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { +func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { var token model.Token - err := r.db.Select("user_id").Where("access_token = ?", accessToken).First(&token).Error + err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error if err != nil { return 0, err } return token.UserID, nil } -func (r *tokenRepository) DeleteByAccessToken(accessToken string) error { - return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error +func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error { + return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error } -func (r *tokenRepository) DeleteByUserID(userId int64) error { - return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error +func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error { + return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error } -func (r *tokenRepository) BatchDelete(accessTokens []string) (int64, error) { +func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) { if len(accessTokens) == 0 { return 0, nil } - result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) + result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{}) return result.RowsAffected, result.Error } diff --git a/internal/repository/user_repository.go b/internal/repository/user_repository.go index 1362fa6..b104d51 100644 --- a/internal/repository/user_repository.go +++ b/internal/repository/user_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "errors" "gorm.io/gorm" @@ -17,50 +18,76 @@ func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } -func (r *userRepository) Create(user *model.User) error { - return r.db.Create(user).Error +func (r *userRepository) Create(ctx context.Context, user *model.User) error { + return r.db.WithContext(ctx).Create(user).Error } -func (r *userRepository) FindByID(id int64) (*model.User, error) { +func (r *userRepository) FindByID(ctx context.Context, id int64) (*model.User, error) { var user model.User - err := r.db.Where("id = ? AND status != -1", id).First(&user).Error + err := r.db.WithContext(ctx).Where("id = ? AND status != -1", id).First(&user).Error return handleNotFoundResult(&user, err) } -func (r *userRepository) FindByUsername(username string) (*model.User, error) { +func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) { var user model.User - err := r.db.Where("username = ? AND status != -1", username).First(&user).Error + err := r.db.WithContext(ctx).Where("username = ? AND status != -1", username).First(&user).Error return handleNotFoundResult(&user, err) } -func (r *userRepository) FindByEmail(email string) (*model.User, error) { +func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) { var user model.User - err := r.db.Where("email = ? AND status != -1", email).First(&user).Error + err := r.db.WithContext(ctx).Where("email = ? AND status != -1", email).First(&user).Error return handleNotFoundResult(&user, err) } -func (r *userRepository) Update(user *model.User) error { - return r.db.Save(user).Error +func (r *userRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) { + if len(ids) == 0 { + return []*model.User{}, nil + } + var users []*model.User + // 使用 IN 查询优化批量查询 + err := r.db.WithContext(ctx).Where("id IN ? AND status != -1", ids).Find(&users).Error + return users, err } -func (r *userRepository) UpdateFields(id int64, fields map[string]interface{}) error { - return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error +func (r *userRepository) Update(ctx context.Context, user *model.User) error { + return r.db.WithContext(ctx).Save(user).Error } -func (r *userRepository) Delete(id int64) error { - return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error +func (r *userRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error { + return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Updates(fields).Error } -func (r *userRepository) CreateLoginLog(log *model.UserLoginLog) error { - return r.db.Create(log).Error +func (r *userRepository) Delete(ctx context.Context, id int64) error { + return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error } -func (r *userRepository) CreatePointLog(log *model.UserPointLog) error { - return r.db.Create(log).Error +func (r *userRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Updates(fields) + return result.RowsAffected, result.Error } -func (r *userRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { - return r.db.Transaction(func(tx *gorm.DB) error { +func (r *userRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) { + if len(ids) == 0 { + return 0, nil + } + result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Update("status", -1) + return result.RowsAffected, result.Error +} + +func (r *userRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error { + return r.db.WithContext(ctx).Create(log).Error +} + +func (r *userRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error { + return r.db.WithContext(ctx).Create(log).Error +} + +func (r *userRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var user model.User if err := tx.Where("id = ?", userID).First(&user).Error; err != nil { return err diff --git a/internal/repository/yggdrasil_repository.go b/internal/repository/yggdrasil_repository.go index 83af4ff..aa053e3 100644 --- a/internal/repository/yggdrasil_repository.go +++ b/internal/repository/yggdrasil_repository.go @@ -2,6 +2,7 @@ package repository import ( "carrotskin/internal/model" + "context" "gorm.io/gorm" ) @@ -16,15 +17,15 @@ func NewYggdrasilRepository(db *gorm.DB) YggdrasilRepository { return &yggdrasilRepository{db: db} } -func (r *yggdrasilRepository) GetPasswordByID(id int64) (string, error) { +func (r *yggdrasilRepository) GetPasswordByID(ctx context.Context, id int64) (string, error) { var yggdrasil model.Yggdrasil - err := r.db.Select("password").Where("id = ?", id).First(&yggdrasil).Error + err := r.db.WithContext(ctx).Select("password").Where("id = ?", id).First(&yggdrasil).Error if err != nil { return "", err } return yggdrasil.Password, nil } -func (r *yggdrasilRepository) ResetPassword(id int64, password string) error { - return r.db.Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error +func (r *yggdrasilRepository) ResetPassword(ctx context.Context, id int64, password string) error { + return r.db.WithContext(ctx).Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error } diff --git a/internal/service/mocks_test.go b/internal/service/mocks_test.go index 694dfe7..6872fcd 100644 --- a/internal/service/mocks_test.go +++ b/internal/service/mocks_test.go @@ -3,6 +3,7 @@ package service import ( "carrotskin/internal/model" "carrotskin/pkg/database" + "context" "errors" "time" ) @@ -28,7 +29,7 @@ func NewMockUserRepository() *MockUserRepository { } } -func (m *MockUserRepository) Create(user *model.User) error { +func (m *MockUserRepository) Create(ctx context.Context, user *model.User) error { if m.FailCreate { return errors.New("mock create error") } @@ -39,7 +40,7 @@ func (m *MockUserRepository) Create(user *model.User) error { return nil } -func (m *MockUserRepository) FindByID(id int64) (*model.User, error) { +func (m *MockUserRepository) FindByID(ctx context.Context, id int64) (*model.User, error) { if m.FailFindByID { return nil, errors.New("mock find error") } @@ -49,7 +50,7 @@ func (m *MockUserRepository) FindByID(id int64) (*model.User, error) { return nil, nil } -func (m *MockUserRepository) FindByUsername(username string) (*model.User, error) { +func (m *MockUserRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) { if m.FailFindByUsername { return nil, errors.New("mock find by username error") } @@ -61,7 +62,7 @@ func (m *MockUserRepository) FindByUsername(username string) (*model.User, error return nil, nil } -func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) { +func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) { if m.FailFindByEmail { return nil, errors.New("mock find by email error") } @@ -73,7 +74,7 @@ func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) { return nil, nil } -func (m *MockUserRepository) Update(user *model.User) error { +func (m *MockUserRepository) Update(ctx context.Context, user *model.User) error { if m.FailUpdate { return errors.New("mock update error") } @@ -81,7 +82,7 @@ func (m *MockUserRepository) Update(user *model.User) error { return nil } -func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{}) error { +func (m *MockUserRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error { if m.FailUpdate { return errors.New("mock update fields error") } @@ -92,23 +93,43 @@ func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{ return nil } -func (m *MockUserRepository) Delete(id int64) error { +func (m *MockUserRepository) Delete(ctx context.Context, id int64) error { delete(m.users, id) return nil } -func (m *MockUserRepository) CreateLoginLog(log *model.UserLoginLog) error { +func (m *MockUserRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error { return nil } -func (m *MockUserRepository) CreatePointLog(log *model.UserPointLog) error { +func (m *MockUserRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error { return nil } -func (m *MockUserRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { +func (m *MockUserRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error { return nil } +// BatchUpdate 和 BatchDelete 仅用于满足接口,在测试中不做具体操作 +func (m *MockUserRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) { + return 0, nil +} + +func (m *MockUserRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) { + return 0, nil +} + +// FindByIDs 批量查询用户 +func (m *MockUserRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) { + var result []*model.User + for _, id := range ids { + if u, ok := m.users[id]; ok { + result = append(result, u) + } + } + return result, nil +} + // MockProfileRepository 模拟ProfileRepository type MockProfileRepository struct { profiles map[string]*model.Profile @@ -128,7 +149,7 @@ func NewMockProfileRepository() *MockProfileRepository { } } -func (m *MockProfileRepository) Create(profile *model.Profile) error { +func (m *MockProfileRepository) Create(ctx context.Context, profile *model.Profile) error { if m.FailCreate { return errors.New("mock create error") } @@ -137,7 +158,7 @@ func (m *MockProfileRepository) Create(profile *model.Profile) error { return nil } -func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) { +func (m *MockProfileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) { if m.FailFind { return nil, errors.New("mock find error") } @@ -147,7 +168,7 @@ func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) return nil, errors.New("profile not found") } -func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) { +func (m *MockProfileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) { if m.FailFind { return nil, errors.New("mock find error") } @@ -159,14 +180,14 @@ func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) return nil, nil } -func (m *MockProfileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { +func (m *MockProfileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) { if m.FailFind { return nil, errors.New("mock find error") } return m.userProfiles[userID], nil } -func (m *MockProfileRepository) Update(profile *model.Profile) error { +func (m *MockProfileRepository) Update(ctx context.Context, profile *model.Profile) error { if m.FailUpdate { return errors.New("mock update error") } @@ -174,14 +195,14 @@ func (m *MockProfileRepository) Update(profile *model.Profile) error { return nil } -func (m *MockProfileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { +func (m *MockProfileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error { if m.FailUpdate { return errors.New("mock update error") } return nil } -func (m *MockProfileRepository) Delete(uuid string) error { +func (m *MockProfileRepository) Delete(ctx context.Context, uuid string) error { if m.FailDelete { return errors.New("mock delete error") } @@ -189,19 +210,19 @@ func (m *MockProfileRepository) Delete(uuid string) error { return nil } -func (m *MockProfileRepository) CountByUserID(userID int64) (int64, error) { +func (m *MockProfileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) { return int64(len(m.userProfiles[userID])), nil } -func (m *MockProfileRepository) SetActive(uuid string, userID int64) error { +func (m *MockProfileRepository) SetActive(ctx context.Context, uuid string, userID int64) error { return nil } -func (m *MockProfileRepository) UpdateLastUsedAt(uuid string) error { +func (m *MockProfileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error { return nil } -func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, error) { +func (m *MockProfileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) { var result []*model.Profile for _, name := range names { for _, profile := range m.profiles { @@ -213,14 +234,34 @@ func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, er return result, nil } -func (m *MockProfileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { +func (m *MockProfileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) { return nil, nil } -func (m *MockProfileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { +func (m *MockProfileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error { return nil } +// BatchUpdate / BatchDelete 仅用于满足接口 +func (m *MockProfileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) { + return 0, nil +} + +func (m *MockProfileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) { + return 0, nil +} + +// FindByUUIDs 批量查询 Profile +func (m *MockProfileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) { + var result []*model.Profile + for _, id := range uuids { + if p, ok := m.profiles[id]; ok { + result = append(result, p) + } + } + return result, nil +} + // MockTextureRepository 模拟TextureRepository type MockTextureRepository struct { textures map[int64]*model.Texture @@ -240,7 +281,7 @@ func NewMockTextureRepository() *MockTextureRepository { } } -func (m *MockTextureRepository) Create(texture *model.Texture) error { +func (m *MockTextureRepository) Create(ctx context.Context, texture *model.Texture) error { if m.FailCreate { return errors.New("mock create error") } @@ -252,7 +293,7 @@ func (m *MockTextureRepository) Create(texture *model.Texture) error { return nil } -func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) { +func (m *MockTextureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) { if m.FailFind { return nil, errors.New("mock find error") } @@ -262,7 +303,7 @@ func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) { return nil, errors.New("texture not found") } -func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) { +func (m *MockTextureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) { if m.FailFind { return nil, errors.New("mock find error") } @@ -274,7 +315,7 @@ func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) return nil, nil } -func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { if m.FailFind { return nil, 0, errors.New("mock find error") } @@ -287,7 +328,7 @@ func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSiz return result, int64(len(result)), nil } -func (m *MockTextureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { +func (m *MockTextureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { if m.FailFind { return nil, 0, errors.New("mock find error") } @@ -301,7 +342,7 @@ func (m *MockTextureRepository) Search(keyword string, textureType model.Texture return result, int64(len(result)), nil } -func (m *MockTextureRepository) Update(texture *model.Texture) error { +func (m *MockTextureRepository) Update(ctx context.Context, texture *model.Texture) error { if m.FailUpdate { return errors.New("mock update error") } @@ -309,14 +350,14 @@ func (m *MockTextureRepository) Update(texture *model.Texture) error { return nil } -func (m *MockTextureRepository) UpdateFields(id int64, fields map[string]interface{}) error { +func (m *MockTextureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error { if m.FailUpdate { return errors.New("mock update error") } return nil } -func (m *MockTextureRepository) Delete(id int64) error { +func (m *MockTextureRepository) Delete(ctx context.Context, id int64) error { if m.FailDelete { return errors.New("mock delete error") } @@ -324,39 +365,39 @@ func (m *MockTextureRepository) Delete(id int64) error { return nil } -func (m *MockTextureRepository) IncrementDownloadCount(id int64) error { +func (m *MockTextureRepository) IncrementDownloadCount(ctx context.Context, id int64) error { if texture, ok := m.textures[id]; ok { texture.DownloadCount++ } return nil } -func (m *MockTextureRepository) IncrementFavoriteCount(id int64) error { +func (m *MockTextureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error { if texture, ok := m.textures[id]; ok { texture.FavoriteCount++ } return nil } -func (m *MockTextureRepository) DecrementFavoriteCount(id int64) error { +func (m *MockTextureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error { if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 { texture.FavoriteCount-- } return nil } -func (m *MockTextureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { +func (m *MockTextureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error { return nil } -func (m *MockTextureRepository) IsFavorited(userID, textureID int64) (bool, error) { +func (m *MockTextureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) { if userFavs, ok := m.favorites[userID]; ok { return userFavs[textureID], nil } return false, nil } -func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error { +func (m *MockTextureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error { if m.favorites[userID] == nil { m.favorites[userID] = make(map[int64]bool) } @@ -364,14 +405,14 @@ func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error { return nil } -func (m *MockTextureRepository) RemoveFavorite(userID, textureID int64) error { +func (m *MockTextureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error { if userFavs, ok := m.favorites[userID]; ok { delete(userFavs, textureID) } return nil } -func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { +func (m *MockTextureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { var result []*model.Texture if userFavs, ok := m.favorites[userID]; ok { for textureID := range userFavs { @@ -383,7 +424,7 @@ func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize in return result, int64(len(result)), nil } -func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, error) { +func (m *MockTextureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) { var count int64 for _, texture := range m.textures { if texture.UploaderID == uploaderID { @@ -393,6 +434,34 @@ func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, erro return count, nil } +// FindByIDs 批量查询 Texture +func (m *MockTextureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) { + var result []*model.Texture + for _, id := range ids { + if tex, ok := m.textures[id]; ok { + result = append(result, tex) + } + } + return result, nil +} + +// BatchUpdate 仅用于满足接口 +func (m *MockTextureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) { + return 0, nil +} + +// BatchDelete 仅用于满足接口 +func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) { + var deleted int64 + for _, id := range ids { + if _, ok := m.textures[id]; ok { + delete(m.textures, id) + deleted++ + } + } + return deleted, nil +} + // MockTokenRepository 模拟TokenRepository type MockTokenRepository struct { tokens map[string]*model.Token @@ -409,7 +478,7 @@ func NewMockTokenRepository() *MockTokenRepository { } } -func (m *MockTokenRepository) Create(token *model.Token) error { +func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error { if m.FailCreate { return errors.New("mock create error") } @@ -418,7 +487,7 @@ func (m *MockTokenRepository) Create(token *model.Token) error { return nil } -func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { +func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) { if m.FailFind { return nil, errors.New("mock find error") } @@ -428,14 +497,14 @@ func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Toke return nil, errors.New("token not found") } -func (m *MockTokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { +func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) { if m.FailFind { return nil, errors.New("mock find error") } return m.userTokens[userId], nil } -func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { +func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { if m.FailFind { return "", errors.New("mock find error") } @@ -445,7 +514,7 @@ func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, return "", errors.New("token not found") } -func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { +func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { if m.FailFind { return 0, errors.New("mock find error") } @@ -455,7 +524,7 @@ func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, return 0, errors.New("token not found") } -func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error { +func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error { if m.FailDelete { return errors.New("mock delete error") } @@ -463,7 +532,7 @@ func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error { return nil } -func (m *MockTokenRepository) DeleteByUserID(userId int64) error { +func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error { if m.FailDelete { return errors.New("mock delete error") } @@ -474,7 +543,7 @@ func (m *MockTokenRepository) DeleteByUserID(userId int64) error { return nil } -func (m *MockTokenRepository) BatchDelete(accessTokens []string) (int64, error) { +func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) { if m.FailDelete { return 0, errors.New("mock delete error") } @@ -499,14 +568,14 @@ func NewMockSystemConfigRepository() *MockSystemConfigRepository { } } -func (m *MockSystemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { +func (m *MockSystemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) { if config, ok := m.configs[key]; ok { return config, nil } return nil, nil } -func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) { +func (m *MockSystemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) { var result []model.SystemConfig for _, v := range m.configs { result = append(result, *v) @@ -514,7 +583,7 @@ func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) { return result, nil } -func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) { +func (m *MockSystemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) { var result []model.SystemConfig for _, v := range m.configs { result = append(result, *v) @@ -522,12 +591,12 @@ func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) { return result, nil } -func (m *MockSystemConfigRepository) Update(config *model.SystemConfig) error { +func (m *MockSystemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error { m.configs[config.Key] = config return nil } -func (m *MockSystemConfigRepository) UpdateValue(key, value string) error { +func (m *MockSystemConfigRepository) UpdateValue(ctx context.Context, key, value string) error { if config, ok := m.configs[key]; ok { config.Value = value return nil diff --git a/internal/service/profile_service.go b/internal/service/profile_service.go index 2279135..eda9a53 100644 --- a/internal/service/profile_service.go +++ b/internal/service/profile_service.go @@ -47,7 +47,7 @@ func NewProfileService( func (s *profileService) Create(ctx context.Context, userID int64, name string) (*model.Profile, error) { // 验证用户存在 - user, err := s.userRepo.FindByID(userID) + user, err := s.userRepo.FindByID(ctx, userID) if err != nil || user == nil { return nil, errors.New("用户不存在") } @@ -56,7 +56,7 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string) } // 检查角色名是否已存在 - existingName, err := s.profileRepo.FindByName(name) + existingName, err := s.profileRepo.FindByName(ctx, name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, fmt.Errorf("查询角色名失败: %w", err) } @@ -80,12 +80,12 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string) IsActive: true, } - if err := s.profileRepo.Create(profile); err != nil { + if err := s.profileRepo.Create(ctx, profile); err != nil { return nil, fmt.Errorf("创建档案失败: %w", err) } // 设置活跃状态 - if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { + if err := s.profileRepo.SetActive(ctx, profileUUID, userID); err != nil { return nil, fmt.Errorf("设置活跃状态失败: %w", err) } @@ -104,7 +104,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro } // 缓存未命中,从数据库查询 - profile2, err := s.profileRepo.FindByUUID(uuid) + profile2, err := s.profileRepo.FindByUUID(ctx, uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrProfileNotFound @@ -131,7 +131,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode } // 缓存未命中,从数据库查询 - profiles, err := s.profileRepo.FindByUserID(userID) + profiles, err := s.profileRepo.FindByUserID(ctx, userID) if err != nil { return nil, fmt.Errorf("查询档案列表失败: %w", err) } @@ -148,7 +148,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode func (s *profileService) Update(ctx context.Context, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) + profile, err := s.profileRepo.FindByUUID(ctx, uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrProfileNotFound @@ -162,7 +162,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64, // 检查角色名是否重复 if name != nil && *name != profile.Name { - existingName, err := s.profileRepo.FindByName(*name) + existingName, err := s.profileRepo.FindByName(ctx, *name) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, fmt.Errorf("查询角色名失败: %w", err) } @@ -180,7 +180,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64, profile.CapeID = capeID } - if err := s.profileRepo.Update(profile); err != nil { + if err := s.profileRepo.Update(ctx, profile); err != nil { return nil, fmt.Errorf("更新档案失败: %w", err) } @@ -190,12 +190,12 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64, s.cacheKeys.ProfileList(userID), ) - return s.profileRepo.FindByUUID(uuid) + return s.profileRepo.FindByUUID(ctx, uuid) } func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) error { // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) + profile, err := s.profileRepo.FindByUUID(ctx, uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrProfileNotFound @@ -207,7 +207,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) return ErrProfileNoPermission } - if err := s.profileRepo.Delete(uuid); err != nil { + if err := s.profileRepo.Delete(ctx, uuid); err != nil { return fmt.Errorf("删除档案失败: %w", err) } @@ -222,7 +222,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) func (s *profileService) SetActive(ctx context.Context, uuid string, userID int64) error { // 获取档案并验证权限 - profile, err := s.profileRepo.FindByUUID(uuid) + profile, err := s.profileRepo.FindByUUID(ctx, uuid) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ErrProfileNotFound @@ -234,11 +234,11 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6 return ErrProfileNoPermission } - if err := s.profileRepo.SetActive(uuid, userID); err != nil { + if err := s.profileRepo.SetActive(ctx, uuid, userID); err != nil { return fmt.Errorf("设置活跃状态失败: %w", err) } - if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { + if err := s.profileRepo.UpdateLastUsedAt(ctx, uuid); err != nil { return fmt.Errorf("更新使用时间失败: %w", err) } @@ -249,7 +249,7 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6 } func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfiles int) error { - count, err := s.profileRepo.CountByUserID(userID) + count, err := s.profileRepo.CountByUserID(ctx, userID) if err != nil { return fmt.Errorf("查询档案数量失败: %w", err) } @@ -261,7 +261,7 @@ func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfil } func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) { - profiles, err := s.profileRepo.GetByNames(names) + profiles, err := s.profileRepo.GetByNames(ctx, names) if err != nil { return nil, fmt.Errorf("查找失败: %w", err) } @@ -270,7 +270,7 @@ func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*mod func (s *profileService) GetByProfileName(ctx context.Context, name string) (*model.Profile, error) { // Profile name 查询通常不会频繁缓存,但为了一致性也添加 - profile, err := s.profileRepo.FindByName(name) + profile, err := s.profileRepo.FindByName(ctx, name) if err != nil { return nil, errors.New("用户角色未创建") } diff --git a/internal/service/profile_service_test.go b/internal/service/profile_service_test.go index d199c43..960f090 100644 --- a/internal/service/profile_service_test.go +++ b/internal/service/profile_service_test.go @@ -426,7 +426,7 @@ func TestProfileServiceImpl_Create(t *testing.T) { Email: "test@example.com", Status: 1, } - userRepo.Create(testUser) + _ = userRepo.Create(context.Background(), testUser) cacheManager := NewMockCacheManager() profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) @@ -459,7 +459,7 @@ func TestProfileServiceImpl_Create(t *testing.T) { wantErr: true, errMsg: "角色名已被使用", setupMocks: func() { - profileRepo.Create(&model.Profile{ + _ = profileRepo.Create(context.Background(), &model.Profile{ UUID: "existing-uuid", UserID: 2, Name: "ExistingProfile", @@ -516,7 +516,7 @@ func TestProfileServiceImpl_GetByUUID(t *testing.T) { UserID: 1, Name: "TestProfile", } - profileRepo.Create(testProfile) + _ = profileRepo.Create(context.Background(), testProfile) cacheManager := NewMockCacheManager() profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) @@ -575,7 +575,7 @@ func TestProfileServiceImpl_Delete(t *testing.T) { UserID: 1, Name: "DeleteTestProfile", } - profileRepo.Create(testProfile) + _ = profileRepo.Create(context.Background(), testProfile) cacheManager := NewMockCacheManager() profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) @@ -625,9 +625,9 @@ func TestProfileServiceImpl_GetByUserID(t *testing.T) { logger := zap.NewNop() // 为用户 1 和 2 预置不同档案 - profileRepo.Create(&model.Profile{UUID: "p1", UserID: 1, Name: "P1"}) - profileRepo.Create(&model.Profile{UUID: "p2", UserID: 1, Name: "P2"}) - profileRepo.Create(&model.Profile{UUID: "p3", UserID: 2, Name: "P3"}) + _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p1", UserID: 1, Name: "P1"}) + _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p2", UserID: 1, Name: "P2"}) + _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p3", UserID: 2, Name: "P3"}) cacheManager := NewMockCacheManager() svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) @@ -653,7 +653,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) { UserID: 1, Name: "OldName", } - profileRepo.Create(profile) + _ = profileRepo.Create(context.Background(), profile) cacheManager := NewMockCacheManager() svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) @@ -678,7 +678,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) { } // 名称重复 - profileRepo.Create(&model.Profile{ + _ = profileRepo.Create(context.Background(), &model.Profile{ UUID: "u2", UserID: 2, Name: "Duplicate", @@ -705,8 +705,8 @@ func TestProfileServiceImpl_CheckLimit_And_GetByNames(t *testing.T) { logger := zap.NewNop() // 为用户 1 预置 2 个档案 - profileRepo.Create(&model.Profile{UUID: "a", UserID: 1, Name: "A"}) - profileRepo.Create(&model.Profile{UUID: "b", UserID: 1, Name: "B"}) + _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "a", UserID: 1, Name: "A"}) + _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "b", UserID: 1, Name: "B"}) cacheManager := NewMockCacheManager() svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) diff --git a/internal/service/signature_service.go b/internal/service/signature_service.go index b1f8134..817736b 100644 --- a/internal/service/signature_service.go +++ b/internal/service/signature_service.go @@ -32,8 +32,8 @@ const ( RedisTTL = 0 // 永不过期,由应用程序管理过期时间 ) -// signatureService 签名服务实现 -type signatureService struct { +// SignatureService 签名服务(导出以便依赖注入) +type SignatureService struct { profileRepo repository.ProfileRepository redis *redis.Client logger *zap.Logger @@ -44,8 +44,8 @@ func NewSignatureService( profileRepo repository.ProfileRepository, redisClient *redis.Client, logger *zap.Logger, -) *signatureService { - return &signatureService{ +) *SignatureService { + return &SignatureService{ profileRepo: profileRepo, redis: redisClient, logger: logger, @@ -53,7 +53,7 @@ func NewSignatureService( } // NewKeyPair 生成新的RSA密钥对 -func (s *signatureService) NewKeyPair() (*model.KeyPair, error) { +func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) { privateKey, err := rsa.GenerateKey(rand.Reader, KeySize) if err != nil { return nil, fmt.Errorf("生成RSA密钥对失败: %w", err) @@ -132,7 +132,7 @@ func (s *signatureService) NewKeyPair() (*model.KeyPair, error) { } // GetOrCreateYggdrasilKeyPair 获取或创建Yggdrasil根密钥对 -func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) { +func (s *SignatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) { ctx := context.Background() // 尝试从Redis获取密钥 @@ -201,7 +201,7 @@ func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKe } // GetPublicKeyFromRedis 从Redis获取公钥 -func (s *signatureService) GetPublicKeyFromRedis() (string, error) { +func (s *SignatureService) GetPublicKeyFromRedis() (string, error) { ctx := context.Background() publicKey, err := s.redis.Get(ctx, PublicKeyRedisKey) if err != nil { @@ -218,7 +218,7 @@ func (s *signatureService) GetPublicKeyFromRedis() (string, error) { } // SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串 -func (s *signatureService) SignStringWithSHA1withRSA(data string) (string, error) { +func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) { ctx := context.Background() // 从Redis获取私钥 diff --git a/internal/service/texture_service.go b/internal/service/texture_service.go index c4a3521..146fe6f 100644 --- a/internal/service/texture_service.go +++ b/internal/service/texture_service.go @@ -41,13 +41,13 @@ func NewTextureService( func (s *textureService) Create(ctx context.Context, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { // 验证用户存在 - user, err := s.userRepo.FindByID(uploaderID) + user, err := s.userRepo.FindByID(ctx, uploaderID) if err != nil || user == nil { return nil, ErrUserNotFound } // 检查Hash是否已存在 - existingTexture, err := s.textureRepo.FindByHash(hash) + existingTexture, err := s.textureRepo.FindByHash(ctx, hash) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (s *textureService) Create(ctx context.Context, uploaderID int64, name, des FavoriteCount: 0, } - if err := s.textureRepo.Create(texture); err != nil { + if err := s.textureRepo.Create(ctx, texture); err != nil { return nil, err } @@ -99,7 +99,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture, } // 缓存未命中,从数据库查询 - texture2, err := s.textureRepo.FindByID(id) + texture2, err := s.textureRepo.FindByID(ctx, id) if err != nil { return nil, err } @@ -132,7 +132,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex } // 缓存未命中,从数据库查询 - texture2, err := s.textureRepo.FindByHash(hash) + texture2, err := s.textureRepo.FindByHash(ctx, hash) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page } // 缓存未命中,从数据库查询 - textures, total, err := s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) + textures, total, err := s.textureRepo.FindByUploaderID(ctx, uploaderID, page, pageSize) if err != nil { return nil, 0, err } @@ -184,12 +184,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page func (s *textureService) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) + return s.textureRepo.Search(ctx, keyword, textureType, publicOnly, page, pageSize) } func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) + texture, err := s.textureRepo.FindByID(ctx, textureID) if err != nil { return nil, err } @@ -213,7 +213,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64 } if len(updates) > 0 { - if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { + if err := s.textureRepo.UpdateFields(ctx, textureID, updates); err != nil { return nil, err } } @@ -222,12 +222,12 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64 s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID)) s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID)) - return s.textureRepo.FindByID(textureID) + return s.textureRepo.FindByID(ctx, textureID) } func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64) error { // 获取材质并验证权限 - texture, err := s.textureRepo.FindByID(textureID) + texture, err := s.textureRepo.FindByID(ctx, textureID) if err != nil { return err } @@ -238,7 +238,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64 return ErrTextureNoPermission } - err = s.textureRepo.Delete(textureID) + err = s.textureRepo.Delete(ctx, textureID) if err != nil { return err } @@ -252,7 +252,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64 func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID int64) (bool, error) { // 确保材质存在 - texture, err := s.textureRepo.FindByID(textureID) + texture, err := s.textureRepo.FindByID(ctx, textureID) if err != nil { return false, err } @@ -260,27 +260,27 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i return false, ErrTextureNotFound } - isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) + isFavorited, err := s.textureRepo.IsFavorited(ctx, userID, textureID) if err != nil { return false, err } if isFavorited { // 已收藏 -> 取消收藏 - if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { + if err := s.textureRepo.RemoveFavorite(ctx, userID, textureID); err != nil { return false, err } - if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { + if err := s.textureRepo.DecrementFavoriteCount(ctx, textureID); err != nil { return false, err } return false, nil } // 未收藏 -> 添加收藏 - if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { + if err := s.textureRepo.AddFavorite(ctx, userID, textureID); err != nil { return false, err } - if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { + if err := s.textureRepo.IncrementFavoriteCount(ctx, textureID); err != nil { return false, err } return true, nil @@ -288,11 +288,11 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i func (s *textureService) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { page, pageSize = NormalizePagination(page, pageSize) - return s.textureRepo.GetUserFavorites(userID, page, pageSize) + return s.textureRepo.GetUserFavorites(ctx, userID, page, pageSize) } func (s *textureService) CheckUploadLimit(ctx context.Context, uploaderID int64, maxTextures int) error { - count, err := s.textureRepo.CountByUploaderID(uploaderID) + count, err := s.textureRepo.CountByUploaderID(ctx, uploaderID) if err != nil { return err } diff --git a/internal/service/texture_service_test.go b/internal/service/texture_service_test.go index 43504fb..990d2c6 100644 --- a/internal/service/texture_service_test.go +++ b/internal/service/texture_service_test.go @@ -491,7 +491,7 @@ func TestTextureServiceImpl_Create(t *testing.T) { Email: "test@example.com", Status: 1, } - userRepo.Create(testUser) + _ = userRepo.Create(context.Background(), testUser) cacheManager := NewMockCacheManager() textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) @@ -539,7 +539,7 @@ func TestTextureServiceImpl_Create(t *testing.T) { wantErr: true, errContains: "已存在", setupMocks: func() { - textureRepo.Create(&model.Texture{ + _ = textureRepo.Create(context.Background(), &model.Texture{ ID: 100, UploaderID: 1, Name: "ExistingTexture", @@ -614,7 +614,7 @@ func TestTextureServiceImpl_GetByID(t *testing.T) { Name: "TestTexture", Hash: "test-hash", } - textureRepo.Create(testTexture) + _ = textureRepo.Create(context.Background(), testTexture) cacheManager := NewMockCacheManager() textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) @@ -666,7 +666,7 @@ func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) { // 预置多条 Texture for i := int64(1); i <= 5; i++ { - textureRepo.Create(&model.Texture{ + _ = textureRepo.Create(context.Background(), &model.Texture{ ID: i, UploaderID: 1, Name: "T", @@ -711,7 +711,7 @@ func TestTextureServiceImpl_Update_And_Delete(t *testing.T) { Description: "OldDesc", IsPublic: false, } - textureRepo.Create(texture) + _ = textureRepo.Create(context.Background(), texture) cacheManager := NewMockCacheManager() textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) @@ -755,12 +755,12 @@ func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) { // 预置若干 Texture 与收藏关系 for i := int64(1); i <= 3; i++ { - textureRepo.Create(&model.Texture{ + _ = textureRepo.Create(context.Background(), &model.Texture{ ID: i, UploaderID: 1, Name: "T", }) - _ = textureRepo.AddFavorite(1, i) + _ = textureRepo.AddFavorite(context.Background(), 1, i) } cacheManager := NewMockCacheManager() @@ -796,7 +796,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) { // 预置用户和Texture testUser := &model.User{ID: 1, Username: "testuser", Status: 1} - userRepo.Create(testUser) + _ = userRepo.Create(context.Background(), testUser) testTexture := &model.Texture{ ID: 1, @@ -804,7 +804,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) { Name: "TestTexture", Hash: "test-hash", } - textureRepo.Create(testTexture) + _ = textureRepo.Create(context.Background(), testTexture) cacheManager := NewMockCacheManager() textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) diff --git a/internal/service/token_service.go b/internal/service/token_service.go index 1dca6d5..840a597 100644 --- a/internal/service/token_service.go +++ b/internal/service/token_service.go @@ -46,12 +46,12 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl ) // 设置超时上下文 - _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) defer cancel() // 验证用户存在 if UUID != "" { - _, err := s.profileRepo.FindByUUID(UUID) + _, err := s.profileRepo.FindByUUID(ctx, UUID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) } @@ -72,7 +72,7 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl } // 获取用户配置文件 - profiles, err := s.profileRepo.FindByUserID(userID) + profiles, err := s.profileRepo.FindByUserID(ctx, userID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) } @@ -85,23 +85,27 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl availableProfiles = profiles // 插入令牌 - err = s.tokenRepo.Create(&token) + err = s.tokenRepo.Create(ctx, &token) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) } - // 清理多余的令牌 - go s.checkAndCleanupExcessTokens(userID) + // 清理多余的令牌(使用独立的后台上下文) + go s.checkAndCleanupExcessTokens(context.Background(), userID) return selectedProfileID, availableProfiles, accessToken, clientToken, nil } func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if accessToken == "" { return false } - token, err := s.tokenRepo.FindByAccessToken(accessToken) + token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken) if err != nil { return false } @@ -118,12 +122,16 @@ func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken st } func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if accessToken == "" { return "", "", errors.New("accessToken不能为空") } // 查找旧令牌 - oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) + oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", "", errors.New("accessToken无效") @@ -134,7 +142,7 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se // 验证profile if selectedProfileID != "" { - valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) + valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID) if validErr != nil { s.logger.Error("验证Profile失败", zap.Error(err), @@ -174,13 +182,13 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se } // 先插入新令牌,再删除旧令牌 - err = s.tokenRepo.Create(&newToken) + err = s.tokenRepo.Create(ctx, &newToken) if err != nil { s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return "", "", fmt.Errorf("创建新Token失败: %w", err) } - err = s.tokenRepo.DeleteByAccessToken(accessToken) + err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken) if err != nil { s.logger.Warn("删除旧Token失败,但新Token已创建", zap.Error(err), @@ -194,11 +202,15 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se } func (s *tokenService) Invalidate(ctx context.Context, accessToken string) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if accessToken == "" { return } - err := s.tokenRepo.DeleteByAccessToken(accessToken) + err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken) if err != nil { s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) return @@ -207,11 +219,15 @@ func (s *tokenService) Invalidate(ctx context.Context, accessToken string) { } func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if userID == 0 { return } - err := s.tokenRepo.DeleteByUserID(userID) + err := s.tokenRepo.DeleteByUserID(ctx, userID) if err != nil { s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) return @@ -221,21 +237,33 @@ func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) { } func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { - return s.tokenRepo.GetUUIDByAccessToken(accessToken) + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + + return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken) } func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { - return s.tokenRepo.GetUserIDByAccessToken(accessToken) + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + + return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken) } // 私有辅助方法 -func (s *tokenService) checkAndCleanupExcessTokens(userID int64) { +func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) { if userID == 0 { return } - tokens, err := s.tokenRepo.GetByUserID(userID) + // 为清理操作设置更长的超时时间 + ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout) + defer cancel() + + tokens, err := s.tokenRepo.GetByUserID(ctx, userID) if err != nil { s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return @@ -250,7 +278,7 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) { tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) } - deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) + deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete) if err != nil { s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return @@ -261,12 +289,12 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) { } } -func (s *tokenService) validateProfileByUserID(userID int64, UUID string) (bool, error) { +func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) { if userID == 0 || UUID == "" { return false, errors.New("用户ID或配置文件ID不能为空") } - profile, err := s.profileRepo.FindByUUID(UUID) + profile, err := s.profileRepo.FindByUUID(ctx, UUID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return false, errors.New("配置文件不存在") diff --git a/internal/service/token_service_jwt.go b/internal/service/token_service_jwt.go index dd6014a..caabe16 100644 --- a/internal/service/token_service_jwt.go +++ b/internal/service/token_service_jwt.go @@ -55,12 +55,12 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, ) // 设置超时上下文 - _, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) defer cancel() // 验证用户存在 if UUID != "" { - _, err := s.profileRepo.FindByUUID(UUID) + _, err := s.profileRepo.FindByUUID(ctx, UUID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) } @@ -73,7 +73,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, // 获取或创建Client var client *model.Client - existingClient, err := s.clientRepo.FindByClientToken(clientToken) + existingClient, err := s.clientRepo.FindByClientToken(ctx, clientToken) if err != nil { // Client不存在,创建新的 clientUUID := uuid.New().String() @@ -90,7 +90,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, client.ProfileID = UUID } - if err := s.clientRepo.Create(client); err != nil { + if err := s.clientRepo.Create(ctx, client); err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err) } } else { @@ -103,14 +103,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, client.UpdatedAt = time.Now() if UUID != "" { client.ProfileID = UUID - if err := s.clientRepo.Update(client); err != nil { + if err := s.clientRepo.Update(ctx, client); err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err) } } } // 获取用户配置文件 - profiles, err := s.profileRepo.FindByUserID(userID) + profiles, err := s.profileRepo.FindByUserID(ctx, userID) if err != nil { return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) } @@ -122,7 +122,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, if profileID == "" { profileID = selectedProfileID.UUID client.ProfileID = profileID - s.clientRepo.Update(client) + _ = s.clientRepo.Update(ctx, client) } } availableProfiles = profiles @@ -170,20 +170,23 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string, StaleAt: &staleAt, } - err = s.tokenRepo.Create(&token) + err = s.tokenRepo.Create(ctx, &token) if err != nil { s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err)) // 不返回错误,因为JWT本身已经生成成功 } - // 清理多余的令牌 - go s.checkAndCleanupExcessTokens(userID) + // 清理多余的令牌(使用独立的后台上下文) + go s.checkAndCleanupExcessTokens(context.Background(), userID) return selectedProfileID, availableProfiles, accessToken, clientToken, nil } // Validate 验证Token(使用JWT验证) func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() if accessToken == "" { return false } @@ -195,7 +198,7 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken } // 查找Client - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { return false } @@ -215,6 +218,9 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken // Refresh 刷新Token(使用Version机制,无需删除旧Token) func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() if accessToken == "" { return "", "", errors.New("accessToken不能为空") } @@ -226,7 +232,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, } // 查找Client - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { return "", "", errors.New("无法找到对应的Client") } @@ -243,7 +249,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, // 验证Profile if selectedProfileID != "" { - valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID) + valid, validErr := s.validateProfileByUserID(ctx, client.UserID, selectedProfileID) if validErr != nil { s.logger.Error("验证Profile失败", zap.Error(validErr), @@ -269,7 +275,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, // 增加Version(这是关键:通过Version失效所有旧Token) client.Version++ client.UpdatedAt = time.Now() - if err := s.clientRepo.Update(client); err != nil { + if err := s.clientRepo.Update(ctx, client); err != nil { return "", "", fmt.Errorf("更新Client版本失败: %w", err) } @@ -315,7 +321,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, StaleAt: &staleAt, } - err = s.tokenRepo.Create(&newToken) + err = s.tokenRepo.Create(ctx, &newToken) if err != nil { s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err)) } @@ -326,6 +332,10 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, // Invalidate 使Token失效(通过增加Version) func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if accessToken == "" { return } @@ -338,7 +348,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { } // 查找Client并增加Version - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { s.logger.Warn("无法找到对应的Client", zap.Error(err)) return @@ -347,7 +357,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { // 增加Version以失效所有旧Token client.Version++ client.UpdatedAt = time.Now() - if err := s.clientRepo.Update(client); err != nil { + if err := s.clientRepo.Update(ctx, client); err != nil { s.logger.Error("失效Token失败", zap.Error(err)) return } @@ -357,12 +367,16 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { // InvalidateUserTokens 使用户所有Token失效 func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) { + // 设置超时上下文 + ctx, cancel := context.WithTimeout(ctx, DefaultTimeout) + defer cancel() + if userID == 0 { return } // 获取用户所有Client - clients, err := s.clientRepo.FindByUserID(userID) + clients, err := s.clientRepo.FindByUserID(ctx, userID) if err != nil { s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID)) return @@ -372,7 +386,7 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64 for _, client := range clients { client.Version++ client.UpdatedAt = time.Now() - if err := s.clientRepo.Update(client); err != nil { + if err := s.clientRepo.Update(ctx, client); err != nil { s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID)) } } @@ -385,7 +399,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { // 如果JWT解析失败,尝试从数据库查询(向后兼容) - return s.tokenRepo.GetUUIDByAccessToken(accessToken) + return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken) } if claims.ProfileID != "" { @@ -393,7 +407,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken } // 如果没有ProfileID,从Client获取 - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { return "", fmt.Errorf("无法找到对应的Client: %w", err) } @@ -410,11 +424,11 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) if err != nil { // 如果JWT解析失败,尝试从数据库查询(向后兼容) - return s.tokenRepo.GetUserIDByAccessToken(accessToken) + return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken) } // 从Client获取UserID - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { return 0, fmt.Errorf("无法找到对应的Client: %w", err) } @@ -429,12 +443,16 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke // 私有辅助方法 -func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) { +func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) { if userID == 0 { return } - tokens, err := s.tokenRepo.GetByUserID(userID) + // 为清理操作设置更长的超时时间 + ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout) + defer cancel() + + tokens, err := s.tokenRepo.GetByUserID(ctx, userID) if err != nil { s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return @@ -449,7 +467,7 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) { tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) } - deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) + deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete) if err != nil { s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) return @@ -460,12 +478,12 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) { } } -func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) { +func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) { if userID == 0 || UUID == "" { return false, errors.New("用户ID或配置文件ID不能为空") } - profile, err := s.profileRepo.FindByUUID(UUID) + profile, err := s.profileRepo.FindByUUID(ctx, UUID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return false, errors.New("配置文件不存在") @@ -482,7 +500,7 @@ func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken st return nil, err } - client, err := s.clientRepo.FindByUUID(claims.Subject) + client, err := s.clientRepo.FindByUUID(ctx, claims.Subject) if err != nil { return nil, err } diff --git a/internal/service/token_service_test.go b/internal/service/token_service_test.go index 826e281..c3c6e98 100644 --- a/internal/service/token_service_test.go +++ b/internal/service/token_service_test.go @@ -208,7 +208,7 @@ func TestTokenServiceImpl_Create(t *testing.T) { Name: "TestProfile", IsActive: true, } - profileRepo.Create(testProfile) + _ = profileRepo.Create(context.Background(), testProfile) tokenService := NewTokenService(tokenRepo, profileRepo, logger) @@ -274,7 +274,7 @@ func TestTokenServiceImpl_Validate(t *testing.T) { ProfileId: "test-profile-uuid", Usable: true, } - tokenRepo.Create(testToken) + _ = tokenRepo.Create(context.Background(), testToken) tokenService := NewTokenService(tokenRepo, profileRepo, logger) @@ -336,7 +336,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) { ProfileId: "test-profile-uuid", Usable: true, } - tokenRepo.Create(testToken) + _ = tokenRepo.Create(context.Background(), testToken) tokenService := NewTokenService(tokenRepo, profileRepo, logger) @@ -352,7 +352,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) { tokenService.Invalidate(ctx, "token-to-invalidate") // 验证Token已失效(从repo中删除) - _, err := tokenRepo.FindByAccessToken("token-to-invalidate") + _, err := tokenRepo.FindByAccessToken(context.Background(), "token-to-invalidate") if err == nil { t.Error("Token应该已被删除") } @@ -366,7 +366,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) { // 预置多个Token for i := 1; i <= 3; i++ { - tokenRepo.Create(&model.Token{ + _ = tokenRepo.Create(context.Background(), &model.Token{ AccessToken: fmt.Sprintf("user1-token-%d", i), ClientToken: "client-token", UserID: 1, @@ -374,7 +374,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) { Usable: true, }) } - tokenRepo.Create(&model.Token{ + _ = tokenRepo.Create(context.Background(), &model.Token{ AccessToken: "user2-token-1", ClientToken: "client-token", UserID: 2, @@ -390,13 +390,13 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) { tokenService.InvalidateUserTokens(ctx, 1) // 验证用户1的Token已失效 - tokens, _ := tokenRepo.GetByUserID(1) + tokens, _ := tokenRepo.GetByUserID(context.Background(), 1) if len(tokens) > 0 { t.Errorf("用户1的Token应该全部被删除,但还剩 %d 个", len(tokens)) } // 验证用户2的Token仍然存在 - tokens2, _ := tokenRepo.GetByUserID(2) + tokens2, _ := tokenRepo.GetByUserID(context.Background(), 2) if len(tokens2) != 1 { t.Errorf("用户2的Token应该仍然存在,期望1个,实际 %d 个", len(tokens2)) } @@ -413,7 +413,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) { UUID: "profile-uuid", UserID: 1, } - profileRepo.Create(profile) + _ = profileRepo.Create(context.Background(), profile) oldToken := &model.Token{ AccessToken: "old-token", @@ -422,7 +422,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) { ProfileId: "", Usable: true, } - tokenRepo.Create(oldToken) + _ = tokenRepo.Create(context.Background(), oldToken) tokenService := NewTokenService(tokenRepo, profileRepo, logger) @@ -455,7 +455,7 @@ func TestTokenServiceImpl_GetByAccessToken(t *testing.T) { ProfileId: "profile-42", Usable: true, } - tokenRepo.Create(token) + _ = tokenRepo.Create(context.Background(), token) tokenService := NewTokenService(tokenRepo, profileRepo, logger) @@ -489,25 +489,25 @@ func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) { UUID: "p-1", UserID: 1, } - profileRepo.Create(profile) + _ = profileRepo.Create(context.Background(), profile) // 参数非法 - if ok, err := svc.validateProfileByUserID(0, ""); err == nil || ok { + if ok, err := svc.validateProfileByUserID(context.Background(), 0, ""); err == nil || ok { t.Fatalf("validateProfileByUserID 在参数非法时应返回错误") } // Profile 不存在 - if ok, err := svc.validateProfileByUserID(1, "not-exists"); err == nil || ok { + if ok, err := svc.validateProfileByUserID(context.Background(), 1, "not-exists"); err == nil || ok { t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误") } // 用户与 Profile 匹配 - if ok, err := svc.validateProfileByUserID(1, "p-1"); err != nil || !ok { + if ok, err := svc.validateProfileByUserID(context.Background(), 1, "p-1"); err != nil || !ok { t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err) } // 用户与 Profile 不匹配 - if ok, err := svc.validateProfileByUserID(2, "p-1"); err != nil || ok { + if ok, err := svc.validateProfileByUserID(context.Background(), 2, "p-1"); err != nil || ok { t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err) } } diff --git a/internal/service/user_service.go b/internal/service/user_service.go index 599a46e..4a556b8 100644 --- a/internal/service/user_service.go +++ b/internal/service/user_service.go @@ -1,12 +1,6 @@ package service import ( - "carrotskin/internal/model" - "carrotskin/internal/repository" - "carrotskin/pkg/auth" - "carrotskin/pkg/config" - "carrotskin/pkg/database" - "carrotskin/pkg/redis" "context" "errors" "fmt" @@ -14,6 +8,14 @@ import ( "strings" "time" + apperrors "carrotskin/internal/errors" + "carrotskin/internal/model" + "carrotskin/internal/repository" + "carrotskin/pkg/auth" + "carrotskin/pkg/config" + "carrotskin/pkg/database" + "carrotskin/pkg/redis" + "go.uber.org/zap" ) @@ -54,21 +56,21 @@ func NewUserService( func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) { // 检查用户名是否已存在 - existingUser, err := s.userRepo.FindByUsername(username) + existingUser, err := s.userRepo.FindByUsername(ctx, username) if err != nil { return nil, "", err } if existingUser != nil { - return nil, "", errors.New("用户名已存在") + return nil, "", apperrors.ErrUserAlreadyExists } // 检查邮箱是否已存在 - existingEmail, err := s.userRepo.FindByEmail(email) + existingEmail, err := s.userRepo.FindByEmail(ctx, email) if err != nil { return nil, "", err } if existingEmail != nil { - return nil, "", errors.New("邮箱已被注册") + return nil, "", apperrors.ErrEmailAlreadyExists } // 加密密码 @@ -98,7 +100,7 @@ func (s *userService) Register(ctx context.Context, username, password, email, a Points: 0, } - if err := s.userRepo.Create(user); err != nil { + if err := s.userRepo.Create(ctx, user); err != nil { return nil, "", err } @@ -126,9 +128,9 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd var err error if strings.Contains(usernameOrEmail, "@") { - user, err = s.userRepo.FindByEmail(usernameOrEmail) + user, err = s.userRepo.FindByEmail(ctx, usernameOrEmail) } else { - user, err = s.userRepo.FindByUsername(usernameOrEmail) + user, err = s.userRepo.FindByUsername(ctx, usernameOrEmail) } if err != nil { @@ -166,12 +168,12 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd // 更新最后登录时间 now := time.Now() user.LastLoginAt = &now - _ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ + _ = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{ "last_login_at": now, }) // 记录成功登录日志 - s.logSuccessLogin(user.ID, ipAddress, userAgent) + s.logSuccessLogin(ctx, user.ID, ipAddress, userAgent) return user, token, nil } @@ -180,7 +182,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error // 使用 Cached 装饰器自动处理缓存 cacheKey := s.cacheKeys.User(id) return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { - return s.userRepo.FindByID(id) + return s.userRepo.FindByID(ctx, id) }, 5*time.Minute) } @@ -188,12 +190,12 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User // 使用 Cached 装饰器自动处理缓存 cacheKey := s.cacheKeys.UserByEmail(email) return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { - return s.userRepo.FindByEmail(email) + return s.userRepo.FindByEmail(ctx, email) }, 5*time.Minute) } func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error { - err := s.userRepo.Update(user) + err := s.userRepo.Update(ctx, user) if err != nil { return err } @@ -209,7 +211,7 @@ func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error { } func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error { - err := s.userRepo.UpdateFields(userID, map[string]interface{}{ + err := s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{ "avatar": avatarURL, }) if err != nil { @@ -223,7 +225,7 @@ func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL } func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error { - user, err := s.userRepo.FindByID(userID) + user, err := s.userRepo.FindByID(ctx, userID) if err != nil || user == nil { return errors.New("用户不存在") } @@ -237,7 +239,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw return errors.New("密码加密失败") } - err = s.userRepo.UpdateFields(userID, map[string]interface{}{ + err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{ "password": hashedPassword, }) if err != nil { @@ -251,7 +253,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw } func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error { - user, err := s.userRepo.FindByEmail(email) + user, err := s.userRepo.FindByEmail(ctx, email) if err != nil || user == nil { return errors.New("用户不存在") } @@ -261,7 +263,7 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri return errors.New("密码加密失败") } - err = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ + err = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{ "password": hashedPassword, }) if err != nil { @@ -279,17 +281,17 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error { // 获取旧邮箱 - oldUser, _ := s.userRepo.FindByID(userID) + oldUser, _ := s.userRepo.FindByID(ctx, userID) - existingUser, err := s.userRepo.FindByEmail(newEmail) + existingUser, err := s.userRepo.FindByEmail(ctx, newEmail) if err != nil { return err } if existingUser != nil && existingUser.ID != userID { - return errors.New("邮箱已被其他用户使用") + return apperrors.ErrEmailAlreadyExists } - err = s.userRepo.UpdateFields(userID, map[string]interface{}{ + err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{ "email": newEmail, }) if err != nil { @@ -346,7 +348,7 @@ func (s *userService) ValidateAvatarURL(ctx context.Context, avatarURL string) e } func (s *userService) GetMaxProfilesPerUser() int { - config, err := s.configRepo.GetByKey("max_profiles_per_user") + config, err := s.configRepo.GetByKey(context.Background(), "max_profiles_per_user") if err != nil || config == nil { return 5 } @@ -359,7 +361,7 @@ func (s *userService) GetMaxProfilesPerUser() int { } func (s *userService) GetMaxTexturesPerUser() int { - config, err := s.configRepo.GetByKey("max_textures_per_user") + config, err := s.configRepo.GetByKey(context.Background(), "max_textures_per_user") if err != nil || config == nil { return 50 } @@ -374,7 +376,7 @@ func (s *userService) GetMaxTexturesPerUser() int { // 私有辅助方法 func (s *userService) getDefaultAvatar() string { - config, err := s.configRepo.GetByKey("default_avatar") + config, err := s.configRepo.GetByKey(context.Background(), "default_avatar") if err != nil || config == nil || config.Value == "" { return "" } @@ -410,14 +412,14 @@ func (s *userService) recordLoginFailure(ctx context.Context, usernameOrEmail, i identifier := usernameOrEmail + ":" + ipAddress count, _ := RecordLoginFailure(ctx, s.redis, identifier) if count >= MaxLoginAttempts { - s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") + s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason+"-账号已锁定") return } } - s.logFailedLogin(userID, ipAddress, userAgent, reason) + s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason) } -func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string) { +func (s *userService) logSuccessLogin(ctx context.Context, userID int64, ipAddress, userAgent string) { log := &model.UserLoginLog{ UserID: userID, IPAddress: ipAddress, @@ -425,10 +427,10 @@ func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string) LoginMethod: "PASSWORD", IsSuccess: true, } - _ = s.userRepo.CreateLoginLog(log) + _ = s.userRepo.CreateLoginLog(ctx, log) } -func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { +func (s *userService) logFailedLogin(ctx context.Context, userID int64, ipAddress, userAgent, reason string) { log := &model.UserLoginLog{ UserID: userID, IPAddress: ipAddress, @@ -437,5 +439,5 @@ func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason IsSuccess: false, FailureReason: reason, } - _ = s.userRepo.CreateLoginLog(log) + _ = s.userRepo.CreateLoginLog(ctx, log) } diff --git a/internal/service/user_service_test.go b/internal/service/user_service_test.go index 91ff893..5ca4abf 100644 --- a/internal/service/user_service_test.go +++ b/internal/service/user_service_test.go @@ -49,9 +49,10 @@ func TestUserServiceImpl_Register(t *testing.T) { email: "new@example.com", avatar: "", wantErr: true, - errMsg: "用户名已存在", + // 服务实现现已统一使用 apperrors.ErrUserAlreadyExists,错误信息为“用户已存在” + errMsg: "用户已存在", setupMocks: func() { - userRepo.Create(&model.User{ + _ = userRepo.Create(context.Background(), &model.User{ Username: "existinguser", Email: "old@example.com", }) @@ -66,7 +67,7 @@ func TestUserServiceImpl_Register(t *testing.T) { wantErr: true, errMsg: "邮箱已被注册", setupMocks: func() { - userRepo.Create(&model.User{ + _ = userRepo.Create(context.Background(), &model.User{ Username: "otheruser", Email: "existing@example.com", }) @@ -126,7 +127,7 @@ func TestUserServiceImpl_Login(t *testing.T) { Password: hashedPassword, Status: 1, } - userRepo.Create(testUser) + _ = userRepo.Create(context.Background(), testUser) cacheManager := NewMockCacheManager() userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) @@ -207,7 +208,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) { Email: "basic@example.com", Avatar: "", } - userRepo.Create(user) + _ = userRepo.Create(context.Background(), user) cacheManager := NewMockCacheManager() userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) @@ -231,7 +232,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) { if err := userService.UpdateInfo(ctx, user); err != nil { t.Fatalf("UpdateInfo 失败: %v", err) } - updated, _ := userRepo.FindByID(1) + updated, _ := userRepo.FindByID(context.Background(), 1) if updated.Username != "updated" { t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username) } @@ -255,7 +256,7 @@ func TestUserServiceImpl_ChangePassword(t *testing.T) { Username: "changepw", Password: hashed, } - userRepo.Create(user) + _ = userRepo.Create(context.Background(), user) cacheManager := NewMockCacheManager() userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) @@ -290,7 +291,7 @@ func TestUserServiceImpl_ResetPassword(t *testing.T) { Username: "resetpw", Email: "reset@example.com", } - userRepo.Create(user) + _ = userRepo.Create(context.Background(), user) cacheManager := NewMockCacheManager() userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) @@ -317,8 +318,8 @@ func TestUserServiceImpl_ChangeEmail(t *testing.T) { user1 := &model.User{ID: 1, Email: "user1@example.com"} user2 := &model.User{ID: 2, Email: "user2@example.com"} - userRepo.Create(user1) - userRepo.Create(user2) + _ = userRepo.Create(context.Background(), user1) + _ = userRepo.Create(context.Background(), user2) cacheManager := NewMockCacheManager() userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) @@ -389,8 +390,8 @@ func TestUserServiceImpl_MaxLimits(t *testing.T) { } // 配置有效值 - configRepo.Update(&model.SystemConfig{Key: "max_profiles_per_user", Value: "10"}) - configRepo.Update(&model.SystemConfig{Key: "max_textures_per_user", Value: "100"}) + _ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_profiles_per_user", Value: "10"}) + _ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_textures_per_user", Value: "100"}) if got := userService.GetMaxProfilesPerUser(); got != 10 { t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got) diff --git a/internal/service/yggdrasil_auth_service.go b/internal/service/yggdrasil_auth_service.go index 6cba18a..fed0593 100644 --- a/internal/service/yggdrasil_auth_service.go +++ b/internal/service/yggdrasil_auth_service.go @@ -38,7 +38,7 @@ func NewYggdrasilAuthService( } func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email string) (int64, error) { - user, err := s.userRepo.FindByEmail(email) + user, err := s.userRepo.FindByEmail(ctx, email) if err != nil { return 0, apperrors.ErrUserNotFound } @@ -46,7 +46,7 @@ func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email strin } func (s *yggdrasilAuthService) VerifyPassword(ctx context.Context, password string, userID int64) error { - passwordStore, err := s.yggdrasilRepo.GetPasswordByID(userID) + passwordStore, err := s.yggdrasilRepo.GetPasswordByID(ctx, userID) if err != nil { return apperrors.ErrPasswordNotSet } @@ -68,7 +68,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI } // 检查Yggdrasil记录是否存在 - _, err = s.yggdrasilRepo.GetPasswordByID(userID) + _, err = s.yggdrasilRepo.GetPasswordByID(ctx, userID) if err != nil { // 如果不存在,创建新记录 yggdrasil := model.Yggdrasil{ @@ -82,7 +82,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI } // 如果存在,更新密码(存储加密后的密码) - if err := s.yggdrasilRepo.ResetPassword(userID, hashedPassword); err != nil { + if err := s.yggdrasilRepo.ResetPassword(ctx, userID, hashedPassword); err != nil { return "", fmt.Errorf("重置Yggdrasil密码失败: %w", err) } diff --git a/internal/service/yggdrasil_certificate_service.go b/internal/service/yggdrasil_certificate_service.go index eacb54b..4605368 100644 --- a/internal/service/yggdrasil_certificate_service.go +++ b/internal/service/yggdrasil_certificate_service.go @@ -21,14 +21,14 @@ type CertificateService interface { // yggdrasilCertificateService 证书服务实现 type yggdrasilCertificateService struct { profileRepo repository.ProfileRepository - signatureService *signatureService + signatureService *SignatureService logger *zap.Logger } // NewCertificateService 创建证书服务实例 func NewCertificateService( profileRepo repository.ProfileRepository, - signatureService *signatureService, + signatureService *SignatureService, logger *zap.Logger, ) CertificateService { return &yggdrasilCertificateService{ @@ -49,7 +49,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont ) // 获取密钥对 - keyPair, err := s.profileRepo.GetKeyPair(uuid) + keyPair, err := s.profileRepo.GetKeyPair(ctx, uuid) if err != nil { s.logger.Info("获取用户密钥对失败,将创建新密钥对", zap.Error(err), @@ -74,7 +74,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont } // 保存密钥对到数据库 - err = s.profileRepo.UpdateKeyPair(uuid, keyPair) + err = s.profileRepo.UpdateKeyPair(ctx, uuid, keyPair) if err != nil { s.logger.Warn("更新用户密钥对失败", zap.Error(err), diff --git a/internal/service/yggdrasil_serialization_service.go b/internal/service/yggdrasil_serialization_service.go index 34d0b71..7d403ed 100644 --- a/internal/service/yggdrasil_serialization_service.go +++ b/internal/service/yggdrasil_serialization_service.go @@ -28,14 +28,14 @@ type Property struct { // yggdrasilSerializationService 序列化服务实现 type yggdrasilSerializationService struct { textureRepo repository.TextureRepository - signatureService *signatureService + signatureService *SignatureService logger *zap.Logger } // NewSerializationService 创建序列化服务实例 func NewSerializationService( textureRepo repository.TextureRepository, - signatureService *signatureService, + signatureService *SignatureService, logger *zap.Logger, ) SerializationService { return &yggdrasilSerializationService{ @@ -58,7 +58,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr // 处理皮肤 if profile.SkinID != nil { - skin, err := s.textureRepo.FindByID(*profile.SkinID) + skin, err := s.textureRepo.FindByID(ctx, *profile.SkinID) if err != nil { s.logger.Error("获取皮肤失败", zap.Error(err), @@ -74,7 +74,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr // 处理披风 if profile.CapeID != nil { - cape, err := s.textureRepo.FindByID(*profile.CapeID) + cape, err := s.textureRepo.FindByID(ctx, *profile.CapeID) if err != nil { s.logger.Error("获取披风失败", zap.Error(err), diff --git a/internal/service/yggdrasil_service_composite.go b/internal/service/yggdrasil_service_composite.go index 498f508..4bf87b0 100644 --- a/internal/service/yggdrasil_service_composite.go +++ b/internal/service/yggdrasil_service_composite.go @@ -33,7 +33,7 @@ func NewYggdrasilServiceComposite( profileRepo repository.ProfileRepository, tokenRepo repository.TokenRepository, yggdrasilRepo repository.YggdrasilRepository, - signatureService *signatureService, + signatureService *SignatureService, redisClient *redis.Client, logger *zap.Logger, ) YggdrasilService { @@ -76,7 +76,7 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context, // JoinServer 加入服务器 func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error { // 验证Token - token, err := s.tokenRepo.FindByAccessToken(accessToken) + token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken) if err != nil { s.logger.Error("验证Token失败", zap.Error(err), @@ -92,7 +92,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac } // 获取Profile以获取用户名 - profile, err := s.profileRepo.FindByUUID(formattedProfile) + profile, err := s.profileRepo.FindByUUID(ctx, formattedProfile) if err != nil { s.logger.Error("获取Profile失败", zap.Error(err),