refactor: Update service and repository methods to use context

- Refactored multiple service and repository methods to accept context as a parameter, enhancing consistency and enabling better control over request lifecycles.
- Updated handlers to utilize context in method calls, improving error handling and performance.
- Cleaned up Dockerfile by removing unnecessary whitespace.
This commit is contained in:
lan
2025-12-03 15:27:12 +08:00
parent 4824a997dd
commit 0bcd9336c4
32 changed files with 833 additions and 497 deletions

View File

@@ -65,3 +65,5 @@ ENTRYPOINT ["./server"]

View File

@@ -44,6 +44,7 @@ type Container struct {
UploadService service.UploadService UploadService service.UploadService
SecurityService service.SecurityService SecurityService service.SecurityService
CaptchaService service.CaptchaService CaptchaService service.CaptchaService
SignatureService *service.SignatureService
} }
// NewContainer 创建依赖容器 // NewContainer 创建依赖容器
@@ -80,26 +81,27 @@ func NewContainer(
c.ConfigRepo = repository.NewSystemConfigRepository(db) c.ConfigRepo = repository.NewSystemConfigRepository(db)
c.YggdrasilRepo = repository.NewYggdrasilRepository(db) c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
// 初始化SignatureService用于获取Yggdrasil私钥 // 初始化SignatureService作为依赖注入,避免在容器中创建并立即调用
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger) // 将SignatureService添加到容器中供其他服务使用
c.SignatureService = service.NewSignatureService(c.ProfileRepo, redisClient, logger)
// 获取Yggdrasil私钥并创建JWT服务
_, privateKey, err := signatureService.GetOrCreateYggdrasilKeyPair()
if err != nil {
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
}
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
// 初始化Service注入缓存管理器 // 初始化Service注入缓存管理器
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger) c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger)
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger) c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger)
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger) c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger)
// 使用JWT版本的TokenService // 获取Yggdrasil私钥并创建JWT服务TokenService需要)
// 注意这里仍然需要预先初始化因为TokenService在创建时需要YggdrasilJWT
// 但SignatureService已经作为依赖注入降低了耦合度
_, privateKey, err := c.SignatureService.GetOrCreateYggdrasilKeyPair()
if err != nil {
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
}
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger) c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
// 使用组合服务(内部包含认证、会话、序列化、证书服务) // 使用组合服务(内部包含认证、会话、序列化、证书服务)
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger) c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger)
// 初始化其他服务 // 初始化其他服务
c.SecurityService = service.NewSecurityService(redisClient) c.SecurityService = service.NewSecurityService(redisClient)

View File

@@ -219,7 +219,7 @@ func (h *CustomSkinHandler) GetTexture(c *gin.Context) {
// 增加下载计数(异步) // 增加下载计数(异步)
go func() { go func() {
_ = h.container.TextureRepo.IncrementDownloadCount(texture.ID) _ = h.container.TextureRepo.IncrementDownloadCount(ctx, texture.ID)
}() }()
// 流式传输文件内容 // 流式传输文件内容

View File

@@ -1,6 +1,7 @@
package handler package handler
import ( import (
"carrotskin/internal/errors"
"carrotskin/internal/model" "carrotskin/internal/model"
"carrotskin/internal/types" "carrotskin/internal/types"
"net/http" "net/http"
@@ -165,17 +166,46 @@ func RespondSuccess(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, model.NewSuccessResponse(data)) c.JSON(http.StatusOK, model.NewSuccessResponse(data))
} }
// RespondWithError 根据错误消息自动选择状态码 // RespondWithError 根据错误类型自动选择状态码
func RespondWithError(c *gin.Context, err error) { func RespondWithError(c *gin.Context, err error) {
msg := err.Error() if err == nil {
switch msg { return
case "档案不存在", "用户不存在", "材质不存在":
RespondNotFound(c, msg)
case "无权操作此档案", "无权操作此材质":
RespondForbidden(c, msg)
case "未授权":
RespondUnauthorized(c, msg)
default:
RespondServerError(c, msg, nil)
} }
// 使用errors.Is检查预定义错误
if errors.Is(err, errors.ErrUserNotFound) ||
errors.Is(err, errors.ErrProfileNotFound) ||
errors.Is(err, errors.ErrTextureNotFound) ||
errors.Is(err, errors.ErrNotFound) {
RespondNotFound(c, err.Error())
return
}
if errors.Is(err, errors.ErrProfileNoPermission) ||
errors.Is(err, errors.ErrTextureNoPermission) ||
errors.Is(err, errors.ErrForbidden) {
RespondForbidden(c, err.Error())
return
}
if errors.Is(err, errors.ErrUnauthorized) ||
errors.Is(err, errors.ErrInvalidToken) ||
errors.Is(err, errors.ErrTokenExpired) {
RespondUnauthorized(c, err.Error())
return
}
// 检查AppError类型
var appErr *errors.AppError
if errors.As(err, &appErr) {
c.JSON(appErr.Code, model.NewErrorResponse(
appErr.Code,
appErr.Message,
appErr.Err,
))
return
}
// 默认返回500错误
RespondServerError(c, err.Error(), err)
} }

View File

@@ -4,6 +4,7 @@ import (
"carrotskin/internal/container" "carrotskin/internal/container"
"carrotskin/internal/middleware" "carrotskin/internal/middleware"
"carrotskin/internal/model" "carrotskin/internal/model"
"carrotskin/pkg/auth"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -47,13 +48,13 @@ func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) {
registerAuthRoutes(v1, h.Auth) registerAuthRoutes(v1, h.Auth)
// 用户路由需要JWT认证 // 用户路由需要JWT认证
registerUserRoutes(v1, h.User) registerUserRoutes(v1, h.User, c.JWT)
// 材质路由 // 材质路由
registerTextureRoutes(v1, h.Texture) registerTextureRoutes(v1, h.Texture, c.JWT)
// 档案路由 // 档案路由
registerProfileRoutesWithDI(v1, h.Profile) registerProfileRoutesWithDI(v1, h.Profile, c.JWT)
// 验证码路由 // 验证码路由
registerCaptchaRoutesWithDI(v1, h.Captcha) registerCaptchaRoutesWithDI(v1, h.Captcha)
@@ -81,9 +82,9 @@ func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) {
} }
// registerUserRoutes 注册用户路由 // registerUserRoutes 注册用户路由
func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) { func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler, jwtService *auth.JWTService) {
userGroup := v1.Group("/user") userGroup := v1.Group("/user")
userGroup.Use(middleware.AuthMiddleware()) userGroup.Use(middleware.AuthMiddleware(jwtService))
{ {
userGroup.GET("/profile", h.GetProfile) userGroup.GET("/profile", h.GetProfile)
userGroup.PUT("/profile", h.UpdateProfile) userGroup.PUT("/profile", h.UpdateProfile)
@@ -101,7 +102,7 @@ func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) {
} }
// registerTextureRoutes 注册材质路由 // registerTextureRoutes 注册材质路由
func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) { func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler, jwtService *auth.JWTService) {
textureGroup := v1.Group("/texture") textureGroup := v1.Group("/texture")
{ {
// 公开路由(无需认证) // 公开路由(无需认证)
@@ -110,7 +111,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
// 需要认证的路由 // 需要认证的路由
textureAuth := textureGroup.Group("") textureAuth := textureGroup.Group("")
textureAuth.Use(middleware.AuthMiddleware()) textureAuth.Use(middleware.AuthMiddleware(jwtService))
{ {
textureAuth.POST("/upload-url", h.GenerateUploadURL) textureAuth.POST("/upload-url", h.GenerateUploadURL)
textureAuth.POST("", h.Create) textureAuth.POST("", h.Create)
@@ -124,7 +125,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
} }
// registerProfileRoutesWithDI 注册档案路由(依赖注入版本) // registerProfileRoutesWithDI 注册档案路由(依赖注入版本)
func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) { func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler, jwtService *auth.JWTService) {
profileGroup := v1.Group("/profile") profileGroup := v1.Group("/profile")
{ {
// 公开路由(无需认证) // 公开路由(无需认证)
@@ -132,7 +133,7 @@ func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) {
// 需要认证的路由 // 需要认证的路由
profileAuth := profileGroup.Group("") profileAuth := profileGroup.Group("")
profileAuth.Use(middleware.AuthMiddleware()) profileAuth.Use(middleware.AuthMiddleware(jwtService))
{ {
profileAuth.POST("/", h.Create) profileAuth.POST("/", h.Create)
profileAuth.GET("/", h.List) profileAuth.GET("/", h.List)

View File

@@ -1,15 +1,95 @@
package handler package handler
import ( import (
"context"
"errors"
"net/http" "net/http"
"time"
"carrotskin/pkg/database"
"carrotskin/pkg/redis"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// HealthCheck 健康检查 // HealthCheck 健康检查,检查依赖服务状态
func HealthCheck(c *gin.Context) { func HealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
"status": "ok", defer cancel()
"message": "CarrotSkin API is running",
checks := make(map[string]string)
status := "ok"
// 检查数据库
if err := checkDatabase(ctx); err != nil {
checks["database"] = "unhealthy: " + err.Error()
status = "degraded"
} else {
checks["database"] = "healthy"
}
// 检查Redis
if err := checkRedis(ctx); err != nil {
checks["redis"] = "unhealthy: " + err.Error()
status = "degraded"
} else {
checks["redis"] = "healthy"
}
// 根据状态返回相应的HTTP状态码
httpStatus := http.StatusOK
if status == "degraded" {
httpStatus = http.StatusServiceUnavailable
}
c.JSON(httpStatus, gin.H{
"status": status,
"message": "CarrotSkin API health check",
"checks": checks,
"timestamp": time.Now().Unix(),
}) })
} }
// checkDatabase 检查数据库连接
func checkDatabase(ctx context.Context) error {
db, err := database.GetDB()
if err != nil {
return err
}
sqlDB, err := db.DB()
if err != nil {
return err
}
// 使用Ping检查连接
if err := sqlDB.PingContext(ctx); err != nil {
return err
}
// 执行简单查询验证
var result int
if err := db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error; err != nil {
return err
}
return nil
}
// checkRedis 检查Redis连接
func checkRedis(ctx context.Context) error {
client, err := redis.GetClient()
if err != nil {
return err
}
if client == nil {
return errors.New("Redis客户端未初始化")
}
// 使用Ping检查连接
if err := client.Ping(ctx).Err(); err != nil {
return err
}
return nil
}

View File

@@ -190,7 +190,7 @@ func (h *YggdrasilHandler) Authenticate(c *gin.Context) {
if emailRegex.MatchString(request.Identifier) { if emailRegex.MatchString(request.Identifier) {
userId, err = h.container.YggdrasilService.GetUserIDByEmail(c.Request.Context(), request.Identifier) userId, err = h.container.YggdrasilService.GetUserIDByEmail(c.Request.Context(), request.Identifier)
} else { } else {
profile, err = h.container.ProfileRepo.FindByName(request.Identifier) profile, err = h.container.ProfileRepo.FindByName(c.Request.Context(), request.Identifier)
if err != nil { if err != nil {
h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err)) h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err))
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()}) c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"carrotskin/internal/model"
"net/http" "net/http"
"strings" "strings"
@@ -9,17 +10,16 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// AuthMiddleware JWT认证中间件 // AuthMiddleware JWT认证中间件注入JWT服务版本
func AuthMiddleware() gin.HandlerFunc { func AuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) { return gin.HandlerFunc(func(c *gin.Context) {
jwtService := auth.MustGetJWTService()
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
"code": 401, model.CodeUnauthorized,
"message": "缺少Authorization头", "缺少Authorization头",
}) nil,
))
c.Abort() c.Abort()
return return
} }
@@ -27,10 +27,11 @@ func AuthMiddleware() gin.HandlerFunc {
// Bearer token格式 // Bearer token格式
tokenParts := strings.SplitN(authHeader, " ", 2) tokenParts := strings.SplitN(authHeader, " ", 2)
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" { if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
"code": 401, model.CodeUnauthorized,
"message": "无效的Authorization头格式", "无效的Authorization头格式",
}) nil,
))
c.Abort() c.Abort()
return return
} }
@@ -38,10 +39,11 @@ func AuthMiddleware() gin.HandlerFunc {
token := tokenParts[1] token := tokenParts[1]
claims, err := jwtService.ValidateToken(token) claims, err := jwtService.ValidateToken(token)
if err != nil { if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{ c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
"code": 401, model.CodeUnauthorized,
"message": "无效的token", "无效的token",
}) err,
))
c.Abort() c.Abort()
return return
} }
@@ -55,11 +57,9 @@ func AuthMiddleware() gin.HandlerFunc {
}) })
} }
// OptionalAuthMiddleware 可选的JWT认证中间件 // OptionalAuthMiddleware 可选的JWT认证中间件注入JWT服务版本
func OptionalAuthMiddleware() gin.HandlerFunc { func OptionalAuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc {
return gin.HandlerFunc(func(c *gin.Context) { return gin.HandlerFunc(func(c *gin.Context) {
jwtService := auth.MustGetJWTService()
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader != "" { if authHeader != "" {
tokenParts := strings.SplitN(authHeader, " ", 2) tokenParts := strings.SplitN(authHeader, " ", 2)

View File

@@ -22,3 +22,5 @@ func (Client) TableName() string {
return "clients" return "clients"
} }

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,48 +17,48 @@ func NewClientRepository(db *gorm.DB) ClientRepository {
return &clientRepository{db: db} return &clientRepository{db: db}
} }
func (r *clientRepository) Create(client *model.Client) error { func (r *clientRepository) Create(ctx context.Context, client *model.Client) error {
return r.db.Create(client).Error return r.db.WithContext(ctx).Create(client).Error
} }
func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) { func (r *clientRepository) FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error) {
var client model.Client var client model.Client
err := r.db.Where("client_token = ?", clientToken).First(&client).Error err := r.db.WithContext(ctx).Where("client_token = ?", clientToken).First(&client).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &client, nil return &client, nil
} }
func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) { func (r *clientRepository) FindByUUID(ctx context.Context, uuid string) (*model.Client, error) {
var client model.Client var client model.Client
err := r.db.Where("uuid = ?", uuid).First(&client).Error err := r.db.WithContext(ctx).Where("uuid = ?", uuid).First(&client).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &client, nil return &client, nil
} }
func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) { func (r *clientRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error) {
var clients []*model.Client var clients []*model.Client
err := r.db.Where("user_id = ?", userID).Find(&clients).Error err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&clients).Error
return clients, err return clients, err
} }
func (r *clientRepository) Update(client *model.Client) error { func (r *clientRepository) Update(ctx context.Context, client *model.Client) error {
return r.db.Save(client).Error return r.db.WithContext(ctx).Save(client).Error
} }
func (r *clientRepository) IncrementVersion(clientUUID string) error { func (r *clientRepository) IncrementVersion(ctx context.Context, clientUUID string) error {
return r.db.Model(&model.Client{}). return r.db.WithContext(ctx).Model(&model.Client{}).
Where("uuid = ?", clientUUID). Where("uuid = ?", clientUUID).
Update("version", gorm.Expr("version + 1")).Error Update("version", gorm.Expr("version + 1")).Error
} }
func (r *clientRepository) DeleteByClientToken(clientToken string) error { func (r *clientRepository) DeleteByClientToken(ctx context.Context, clientToken string) error {
return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error return r.db.WithContext(ctx).Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
} }
func (r *clientRepository) DeleteByUserID(userID int64) error { func (r *clientRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&model.Client{}).Error
} }

View File

@@ -2,95 +2,105 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
) )
// UserRepository 用户仓储接口 // UserRepository 用户仓储接口
type UserRepository interface { type UserRepository interface {
Create(user *model.User) error Create(ctx context.Context, user *model.User) error
FindByID(id int64) (*model.User, error) FindByID(ctx context.Context, id int64) (*model.User, error)
FindByUsername(username string) (*model.User, error) FindByUsername(ctx context.Context, username string) (*model.User, error)
FindByEmail(email string) (*model.User, error) FindByEmail(ctx context.Context, email string) (*model.User, error)
Update(user *model.User) error FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) // 批量查询
UpdateFields(id int64, fields map[string]interface{}) error Update(ctx context.Context, user *model.User) error
Delete(id int64) error UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
CreateLoginLog(log *model.UserLoginLog) error BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
CreatePointLog(log *model.UserPointLog) error Delete(ctx context.Context, id int64) error
UpdatePoints(userID int64, amount int, changeType, reason string) error BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error
CreatePointLog(ctx context.Context, log *model.UserPointLog) error
UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error
} }
// ProfileRepository 档案仓储接口 // ProfileRepository 档案仓储接口
type ProfileRepository interface { type ProfileRepository interface {
Create(profile *model.Profile) error Create(ctx context.Context, profile *model.Profile) error
FindByUUID(uuid string) (*model.Profile, error) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error)
FindByName(name string) (*model.Profile, error) FindByName(ctx context.Context, name string) (*model.Profile, error)
FindByUserID(userID int64) ([]*model.Profile, error) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error)
Update(profile *model.Profile) error FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) // 批量查询
UpdateFields(uuid string, updates map[string]interface{}) error Update(ctx context.Context, profile *model.Profile) error
Delete(uuid string) error UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error
CountByUserID(userID int64) (int64, error) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) // 批量更新
SetActive(uuid string, userID int64) error Delete(ctx context.Context, uuid string) error
UpdateLastUsedAt(uuid string) error BatchDelete(ctx context.Context, uuids []string) (int64, error) // 批量删除
GetByNames(names []string) ([]*model.Profile, error) CountByUserID(ctx context.Context, userID int64) (int64, error)
GetKeyPair(profileId string) (*model.KeyPair, error) SetActive(ctx context.Context, uuid string, userID int64) error
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error UpdateLastUsedAt(ctx context.Context, uuid string) error
GetByNames(ctx context.Context, names []string) ([]*model.Profile, error)
GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error)
UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error
} }
// TextureRepository 材质仓储接口 // TextureRepository 材质仓储接口
type TextureRepository interface { type TextureRepository interface {
Create(texture *model.Texture) error Create(ctx context.Context, texture *model.Texture) error
FindByID(id int64) (*model.Texture, error) FindByID(ctx context.Context, id int64) (*model.Texture, error)
FindByHash(hash string) (*model.Texture, error) FindByHash(ctx context.Context, hash string) (*model.Texture, error)
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) // 批量查询
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
Update(texture *model.Texture) error Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
UpdateFields(id int64, fields map[string]interface{}) error Update(ctx context.Context, texture *model.Texture) error
Delete(id int64) error UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
IncrementDownloadCount(id int64) error BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
IncrementFavoriteCount(id int64) error Delete(ctx context.Context, id int64) error
DecrementFavoriteCount(id int64) error BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
CreateDownloadLog(log *model.TextureDownloadLog) error IncrementDownloadCount(ctx context.Context, id int64) error
IsFavorited(userID, textureID int64) (bool, error) IncrementFavoriteCount(ctx context.Context, id int64) error
AddFavorite(userID, textureID int64) error DecrementFavoriteCount(ctx context.Context, id int64) error
RemoveFavorite(userID, textureID int64) error CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error)
CountByUploaderID(uploaderID int64) (int64, error) AddFavorite(ctx context.Context, userID, textureID int64) error
RemoveFavorite(ctx context.Context, userID, textureID int64) error
GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error)
CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error)
} }
// TokenRepository 令牌仓储接口 // TokenRepository 令牌仓储接口
type TokenRepository interface { type TokenRepository interface {
Create(token *model.Token) error Create(ctx context.Context, token *model.Token) error
FindByAccessToken(accessToken string) (*model.Token, error) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error)
GetByUserID(userId int64) ([]*model.Token, error) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error)
GetUUIDByAccessToken(accessToken string) (string, error) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error)
GetUserIDByAccessToken(accessToken string) (int64, error) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error)
DeleteByAccessToken(accessToken string) error DeleteByAccessToken(ctx context.Context, accessToken string) error
DeleteByUserID(userId int64) error DeleteByUserID(ctx context.Context, userId int64) error
BatchDelete(accessTokens []string) (int64, error) BatchDelete(ctx context.Context, accessTokens []string) (int64, error)
} }
// SystemConfigRepository 系统配置仓储接口 // SystemConfigRepository 系统配置仓储接口
type SystemConfigRepository interface { type SystemConfigRepository interface {
GetByKey(key string) (*model.SystemConfig, error) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error)
GetPublic() ([]model.SystemConfig, error) GetPublic(ctx context.Context) ([]model.SystemConfig, error)
GetAll() ([]model.SystemConfig, error) GetAll(ctx context.Context) ([]model.SystemConfig, error)
Update(config *model.SystemConfig) error Update(ctx context.Context, config *model.SystemConfig) error
UpdateValue(key, value string) error UpdateValue(ctx context.Context, key, value string) error
} }
// YggdrasilRepository Yggdrasil仓储接口 // YggdrasilRepository Yggdrasil仓储接口
type YggdrasilRepository interface { type YggdrasilRepository interface {
GetPasswordByID(id int64) (string, error) GetPasswordByID(ctx context.Context, id int64) (string, error)
ResetPassword(id int64, password string) error ResetPassword(ctx context.Context, id int64, password string) error
} }
// ClientRepository Client仓储接口 // ClientRepository Client仓储接口
type ClientRepository interface { type ClientRepository interface {
Create(client *model.Client) error Create(ctx context.Context, client *model.Client) error
FindByClientToken(clientToken string) (*model.Client, error) FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error)
FindByUUID(uuid string) (*model.Client, error) FindByUUID(ctx context.Context, uuid string) (*model.Client, error)
FindByUserID(userID int64) ([]*model.Client, error) FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error)
Update(client *model.Client) error Update(ctx context.Context, client *model.Client) error
IncrementVersion(clientUUID string) error IncrementVersion(ctx context.Context, clientUUID string) error
DeleteByClientToken(clientToken string) error DeleteByClientToken(ctx context.Context, clientToken string) error
DeleteByUserID(userID int64) error DeleteByUserID(ctx context.Context, userID int64) error
} }

View File

@@ -19,13 +19,13 @@ func NewProfileRepository(db *gorm.DB) ProfileRepository {
return &profileRepository{db: db} return &profileRepository{db: db}
} }
func (r *profileRepository) Create(profile *model.Profile) error { func (r *profileRepository) Create(ctx context.Context, profile *model.Profile) error {
return r.db.Create(profile).Error return r.db.WithContext(ctx).Create(profile).Error
} }
func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) { func (r *profileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) {
var profile model.Profile var profile model.Profile
err := r.db.Where("uuid = ?", uuid). err := r.db.WithContext(ctx).Where("uuid = ?", uuid).
Preload("Skin"). Preload("Skin").
Preload("Cape"). Preload("Cape").
First(&profile).Error First(&profile).Error
@@ -35,10 +35,10 @@ func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) {
return &profile, nil return &profile, nil
} }
func (r *profileRepository) FindByName(name string) (*model.Profile, error) { func (r *profileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) {
var profile model.Profile var profile model.Profile
// 使用 LOWER 函数进行不区分大小写的查询,并预加载 Skin 和 Cape // 使用 LOWER 函数进行不区分大小写的查询,并预加载 Skin 和 Cape
err := r.db.Where("LOWER(name) = LOWER(?)", name). err := r.db.WithContext(ctx).Where("LOWER(name) = LOWER(?)", name).
Preload("Skin"). Preload("Skin").
Preload("Cape"). Preload("Cape").
First(&profile).Error First(&profile).Error
@@ -48,9 +48,9 @@ func (r *profileRepository) FindByName(name string) (*model.Profile, error) {
return &profile, nil return &profile, nil
} }
func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { func (r *profileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) {
var profiles []*model.Profile var profiles []*model.Profile
err := r.db.Where("user_id = ?", userID). err := r.db.WithContext(ctx).Where("user_id = ?", userID).
Preload("Skin"). Preload("Skin").
Preload("Cape"). Preload("Cape").
Order("created_at DESC"). Order("created_at DESC").
@@ -58,30 +58,59 @@ func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error)
return profiles, err return profiles, err
} }
func (r *profileRepository) Update(profile *model.Profile) error { func (r *profileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) {
return r.db.Save(profile).Error if len(uuids) == 0 {
return []*model.Profile{}, nil
}
var profiles []*model.Profile
// 使用 IN 查询优化批量查询,并预加载关联
err := r.db.WithContext(ctx).Where("uuid IN ?", uuids).
Preload("Skin").
Preload("Cape").
Find(&profiles).Error
return profiles, err
} }
func (r *profileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { func (r *profileRepository) Update(ctx context.Context, profile *model.Profile) error {
return r.db.Model(&model.Profile{}). return r.db.WithContext(ctx).Save(profile).Error
}
func (r *profileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error {
return r.db.WithContext(ctx).Model(&model.Profile{}).
Where("uuid = ?", uuid). Where("uuid = ?", uuid).
Updates(updates).Error Updates(updates).Error
} }
func (r *profileRepository) Delete(uuid string) error { func (r *profileRepository) Delete(ctx context.Context, uuid string) error {
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error return r.db.WithContext(ctx).Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
} }
func (r *profileRepository) CountByUserID(userID int64) (int64, error) { func (r *profileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) {
if len(uuids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Profile{}).Where("uuid IN ?", uuids).Updates(updates)
return result.RowsAffected, result.Error
}
func (r *profileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) {
if len(uuids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Where("uuid IN ?", uuids).Delete(&model.Profile{})
return result.RowsAffected, result.Error
}
func (r *profileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64 var count int64
err := r.db.Model(&model.Profile{}). err := r.db.WithContext(ctx).Model(&model.Profile{}).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Count(&count).Error Count(&count).Error
return count, err return count, err
} }
func (r *profileRepository) SetActive(uuid string, userID int64) error { func (r *profileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
return r.db.Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&model.Profile{}). if err := tx.Model(&model.Profile{}).
Where("user_id = ?", userID). Where("user_id = ?", userID).
Update("is_active", false).Error; err != nil { Update("is_active", false).Error; err != nil {
@@ -94,28 +123,28 @@ func (r *profileRepository) SetActive(uuid string, userID int64) error {
}) })
} }
func (r *profileRepository) UpdateLastUsedAt(uuid string) error { func (r *profileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
return r.db.Model(&model.Profile{}). return r.db.WithContext(ctx).Model(&model.Profile{}).
Where("uuid = ?", uuid). Where("uuid = ?", uuid).
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
} }
func (r *profileRepository) GetByNames(names []string) ([]*model.Profile, error) { func (r *profileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
var profiles []*model.Profile var profiles []*model.Profile
err := r.db.Where("name in (?)", names). err := r.db.WithContext(ctx).Where("name in (?)", names).
Preload("Skin"). Preload("Skin").
Preload("Cape"). Preload("Cape").
Find(&profiles).Error Find(&profiles).Error
return profiles, err return profiles, err
} }
func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { func (r *profileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) {
if profileId == "" { if profileId == "" {
return nil, errors.New("参数不能为空") return nil, errors.New("参数不能为空")
} }
var profile model.Profile var profile model.Profile
result := r.db.WithContext(context.Background()). result := r.db.WithContext(ctx).
Select("key_pair"). Select("key_pair").
Where("id = ?", profileId). Where("id = ?", profileId).
First(&profile) First(&profile)
@@ -130,7 +159,7 @@ func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error)
return &model.KeyPair{}, nil return &model.KeyPair{}, nil
} }
func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { func (r *profileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error {
if profileId == "" { if profileId == "" {
return errors.New("profileId 不能为空") return errors.New("profileId 不能为空")
} }
@@ -138,9 +167,8 @@ func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPa
return errors.New("keyPair 不能为 nil") return errors.New("keyPair 不能为 nil")
} }
return r.db.Transaction(func(tx *gorm.DB) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
result := tx.WithContext(context.Background()). result := tx.Table("profiles").
Table("profiles").
Where("id = ?", profileId). Where("id = ?", profileId).
UpdateColumns(map[string]interface{}{ UpdateColumns(map[string]interface{}{
"private_key": keyPair.PrivateKey, "private_key": keyPair.PrivateKey,

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,28 +17,28 @@ func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
return &systemConfigRepository{db: db} return &systemConfigRepository{db: db}
} }
func (r *systemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { func (r *systemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
var config model.SystemConfig var config model.SystemConfig
err := r.db.Where("key = ?", key).First(&config).Error err := r.db.WithContext(ctx).Where("key = ?", key).First(&config).Error
return handleNotFoundResult(&config, err) return handleNotFoundResult(&config, err)
} }
func (r *systemConfigRepository) GetPublic() ([]model.SystemConfig, error) { func (r *systemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
var configs []model.SystemConfig var configs []model.SystemConfig
err := r.db.Where("is_public = ?", true).Find(&configs).Error err := r.db.WithContext(ctx).Where("is_public = ?", true).Find(&configs).Error
return configs, err return configs, err
} }
func (r *systemConfigRepository) GetAll() ([]model.SystemConfig, error) { func (r *systemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
var configs []model.SystemConfig var configs []model.SystemConfig
err := r.db.Find(&configs).Error err := r.db.WithContext(ctx).Find(&configs).Error
return configs, err return configs, err
} }
func (r *systemConfigRepository) Update(config *model.SystemConfig) error { func (r *systemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
return r.db.Save(config).Error return r.db.WithContext(ctx).Save(config).Error
} }
func (r *systemConfigRepository) UpdateValue(key, value string) error { func (r *systemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error return r.db.WithContext(ctx).Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
} }

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,27 +17,39 @@ func NewTextureRepository(db *gorm.DB) TextureRepository {
return &textureRepository{db: db} return &textureRepository{db: db}
} }
func (r *textureRepository) Create(texture *model.Texture) error { func (r *textureRepository) Create(ctx context.Context, texture *model.Texture) error {
return r.db.Create(texture).Error return r.db.WithContext(ctx).Create(texture).Error
} }
func (r *textureRepository) FindByID(id int64) (*model.Texture, error) { func (r *textureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) {
var texture model.Texture var texture model.Texture
err := r.db.Preload("Uploader").First(&texture, id).Error err := r.db.WithContext(ctx).Preload("Uploader").First(&texture, id).Error
return handleNotFoundResult(&texture, err) return handleNotFoundResult(&texture, err)
} }
func (r *textureRepository) FindByHash(hash string) (*model.Texture, error) { func (r *textureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) {
var texture model.Texture var texture model.Texture
err := r.db.Where("hash = ?", hash).First(&texture).Error err := r.db.WithContext(ctx).Where("hash = ?", hash).First(&texture).Error
return handleNotFoundResult(&texture, err) return handleNotFoundResult(&texture, err)
} }
func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (r *textureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) {
if len(ids) == 0 {
return []*model.Texture{}, nil
}
var textures []*model.Texture
// 使用 IN 查询优化批量查询,并预加载关联
err := r.db.WithContext(ctx).Where("id IN ?", ids).
Preload("Uploader").
Find(&textures).Error
return textures, err
}
func (r *textureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture var textures []*model.Texture
var total int64 var total int64
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID) query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, 0, err return nil, 0, err
@@ -54,11 +67,11 @@ func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize in
return textures, total, nil return textures, total, nil
} }
func (r *textureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { func (r *textureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture var textures []*model.Texture
var total int64 var total int64
query := r.db.Model(&model.Texture{}).Where("status = 1") query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("status = 1")
if publicOnly { if publicOnly {
query = query.Where("is_public = ?", true) query = query.Where("is_public = ?", true)
@@ -86,67 +99,86 @@ func (r *textureRepository) Search(keyword string, textureType model.TextureType
return textures, total, nil return textures, total, nil
} }
func (r *textureRepository) Update(texture *model.Texture) error { func (r *textureRepository) Update(ctx context.Context, texture *model.Texture) error {
return r.db.Save(texture).Error return r.db.WithContext(ctx).Save(texture).Error
} }
func (r *textureRepository) UpdateFields(id int64, fields map[string]interface{}) error { func (r *textureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
} }
func (r *textureRepository) Delete(id int64) error { func (r *textureRepository) Delete(ctx context.Context, id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
} }
func (r *textureRepository) IncrementDownloadCount(id int64) error { func (r *textureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
return r.db.Model(&model.Texture{}).Where("id = ?", id). if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Updates(fields)
return result.RowsAffected, result.Error
}
func (r *textureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Update("status", -1)
return result.RowsAffected, result.Error
}
func (r *textureRepository) IncrementDownloadCount(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
} }
func (r *textureRepository) IncrementFavoriteCount(id int64) error { func (r *textureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
} }
func (r *textureRepository) DecrementFavoriteCount(id int64) error { func (r *textureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id). return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
} }
func (r *textureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { func (r *textureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error {
return r.db.Create(log).Error return r.db.WithContext(ctx).Create(log).Error
} }
func (r *textureRepository) IsFavorited(userID, textureID int64) (bool, error) { func (r *textureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) {
var count int64 var count int64
err := r.db.Model(&model.UserTextureFavorite{}). // 使用 Select("1") 优化,只查询是否存在,不需要查询所有字段
err := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
Select("1").
Where("user_id = ? AND texture_id = ?", userID, textureID). Where("user_id = ? AND texture_id = ?", userID, textureID).
Limit(1).
Count(&count).Error Count(&count).Error
return count > 0, err return count > 0, err
} }
func (r *textureRepository) AddFavorite(userID, textureID int64) error { func (r *textureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error {
favorite := &model.UserTextureFavorite{ favorite := &model.UserTextureFavorite{
UserID: userID, UserID: userID,
TextureID: textureID, TextureID: textureID,
} }
return r.db.Create(favorite).Error return r.db.WithContext(ctx).Create(favorite).Error
} }
func (r *textureRepository) RemoveFavorite(userID, textureID int64) error { func (r *textureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error {
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID). return r.db.WithContext(ctx).Where("user_id = ? AND texture_id = ?", userID, textureID).
Delete(&model.UserTextureFavorite{}).Error Delete(&model.UserTextureFavorite{}).Error
} }
func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (r *textureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture var textures []*model.Texture
var total int64 var total int64
subQuery := r.db.Model(&model.UserTextureFavorite{}). subQuery := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
Select("texture_id"). Select("texture_id").
Where("user_id = ?", userID) Where("user_id = ?", userID)
query := r.db.Model(&model.Texture{}). query := r.db.WithContext(ctx).Model(&model.Texture{}).
Where("id IN (?) AND status = 1", subQuery) Where("id IN (?) AND status = 1", subQuery)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
@@ -165,9 +197,9 @@ func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) (
return textures, total, nil return textures, total, nil
} }
func (r *textureRepository) CountByUploaderID(uploaderID int64) (int64, error) { func (r *textureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) {
var count int64 var count int64
err := r.db.Model(&model.Texture{}). err := r.db.WithContext(ctx).Model(&model.Texture{}).
Where("uploader_id = ? AND status != -1", uploaderID). Where("uploader_id = ? AND status != -1", uploaderID).
Count(&count).Error Count(&count).Error
return count, err return count, err

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,55 +17,55 @@ func NewTokenRepository(db *gorm.DB) TokenRepository {
return &tokenRepository{db: db} return &tokenRepository{db: db}
} }
func (r *tokenRepository) Create(token *model.Token) error { func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error {
return r.db.Create(token).Error return r.db.WithContext(ctx).Create(token).Error
} }
func (r *tokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
var token model.Token var token model.Token
err := r.db.Where("access_token = ?", accessToken).First(&token).Error err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &token, nil return &token, nil
} }
func (r *tokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
var tokens []*model.Token var tokens []*model.Token
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error
return tokens, err return tokens, err
} }
func (r *tokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
var token model.Token var token model.Token
err := r.db.Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil { if err != nil {
return "", err return "", err
} }
return token.ProfileId, nil return token.ProfileId, nil
} }
func (r *tokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
var token model.Token var token model.Token
err := r.db.Select("user_id").Where("access_token = ?", accessToken).First(&token).Error err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil { if err != nil {
return 0, err return 0, err
} }
return token.UserID, nil return token.UserID, nil
} }
func (r *tokenRepository) DeleteByAccessToken(accessToken string) error { func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
} }
func (r *tokenRepository) DeleteByUserID(userId int64) error { func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error
} }
func (r *tokenRepository) BatchDelete(accessTokens []string) (int64, error) { func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
if len(accessTokens) == 0 { if len(accessTokens) == 0 {
return 0, nil return 0, nil
} }
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{}) result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"errors" "errors"
"gorm.io/gorm" "gorm.io/gorm"
@@ -17,50 +18,76 @@ func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db} return &userRepository{db: db}
} }
func (r *userRepository) Create(user *model.User) error { func (r *userRepository) Create(ctx context.Context, user *model.User) error {
return r.db.Create(user).Error return r.db.WithContext(ctx).Create(user).Error
} }
func (r *userRepository) FindByID(id int64) (*model.User, error) { func (r *userRepository) FindByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User var user model.User
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error err := r.db.WithContext(ctx).Where("id = ? AND status != -1", id).First(&user).Error
return handleNotFoundResult(&user, err) return handleNotFoundResult(&user, err)
} }
func (r *userRepository) FindByUsername(username string) (*model.User, error) { func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
var user model.User var user model.User
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error err := r.db.WithContext(ctx).Where("username = ? AND status != -1", username).First(&user).Error
return handleNotFoundResult(&user, err) return handleNotFoundResult(&user, err)
} }
func (r *userRepository) FindByEmail(email string) (*model.User, error) { func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User var user model.User
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error err := r.db.WithContext(ctx).Where("email = ? AND status != -1", email).First(&user).Error
return handleNotFoundResult(&user, err) return handleNotFoundResult(&user, err)
} }
func (r *userRepository) Update(user *model.User) error { func (r *userRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) {
return r.db.Save(user).Error if len(ids) == 0 {
return []*model.User{}, nil
}
var users []*model.User
// 使用 IN 查询优化批量查询
err := r.db.WithContext(ctx).Where("id IN ? AND status != -1", ids).Find(&users).Error
return users, err
} }
func (r *userRepository) UpdateFields(id int64, fields map[string]interface{}) error { func (r *userRepository) Update(ctx context.Context, user *model.User) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error return r.db.WithContext(ctx).Save(user).Error
} }
func (r *userRepository) Delete(id int64) error { func (r *userRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
} }
func (r *userRepository) CreateLoginLog(log *model.UserLoginLog) error { func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.Create(log).Error return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
} }
func (r *userRepository) CreatePointLog(log *model.UserPointLog) error { func (r *userRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
return r.db.Create(log).Error if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Updates(fields)
return result.RowsAffected, result.Error
} }
func (r *userRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { func (r *userRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
return r.db.Transaction(func(tx *gorm.DB) error { if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Update("status", -1)
return result.RowsAffected, result.Error
}
func (r *userRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *userRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *userRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var user model.User var user model.User
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil { if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
return err return err

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,15 +17,15 @@ func NewYggdrasilRepository(db *gorm.DB) YggdrasilRepository {
return &yggdrasilRepository{db: db} return &yggdrasilRepository{db: db}
} }
func (r *yggdrasilRepository) GetPasswordByID(id int64) (string, error) { func (r *yggdrasilRepository) GetPasswordByID(ctx context.Context, id int64) (string, error) {
var yggdrasil model.Yggdrasil var yggdrasil model.Yggdrasil
err := r.db.Select("password").Where("id = ?", id).First(&yggdrasil).Error err := r.db.WithContext(ctx).Select("password").Where("id = ?", id).First(&yggdrasil).Error
if err != nil { if err != nil {
return "", err return "", err
} }
return yggdrasil.Password, nil return yggdrasil.Password, nil
} }
func (r *yggdrasilRepository) ResetPassword(id int64, password string) error { func (r *yggdrasilRepository) ResetPassword(ctx context.Context, id int64, password string) error {
return r.db.Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error return r.db.WithContext(ctx).Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error
} }

View File

@@ -3,6 +3,7 @@ package service
import ( import (
"carrotskin/internal/model" "carrotskin/internal/model"
"carrotskin/pkg/database" "carrotskin/pkg/database"
"context"
"errors" "errors"
"time" "time"
) )
@@ -28,7 +29,7 @@ func NewMockUserRepository() *MockUserRepository {
} }
} }
func (m *MockUserRepository) Create(user *model.User) error { func (m *MockUserRepository) Create(ctx context.Context, user *model.User) error {
if m.FailCreate { if m.FailCreate {
return errors.New("mock create error") return errors.New("mock create error")
} }
@@ -39,7 +40,7 @@ func (m *MockUserRepository) Create(user *model.User) error {
return nil return nil
} }
func (m *MockUserRepository) FindByID(id int64) (*model.User, error) { func (m *MockUserRepository) FindByID(ctx context.Context, id int64) (*model.User, error) {
if m.FailFindByID { if m.FailFindByID {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -49,7 +50,7 @@ func (m *MockUserRepository) FindByID(id int64) (*model.User, error) {
return nil, nil return nil, nil
} }
func (m *MockUserRepository) FindByUsername(username string) (*model.User, error) { func (m *MockUserRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
if m.FailFindByUsername { if m.FailFindByUsername {
return nil, errors.New("mock find by username error") return nil, errors.New("mock find by username error")
} }
@@ -61,7 +62,7 @@ func (m *MockUserRepository) FindByUsername(username string) (*model.User, error
return nil, nil return nil, nil
} }
func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) { func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
if m.FailFindByEmail { if m.FailFindByEmail {
return nil, errors.New("mock find by email error") return nil, errors.New("mock find by email error")
} }
@@ -73,7 +74,7 @@ func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) {
return nil, nil return nil, nil
} }
func (m *MockUserRepository) Update(user *model.User) error { func (m *MockUserRepository) Update(ctx context.Context, user *model.User) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update error") return errors.New("mock update error")
} }
@@ -81,7 +82,7 @@ func (m *MockUserRepository) Update(user *model.User) error {
return nil return nil
} }
func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{}) error { func (m *MockUserRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update fields error") return errors.New("mock update fields error")
} }
@@ -92,23 +93,43 @@ func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{
return nil return nil
} }
func (m *MockUserRepository) Delete(id int64) error { func (m *MockUserRepository) Delete(ctx context.Context, id int64) error {
delete(m.users, id) delete(m.users, id)
return nil return nil
} }
func (m *MockUserRepository) CreateLoginLog(log *model.UserLoginLog) error { func (m *MockUserRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error {
return nil return nil
} }
func (m *MockUserRepository) CreatePointLog(log *model.UserPointLog) error { func (m *MockUserRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error {
return nil return nil
} }
func (m *MockUserRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error { func (m *MockUserRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error {
return nil return nil
} }
// BatchUpdate 和 BatchDelete 仅用于满足接口,在测试中不做具体操作
func (m *MockUserRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
return 0, nil
}
func (m *MockUserRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
return 0, nil
}
// FindByIDs 批量查询用户
func (m *MockUserRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) {
var result []*model.User
for _, id := range ids {
if u, ok := m.users[id]; ok {
result = append(result, u)
}
}
return result, nil
}
// MockProfileRepository 模拟ProfileRepository // MockProfileRepository 模拟ProfileRepository
type MockProfileRepository struct { type MockProfileRepository struct {
profiles map[string]*model.Profile profiles map[string]*model.Profile
@@ -128,7 +149,7 @@ func NewMockProfileRepository() *MockProfileRepository {
} }
} }
func (m *MockProfileRepository) Create(profile *model.Profile) error { func (m *MockProfileRepository) Create(ctx context.Context, profile *model.Profile) error {
if m.FailCreate { if m.FailCreate {
return errors.New("mock create error") return errors.New("mock create error")
} }
@@ -137,7 +158,7 @@ func (m *MockProfileRepository) Create(profile *model.Profile) error {
return nil return nil
} }
func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) { func (m *MockProfileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -147,7 +168,7 @@ func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error)
return nil, errors.New("profile not found") return nil, errors.New("profile not found")
} }
func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) { func (m *MockProfileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -159,14 +180,14 @@ func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error)
return nil, nil return nil, nil
} }
func (m *MockProfileRepository) FindByUserID(userID int64) ([]*model.Profile, error) { func (m *MockProfileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
return m.userProfiles[userID], nil return m.userProfiles[userID], nil
} }
func (m *MockProfileRepository) Update(profile *model.Profile) error { func (m *MockProfileRepository) Update(ctx context.Context, profile *model.Profile) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update error") return errors.New("mock update error")
} }
@@ -174,14 +195,14 @@ func (m *MockProfileRepository) Update(profile *model.Profile) error {
return nil return nil
} }
func (m *MockProfileRepository) UpdateFields(uuid string, updates map[string]interface{}) error { func (m *MockProfileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update error") return errors.New("mock update error")
} }
return nil return nil
} }
func (m *MockProfileRepository) Delete(uuid string) error { func (m *MockProfileRepository) Delete(ctx context.Context, uuid string) error {
if m.FailDelete { if m.FailDelete {
return errors.New("mock delete error") return errors.New("mock delete error")
} }
@@ -189,19 +210,19 @@ func (m *MockProfileRepository) Delete(uuid string) error {
return nil return nil
} }
func (m *MockProfileRepository) CountByUserID(userID int64) (int64, error) { func (m *MockProfileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
return int64(len(m.userProfiles[userID])), nil return int64(len(m.userProfiles[userID])), nil
} }
func (m *MockProfileRepository) SetActive(uuid string, userID int64) error { func (m *MockProfileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
return nil return nil
} }
func (m *MockProfileRepository) UpdateLastUsedAt(uuid string) error { func (m *MockProfileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
return nil return nil
} }
func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, error) { func (m *MockProfileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
var result []*model.Profile var result []*model.Profile
for _, name := range names { for _, name := range names {
for _, profile := range m.profiles { for _, profile := range m.profiles {
@@ -213,14 +234,34 @@ func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, er
return result, nil return result, nil
} }
func (m *MockProfileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) { func (m *MockProfileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) {
return nil, nil return nil, nil
} }
func (m *MockProfileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error { func (m *MockProfileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error {
return nil return nil
} }
// BatchUpdate / BatchDelete 仅用于满足接口
func (m *MockProfileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) {
return 0, nil
}
func (m *MockProfileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) {
return 0, nil
}
// FindByUUIDs 批量查询 Profile
func (m *MockProfileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) {
var result []*model.Profile
for _, id := range uuids {
if p, ok := m.profiles[id]; ok {
result = append(result, p)
}
}
return result, nil
}
// MockTextureRepository 模拟TextureRepository // MockTextureRepository 模拟TextureRepository
type MockTextureRepository struct { type MockTextureRepository struct {
textures map[int64]*model.Texture textures map[int64]*model.Texture
@@ -240,7 +281,7 @@ func NewMockTextureRepository() *MockTextureRepository {
} }
} }
func (m *MockTextureRepository) Create(texture *model.Texture) error { func (m *MockTextureRepository) Create(ctx context.Context, texture *model.Texture) error {
if m.FailCreate { if m.FailCreate {
return errors.New("mock create error") return errors.New("mock create error")
} }
@@ -252,7 +293,7 @@ func (m *MockTextureRepository) Create(texture *model.Texture) error {
return nil return nil
} }
func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) { func (m *MockTextureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -262,7 +303,7 @@ func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) {
return nil, errors.New("texture not found") return nil, errors.New("texture not found")
} }
func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) { func (m *MockTextureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -274,7 +315,7 @@ func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error)
return nil, nil return nil, nil
} }
func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
if m.FailFind { if m.FailFind {
return nil, 0, errors.New("mock find error") return nil, 0, errors.New("mock find error")
} }
@@ -287,7 +328,7 @@ func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSiz
return result, int64(len(result)), nil return result, int64(len(result)), nil
} }
func (m *MockTextureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { func (m *MockTextureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
if m.FailFind { if m.FailFind {
return nil, 0, errors.New("mock find error") return nil, 0, errors.New("mock find error")
} }
@@ -301,7 +342,7 @@ func (m *MockTextureRepository) Search(keyword string, textureType model.Texture
return result, int64(len(result)), nil return result, int64(len(result)), nil
} }
func (m *MockTextureRepository) Update(texture *model.Texture) error { func (m *MockTextureRepository) Update(ctx context.Context, texture *model.Texture) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update error") return errors.New("mock update error")
} }
@@ -309,14 +350,14 @@ func (m *MockTextureRepository) Update(texture *model.Texture) error {
return nil return nil
} }
func (m *MockTextureRepository) UpdateFields(id int64, fields map[string]interface{}) error { func (m *MockTextureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
if m.FailUpdate { if m.FailUpdate {
return errors.New("mock update error") return errors.New("mock update error")
} }
return nil return nil
} }
func (m *MockTextureRepository) Delete(id int64) error { func (m *MockTextureRepository) Delete(ctx context.Context, id int64) error {
if m.FailDelete { if m.FailDelete {
return errors.New("mock delete error") return errors.New("mock delete error")
} }
@@ -324,39 +365,39 @@ func (m *MockTextureRepository) Delete(id int64) error {
return nil return nil
} }
func (m *MockTextureRepository) IncrementDownloadCount(id int64) error { func (m *MockTextureRepository) IncrementDownloadCount(ctx context.Context, id int64) error {
if texture, ok := m.textures[id]; ok { if texture, ok := m.textures[id]; ok {
texture.DownloadCount++ texture.DownloadCount++
} }
return nil return nil
} }
func (m *MockTextureRepository) IncrementFavoriteCount(id int64) error { func (m *MockTextureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error {
if texture, ok := m.textures[id]; ok { if texture, ok := m.textures[id]; ok {
texture.FavoriteCount++ texture.FavoriteCount++
} }
return nil return nil
} }
func (m *MockTextureRepository) DecrementFavoriteCount(id int64) error { func (m *MockTextureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error {
if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 { if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 {
texture.FavoriteCount-- texture.FavoriteCount--
} }
return nil return nil
} }
func (m *MockTextureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error { func (m *MockTextureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error {
return nil return nil
} }
func (m *MockTextureRepository) IsFavorited(userID, textureID int64) (bool, error) { func (m *MockTextureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) {
if userFavs, ok := m.favorites[userID]; ok { if userFavs, ok := m.favorites[userID]; ok {
return userFavs[textureID], nil return userFavs[textureID], nil
} }
return false, nil return false, nil
} }
func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error { func (m *MockTextureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error {
if m.favorites[userID] == nil { if m.favorites[userID] == nil {
m.favorites[userID] = make(map[int64]bool) m.favorites[userID] = make(map[int64]bool)
} }
@@ -364,14 +405,14 @@ func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error {
return nil return nil
} }
func (m *MockTextureRepository) RemoveFavorite(userID, textureID int64) error { func (m *MockTextureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error {
if userFavs, ok := m.favorites[userID]; ok { if userFavs, ok := m.favorites[userID]; ok {
delete(userFavs, textureID) delete(userFavs, textureID)
} }
return nil return nil
} }
func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (m *MockTextureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var result []*model.Texture var result []*model.Texture
if userFavs, ok := m.favorites[userID]; ok { if userFavs, ok := m.favorites[userID]; ok {
for textureID := range userFavs { for textureID := range userFavs {
@@ -383,7 +424,7 @@ func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize in
return result, int64(len(result)), nil return result, int64(len(result)), nil
} }
func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, error) { func (m *MockTextureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) {
var count int64 var count int64
for _, texture := range m.textures { for _, texture := range m.textures {
if texture.UploaderID == uploaderID { if texture.UploaderID == uploaderID {
@@ -393,6 +434,34 @@ func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, erro
return count, nil return count, nil
} }
// FindByIDs 批量查询 Texture
func (m *MockTextureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) {
var result []*model.Texture
for _, id := range ids {
if tex, ok := m.textures[id]; ok {
result = append(result, tex)
}
}
return result, nil
}
// BatchUpdate 仅用于满足接口
func (m *MockTextureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
return 0, nil
}
// BatchDelete 仅用于满足接口
func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
var deleted int64
for _, id := range ids {
if _, ok := m.textures[id]; ok {
delete(m.textures, id)
deleted++
}
}
return deleted, nil
}
// MockTokenRepository 模拟TokenRepository // MockTokenRepository 模拟TokenRepository
type MockTokenRepository struct { type MockTokenRepository struct {
tokens map[string]*model.Token tokens map[string]*model.Token
@@ -409,7 +478,7 @@ func NewMockTokenRepository() *MockTokenRepository {
} }
} }
func (m *MockTokenRepository) Create(token *model.Token) error { func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error {
if m.FailCreate { if m.FailCreate {
return errors.New("mock create error") return errors.New("mock create error")
} }
@@ -418,7 +487,7 @@ func (m *MockTokenRepository) Create(token *model.Token) error {
return nil return nil
} }
func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) { func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
@@ -428,14 +497,14 @@ func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Toke
return nil, errors.New("token not found") return nil, errors.New("token not found")
} }
func (m *MockTokenRepository) GetByUserID(userId int64) ([]*model.Token, error) { func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
if m.FailFind { if m.FailFind {
return nil, errors.New("mock find error") return nil, errors.New("mock find error")
} }
return m.userTokens[userId], nil return m.userTokens[userId], nil
} }
func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) { func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
if m.FailFind { if m.FailFind {
return "", errors.New("mock find error") return "", errors.New("mock find error")
} }
@@ -445,7 +514,7 @@ func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string,
return "", errors.New("token not found") return "", errors.New("token not found")
} }
func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) { func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
if m.FailFind { if m.FailFind {
return 0, errors.New("mock find error") return 0, errors.New("mock find error")
} }
@@ -455,7 +524,7 @@ func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64,
return 0, errors.New("token not found") return 0, errors.New("token not found")
} }
func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error { func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
if m.FailDelete { if m.FailDelete {
return errors.New("mock delete error") return errors.New("mock delete error")
} }
@@ -463,7 +532,7 @@ func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error {
return nil return nil
} }
func (m *MockTokenRepository) DeleteByUserID(userId int64) error { func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
if m.FailDelete { if m.FailDelete {
return errors.New("mock delete error") return errors.New("mock delete error")
} }
@@ -474,7 +543,7 @@ func (m *MockTokenRepository) DeleteByUserID(userId int64) error {
return nil return nil
} }
func (m *MockTokenRepository) BatchDelete(accessTokens []string) (int64, error) { func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
if m.FailDelete { if m.FailDelete {
return 0, errors.New("mock delete error") return 0, errors.New("mock delete error")
} }
@@ -499,14 +568,14 @@ func NewMockSystemConfigRepository() *MockSystemConfigRepository {
} }
} }
func (m *MockSystemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) { func (m *MockSystemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
if config, ok := m.configs[key]; ok { if config, ok := m.configs[key]; ok {
return config, nil return config, nil
} }
return nil, nil return nil, nil
} }
func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) { func (m *MockSystemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
var result []model.SystemConfig var result []model.SystemConfig
for _, v := range m.configs { for _, v := range m.configs {
result = append(result, *v) result = append(result, *v)
@@ -514,7 +583,7 @@ func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) {
return result, nil return result, nil
} }
func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) { func (m *MockSystemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
var result []model.SystemConfig var result []model.SystemConfig
for _, v := range m.configs { for _, v := range m.configs {
result = append(result, *v) result = append(result, *v)
@@ -522,12 +591,12 @@ func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) {
return result, nil return result, nil
} }
func (m *MockSystemConfigRepository) Update(config *model.SystemConfig) error { func (m *MockSystemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
m.configs[config.Key] = config m.configs[config.Key] = config
return nil return nil
} }
func (m *MockSystemConfigRepository) UpdateValue(key, value string) error { func (m *MockSystemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
if config, ok := m.configs[key]; ok { if config, ok := m.configs[key]; ok {
config.Value = value config.Value = value
return nil return nil

View File

@@ -47,7 +47,7 @@ func NewProfileService(
func (s *profileService) Create(ctx context.Context, userID int64, name string) (*model.Profile, error) { func (s *profileService) Create(ctx context.Context, userID int64, name string) (*model.Profile, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.FindByID(userID) user, err := s.userRepo.FindByID(ctx, userID)
if err != nil || user == nil { if err != nil || user == nil {
return nil, errors.New("用户不存在") return nil, errors.New("用户不存在")
} }
@@ -56,7 +56,7 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string)
} }
// 检查角色名是否已存在 // 检查角色名是否已存在
existingName, err := s.profileRepo.FindByName(name) existingName, err := s.profileRepo.FindByName(ctx, name)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("查询角色名失败: %w", err) return nil, fmt.Errorf("查询角色名失败: %w", err)
} }
@@ -80,12 +80,12 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string)
IsActive: true, IsActive: true,
} }
if err := s.profileRepo.Create(profile); err != nil { if err := s.profileRepo.Create(ctx, profile); err != nil {
return nil, fmt.Errorf("创建档案失败: %w", err) return nil, fmt.Errorf("创建档案失败: %w", err)
} }
// 设置活跃状态 // 设置活跃状态
if err := s.profileRepo.SetActive(profileUUID, userID); err != nil { if err := s.profileRepo.SetActive(ctx, profileUUID, userID); err != nil {
return nil, fmt.Errorf("设置活跃状态失败: %w", err) return nil, fmt.Errorf("设置活跃状态失败: %w", err)
} }
@@ -104,7 +104,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
} }
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
profile2, err := s.profileRepo.FindByUUID(uuid) profile2, err := s.profileRepo.FindByUUID(ctx, uuid)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProfileNotFound return nil, ErrProfileNotFound
@@ -131,7 +131,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
} }
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
profiles, err := s.profileRepo.FindByUserID(userID) profiles, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil { if err != nil {
return nil, fmt.Errorf("查询档案列表失败: %w", err) return nil, fmt.Errorf("查询档案列表失败: %w", err)
} }
@@ -148,7 +148,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
func (s *profileService) Update(ctx context.Context, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) { func (s *profileService) Update(ctx context.Context, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
// 获取档案并验证权限 // 获取档案并验证权限
profile, err := s.profileRepo.FindByUUID(uuid) profile, err := s.profileRepo.FindByUUID(ctx, uuid)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrProfileNotFound return nil, ErrProfileNotFound
@@ -162,7 +162,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
// 检查角色名是否重复 // 检查角色名是否重复
if name != nil && *name != profile.Name { if name != nil && *name != profile.Name {
existingName, err := s.profileRepo.FindByName(*name) existingName, err := s.profileRepo.FindByName(ctx, *name)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("查询角色名失败: %w", err) return nil, fmt.Errorf("查询角色名失败: %w", err)
} }
@@ -180,7 +180,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
profile.CapeID = capeID profile.CapeID = capeID
} }
if err := s.profileRepo.Update(profile); err != nil { if err := s.profileRepo.Update(ctx, profile); err != nil {
return nil, fmt.Errorf("更新档案失败: %w", err) return nil, fmt.Errorf("更新档案失败: %w", err)
} }
@@ -190,12 +190,12 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
s.cacheKeys.ProfileList(userID), s.cacheKeys.ProfileList(userID),
) )
return s.profileRepo.FindByUUID(uuid) return s.profileRepo.FindByUUID(ctx, uuid)
} }
func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) error { func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) error {
// 获取档案并验证权限 // 获取档案并验证权限
profile, err := s.profileRepo.FindByUUID(uuid) profile, err := s.profileRepo.FindByUUID(ctx, uuid)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProfileNotFound return ErrProfileNotFound
@@ -207,7 +207,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64)
return ErrProfileNoPermission return ErrProfileNoPermission
} }
if err := s.profileRepo.Delete(uuid); err != nil { if err := s.profileRepo.Delete(ctx, uuid); err != nil {
return fmt.Errorf("删除档案失败: %w", err) return fmt.Errorf("删除档案失败: %w", err)
} }
@@ -222,7 +222,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64)
func (s *profileService) SetActive(ctx context.Context, uuid string, userID int64) error { func (s *profileService) SetActive(ctx context.Context, uuid string, userID int64) error {
// 获取档案并验证权限 // 获取档案并验证权限
profile, err := s.profileRepo.FindByUUID(uuid) profile, err := s.profileRepo.FindByUUID(ctx, uuid)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return ErrProfileNotFound return ErrProfileNotFound
@@ -234,11 +234,11 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6
return ErrProfileNoPermission return ErrProfileNoPermission
} }
if err := s.profileRepo.SetActive(uuid, userID); err != nil { if err := s.profileRepo.SetActive(ctx, uuid, userID); err != nil {
return fmt.Errorf("设置活跃状态失败: %w", err) return fmt.Errorf("设置活跃状态失败: %w", err)
} }
if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil { if err := s.profileRepo.UpdateLastUsedAt(ctx, uuid); err != nil {
return fmt.Errorf("更新使用时间失败: %w", err) return fmt.Errorf("更新使用时间失败: %w", err)
} }
@@ -249,7 +249,7 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6
} }
func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfiles int) error { func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfiles int) error {
count, err := s.profileRepo.CountByUserID(userID) count, err := s.profileRepo.CountByUserID(ctx, userID)
if err != nil { if err != nil {
return fmt.Errorf("查询档案数量失败: %w", err) return fmt.Errorf("查询档案数量失败: %w", err)
} }
@@ -261,7 +261,7 @@ func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfil
} }
func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) { func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
profiles, err := s.profileRepo.GetByNames(names) profiles, err := s.profileRepo.GetByNames(ctx, names)
if err != nil { if err != nil {
return nil, fmt.Errorf("查找失败: %w", err) return nil, fmt.Errorf("查找失败: %w", err)
} }
@@ -270,7 +270,7 @@ func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*mod
func (s *profileService) GetByProfileName(ctx context.Context, name string) (*model.Profile, error) { func (s *profileService) GetByProfileName(ctx context.Context, name string) (*model.Profile, error) {
// Profile name 查询通常不会频繁缓存,但为了一致性也添加 // Profile name 查询通常不会频繁缓存,但为了一致性也添加
profile, err := s.profileRepo.FindByName(name) profile, err := s.profileRepo.FindByName(ctx, name)
if err != nil { if err != nil {
return nil, errors.New("用户角色未创建") return nil, errors.New("用户角色未创建")
} }

View File

@@ -426,7 +426,7 @@ func TestProfileServiceImpl_Create(t *testing.T) {
Email: "test@example.com", Email: "test@example.com",
Status: 1, Status: 1,
} }
userRepo.Create(testUser) _ = userRepo.Create(context.Background(), testUser)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
@@ -459,7 +459,7 @@ func TestProfileServiceImpl_Create(t *testing.T) {
wantErr: true, wantErr: true,
errMsg: "角色名已被使用", errMsg: "角色名已被使用",
setupMocks: func() { setupMocks: func() {
profileRepo.Create(&model.Profile{ _ = profileRepo.Create(context.Background(), &model.Profile{
UUID: "existing-uuid", UUID: "existing-uuid",
UserID: 2, UserID: 2,
Name: "ExistingProfile", Name: "ExistingProfile",
@@ -516,7 +516,7 @@ func TestProfileServiceImpl_GetByUUID(t *testing.T) {
UserID: 1, UserID: 1,
Name: "TestProfile", Name: "TestProfile",
} }
profileRepo.Create(testProfile) _ = profileRepo.Create(context.Background(), testProfile)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
@@ -575,7 +575,7 @@ func TestProfileServiceImpl_Delete(t *testing.T) {
UserID: 1, UserID: 1,
Name: "DeleteTestProfile", Name: "DeleteTestProfile",
} }
profileRepo.Create(testProfile) _ = profileRepo.Create(context.Background(), testProfile)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger) profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
@@ -625,9 +625,9 @@ func TestProfileServiceImpl_GetByUserID(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()
// 为用户 1 和 2 预置不同档案 // 为用户 1 和 2 预置不同档案
profileRepo.Create(&model.Profile{UUID: "p1", UserID: 1, Name: "P1"}) _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p1", UserID: 1, Name: "P1"})
profileRepo.Create(&model.Profile{UUID: "p2", UserID: 1, Name: "P2"}) _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p2", UserID: 1, Name: "P2"})
profileRepo.Create(&model.Profile{UUID: "p3", UserID: 2, Name: "P3"}) _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p3", UserID: 2, Name: "P3"})
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)
@@ -653,7 +653,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) {
UserID: 1, UserID: 1,
Name: "OldName", Name: "OldName",
} }
profileRepo.Create(profile) _ = profileRepo.Create(context.Background(), profile)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)
@@ -678,7 +678,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) {
} }
// 名称重复 // 名称重复
profileRepo.Create(&model.Profile{ _ = profileRepo.Create(context.Background(), &model.Profile{
UUID: "u2", UUID: "u2",
UserID: 2, UserID: 2,
Name: "Duplicate", Name: "Duplicate",
@@ -705,8 +705,8 @@ func TestProfileServiceImpl_CheckLimit_And_GetByNames(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()
// 为用户 1 预置 2 个档案 // 为用户 1 预置 2 个档案
profileRepo.Create(&model.Profile{UUID: "a", UserID: 1, Name: "A"}) _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "a", UserID: 1, Name: "A"})
profileRepo.Create(&model.Profile{UUID: "b", UserID: 1, Name: "B"}) _ = profileRepo.Create(context.Background(), &model.Profile{UUID: "b", UserID: 1, Name: "B"})
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger) svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)

View File

@@ -32,8 +32,8 @@ const (
RedisTTL = 0 // 永不过期,由应用程序管理过期时间 RedisTTL = 0 // 永不过期,由应用程序管理过期时间
) )
// signatureService 签名服务实现 // SignatureService 签名服务(导出以便依赖注入)
type signatureService struct { type SignatureService struct {
profileRepo repository.ProfileRepository profileRepo repository.ProfileRepository
redis *redis.Client redis *redis.Client
logger *zap.Logger logger *zap.Logger
@@ -44,8 +44,8 @@ func NewSignatureService(
profileRepo repository.ProfileRepository, profileRepo repository.ProfileRepository,
redisClient *redis.Client, redisClient *redis.Client,
logger *zap.Logger, logger *zap.Logger,
) *signatureService { ) *SignatureService {
return &signatureService{ return &SignatureService{
profileRepo: profileRepo, profileRepo: profileRepo,
redis: redisClient, redis: redisClient,
logger: logger, logger: logger,
@@ -53,7 +53,7 @@ func NewSignatureService(
} }
// NewKeyPair 生成新的RSA密钥对 // NewKeyPair 生成新的RSA密钥对
func (s *signatureService) NewKeyPair() (*model.KeyPair, error) { func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, KeySize) privateKey, err := rsa.GenerateKey(rand.Reader, KeySize)
if err != nil { if err != nil {
return nil, fmt.Errorf("生成RSA密钥对失败: %w", err) return nil, fmt.Errorf("生成RSA密钥对失败: %w", err)
@@ -132,7 +132,7 @@ func (s *signatureService) NewKeyPair() (*model.KeyPair, error) {
} }
// GetOrCreateYggdrasilKeyPair 获取或创建Yggdrasil根密钥对 // GetOrCreateYggdrasilKeyPair 获取或创建Yggdrasil根密钥对
func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) { func (s *SignatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) {
ctx := context.Background() ctx := context.Background()
// 尝试从Redis获取密钥 // 尝试从Redis获取密钥
@@ -201,7 +201,7 @@ func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKe
} }
// GetPublicKeyFromRedis 从Redis获取公钥 // GetPublicKeyFromRedis 从Redis获取公钥
func (s *signatureService) GetPublicKeyFromRedis() (string, error) { func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
ctx := context.Background() ctx := context.Background()
publicKey, err := s.redis.Get(ctx, PublicKeyRedisKey) publicKey, err := s.redis.Get(ctx, PublicKeyRedisKey)
if err != nil { if err != nil {
@@ -218,7 +218,7 @@ func (s *signatureService) GetPublicKeyFromRedis() (string, error) {
} }
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串 // SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串
func (s *signatureService) SignStringWithSHA1withRSA(data string) (string, error) { func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
ctx := context.Background() ctx := context.Background()
// 从Redis获取私钥 // 从Redis获取私钥

View File

@@ -41,13 +41,13 @@ func NewTextureService(
func (s *textureService) Create(ctx context.Context, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) { func (s *textureService) Create(ctx context.Context, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
// 验证用户存在 // 验证用户存在
user, err := s.userRepo.FindByID(uploaderID) user, err := s.userRepo.FindByID(ctx, uploaderID)
if err != nil || user == nil { if err != nil || user == nil {
return nil, ErrUserNotFound return nil, ErrUserNotFound
} }
// 检查Hash是否已存在 // 检查Hash是否已存在
existingTexture, err := s.textureRepo.FindByHash(hash) existingTexture, err := s.textureRepo.FindByHash(ctx, hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -77,7 +77,7 @@ func (s *textureService) Create(ctx context.Context, uploaderID int64, name, des
FavoriteCount: 0, FavoriteCount: 0,
} }
if err := s.textureRepo.Create(texture); err != nil { if err := s.textureRepo.Create(ctx, texture); err != nil {
return nil, err return nil, err
} }
@@ -99,7 +99,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
} }
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
texture2, err := s.textureRepo.FindByID(id) texture2, err := s.textureRepo.FindByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -132,7 +132,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
} }
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
texture2, err := s.textureRepo.FindByHash(hash) texture2, err := s.textureRepo.FindByHash(ctx, hash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -165,7 +165,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
} }
// 缓存未命中,从数据库查询 // 缓存未命中,从数据库查询
textures, total, err := s.textureRepo.FindByUploaderID(uploaderID, page, pageSize) textures, total, err := s.textureRepo.FindByUploaderID(ctx, uploaderID, page, pageSize)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@@ -184,12 +184,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
func (s *textureService) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) { func (s *textureService) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
page, pageSize = NormalizePagination(page, pageSize) page, pageSize = NormalizePagination(page, pageSize)
return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize) return s.textureRepo.Search(ctx, keyword, textureType, publicOnly, page, pageSize)
} }
func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) { func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
// 获取材质并验证权限 // 获取材质并验证权限
texture, err := s.textureRepo.FindByID(textureID) texture, err := s.textureRepo.FindByID(ctx, textureID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -213,7 +213,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
} }
if len(updates) > 0 { if len(updates) > 0 {
if err := s.textureRepo.UpdateFields(textureID, updates); err != nil { if err := s.textureRepo.UpdateFields(ctx, textureID, updates); err != nil {
return nil, err return nil, err
} }
} }
@@ -222,12 +222,12 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID)) s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID))
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID)) s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
return s.textureRepo.FindByID(textureID) return s.textureRepo.FindByID(ctx, textureID)
} }
func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64) error { func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64) error {
// 获取材质并验证权限 // 获取材质并验证权限
texture, err := s.textureRepo.FindByID(textureID) texture, err := s.textureRepo.FindByID(ctx, textureID)
if err != nil { if err != nil {
return err return err
} }
@@ -238,7 +238,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
return ErrTextureNoPermission return ErrTextureNoPermission
} }
err = s.textureRepo.Delete(textureID) err = s.textureRepo.Delete(ctx, textureID)
if err != nil { if err != nil {
return err return err
} }
@@ -252,7 +252,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID int64) (bool, error) { func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID int64) (bool, error) {
// 确保材质存在 // 确保材质存在
texture, err := s.textureRepo.FindByID(textureID) texture, err := s.textureRepo.FindByID(ctx, textureID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -260,27 +260,27 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i
return false, ErrTextureNotFound return false, ErrTextureNotFound
} }
isFavorited, err := s.textureRepo.IsFavorited(userID, textureID) isFavorited, err := s.textureRepo.IsFavorited(ctx, userID, textureID)
if err != nil { if err != nil {
return false, err return false, err
} }
if isFavorited { if isFavorited {
// 已收藏 -> 取消收藏 // 已收藏 -> 取消收藏
if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil { if err := s.textureRepo.RemoveFavorite(ctx, userID, textureID); err != nil {
return false, err return false, err
} }
if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil { if err := s.textureRepo.DecrementFavoriteCount(ctx, textureID); err != nil {
return false, err return false, err
} }
return false, nil return false, nil
} }
// 未收藏 -> 添加收藏 // 未收藏 -> 添加收藏
if err := s.textureRepo.AddFavorite(userID, textureID); err != nil { if err := s.textureRepo.AddFavorite(ctx, userID, textureID); err != nil {
return false, err return false, err
} }
if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil { if err := s.textureRepo.IncrementFavoriteCount(ctx, textureID); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
@@ -288,11 +288,11 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i
func (s *textureService) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) { func (s *textureService) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
page, pageSize = NormalizePagination(page, pageSize) page, pageSize = NormalizePagination(page, pageSize)
return s.textureRepo.GetUserFavorites(userID, page, pageSize) return s.textureRepo.GetUserFavorites(ctx, userID, page, pageSize)
} }
func (s *textureService) CheckUploadLimit(ctx context.Context, uploaderID int64, maxTextures int) error { func (s *textureService) CheckUploadLimit(ctx context.Context, uploaderID int64, maxTextures int) error {
count, err := s.textureRepo.CountByUploaderID(uploaderID) count, err := s.textureRepo.CountByUploaderID(ctx, uploaderID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -491,7 +491,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
Email: "test@example.com", Email: "test@example.com",
Status: 1, Status: 1,
} }
userRepo.Create(testUser) _ = userRepo.Create(context.Background(), testUser)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
@@ -539,7 +539,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
wantErr: true, wantErr: true,
errContains: "已存在", errContains: "已存在",
setupMocks: func() { setupMocks: func() {
textureRepo.Create(&model.Texture{ _ = textureRepo.Create(context.Background(), &model.Texture{
ID: 100, ID: 100,
UploaderID: 1, UploaderID: 1,
Name: "ExistingTexture", Name: "ExistingTexture",
@@ -614,7 +614,7 @@ func TestTextureServiceImpl_GetByID(t *testing.T) {
Name: "TestTexture", Name: "TestTexture",
Hash: "test-hash", Hash: "test-hash",
} }
textureRepo.Create(testTexture) _ = textureRepo.Create(context.Background(), testTexture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
@@ -666,7 +666,7 @@ func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) {
// 预置多条 Texture // 预置多条 Texture
for i := int64(1); i <= 5; i++ { for i := int64(1); i <= 5; i++ {
textureRepo.Create(&model.Texture{ _ = textureRepo.Create(context.Background(), &model.Texture{
ID: i, ID: i,
UploaderID: 1, UploaderID: 1,
Name: "T", Name: "T",
@@ -711,7 +711,7 @@ func TestTextureServiceImpl_Update_And_Delete(t *testing.T) {
Description: "OldDesc", Description: "OldDesc",
IsPublic: false, IsPublic: false,
} }
textureRepo.Create(texture) _ = textureRepo.Create(context.Background(), texture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
@@ -755,12 +755,12 @@ func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) {
// 预置若干 Texture 与收藏关系 // 预置若干 Texture 与收藏关系
for i := int64(1); i <= 3; i++ { for i := int64(1); i <= 3; i++ {
textureRepo.Create(&model.Texture{ _ = textureRepo.Create(context.Background(), &model.Texture{
ID: i, ID: i,
UploaderID: 1, UploaderID: 1,
Name: "T", Name: "T",
}) })
_ = textureRepo.AddFavorite(1, i) _ = textureRepo.AddFavorite(context.Background(), 1, i)
} }
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
@@ -796,7 +796,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
// 预置用户和Texture // 预置用户和Texture
testUser := &model.User{ID: 1, Username: "testuser", Status: 1} testUser := &model.User{ID: 1, Username: "testuser", Status: 1}
userRepo.Create(testUser) _ = userRepo.Create(context.Background(), testUser)
testTexture := &model.Texture{ testTexture := &model.Texture{
ID: 1, ID: 1,
@@ -804,7 +804,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
Name: "TestTexture", Name: "TestTexture",
Hash: "test-hash", Hash: "test-hash",
} }
textureRepo.Create(testTexture) _ = textureRepo.Create(context.Background(), testTexture)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger) textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)

View File

@@ -46,12 +46,12 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
) )
// 设置超时上下文 // 设置超时上下文
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
// 验证用户存在 // 验证用户存在
if UUID != "" { if UUID != "" {
_, err := s.profileRepo.FindByUUID(UUID) _, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil { if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
} }
@@ -72,7 +72,7 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
} }
// 获取用户配置文件 // 获取用户配置文件
profiles, err := s.profileRepo.FindByUserID(userID) profiles, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil { if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
} }
@@ -85,23 +85,27 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
availableProfiles = profiles availableProfiles = profiles
// 插入令牌 // 插入令牌
err = s.tokenRepo.Create(&token) err = s.tokenRepo.Create(ctx, &token)
if err != nil { if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
} }
// 清理多余的令牌 // 清理多余的令牌(使用独立的后台上下文)
go s.checkAndCleanupExcessTokens(userID) go s.checkAndCleanupExcessTokens(context.Background(), userID)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil return selectedProfileID, availableProfiles, accessToken, clientToken, nil
} }
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool { func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return false return false
} }
token, err := s.tokenRepo.FindByAccessToken(accessToken) token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
return false return false
} }
@@ -118,12 +122,16 @@ func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken st
} }
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("accessToken不能为空") return "", "", errors.New("accessToken不能为空")
} }
// 查找旧令牌 // 查找旧令牌
oldToken, err := s.tokenRepo.FindByAccessToken(accessToken) oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return "", "", errors.New("accessToken无效") return "", "", errors.New("accessToken无效")
@@ -134,7 +142,7 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
// 验证profile // 验证profile
if selectedProfileID != "" { if selectedProfileID != "" {
valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID) valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
if validErr != nil { if validErr != nil {
s.logger.Error("验证Profile失败", s.logger.Error("验证Profile失败",
zap.Error(err), zap.Error(err),
@@ -174,13 +182,13 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
} }
// 先插入新令牌,再删除旧令牌 // 先插入新令牌,再删除旧令牌
err = s.tokenRepo.Create(&newToken) err = s.tokenRepo.Create(ctx, &newToken)
if err != nil { if err != nil {
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken)) s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
return "", "", fmt.Errorf("创建新Token失败: %w", err) return "", "", fmt.Errorf("创建新Token失败: %w", err)
} }
err = s.tokenRepo.DeleteByAccessToken(accessToken) err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
s.logger.Warn("删除旧Token失败但新Token已创建", s.logger.Warn("删除旧Token失败但新Token已创建",
zap.Error(err), zap.Error(err),
@@ -194,11 +202,15 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
} }
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) { func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return return
} }
err := s.tokenRepo.DeleteByAccessToken(accessToken) err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken)) s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
return return
@@ -207,11 +219,15 @@ func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
} }
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) { func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if userID == 0 { if userID == 0 {
return return
} }
err := s.tokenRepo.DeleteByUserID(userID) err := s.tokenRepo.DeleteByUserID(ctx, userID)
if err != nil { if err != nil {
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID)) s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
return return
@@ -221,21 +237,33 @@ func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
} }
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) { func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
return s.tokenRepo.GetUUIDByAccessToken(accessToken) // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
} }
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) { func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
return s.tokenRepo.GetUserIDByAccessToken(accessToken) // 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
} }
// 私有辅助方法 // 私有辅助方法
func (s *tokenService) checkAndCleanupExcessTokens(userID int64) { func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
if userID == 0 { if userID == 0 {
return return
} }
tokens, err := s.tokenRepo.GetByUserID(userID) // 为清理操作设置更长的超时时间
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
defer cancel()
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
if err != nil { if err != nil {
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return return
@@ -250,7 +278,7 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
} }
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
if err != nil { if err != nil {
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return return
@@ -261,12 +289,12 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
} }
} }
func (s *tokenService) validateProfileByUserID(userID int64, UUID string) (bool, error) { func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" { if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空") return false, errors.New("用户ID或配置文件ID不能为空")
} }
profile, err := s.profileRepo.FindByUUID(UUID) profile, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在") return false, errors.New("配置文件不存在")

View File

@@ -55,12 +55,12 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
) )
// 设置超时上下文 // 设置超时上下文
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout) ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel() defer cancel()
// 验证用户存在 // 验证用户存在
if UUID != "" { if UUID != "" {
_, err := s.profileRepo.FindByUUID(UUID) _, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil { if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
} }
@@ -73,7 +73,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
// 获取或创建Client // 获取或创建Client
var client *model.Client var client *model.Client
existingClient, err := s.clientRepo.FindByClientToken(clientToken) existingClient, err := s.clientRepo.FindByClientToken(ctx, clientToken)
if err != nil { if err != nil {
// Client不存在创建新的 // Client不存在创建新的
clientUUID := uuid.New().String() clientUUID := uuid.New().String()
@@ -90,7 +90,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
client.ProfileID = UUID client.ProfileID = UUID
} }
if err := s.clientRepo.Create(client); err != nil { if err := s.clientRepo.Create(ctx, client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
} }
} else { } else {
@@ -103,14 +103,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
client.UpdatedAt = time.Now() client.UpdatedAt = time.Now()
if UUID != "" { if UUID != "" {
client.ProfileID = UUID client.ProfileID = UUID
if err := s.clientRepo.Update(client); err != nil { if err := s.clientRepo.Update(ctx, client); err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
} }
} }
} }
// 获取用户配置文件 // 获取用户配置文件
profiles, err := s.profileRepo.FindByUserID(userID) profiles, err := s.profileRepo.FindByUserID(ctx, userID)
if err != nil { if err != nil {
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err) return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
} }
@@ -122,7 +122,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
if profileID == "" { if profileID == "" {
profileID = selectedProfileID.UUID profileID = selectedProfileID.UUID
client.ProfileID = profileID client.ProfileID = profileID
s.clientRepo.Update(client) _ = s.clientRepo.Update(ctx, client)
} }
} }
availableProfiles = profiles availableProfiles = profiles
@@ -170,20 +170,23 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
StaleAt: &staleAt, StaleAt: &staleAt,
} }
err = s.tokenRepo.Create(&token) err = s.tokenRepo.Create(ctx, &token)
if err != nil { if err != nil {
s.logger.Warn("保存Token记录失败但JWT已生成", zap.Error(err)) s.logger.Warn("保存Token记录失败但JWT已生成", zap.Error(err))
// 不返回错误因为JWT本身已经生成成功 // 不返回错误因为JWT本身已经生成成功
} }
// 清理多余的令牌 // 清理多余的令牌(使用独立的后台上下文)
go s.checkAndCleanupExcessTokens(userID) go s.checkAndCleanupExcessTokens(context.Background(), userID)
return selectedProfileID, availableProfiles, accessToken, clientToken, nil return selectedProfileID, availableProfiles, accessToken, clientToken, nil
} }
// Validate 验证Token使用JWT验证 // Validate 验证Token使用JWT验证
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool { func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return false return false
} }
@@ -195,7 +198,7 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
} }
// 查找Client // 查找Client
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
return false return false
} }
@@ -215,6 +218,9 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
// Refresh 刷新Token使用Version机制无需删除旧Token // Refresh 刷新Token使用Version机制无需删除旧Token
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) { func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return "", "", errors.New("accessToken不能为空") return "", "", errors.New("accessToken不能为空")
} }
@@ -226,7 +232,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
} }
// 查找Client // 查找Client
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
return "", "", errors.New("无法找到对应的Client") return "", "", errors.New("无法找到对应的Client")
} }
@@ -243,7 +249,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
// 验证Profile // 验证Profile
if selectedProfileID != "" { if selectedProfileID != "" {
valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID) valid, validErr := s.validateProfileByUserID(ctx, client.UserID, selectedProfileID)
if validErr != nil { if validErr != nil {
s.logger.Error("验证Profile失败", s.logger.Error("验证Profile失败",
zap.Error(validErr), zap.Error(validErr),
@@ -269,7 +275,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
// 增加Version这是关键通过Version失效所有旧Token // 增加Version这是关键通过Version失效所有旧Token
client.Version++ client.Version++
client.UpdatedAt = time.Now() client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil { if err := s.clientRepo.Update(ctx, client); err != nil {
return "", "", fmt.Errorf("更新Client版本失败: %w", err) return "", "", fmt.Errorf("更新Client版本失败: %w", err)
} }
@@ -315,7 +321,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
StaleAt: &staleAt, StaleAt: &staleAt,
} }
err = s.tokenRepo.Create(&newToken) err = s.tokenRepo.Create(ctx, &newToken)
if err != nil { if err != nil {
s.logger.Warn("保存新Token记录失败但JWT已生成", zap.Error(err)) s.logger.Warn("保存新Token记录失败但JWT已生成", zap.Error(err))
} }
@@ -326,6 +332,10 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
// Invalidate 使Token失效通过增加Version // Invalidate 使Token失效通过增加Version
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) { func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if accessToken == "" { if accessToken == "" {
return return
} }
@@ -338,7 +348,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
} }
// 查找Client并增加Version // 查找Client并增加Version
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
s.logger.Warn("无法找到对应的Client", zap.Error(err)) s.logger.Warn("无法找到对应的Client", zap.Error(err))
return return
@@ -347,7 +357,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
// 增加Version以失效所有旧Token // 增加Version以失效所有旧Token
client.Version++ client.Version++
client.UpdatedAt = time.Now() client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil { if err := s.clientRepo.Update(ctx, client); err != nil {
s.logger.Error("失效Token失败", zap.Error(err)) s.logger.Error("失效Token失败", zap.Error(err))
return return
} }
@@ -357,12 +367,16 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
// InvalidateUserTokens 使用户所有Token失效 // InvalidateUserTokens 使用户所有Token失效
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) { func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
// 设置超时上下文
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
defer cancel()
if userID == 0 { if userID == 0 {
return return
} }
// 获取用户所有Client // 获取用户所有Client
clients, err := s.clientRepo.FindByUserID(userID) clients, err := s.clientRepo.FindByUserID(ctx, userID)
if err != nil { if err != nil {
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID)) s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
return return
@@ -372,7 +386,7 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
for _, client := range clients { for _, client := range clients {
client.Version++ client.Version++
client.UpdatedAt = time.Now() client.UpdatedAt = time.Now()
if err := s.clientRepo.Update(client); err != nil { if err := s.clientRepo.Update(ctx, client); err != nil {
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID)) s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
} }
} }
@@ -385,7 +399,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil { if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容 // 如果JWT解析失败尝试从数据库查询向后兼容
return s.tokenRepo.GetUUIDByAccessToken(accessToken) return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
} }
if claims.ProfileID != "" { if claims.ProfileID != "" {
@@ -393,7 +407,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
} }
// 如果没有ProfileID从Client获取 // 如果没有ProfileID从Client获取
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
return "", fmt.Errorf("无法找到对应的Client: %w", err) return "", fmt.Errorf("无法找到对应的Client: %w", err)
} }
@@ -410,11 +424,11 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow) claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
if err != nil { if err != nil {
// 如果JWT解析失败尝试从数据库查询向后兼容 // 如果JWT解析失败尝试从数据库查询向后兼容
return s.tokenRepo.GetUserIDByAccessToken(accessToken) return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
} }
// 从Client获取UserID // 从Client获取UserID
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
return 0, fmt.Errorf("无法找到对应的Client: %w", err) return 0, fmt.Errorf("无法找到对应的Client: %w", err)
} }
@@ -429,12 +443,16 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
// 私有辅助方法 // 私有辅助方法
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) { func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
if userID == 0 { if userID == 0 {
return return
} }
tokens, err := s.tokenRepo.GetByUserID(userID) // 为清理操作设置更长的超时时间
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
defer cancel()
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
if err != nil { if err != nil {
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return return
@@ -449,7 +467,7 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken) tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
} }
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete) deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
if err != nil { if err != nil {
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10))) s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
return return
@@ -460,12 +478,12 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
} }
} }
func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) { func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
if userID == 0 || UUID == "" { if userID == 0 || UUID == "" {
return false, errors.New("用户ID或配置文件ID不能为空") return false, errors.New("用户ID或配置文件ID不能为空")
} }
profile, err := s.profileRepo.FindByUUID(UUID) profile, err := s.profileRepo.FindByUUID(ctx, UUID)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return false, errors.New("配置文件不存在") return false, errors.New("配置文件不存在")
@@ -482,7 +500,7 @@ func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken st
return nil, err return nil, err
} }
client, err := s.clientRepo.FindByUUID(claims.Subject) client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -208,7 +208,7 @@ func TestTokenServiceImpl_Create(t *testing.T) {
Name: "TestProfile", Name: "TestProfile",
IsActive: true, IsActive: true,
} }
profileRepo.Create(testProfile) _ = profileRepo.Create(context.Background(), testProfile)
tokenService := NewTokenService(tokenRepo, profileRepo, logger) tokenService := NewTokenService(tokenRepo, profileRepo, logger)
@@ -274,7 +274,7 @@ func TestTokenServiceImpl_Validate(t *testing.T) {
ProfileId: "test-profile-uuid", ProfileId: "test-profile-uuid",
Usable: true, Usable: true,
} }
tokenRepo.Create(testToken) _ = tokenRepo.Create(context.Background(), testToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger) tokenService := NewTokenService(tokenRepo, profileRepo, logger)
@@ -336,7 +336,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) {
ProfileId: "test-profile-uuid", ProfileId: "test-profile-uuid",
Usable: true, Usable: true,
} }
tokenRepo.Create(testToken) _ = tokenRepo.Create(context.Background(), testToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger) tokenService := NewTokenService(tokenRepo, profileRepo, logger)
@@ -352,7 +352,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) {
tokenService.Invalidate(ctx, "token-to-invalidate") tokenService.Invalidate(ctx, "token-to-invalidate")
// 验证Token已失效从repo中删除 // 验证Token已失效从repo中删除
_, err := tokenRepo.FindByAccessToken("token-to-invalidate") _, err := tokenRepo.FindByAccessToken(context.Background(), "token-to-invalidate")
if err == nil { if err == nil {
t.Error("Token应该已被删除") t.Error("Token应该已被删除")
} }
@@ -366,7 +366,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
// 预置多个Token // 预置多个Token
for i := 1; i <= 3; i++ { for i := 1; i <= 3; i++ {
tokenRepo.Create(&model.Token{ _ = tokenRepo.Create(context.Background(), &model.Token{
AccessToken: fmt.Sprintf("user1-token-%d", i), AccessToken: fmt.Sprintf("user1-token-%d", i),
ClientToken: "client-token", ClientToken: "client-token",
UserID: 1, UserID: 1,
@@ -374,7 +374,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
Usable: true, Usable: true,
}) })
} }
tokenRepo.Create(&model.Token{ _ = tokenRepo.Create(context.Background(), &model.Token{
AccessToken: "user2-token-1", AccessToken: "user2-token-1",
ClientToken: "client-token", ClientToken: "client-token",
UserID: 2, UserID: 2,
@@ -390,13 +390,13 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
tokenService.InvalidateUserTokens(ctx, 1) tokenService.InvalidateUserTokens(ctx, 1)
// 验证用户1的Token已失效 // 验证用户1的Token已失效
tokens, _ := tokenRepo.GetByUserID(1) tokens, _ := tokenRepo.GetByUserID(context.Background(), 1)
if len(tokens) > 0 { if len(tokens) > 0 {
t.Errorf("用户1的Token应该全部被删除但还剩 %d 个", len(tokens)) t.Errorf("用户1的Token应该全部被删除但还剩 %d 个", len(tokens))
} }
// 验证用户2的Token仍然存在 // 验证用户2的Token仍然存在
tokens2, _ := tokenRepo.GetByUserID(2) tokens2, _ := tokenRepo.GetByUserID(context.Background(), 2)
if len(tokens2) != 1 { if len(tokens2) != 1 {
t.Errorf("用户2的Token应该仍然存在期望1个实际 %d 个", len(tokens2)) t.Errorf("用户2的Token应该仍然存在期望1个实际 %d 个", len(tokens2))
} }
@@ -413,7 +413,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) {
UUID: "profile-uuid", UUID: "profile-uuid",
UserID: 1, UserID: 1,
} }
profileRepo.Create(profile) _ = profileRepo.Create(context.Background(), profile)
oldToken := &model.Token{ oldToken := &model.Token{
AccessToken: "old-token", AccessToken: "old-token",
@@ -422,7 +422,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) {
ProfileId: "", ProfileId: "",
Usable: true, Usable: true,
} }
tokenRepo.Create(oldToken) _ = tokenRepo.Create(context.Background(), oldToken)
tokenService := NewTokenService(tokenRepo, profileRepo, logger) tokenService := NewTokenService(tokenRepo, profileRepo, logger)
@@ -455,7 +455,7 @@ func TestTokenServiceImpl_GetByAccessToken(t *testing.T) {
ProfileId: "profile-42", ProfileId: "profile-42",
Usable: true, Usable: true,
} }
tokenRepo.Create(token) _ = tokenRepo.Create(context.Background(), token)
tokenService := NewTokenService(tokenRepo, profileRepo, logger) tokenService := NewTokenService(tokenRepo, profileRepo, logger)
@@ -489,25 +489,25 @@ func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) {
UUID: "p-1", UUID: "p-1",
UserID: 1, UserID: 1,
} }
profileRepo.Create(profile) _ = profileRepo.Create(context.Background(), profile)
// 参数非法 // 参数非法
if ok, err := svc.validateProfileByUserID(0, ""); err == nil || ok { if ok, err := svc.validateProfileByUserID(context.Background(), 0, ""); err == nil || ok {
t.Fatalf("validateProfileByUserID 在参数非法时应返回错误") t.Fatalf("validateProfileByUserID 在参数非法时应返回错误")
} }
// Profile 不存在 // Profile 不存在
if ok, err := svc.validateProfileByUserID(1, "not-exists"); err == nil || ok { if ok, err := svc.validateProfileByUserID(context.Background(), 1, "not-exists"); err == nil || ok {
t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误") t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误")
} }
// 用户与 Profile 匹配 // 用户与 Profile 匹配
if ok, err := svc.validateProfileByUserID(1, "p-1"); err != nil || !ok { if ok, err := svc.validateProfileByUserID(context.Background(), 1, "p-1"); err != nil || !ok {
t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err) t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err)
} }
// 用户与 Profile 不匹配 // 用户与 Profile 不匹配
if ok, err := svc.validateProfileByUserID(2, "p-1"); err != nil || ok { if ok, err := svc.validateProfileByUserID(context.Background(), 2, "p-1"); err != nil || ok {
t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err) t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err)
} }
} }

View File

@@ -1,12 +1,6 @@
package service package service
import ( import (
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"carrotskin/pkg/config"
"carrotskin/pkg/database"
"carrotskin/pkg/redis"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@@ -14,6 +8,14 @@ import (
"strings" "strings"
"time" "time"
apperrors "carrotskin/internal/errors"
"carrotskin/internal/model"
"carrotskin/internal/repository"
"carrotskin/pkg/auth"
"carrotskin/pkg/config"
"carrotskin/pkg/database"
"carrotskin/pkg/redis"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -54,21 +56,21 @@ func NewUserService(
func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) { func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) {
// 检查用户名是否已存在 // 检查用户名是否已存在
existingUser, err := s.userRepo.FindByUsername(username) existingUser, err := s.userRepo.FindByUsername(ctx, username)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
if existingUser != nil { if existingUser != nil {
return nil, "", errors.New("用户名已存在") return nil, "", apperrors.ErrUserAlreadyExists
} }
// 检查邮箱是否已存在 // 检查邮箱是否已存在
existingEmail, err := s.userRepo.FindByEmail(email) existingEmail, err := s.userRepo.FindByEmail(ctx, email)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
if existingEmail != nil { if existingEmail != nil {
return nil, "", errors.New("邮箱已被注册") return nil, "", apperrors.ErrEmailAlreadyExists
} }
// 加密密码 // 加密密码
@@ -98,7 +100,7 @@ func (s *userService) Register(ctx context.Context, username, password, email, a
Points: 0, Points: 0,
} }
if err := s.userRepo.Create(user); err != nil { if err := s.userRepo.Create(ctx, user); err != nil {
return nil, "", err return nil, "", err
} }
@@ -126,9 +128,9 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd
var err error var err error
if strings.Contains(usernameOrEmail, "@") { if strings.Contains(usernameOrEmail, "@") {
user, err = s.userRepo.FindByEmail(usernameOrEmail) user, err = s.userRepo.FindByEmail(ctx, usernameOrEmail)
} else { } else {
user, err = s.userRepo.FindByUsername(usernameOrEmail) user, err = s.userRepo.FindByUsername(ctx, usernameOrEmail)
} }
if err != nil { if err != nil {
@@ -166,12 +168,12 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd
// 更新最后登录时间 // 更新最后登录时间
now := time.Now() now := time.Now()
user.LastLoginAt = &now user.LastLoginAt = &now
_ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ _ = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
"last_login_at": now, "last_login_at": now,
}) })
// 记录成功登录日志 // 记录成功登录日志
s.logSuccessLogin(user.ID, ipAddress, userAgent) s.logSuccessLogin(ctx, user.ID, ipAddress, userAgent)
return user, token, nil return user, token, nil
} }
@@ -180,7 +182,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error
// 使用 Cached 装饰器自动处理缓存 // 使用 Cached 装饰器自动处理缓存
cacheKey := s.cacheKeys.User(id) cacheKey := s.cacheKeys.User(id)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByID(id) return s.userRepo.FindByID(ctx, id)
}, 5*time.Minute) }, 5*time.Minute)
} }
@@ -188,12 +190,12 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User
// 使用 Cached 装饰器自动处理缓存 // 使用 Cached 装饰器自动处理缓存
cacheKey := s.cacheKeys.UserByEmail(email) cacheKey := s.cacheKeys.UserByEmail(email)
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) { return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
return s.userRepo.FindByEmail(email) return s.userRepo.FindByEmail(ctx, email)
}, 5*time.Minute) }, 5*time.Minute)
} }
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error { func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
err := s.userRepo.Update(user) err := s.userRepo.Update(ctx, user)
if err != nil { if err != nil {
return err return err
} }
@@ -209,7 +211,7 @@ func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
} }
func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error { func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error {
err := s.userRepo.UpdateFields(userID, map[string]interface{}{ err := s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"avatar": avatarURL, "avatar": avatarURL,
}) })
if err != nil { if err != nil {
@@ -223,7 +225,7 @@ func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL
} }
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error { func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
user, err := s.userRepo.FindByID(userID) user, err := s.userRepo.FindByID(ctx, userID)
if err != nil || user == nil { if err != nil || user == nil {
return errors.New("用户不存在") return errors.New("用户不存在")
} }
@@ -237,7 +239,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw
return errors.New("密码加密失败") return errors.New("密码加密失败")
} }
err = s.userRepo.UpdateFields(userID, map[string]interface{}{ err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"password": hashedPassword, "password": hashedPassword,
}) })
if err != nil { if err != nil {
@@ -251,7 +253,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw
} }
func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error { func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error {
user, err := s.userRepo.FindByEmail(email) user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil || user == nil { if err != nil || user == nil {
return errors.New("用户不存在") return errors.New("用户不存在")
} }
@@ -261,7 +263,7 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri
return errors.New("密码加密失败") return errors.New("密码加密失败")
} }
err = s.userRepo.UpdateFields(user.ID, map[string]interface{}{ err = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
"password": hashedPassword, "password": hashedPassword,
}) })
if err != nil { if err != nil {
@@ -279,17 +281,17 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri
func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error { func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error {
// 获取旧邮箱 // 获取旧邮箱
oldUser, _ := s.userRepo.FindByID(userID) oldUser, _ := s.userRepo.FindByID(ctx, userID)
existingUser, err := s.userRepo.FindByEmail(newEmail) existingUser, err := s.userRepo.FindByEmail(ctx, newEmail)
if err != nil { if err != nil {
return err return err
} }
if existingUser != nil && existingUser.ID != userID { if existingUser != nil && existingUser.ID != userID {
return errors.New("邮箱已被其他用户使用") return apperrors.ErrEmailAlreadyExists
} }
err = s.userRepo.UpdateFields(userID, map[string]interface{}{ err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
"email": newEmail, "email": newEmail,
}) })
if err != nil { if err != nil {
@@ -346,7 +348,7 @@ func (s *userService) ValidateAvatarURL(ctx context.Context, avatarURL string) e
} }
func (s *userService) GetMaxProfilesPerUser() int { func (s *userService) GetMaxProfilesPerUser() int {
config, err := s.configRepo.GetByKey("max_profiles_per_user") config, err := s.configRepo.GetByKey(context.Background(), "max_profiles_per_user")
if err != nil || config == nil { if err != nil || config == nil {
return 5 return 5
} }
@@ -359,7 +361,7 @@ func (s *userService) GetMaxProfilesPerUser() int {
} }
func (s *userService) GetMaxTexturesPerUser() int { func (s *userService) GetMaxTexturesPerUser() int {
config, err := s.configRepo.GetByKey("max_textures_per_user") config, err := s.configRepo.GetByKey(context.Background(), "max_textures_per_user")
if err != nil || config == nil { if err != nil || config == nil {
return 50 return 50
} }
@@ -374,7 +376,7 @@ func (s *userService) GetMaxTexturesPerUser() int {
// 私有辅助方法 // 私有辅助方法
func (s *userService) getDefaultAvatar() string { func (s *userService) getDefaultAvatar() string {
config, err := s.configRepo.GetByKey("default_avatar") config, err := s.configRepo.GetByKey(context.Background(), "default_avatar")
if err != nil || config == nil || config.Value == "" { if err != nil || config == nil || config.Value == "" {
return "" return ""
} }
@@ -410,14 +412,14 @@ func (s *userService) recordLoginFailure(ctx context.Context, usernameOrEmail, i
identifier := usernameOrEmail + ":" + ipAddress identifier := usernameOrEmail + ":" + ipAddress
count, _ := RecordLoginFailure(ctx, s.redis, identifier) count, _ := RecordLoginFailure(ctx, s.redis, identifier)
if count >= MaxLoginAttempts { if count >= MaxLoginAttempts {
s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定") s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason+"-账号已锁定")
return return
} }
} }
s.logFailedLogin(userID, ipAddress, userAgent, reason) s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason)
} }
func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string) { func (s *userService) logSuccessLogin(ctx context.Context, userID int64, ipAddress, userAgent string) {
log := &model.UserLoginLog{ log := &model.UserLoginLog{
UserID: userID, UserID: userID,
IPAddress: ipAddress, IPAddress: ipAddress,
@@ -425,10 +427,10 @@ func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string)
LoginMethod: "PASSWORD", LoginMethod: "PASSWORD",
IsSuccess: true, IsSuccess: true,
} }
_ = s.userRepo.CreateLoginLog(log) _ = s.userRepo.CreateLoginLog(ctx, log)
} }
func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason string) { func (s *userService) logFailedLogin(ctx context.Context, userID int64, ipAddress, userAgent, reason string) {
log := &model.UserLoginLog{ log := &model.UserLoginLog{
UserID: userID, UserID: userID,
IPAddress: ipAddress, IPAddress: ipAddress,
@@ -437,5 +439,5 @@ func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason
IsSuccess: false, IsSuccess: false,
FailureReason: reason, FailureReason: reason,
} }
_ = s.userRepo.CreateLoginLog(log) _ = s.userRepo.CreateLoginLog(ctx, log)
} }

View File

@@ -49,9 +49,10 @@ func TestUserServiceImpl_Register(t *testing.T) {
email: "new@example.com", email: "new@example.com",
avatar: "", avatar: "",
wantErr: true, wantErr: true,
errMsg: "用户已存在", // 服务实现现已统一使用 apperrors.ErrUserAlreadyExists错误信息为“用户已存在
errMsg: "用户已存在",
setupMocks: func() { setupMocks: func() {
userRepo.Create(&model.User{ _ = userRepo.Create(context.Background(), &model.User{
Username: "existinguser", Username: "existinguser",
Email: "old@example.com", Email: "old@example.com",
}) })
@@ -66,7 +67,7 @@ func TestUserServiceImpl_Register(t *testing.T) {
wantErr: true, wantErr: true,
errMsg: "邮箱已被注册", errMsg: "邮箱已被注册",
setupMocks: func() { setupMocks: func() {
userRepo.Create(&model.User{ _ = userRepo.Create(context.Background(), &model.User{
Username: "otheruser", Username: "otheruser",
Email: "existing@example.com", Email: "existing@example.com",
}) })
@@ -126,7 +127,7 @@ func TestUserServiceImpl_Login(t *testing.T) {
Password: hashedPassword, Password: hashedPassword,
Status: 1, Status: 1,
} }
userRepo.Create(testUser) _ = userRepo.Create(context.Background(), testUser)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
@@ -207,7 +208,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) {
Email: "basic@example.com", Email: "basic@example.com",
Avatar: "", Avatar: "",
} }
userRepo.Create(user) _ = userRepo.Create(context.Background(), user)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
@@ -231,7 +232,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) {
if err := userService.UpdateInfo(ctx, user); err != nil { if err := userService.UpdateInfo(ctx, user); err != nil {
t.Fatalf("UpdateInfo 失败: %v", err) t.Fatalf("UpdateInfo 失败: %v", err)
} }
updated, _ := userRepo.FindByID(1) updated, _ := userRepo.FindByID(context.Background(), 1)
if updated.Username != "updated" { if updated.Username != "updated" {
t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username) t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username)
} }
@@ -255,7 +256,7 @@ func TestUserServiceImpl_ChangePassword(t *testing.T) {
Username: "changepw", Username: "changepw",
Password: hashed, Password: hashed,
} }
userRepo.Create(user) _ = userRepo.Create(context.Background(), user)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
@@ -290,7 +291,7 @@ func TestUserServiceImpl_ResetPassword(t *testing.T) {
Username: "resetpw", Username: "resetpw",
Email: "reset@example.com", Email: "reset@example.com",
} }
userRepo.Create(user) _ = userRepo.Create(context.Background(), user)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
@@ -317,8 +318,8 @@ func TestUserServiceImpl_ChangeEmail(t *testing.T) {
user1 := &model.User{ID: 1, Email: "user1@example.com"} user1 := &model.User{ID: 1, Email: "user1@example.com"}
user2 := &model.User{ID: 2, Email: "user2@example.com"} user2 := &model.User{ID: 2, Email: "user2@example.com"}
userRepo.Create(user1) _ = userRepo.Create(context.Background(), user1)
userRepo.Create(user2) _ = userRepo.Create(context.Background(), user2)
cacheManager := NewMockCacheManager() cacheManager := NewMockCacheManager()
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger) userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
@@ -389,8 +390,8 @@ func TestUserServiceImpl_MaxLimits(t *testing.T) {
} }
// 配置有效值 // 配置有效值
configRepo.Update(&model.SystemConfig{Key: "max_profiles_per_user", Value: "10"}) _ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_profiles_per_user", Value: "10"})
configRepo.Update(&model.SystemConfig{Key: "max_textures_per_user", Value: "100"}) _ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_textures_per_user", Value: "100"})
if got := userService.GetMaxProfilesPerUser(); got != 10 { if got := userService.GetMaxProfilesPerUser(); got != 10 {
t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got) t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got)

View File

@@ -38,7 +38,7 @@ func NewYggdrasilAuthService(
} }
func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email string) (int64, error) { func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email string) (int64, error) {
user, err := s.userRepo.FindByEmail(email) user, err := s.userRepo.FindByEmail(ctx, email)
if err != nil { if err != nil {
return 0, apperrors.ErrUserNotFound return 0, apperrors.ErrUserNotFound
} }
@@ -46,7 +46,7 @@ func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email strin
} }
func (s *yggdrasilAuthService) VerifyPassword(ctx context.Context, password string, userID int64) error { func (s *yggdrasilAuthService) VerifyPassword(ctx context.Context, password string, userID int64) error {
passwordStore, err := s.yggdrasilRepo.GetPasswordByID(userID) passwordStore, err := s.yggdrasilRepo.GetPasswordByID(ctx, userID)
if err != nil { if err != nil {
return apperrors.ErrPasswordNotSet return apperrors.ErrPasswordNotSet
} }
@@ -68,7 +68,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI
} }
// 检查Yggdrasil记录是否存在 // 检查Yggdrasil记录是否存在
_, err = s.yggdrasilRepo.GetPasswordByID(userID) _, err = s.yggdrasilRepo.GetPasswordByID(ctx, userID)
if err != nil { if err != nil {
// 如果不存在,创建新记录 // 如果不存在,创建新记录
yggdrasil := model.Yggdrasil{ yggdrasil := model.Yggdrasil{
@@ -82,7 +82,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI
} }
// 如果存在,更新密码(存储加密后的密码) // 如果存在,更新密码(存储加密后的密码)
if err := s.yggdrasilRepo.ResetPassword(userID, hashedPassword); err != nil { if err := s.yggdrasilRepo.ResetPassword(ctx, userID, hashedPassword); err != nil {
return "", fmt.Errorf("重置Yggdrasil密码失败: %w", err) return "", fmt.Errorf("重置Yggdrasil密码失败: %w", err)
} }

View File

@@ -21,14 +21,14 @@ type CertificateService interface {
// yggdrasilCertificateService 证书服务实现 // yggdrasilCertificateService 证书服务实现
type yggdrasilCertificateService struct { type yggdrasilCertificateService struct {
profileRepo repository.ProfileRepository profileRepo repository.ProfileRepository
signatureService *signatureService signatureService *SignatureService
logger *zap.Logger logger *zap.Logger
} }
// NewCertificateService 创建证书服务实例 // NewCertificateService 创建证书服务实例
func NewCertificateService( func NewCertificateService(
profileRepo repository.ProfileRepository, profileRepo repository.ProfileRepository,
signatureService *signatureService, signatureService *SignatureService,
logger *zap.Logger, logger *zap.Logger,
) CertificateService { ) CertificateService {
return &yggdrasilCertificateService{ return &yggdrasilCertificateService{
@@ -49,7 +49,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont
) )
// 获取密钥对 // 获取密钥对
keyPair, err := s.profileRepo.GetKeyPair(uuid) keyPair, err := s.profileRepo.GetKeyPair(ctx, uuid)
if err != nil { if err != nil {
s.logger.Info("获取用户密钥对失败,将创建新密钥对", s.logger.Info("获取用户密钥对失败,将创建新密钥对",
zap.Error(err), zap.Error(err),
@@ -74,7 +74,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont
} }
// 保存密钥对到数据库 // 保存密钥对到数据库
err = s.profileRepo.UpdateKeyPair(uuid, keyPair) err = s.profileRepo.UpdateKeyPair(ctx, uuid, keyPair)
if err != nil { if err != nil {
s.logger.Warn("更新用户密钥对失败", s.logger.Warn("更新用户密钥对失败",
zap.Error(err), zap.Error(err),

View File

@@ -28,14 +28,14 @@ type Property struct {
// yggdrasilSerializationService 序列化服务实现 // yggdrasilSerializationService 序列化服务实现
type yggdrasilSerializationService struct { type yggdrasilSerializationService struct {
textureRepo repository.TextureRepository textureRepo repository.TextureRepository
signatureService *signatureService signatureService *SignatureService
logger *zap.Logger logger *zap.Logger
} }
// NewSerializationService 创建序列化服务实例 // NewSerializationService 创建序列化服务实例
func NewSerializationService( func NewSerializationService(
textureRepo repository.TextureRepository, textureRepo repository.TextureRepository,
signatureService *signatureService, signatureService *SignatureService,
logger *zap.Logger, logger *zap.Logger,
) SerializationService { ) SerializationService {
return &yggdrasilSerializationService{ return &yggdrasilSerializationService{
@@ -58,7 +58,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr
// 处理皮肤 // 处理皮肤
if profile.SkinID != nil { if profile.SkinID != nil {
skin, err := s.textureRepo.FindByID(*profile.SkinID) skin, err := s.textureRepo.FindByID(ctx, *profile.SkinID)
if err != nil { if err != nil {
s.logger.Error("获取皮肤失败", s.logger.Error("获取皮肤失败",
zap.Error(err), zap.Error(err),
@@ -74,7 +74,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr
// 处理披风 // 处理披风
if profile.CapeID != nil { if profile.CapeID != nil {
cape, err := s.textureRepo.FindByID(*profile.CapeID) cape, err := s.textureRepo.FindByID(ctx, *profile.CapeID)
if err != nil { if err != nil {
s.logger.Error("获取披风失败", s.logger.Error("获取披风失败",
zap.Error(err), zap.Error(err),

View File

@@ -33,7 +33,7 @@ func NewYggdrasilServiceComposite(
profileRepo repository.ProfileRepository, profileRepo repository.ProfileRepository,
tokenRepo repository.TokenRepository, tokenRepo repository.TokenRepository,
yggdrasilRepo repository.YggdrasilRepository, yggdrasilRepo repository.YggdrasilRepository,
signatureService *signatureService, signatureService *SignatureService,
redisClient *redis.Client, redisClient *redis.Client,
logger *zap.Logger, logger *zap.Logger,
) YggdrasilService { ) YggdrasilService {
@@ -76,7 +76,7 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context,
// JoinServer 加入服务器 // JoinServer 加入服务器
func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error { func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error {
// 验证Token // 验证Token
token, err := s.tokenRepo.FindByAccessToken(accessToken) token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
if err != nil { if err != nil {
s.logger.Error("验证Token失败", s.logger.Error("验证Token失败",
zap.Error(err), zap.Error(err),
@@ -92,7 +92,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac
} }
// 获取Profile以获取用户名 // 获取Profile以获取用户名
profile, err := s.profileRepo.FindByUUID(formattedProfile) profile, err := s.profileRepo.FindByUUID(ctx, formattedProfile)
if err != nil { if err != nil {
s.logger.Error("获取Profile失败", s.logger.Error("获取Profile失败",
zap.Error(err), zap.Error(err),