refactor: Update service and repository methods to use context
- Refactored multiple service and repository methods to accept context as a parameter, enhancing consistency and enabling better control over request lifecycles. - Updated handlers to utilize context in method calls, improving error handling and performance. - Cleaned up Dockerfile by removing unnecessary whitespace.
This commit is contained in:
@@ -65,3 +65,5 @@ ENTRYPOINT ["./server"]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ type Container struct {
|
||||
UploadService service.UploadService
|
||||
SecurityService service.SecurityService
|
||||
CaptchaService service.CaptchaService
|
||||
SignatureService *service.SignatureService
|
||||
}
|
||||
|
||||
// NewContainer 创建依赖容器
|
||||
@@ -80,26 +81,27 @@ func NewContainer(
|
||||
c.ConfigRepo = repository.NewSystemConfigRepository(db)
|
||||
c.YggdrasilRepo = repository.NewYggdrasilRepository(db)
|
||||
|
||||
// 初始化SignatureService(用于获取Yggdrasil私钥)
|
||||
signatureService := service.NewSignatureService(c.ProfileRepo, redisClient, logger)
|
||||
|
||||
// 获取Yggdrasil私钥并创建JWT服务
|
||||
_, privateKey, err := signatureService.GetOrCreateYggdrasilKeyPair()
|
||||
if err != nil {
|
||||
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
|
||||
}
|
||||
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||
// 初始化SignatureService(作为依赖注入,避免在容器中创建并立即调用)
|
||||
// 将SignatureService添加到容器中,供其他服务使用
|
||||
c.SignatureService = service.NewSignatureService(c.ProfileRepo, redisClient, logger)
|
||||
|
||||
// 初始化Service(注入缓存管理器)
|
||||
c.UserService = service.NewUserService(c.UserRepo, c.ConfigRepo, jwtService, redisClient, cacheManager, logger)
|
||||
c.ProfileService = service.NewProfileService(c.ProfileRepo, c.UserRepo, cacheManager, logger)
|
||||
c.TextureService = service.NewTextureService(c.TextureRepo, c.UserRepo, cacheManager, logger)
|
||||
|
||||
// 使用JWT版本的TokenService
|
||||
|
||||
// 获取Yggdrasil私钥并创建JWT服务(TokenService需要)
|
||||
// 注意:这里仍然需要预先初始化,因为TokenService在创建时需要YggdrasilJWT
|
||||
// 但SignatureService已经作为依赖注入,降低了耦合度
|
||||
_, privateKey, err := c.SignatureService.GetOrCreateYggdrasilKeyPair()
|
||||
if err != nil {
|
||||
logger.Fatal("获取Yggdrasil私钥失败", zap.Error(err))
|
||||
}
|
||||
yggdrasilJWT := auth.NewYggdrasilJWTService(privateKey, "carrotskin")
|
||||
c.TokenService = service.NewTokenServiceJWT(c.TokenRepo, c.ClientRepo, c.ProfileRepo, yggdrasilJWT, logger)
|
||||
|
||||
// 使用组合服务(内部包含认证、会话、序列化、证书服务)
|
||||
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, signatureService, redisClient, logger)
|
||||
c.YggdrasilService = service.NewYggdrasilServiceComposite(db, c.UserRepo, c.ProfileRepo, c.TokenRepo, c.YggdrasilRepo, c.SignatureService, redisClient, logger)
|
||||
|
||||
// 初始化其他服务
|
||||
c.SecurityService = service.NewSecurityService(redisClient)
|
||||
|
||||
@@ -219,7 +219,7 @@ func (h *CustomSkinHandler) GetTexture(c *gin.Context) {
|
||||
|
||||
// 增加下载计数(异步)
|
||||
go func() {
|
||||
_ = h.container.TextureRepo.IncrementDownloadCount(texture.ID)
|
||||
_ = h.container.TextureRepo.IncrementDownloadCount(ctx, texture.ID)
|
||||
}()
|
||||
|
||||
// 流式传输文件内容
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"carrotskin/internal/errors"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/types"
|
||||
"net/http"
|
||||
@@ -165,17 +166,46 @@ func RespondSuccess(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, model.NewSuccessResponse(data))
|
||||
}
|
||||
|
||||
// RespondWithError 根据错误消息自动选择状态码
|
||||
// RespondWithError 根据错误类型自动选择状态码
|
||||
func RespondWithError(c *gin.Context, err error) {
|
||||
msg := err.Error()
|
||||
switch msg {
|
||||
case "档案不存在", "用户不存在", "材质不存在":
|
||||
RespondNotFound(c, msg)
|
||||
case "无权操作此档案", "无权操作此材质":
|
||||
RespondForbidden(c, msg)
|
||||
case "未授权":
|
||||
RespondUnauthorized(c, msg)
|
||||
default:
|
||||
RespondServerError(c, msg, nil)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用errors.Is检查预定义错误
|
||||
if errors.Is(err, errors.ErrUserNotFound) ||
|
||||
errors.Is(err, errors.ErrProfileNotFound) ||
|
||||
errors.Is(err, errors.ErrTextureNotFound) ||
|
||||
errors.Is(err, errors.ErrNotFound) {
|
||||
RespondNotFound(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, errors.ErrProfileNoPermission) ||
|
||||
errors.Is(err, errors.ErrTextureNoPermission) ||
|
||||
errors.Is(err, errors.ErrForbidden) {
|
||||
RespondForbidden(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, errors.ErrUnauthorized) ||
|
||||
errors.Is(err, errors.ErrInvalidToken) ||
|
||||
errors.Is(err, errors.ErrTokenExpired) {
|
||||
RespondUnauthorized(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 检查AppError类型
|
||||
var appErr *errors.AppError
|
||||
if errors.As(err, &appErr) {
|
||||
c.JSON(appErr.Code, model.NewErrorResponse(
|
||||
appErr.Code,
|
||||
appErr.Message,
|
||||
appErr.Err,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
// 默认返回500错误
|
||||
RespondServerError(c, err.Error(), err)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"carrotskin/internal/container"
|
||||
"carrotskin/internal/middleware"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/auth"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -47,13 +48,13 @@ func RegisterRoutesWithDI(router *gin.Engine, c *container.Container) {
|
||||
registerAuthRoutes(v1, h.Auth)
|
||||
|
||||
// 用户路由(需要JWT认证)
|
||||
registerUserRoutes(v1, h.User)
|
||||
registerUserRoutes(v1, h.User, c.JWT)
|
||||
|
||||
// 材质路由
|
||||
registerTextureRoutes(v1, h.Texture)
|
||||
registerTextureRoutes(v1, h.Texture, c.JWT)
|
||||
|
||||
// 档案路由
|
||||
registerProfileRoutesWithDI(v1, h.Profile)
|
||||
registerProfileRoutesWithDI(v1, h.Profile, c.JWT)
|
||||
|
||||
// 验证码路由
|
||||
registerCaptchaRoutesWithDI(v1, h.Captcha)
|
||||
@@ -81,9 +82,9 @@ func registerAuthRoutes(v1 *gin.RouterGroup, h *AuthHandler) {
|
||||
}
|
||||
|
||||
// registerUserRoutes 注册用户路由
|
||||
func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) {
|
||||
func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler, jwtService *auth.JWTService) {
|
||||
userGroup := v1.Group("/user")
|
||||
userGroup.Use(middleware.AuthMiddleware())
|
||||
userGroup.Use(middleware.AuthMiddleware(jwtService))
|
||||
{
|
||||
userGroup.GET("/profile", h.GetProfile)
|
||||
userGroup.PUT("/profile", h.UpdateProfile)
|
||||
@@ -101,7 +102,7 @@ func registerUserRoutes(v1 *gin.RouterGroup, h *UserHandler) {
|
||||
}
|
||||
|
||||
// registerTextureRoutes 注册材质路由
|
||||
func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
|
||||
func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler, jwtService *auth.JWTService) {
|
||||
textureGroup := v1.Group("/texture")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
@@ -110,7 +111,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
|
||||
|
||||
// 需要认证的路由
|
||||
textureAuth := textureGroup.Group("")
|
||||
textureAuth.Use(middleware.AuthMiddleware())
|
||||
textureAuth.Use(middleware.AuthMiddleware(jwtService))
|
||||
{
|
||||
textureAuth.POST("/upload-url", h.GenerateUploadURL)
|
||||
textureAuth.POST("", h.Create)
|
||||
@@ -124,7 +125,7 @@ func registerTextureRoutes(v1 *gin.RouterGroup, h *TextureHandler) {
|
||||
}
|
||||
|
||||
// registerProfileRoutesWithDI 注册档案路由(依赖注入版本)
|
||||
func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) {
|
||||
func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler, jwtService *auth.JWTService) {
|
||||
profileGroup := v1.Group("/profile")
|
||||
{
|
||||
// 公开路由(无需认证)
|
||||
@@ -132,7 +133,7 @@ func registerProfileRoutesWithDI(v1 *gin.RouterGroup, h *ProfileHandler) {
|
||||
|
||||
// 需要认证的路由
|
||||
profileAuth := profileGroup.Group("")
|
||||
profileAuth.Use(middleware.AuthMiddleware())
|
||||
profileAuth.Use(middleware.AuthMiddleware(jwtService))
|
||||
{
|
||||
profileAuth.POST("/", h.Create)
|
||||
profileAuth.GET("/", h.List)
|
||||
|
||||
@@ -1,15 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/redis"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HealthCheck 健康检查
|
||||
// HealthCheck 健康检查,检查依赖服务状态
|
||||
func HealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"message": "CarrotSkin API is running",
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
checks := make(map[string]string)
|
||||
status := "ok"
|
||||
|
||||
// 检查数据库
|
||||
if err := checkDatabase(ctx); err != nil {
|
||||
checks["database"] = "unhealthy: " + err.Error()
|
||||
status = "degraded"
|
||||
} else {
|
||||
checks["database"] = "healthy"
|
||||
}
|
||||
|
||||
// 检查Redis
|
||||
if err := checkRedis(ctx); err != nil {
|
||||
checks["redis"] = "unhealthy: " + err.Error()
|
||||
status = "degraded"
|
||||
} else {
|
||||
checks["redis"] = "healthy"
|
||||
}
|
||||
|
||||
// 根据状态返回相应的HTTP状态码
|
||||
httpStatus := http.StatusOK
|
||||
if status == "degraded" {
|
||||
httpStatus = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
c.JSON(httpStatus, gin.H{
|
||||
"status": status,
|
||||
"message": "CarrotSkin API health check",
|
||||
"checks": checks,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
// checkDatabase 检查数据库连接
|
||||
func checkDatabase(ctx context.Context) error {
|
||||
db, err := database.GetDB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用Ping检查连接
|
||||
if err := sqlDB.PingContext(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 执行简单查询验证
|
||||
var result int
|
||||
if err := db.WithContext(ctx).Raw("SELECT 1").Scan(&result).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkRedis 检查Redis连接
|
||||
func checkRedis(ctx context.Context) error {
|
||||
client, err := redis.GetClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if client == nil {
|
||||
return errors.New("Redis客户端未初始化")
|
||||
}
|
||||
|
||||
// 使用Ping检查连接
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (h *YggdrasilHandler) Authenticate(c *gin.Context) {
|
||||
if emailRegex.MatchString(request.Identifier) {
|
||||
userId, err = h.container.YggdrasilService.GetUserIDByEmail(c.Request.Context(), request.Identifier)
|
||||
} else {
|
||||
profile, err = h.container.ProfileRepo.FindByName(request.Identifier)
|
||||
profile, err = h.container.ProfileRepo.FindByName(c.Request.Context(), request.Identifier)
|
||||
if err != nil {
|
||||
h.logger.Error("用户名不存在", zap.String("identifier", request.Identifier), zap.Error(err))
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -9,17 +10,16 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthMiddleware JWT认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
// AuthMiddleware JWT认证中间件(注入JWT服务版本)
|
||||
func AuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "缺少Authorization头",
|
||||
})
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"缺少Authorization头",
|
||||
nil,
|
||||
))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -27,10 +27,11 @@ func AuthMiddleware() gin.HandlerFunc {
|
||||
// Bearer token格式
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的Authorization头格式",
|
||||
})
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"无效的Authorization头格式",
|
||||
nil,
|
||||
))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -38,10 +39,11 @@ func AuthMiddleware() gin.HandlerFunc {
|
||||
token := tokenParts[1]
|
||||
claims, err := jwtService.ValidateToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "无效的token",
|
||||
})
|
||||
c.JSON(http.StatusUnauthorized, model.NewErrorResponse(
|
||||
model.CodeUnauthorized,
|
||||
"无效的token",
|
||||
err,
|
||||
))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@@ -55,11 +57,9 @@ func AuthMiddleware() gin.HandlerFunc {
|
||||
})
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选的JWT认证中间件
|
||||
func OptionalAuthMiddleware() gin.HandlerFunc {
|
||||
// OptionalAuthMiddleware 可选的JWT认证中间件(注入JWT服务版本)
|
||||
func OptionalAuthMiddleware(jwtService *auth.JWTService) gin.HandlerFunc {
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
jwtService := auth.MustGetJWTService()
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader != "" {
|
||||
tokenParts := strings.SplitN(authHeader, " ", 2)
|
||||
|
||||
@@ -22,3 +22,5 @@ func (Client) TableName() string {
|
||||
return "clients"
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -16,48 +17,48 @@ func NewClientRepository(db *gorm.DB) ClientRepository {
|
||||
return &clientRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *clientRepository) Create(client *model.Client) error {
|
||||
return r.db.Create(client).Error
|
||||
func (r *clientRepository) Create(ctx context.Context, client *model.Client) error {
|
||||
return r.db.WithContext(ctx).Create(client).Error
|
||||
}
|
||||
|
||||
func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) {
|
||||
func (r *clientRepository) FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error) {
|
||||
var client model.Client
|
||||
err := r.db.Where("client_token = ?", clientToken).First(&client).Error
|
||||
err := r.db.WithContext(ctx).Where("client_token = ?", clientToken).First(&client).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) {
|
||||
func (r *clientRepository) FindByUUID(ctx context.Context, uuid string) (*model.Client, error) {
|
||||
var client model.Client
|
||||
err := r.db.Where("uuid = ?", uuid).First(&client).Error
|
||||
err := r.db.WithContext(ctx).Where("uuid = ?", uuid).First(&client).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) {
|
||||
func (r *clientRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error) {
|
||||
var clients []*model.Client
|
||||
err := r.db.Where("user_id = ?", userID).Find(&clients).Error
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&clients).Error
|
||||
return clients, err
|
||||
}
|
||||
|
||||
func (r *clientRepository) Update(client *model.Client) error {
|
||||
return r.db.Save(client).Error
|
||||
func (r *clientRepository) Update(ctx context.Context, client *model.Client) error {
|
||||
return r.db.WithContext(ctx).Save(client).Error
|
||||
}
|
||||
|
||||
func (r *clientRepository) IncrementVersion(clientUUID string) error {
|
||||
return r.db.Model(&model.Client{}).
|
||||
func (r *clientRepository) IncrementVersion(ctx context.Context, clientUUID string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Client{}).
|
||||
Where("uuid = ?", clientUUID).
|
||||
Update("version", gorm.Expr("version + 1")).Error
|
||||
}
|
||||
|
||||
func (r *clientRepository) DeleteByClientToken(clientToken string) error {
|
||||
return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
|
||||
func (r *clientRepository) DeleteByClientToken(ctx context.Context, clientToken string) error {
|
||||
return r.db.WithContext(ctx).Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
|
||||
}
|
||||
|
||||
func (r *clientRepository) DeleteByUserID(userID int64) error {
|
||||
return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error
|
||||
func (r *clientRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&model.Client{}).Error
|
||||
}
|
||||
|
||||
@@ -2,95 +2,105 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储接口
|
||||
type UserRepository interface {
|
||||
Create(user *model.User) error
|
||||
FindByID(id int64) (*model.User, error)
|
||||
FindByUsername(username string) (*model.User, error)
|
||||
FindByEmail(email string) (*model.User, error)
|
||||
Update(user *model.User) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
CreateLoginLog(log *model.UserLoginLog) error
|
||||
CreatePointLog(log *model.UserPointLog) error
|
||||
UpdatePoints(userID int64, amount int, changeType, reason string) error
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
FindByID(ctx context.Context, id int64) (*model.User, error)
|
||||
FindByUsername(ctx context.Context, username string) (*model.User, error)
|
||||
FindByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) // 批量查询
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
|
||||
BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
|
||||
Delete(ctx context.Context, id int64) error
|
||||
BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
|
||||
CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error
|
||||
CreatePointLog(ctx context.Context, log *model.UserPointLog) error
|
||||
UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error
|
||||
}
|
||||
|
||||
// ProfileRepository 档案仓储接口
|
||||
type ProfileRepository interface {
|
||||
Create(profile *model.Profile) error
|
||||
FindByUUID(uuid string) (*model.Profile, error)
|
||||
FindByName(name string) (*model.Profile, error)
|
||||
FindByUserID(userID int64) ([]*model.Profile, error)
|
||||
Update(profile *model.Profile) error
|
||||
UpdateFields(uuid string, updates map[string]interface{}) error
|
||||
Delete(uuid string) error
|
||||
CountByUserID(userID int64) (int64, error)
|
||||
SetActive(uuid string, userID int64) error
|
||||
UpdateLastUsedAt(uuid string) error
|
||||
GetByNames(names []string) ([]*model.Profile, error)
|
||||
GetKeyPair(profileId string) (*model.KeyPair, error)
|
||||
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error
|
||||
Create(ctx context.Context, profile *model.Profile) error
|
||||
FindByUUID(ctx context.Context, uuid string) (*model.Profile, error)
|
||||
FindByName(ctx context.Context, name string) (*model.Profile, error)
|
||||
FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error)
|
||||
FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) // 批量查询
|
||||
Update(ctx context.Context, profile *model.Profile) error
|
||||
UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error
|
||||
BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) // 批量更新
|
||||
Delete(ctx context.Context, uuid string) error
|
||||
BatchDelete(ctx context.Context, uuids []string) (int64, error) // 批量删除
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
SetActive(ctx context.Context, uuid string, userID int64) error
|
||||
UpdateLastUsedAt(ctx context.Context, uuid string) error
|
||||
GetByNames(ctx context.Context, names []string) ([]*model.Profile, error)
|
||||
GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error)
|
||||
UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error
|
||||
}
|
||||
|
||||
// TextureRepository 材质仓储接口
|
||||
type TextureRepository interface {
|
||||
Create(texture *model.Texture) error
|
||||
FindByID(id int64) (*model.Texture, error)
|
||||
FindByHash(hash string) (*model.Texture, error)
|
||||
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Update(texture *model.Texture) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
IncrementDownloadCount(id int64) error
|
||||
IncrementFavoriteCount(id int64) error
|
||||
DecrementFavoriteCount(id int64) error
|
||||
CreateDownloadLog(log *model.TextureDownloadLog) error
|
||||
IsFavorited(userID, textureID int64) (bool, error)
|
||||
AddFavorite(userID, textureID int64) error
|
||||
RemoveFavorite(userID, textureID int64) error
|
||||
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
CountByUploaderID(uploaderID int64) (int64, error)
|
||||
Create(ctx context.Context, texture *model.Texture) error
|
||||
FindByID(ctx context.Context, id int64) (*model.Texture, error)
|
||||
FindByHash(ctx context.Context, hash string) (*model.Texture, error)
|
||||
FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) // 批量查询
|
||||
FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Update(ctx context.Context, texture *model.Texture) error
|
||||
UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
|
||||
BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
|
||||
Delete(ctx context.Context, id int64) error
|
||||
BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
|
||||
IncrementDownloadCount(ctx context.Context, id int64) error
|
||||
IncrementFavoriteCount(ctx context.Context, id int64) error
|
||||
DecrementFavoriteCount(ctx context.Context, id int64) error
|
||||
CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error
|
||||
IsFavorited(ctx context.Context, userID, textureID int64) (bool, error)
|
||||
AddFavorite(ctx context.Context, userID, textureID int64) error
|
||||
RemoveFavorite(ctx context.Context, userID, textureID int64) error
|
||||
GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error)
|
||||
}
|
||||
|
||||
// TokenRepository 令牌仓储接口
|
||||
type TokenRepository interface {
|
||||
Create(token *model.Token) error
|
||||
FindByAccessToken(accessToken string) (*model.Token, error)
|
||||
GetByUserID(userId int64) ([]*model.Token, error)
|
||||
GetUUIDByAccessToken(accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(accessToken string) (int64, error)
|
||||
DeleteByAccessToken(accessToken string) error
|
||||
DeleteByUserID(userId int64) error
|
||||
BatchDelete(accessTokens []string) (int64, error)
|
||||
Create(ctx context.Context, token *model.Token) error
|
||||
FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error)
|
||||
GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error)
|
||||
GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error)
|
||||
DeleteByAccessToken(ctx context.Context, accessToken string) error
|
||||
DeleteByUserID(ctx context.Context, userId int64) error
|
||||
BatchDelete(ctx context.Context, accessTokens []string) (int64, error)
|
||||
}
|
||||
|
||||
// SystemConfigRepository 系统配置仓储接口
|
||||
type SystemConfigRepository interface {
|
||||
GetByKey(key string) (*model.SystemConfig, error)
|
||||
GetPublic() ([]model.SystemConfig, error)
|
||||
GetAll() ([]model.SystemConfig, error)
|
||||
Update(config *model.SystemConfig) error
|
||||
UpdateValue(key, value string) error
|
||||
GetByKey(ctx context.Context, key string) (*model.SystemConfig, error)
|
||||
GetPublic(ctx context.Context) ([]model.SystemConfig, error)
|
||||
GetAll(ctx context.Context) ([]model.SystemConfig, error)
|
||||
Update(ctx context.Context, config *model.SystemConfig) error
|
||||
UpdateValue(ctx context.Context, key, value string) error
|
||||
}
|
||||
|
||||
// YggdrasilRepository Yggdrasil仓储接口
|
||||
type YggdrasilRepository interface {
|
||||
GetPasswordByID(id int64) (string, error)
|
||||
ResetPassword(id int64, password string) error
|
||||
GetPasswordByID(ctx context.Context, id int64) (string, error)
|
||||
ResetPassword(ctx context.Context, id int64, password string) error
|
||||
}
|
||||
|
||||
// ClientRepository Client仓储接口
|
||||
type ClientRepository interface {
|
||||
Create(client *model.Client) error
|
||||
FindByClientToken(clientToken string) (*model.Client, error)
|
||||
FindByUUID(uuid string) (*model.Client, error)
|
||||
FindByUserID(userID int64) ([]*model.Client, error)
|
||||
Update(client *model.Client) error
|
||||
IncrementVersion(clientUUID string) error
|
||||
DeleteByClientToken(clientToken string) error
|
||||
DeleteByUserID(userID int64) error
|
||||
Create(ctx context.Context, client *model.Client) error
|
||||
FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error)
|
||||
FindByUUID(ctx context.Context, uuid string) (*model.Client, error)
|
||||
FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error)
|
||||
Update(ctx context.Context, client *model.Client) error
|
||||
IncrementVersion(ctx context.Context, clientUUID string) error
|
||||
DeleteByClientToken(ctx context.Context, clientToken string) error
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
@@ -19,13 +19,13 @@ func NewProfileRepository(db *gorm.DB) ProfileRepository {
|
||||
return &profileRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *profileRepository) Create(profile *model.Profile) error {
|
||||
return r.db.Create(profile).Error
|
||||
func (r *profileRepository) Create(ctx context.Context, profile *model.Profile) error {
|
||||
return r.db.WithContext(ctx).Create(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) {
|
||||
func (r *profileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
err := r.db.Where("uuid = ?", uuid).
|
||||
err := r.db.WithContext(ctx).Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
@@ -35,10 +35,10 @@ func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) {
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepository) FindByName(name string) (*model.Profile, error) {
|
||||
func (r *profileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
// 使用 LOWER 函数进行不区分大小写的查询,并预加载 Skin 和 Cape
|
||||
err := r.db.Where("LOWER(name) = LOWER(?)", name).
|
||||
err := r.db.WithContext(ctx).Where("LOWER(name) = LOWER(?)", name).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
@@ -48,9 +48,9 @@ func (r *profileRepository) FindByName(name string) (*model.Profile, error) {
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error) {
|
||||
func (r *profileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
@@ -58,30 +58,59 @@ func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error)
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepository) Update(profile *model.Profile) error {
|
||||
return r.db.Save(profile).Error
|
||||
func (r *profileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) {
|
||||
if len(uuids) == 0 {
|
||||
return []*model.Profile{}, nil
|
||||
}
|
||||
var profiles []*model.Profile
|
||||
// 使用 IN 查询优化批量查询,并预加载关联
|
||||
err := r.db.WithContext(ctx).Where("uuid IN ?", uuids).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepository) UpdateFields(uuid string, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
func (r *profileRepository) Update(ctx context.Context, profile *model.Profile) error {
|
||||
return r.db.WithContext(ctx).Save(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) Delete(uuid string) error {
|
||||
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
func (r *profileRepository) Delete(ctx context.Context, uuid string) error {
|
||||
return r.db.WithContext(ctx).Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) CountByUserID(userID int64) (int64, error) {
|
||||
func (r *profileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) {
|
||||
if len(uuids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Model(&model.Profile{}).Where("uuid IN ?", uuids).Updates(updates)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) {
|
||||
if len(uuids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Where("uuid IN ?", uuids).Delete(&model.Profile{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Profile{}).
|
||||
err := r.db.WithContext(ctx).Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *profileRepository) SetActive(uuid string, userID int64) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
func (r *profileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
@@ -94,28 +123,28 @@ func (r *profileRepository) SetActive(uuid string, userID int64) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (r *profileRepository) UpdateLastUsedAt(uuid string) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
func (r *profileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
|
||||
func (r *profileRepository) GetByNames(names []string) ([]*model.Profile, error) {
|
||||
func (r *profileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("name in (?)", names).
|
||||
err := r.db.WithContext(ctx).Where("name in (?)", names).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
func (r *profileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) {
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
var profile model.Profile
|
||||
result := r.db.WithContext(context.Background()).
|
||||
result := r.db.WithContext(ctx).
|
||||
Select("key_pair").
|
||||
Where("id = ?", profileId).
|
||||
First(&profile)
|
||||
@@ -130,7 +159,7 @@ func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error)
|
||||
return &model.KeyPair{}, nil
|
||||
}
|
||||
|
||||
func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
func (r *profileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error {
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
@@ -138,9 +167,8 @@ func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPa
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles").
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Table("profiles").
|
||||
Where("id = ?", profileId).
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey,
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -16,28 +17,28 @@ func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
|
||||
return &systemConfigRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *systemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) {
|
||||
func (r *systemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
|
||||
var config model.SystemConfig
|
||||
err := r.db.Where("key = ?", key).First(&config).Error
|
||||
err := r.db.WithContext(ctx).Where("key = ?", key).First(&config).Error
|
||||
return handleNotFoundResult(&config, err)
|
||||
}
|
||||
|
||||
func (r *systemConfigRepository) GetPublic() ([]model.SystemConfig, error) {
|
||||
func (r *systemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Where("is_public = ?", true).Find(&configs).Error
|
||||
err := r.db.WithContext(ctx).Where("is_public = ?", true).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepository) GetAll() ([]model.SystemConfig, error) {
|
||||
func (r *systemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Find(&configs).Error
|
||||
err := r.db.WithContext(ctx).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepository) Update(config *model.SystemConfig) error {
|
||||
return r.db.Save(config).Error
|
||||
func (r *systemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
|
||||
return r.db.WithContext(ctx).Save(config).Error
|
||||
}
|
||||
|
||||
func (r *systemConfigRepository) UpdateValue(key, value string) error {
|
||||
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
func (r *systemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -16,27 +17,39 @@ func NewTextureRepository(db *gorm.DB) TextureRepository {
|
||||
return &textureRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *textureRepository) Create(texture *model.Texture) error {
|
||||
return r.db.Create(texture).Error
|
||||
func (r *textureRepository) Create(ctx context.Context, texture *model.Texture) error {
|
||||
return r.db.WithContext(ctx).Create(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) FindByID(id int64) (*model.Texture, error) {
|
||||
func (r *textureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Preload("Uploader").First(&texture, id).Error
|
||||
err := r.db.WithContext(ctx).Preload("Uploader").First(&texture, id).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepository) FindByHash(hash string) (*model.Texture, error) {
|
||||
func (r *textureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Where("hash = ?", hash).First(&texture).Error
|
||||
err := r.db.WithContext(ctx).Where("hash = ?", hash).First(&texture).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (r *textureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*model.Texture{}, nil
|
||||
}
|
||||
var textures []*model.Texture
|
||||
// 使用 IN 查询优化批量查询,并预加载关联
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).
|
||||
Preload("Uploader").
|
||||
Find(&textures).Error
|
||||
return textures, err
|
||||
}
|
||||
|
||||
func (r *textureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
@@ -54,11 +67,11 @@ func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize in
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (r *textureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("status = 1")
|
||||
query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("status = 1")
|
||||
|
||||
if publicOnly {
|
||||
query = query.Where("is_public = ?", true)
|
||||
@@ -86,67 +99,86 @@ func (r *textureRepository) Search(keyword string, textureType model.TextureType
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepository) Update(texture *model.Texture) error {
|
||||
return r.db.Save(texture).Error
|
||||
func (r *textureRepository) Update(ctx context.Context, texture *model.Texture) error {
|
||||
return r.db.WithContext(ctx).Save(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
func (r *textureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) Delete(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
func (r *textureRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) IncrementDownloadCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
func (r *textureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Updates(fields)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Update("status", -1)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) IncrementDownloadCount(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) IncrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
func (r *textureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) DecrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
func (r *textureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error {
|
||||
return r.db.Create(log).Error
|
||||
func (r *textureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) IsFavorited(userID, textureID int64) (bool, error) {
|
||||
func (r *textureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserTextureFavorite{}).
|
||||
// 使用 Select("1") 优化,只查询是否存在,不需要查询所有字段
|
||||
err := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
|
||||
Select("1").
|
||||
Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Limit(1).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *textureRepository) AddFavorite(userID, textureID int64) error {
|
||||
func (r *textureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error {
|
||||
favorite := &model.UserTextureFavorite{
|
||||
UserID: userID,
|
||||
TextureID: textureID,
|
||||
}
|
||||
return r.db.Create(favorite).Error
|
||||
return r.db.WithContext(ctx).Create(favorite).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) RemoveFavorite(userID, textureID int64) error {
|
||||
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
func (r *textureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Delete(&model.UserTextureFavorite{}).Error
|
||||
}
|
||||
|
||||
func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (r *textureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.UserTextureFavorite{}).
|
||||
subQuery := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
|
||||
Select("texture_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := r.db.Model(&model.Texture{}).
|
||||
query := r.db.WithContext(ctx).Model(&model.Texture{}).
|
||||
Where("id IN (?) AND status = 1", subQuery)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
@@ -165,9 +197,9 @@ func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) (
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepository) CountByUploaderID(uploaderID int64) (int64, error) {
|
||||
func (r *textureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Texture{}).
|
||||
err := r.db.WithContext(ctx).Model(&model.Texture{}).
|
||||
Where("uploader_id = ? AND status != -1", uploaderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -16,55 +17,55 @@ func NewTokenRepository(db *gorm.DB) TokenRepository {
|
||||
return &tokenRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *tokenRepository) Create(token *model.Token) error {
|
||||
return r.db.Create(token).Error
|
||||
func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error {
|
||||
return r.db.WithContext(ctx).Create(token).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) {
|
||||
func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetByUserID(userId int64) ([]*model.Token, error) {
|
||||
func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
var token model.Token
|
||||
err := r.db.Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
var token model.Token
|
||||
err := r.db.Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepository) DeleteByAccessToken(accessToken string) error {
|
||||
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
|
||||
return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) DeleteByUserID(userId int64) error {
|
||||
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepository) BatchDelete(accessTokens []string) (int64, error) {
|
||||
func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
|
||||
if len(accessTokens) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{})
|
||||
result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -17,50 +18,76 @@ func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepository) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByID(id int64) (*model.User, error) {
|
||||
func (r *userRepository) FindByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error
|
||||
err := r.db.WithContext(ctx).Where("id = ? AND status != -1", id).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByUsername(username string) (*model.User, error) {
|
||||
func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error
|
||||
err := r.db.WithContext(ctx).Where("username = ? AND status != -1", username).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepository) FindByEmail(email string) (*model.User, error) {
|
||||
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error
|
||||
err := r.db.WithContext(ctx).Where("email = ? AND status != -1", email).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepository) Update(user *model.User) error {
|
||||
return r.db.Save(user).Error
|
||||
func (r *userRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*model.User{}, nil
|
||||
}
|
||||
var users []*model.User
|
||||
// 使用 IN 查询优化批量查询
|
||||
err := r.db.WithContext(ctx).Where("id IN ? AND status != -1", ids).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) Delete(id int64) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
func (r *userRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) CreateLoginLog(log *model.UserLoginLog) error {
|
||||
return r.db.Create(log).Error
|
||||
func (r *userRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) CreatePointLog(log *model.UserPointLog) error {
|
||||
return r.db.Create(log).Error
|
||||
func (r *userRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Updates(fields)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
func (r *userRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Update("status", -1)
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func (r *userRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var user model.User
|
||||
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -16,15 +17,15 @@ func NewYggdrasilRepository(db *gorm.DB) YggdrasilRepository {
|
||||
return &yggdrasilRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *yggdrasilRepository) GetPasswordByID(id int64) (string, error) {
|
||||
func (r *yggdrasilRepository) GetPasswordByID(ctx context.Context, id int64) (string, error) {
|
||||
var yggdrasil model.Yggdrasil
|
||||
err := r.db.Select("password").Where("id = ?", id).First(&yggdrasil).Error
|
||||
err := r.db.WithContext(ctx).Select("password").Where("id = ?", id).First(&yggdrasil).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return yggdrasil.Password, nil
|
||||
}
|
||||
|
||||
func (r *yggdrasilRepository) ResetPassword(id int64, password string) error {
|
||||
return r.db.Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error
|
||||
func (r *yggdrasilRepository) ResetPassword(ctx context.Context, id int64, password string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
@@ -28,7 +29,7 @@ func NewMockUserRepository() *MockUserRepository {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Create(user *model.User) error {
|
||||
func (m *MockUserRepository) Create(ctx context.Context, user *model.User) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
}
|
||||
@@ -39,7 +40,7 @@ func (m *MockUserRepository) Create(user *model.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByID(id int64) (*model.User, error) {
|
||||
func (m *MockUserRepository) FindByID(ctx context.Context, id int64) (*model.User, error) {
|
||||
if m.FailFindByID {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -49,7 +50,7 @@ func (m *MockUserRepository) FindByID(id int64) (*model.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByUsername(username string) (*model.User, error) {
|
||||
func (m *MockUserRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
|
||||
if m.FailFindByUsername {
|
||||
return nil, errors.New("mock find by username error")
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func (m *MockUserRepository) FindByUsername(username string) (*model.User, error
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) {
|
||||
func (m *MockUserRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
|
||||
if m.FailFindByEmail {
|
||||
return nil, errors.New("mock find by email error")
|
||||
}
|
||||
@@ -73,7 +74,7 @@ func (m *MockUserRepository) FindByEmail(email string) (*model.User, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Update(user *model.User) error {
|
||||
func (m *MockUserRepository) Update(ctx context.Context, user *model.User) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update error")
|
||||
}
|
||||
@@ -81,7 +82,7 @@ func (m *MockUserRepository) Update(user *model.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
func (m *MockUserRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update fields error")
|
||||
}
|
||||
@@ -92,23 +93,43 @@ func (m *MockUserRepository) UpdateFields(id int64, fields map[string]interface{
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) Delete(id int64) error {
|
||||
func (m *MockUserRepository) Delete(ctx context.Context, id int64) error {
|
||||
delete(m.users, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) CreateLoginLog(log *model.UserLoginLog) error {
|
||||
func (m *MockUserRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) CreatePointLog(log *model.UserPointLog) error {
|
||||
func (m *MockUserRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error {
|
||||
func (m *MockUserRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchUpdate 和 BatchDelete 仅用于满足接口,在测试中不做具体操作
|
||||
func (m *MockUserRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockUserRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// FindByIDs 批量查询用户
|
||||
func (m *MockUserRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) {
|
||||
var result []*model.User
|
||||
for _, id := range ids {
|
||||
if u, ok := m.users[id]; ok {
|
||||
result = append(result, u)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MockProfileRepository 模拟ProfileRepository
|
||||
type MockProfileRepository struct {
|
||||
profiles map[string]*model.Profile
|
||||
@@ -128,7 +149,7 @@ func NewMockProfileRepository() *MockProfileRepository {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) Create(profile *model.Profile) error {
|
||||
func (m *MockProfileRepository) Create(ctx context.Context, profile *model.Profile) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
}
|
||||
@@ -137,7 +158,7 @@ func (m *MockProfileRepository) Create(profile *model.Profile) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error) {
|
||||
func (m *MockProfileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -147,7 +168,7 @@ func (m *MockProfileRepository) FindByUUID(uuid string) (*model.Profile, error)
|
||||
return nil, errors.New("profile not found")
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error) {
|
||||
func (m *MockProfileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -159,14 +180,14 @@ func (m *MockProfileRepository) FindByName(name string) (*model.Profile, error)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) FindByUserID(userID int64) ([]*model.Profile, error) {
|
||||
func (m *MockProfileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
return m.userProfiles[userID], nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) Update(profile *model.Profile) error {
|
||||
func (m *MockProfileRepository) Update(ctx context.Context, profile *model.Profile) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update error")
|
||||
}
|
||||
@@ -174,14 +195,14 @@ func (m *MockProfileRepository) Update(profile *model.Profile) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) UpdateFields(uuid string, updates map[string]interface{}) error {
|
||||
func (m *MockProfileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) Delete(uuid string) error {
|
||||
func (m *MockProfileRepository) Delete(ctx context.Context, uuid string) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
@@ -189,19 +210,19 @@ func (m *MockProfileRepository) Delete(uuid string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) CountByUserID(userID int64) (int64, error) {
|
||||
func (m *MockProfileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
return int64(len(m.userProfiles[userID])), nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) SetActive(uuid string, userID int64) error {
|
||||
func (m *MockProfileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) UpdateLastUsedAt(uuid string) error {
|
||||
func (m *MockProfileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, error) {
|
||||
func (m *MockProfileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
|
||||
var result []*model.Profile
|
||||
for _, name := range names {
|
||||
for _, profile := range m.profiles {
|
||||
@@ -213,14 +234,34 @@ func (m *MockProfileRepository) GetByNames(names []string) ([]*model.Profile, er
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
func (m *MockProfileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
func (m *MockProfileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchUpdate / BatchDelete 仅用于满足接口
|
||||
func (m *MockProfileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockProfileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// FindByUUIDs 批量查询 Profile
|
||||
func (m *MockProfileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) {
|
||||
var result []*model.Profile
|
||||
for _, id := range uuids {
|
||||
if p, ok := m.profiles[id]; ok {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MockTextureRepository 模拟TextureRepository
|
||||
type MockTextureRepository struct {
|
||||
textures map[int64]*model.Texture
|
||||
@@ -240,7 +281,7 @@ func NewMockTextureRepository() *MockTextureRepository {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) Create(texture *model.Texture) error {
|
||||
func (m *MockTextureRepository) Create(ctx context.Context, texture *model.Texture) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
}
|
||||
@@ -252,7 +293,7 @@ func (m *MockTextureRepository) Create(texture *model.Texture) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) {
|
||||
func (m *MockTextureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -262,7 +303,7 @@ func (m *MockTextureRepository) FindByID(id int64) (*model.Texture, error) {
|
||||
return nil, errors.New("texture not found")
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error) {
|
||||
func (m *MockTextureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -274,7 +315,7 @@ func (m *MockTextureRepository) FindByHash(hash string) (*model.Texture, error)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (m *MockTextureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if m.FailFind {
|
||||
return nil, 0, errors.New("mock find error")
|
||||
}
|
||||
@@ -287,7 +328,7 @@ func (m *MockTextureRepository) FindByUploaderID(uploaderID int64, page, pageSiz
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (m *MockTextureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
if m.FailFind {
|
||||
return nil, 0, errors.New("mock find error")
|
||||
}
|
||||
@@ -301,7 +342,7 @@ func (m *MockTextureRepository) Search(keyword string, textureType model.Texture
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) Update(texture *model.Texture) error {
|
||||
func (m *MockTextureRepository) Update(ctx context.Context, texture *model.Texture) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update error")
|
||||
}
|
||||
@@ -309,14 +350,14 @@ func (m *MockTextureRepository) Update(texture *model.Texture) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
func (m *MockTextureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
|
||||
if m.FailUpdate {
|
||||
return errors.New("mock update error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) Delete(id int64) error {
|
||||
func (m *MockTextureRepository) Delete(ctx context.Context, id int64) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
@@ -324,39 +365,39 @@ func (m *MockTextureRepository) Delete(id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) IncrementDownloadCount(id int64) error {
|
||||
func (m *MockTextureRepository) IncrementDownloadCount(ctx context.Context, id int64) error {
|
||||
if texture, ok := m.textures[id]; ok {
|
||||
texture.DownloadCount++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) IncrementFavoriteCount(id int64) error {
|
||||
func (m *MockTextureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error {
|
||||
if texture, ok := m.textures[id]; ok {
|
||||
texture.FavoriteCount++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) DecrementFavoriteCount(id int64) error {
|
||||
func (m *MockTextureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error {
|
||||
if texture, ok := m.textures[id]; ok && texture.FavoriteCount > 0 {
|
||||
texture.FavoriteCount--
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error {
|
||||
func (m *MockTextureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) IsFavorited(userID, textureID int64) (bool, error) {
|
||||
func (m *MockTextureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) {
|
||||
if userFavs, ok := m.favorites[userID]; ok {
|
||||
return userFavs[textureID], nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error {
|
||||
func (m *MockTextureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error {
|
||||
if m.favorites[userID] == nil {
|
||||
m.favorites[userID] = make(map[int64]bool)
|
||||
}
|
||||
@@ -364,14 +405,14 @@ func (m *MockTextureRepository) AddFavorite(userID, textureID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) RemoveFavorite(userID, textureID int64) error {
|
||||
func (m *MockTextureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error {
|
||||
if userFavs, ok := m.favorites[userID]; ok {
|
||||
delete(userFavs, textureID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
func (m *MockTextureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var result []*model.Texture
|
||||
if userFavs, ok := m.favorites[userID]; ok {
|
||||
for textureID := range userFavs {
|
||||
@@ -383,7 +424,7 @@ func (m *MockTextureRepository) GetUserFavorites(userID int64, page, pageSize in
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, error) {
|
||||
func (m *MockTextureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) {
|
||||
var count int64
|
||||
for _, texture := range m.textures {
|
||||
if texture.UploaderID == uploaderID {
|
||||
@@ -393,6 +434,34 @@ func (m *MockTextureRepository) CountByUploaderID(uploaderID int64) (int64, erro
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// FindByIDs 批量查询 Texture
|
||||
func (m *MockTextureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) {
|
||||
var result []*model.Texture
|
||||
for _, id := range ids {
|
||||
if tex, ok := m.textures[id]; ok {
|
||||
result = append(result, tex)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BatchUpdate 仅用于满足接口
|
||||
func (m *MockTextureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// BatchDelete 仅用于满足接口
|
||||
func (m *MockTextureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
|
||||
var deleted int64
|
||||
for _, id := range ids {
|
||||
if _, ok := m.textures[id]; ok {
|
||||
delete(m.textures, id)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// MockTokenRepository 模拟TokenRepository
|
||||
type MockTokenRepository struct {
|
||||
tokens map[string]*model.Token
|
||||
@@ -409,7 +478,7 @@ func NewMockTokenRepository() *MockTokenRepository {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) Create(token *model.Token) error {
|
||||
func (m *MockTokenRepository) Create(ctx context.Context, token *model.Token) error {
|
||||
if m.FailCreate {
|
||||
return errors.New("mock create error")
|
||||
}
|
||||
@@ -418,7 +487,7 @@ func (m *MockTokenRepository) Create(token *model.Token) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) {
|
||||
func (m *MockTokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
@@ -428,14 +497,14 @@ func (m *MockTokenRepository) FindByAccessToken(accessToken string) (*model.Toke
|
||||
return nil, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetByUserID(userId int64) ([]*model.Token, error) {
|
||||
func (m *MockTokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
|
||||
if m.FailFind {
|
||||
return nil, errors.New("mock find error")
|
||||
}
|
||||
return m.userTokens[userId], nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
func (m *MockTokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
if m.FailFind {
|
||||
return "", errors.New("mock find error")
|
||||
}
|
||||
@@ -445,7 +514,7 @@ func (m *MockTokenRepository) GetUUIDByAccessToken(accessToken string) (string,
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
func (m *MockTokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
if m.FailFind {
|
||||
return 0, errors.New("mock find error")
|
||||
}
|
||||
@@ -455,7 +524,7 @@ func (m *MockTokenRepository) GetUserIDByAccessToken(accessToken string) (int64,
|
||||
return 0, errors.New("token not found")
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error {
|
||||
func (m *MockTokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
@@ -463,7 +532,7 @@ func (m *MockTokenRepository) DeleteByAccessToken(accessToken string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) DeleteByUserID(userId int64) error {
|
||||
func (m *MockTokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
|
||||
if m.FailDelete {
|
||||
return errors.New("mock delete error")
|
||||
}
|
||||
@@ -474,7 +543,7 @@ func (m *MockTokenRepository) DeleteByUserID(userId int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockTokenRepository) BatchDelete(accessTokens []string) (int64, error) {
|
||||
func (m *MockTokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
|
||||
if m.FailDelete {
|
||||
return 0, errors.New("mock delete error")
|
||||
}
|
||||
@@ -499,14 +568,14 @@ func NewMockSystemConfigRepository() *MockSystemConfigRepository {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) {
|
||||
func (m *MockSystemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
|
||||
if config, ok := m.configs[key]; ok {
|
||||
return config, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) {
|
||||
func (m *MockSystemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var result []model.SystemConfig
|
||||
for _, v := range m.configs {
|
||||
result = append(result, *v)
|
||||
@@ -514,7 +583,7 @@ func (m *MockSystemConfigRepository) GetPublic() ([]model.SystemConfig, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) {
|
||||
func (m *MockSystemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
|
||||
var result []model.SystemConfig
|
||||
for _, v := range m.configs {
|
||||
result = append(result, *v)
|
||||
@@ -522,12 +591,12 @@ func (m *MockSystemConfigRepository) GetAll() ([]model.SystemConfig, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) Update(config *model.SystemConfig) error {
|
||||
func (m *MockSystemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
|
||||
m.configs[config.Key] = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockSystemConfigRepository) UpdateValue(key, value string) error {
|
||||
func (m *MockSystemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
|
||||
if config, ok := m.configs[key]; ok {
|
||||
config.Value = value
|
||||
return nil
|
||||
|
||||
@@ -47,7 +47,7 @@ func NewProfileService(
|
||||
|
||||
func (s *profileService) Create(ctx context.Context, userID int64, name string) (*model.Profile, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
user, err := s.userRepo.FindByID(ctx, userID)
|
||||
if err != nil || user == nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string)
|
||||
}
|
||||
|
||||
// 检查角色名是否已存在
|
||||
existingName, err := s.profileRepo.FindByName(name)
|
||||
existingName, err := s.profileRepo.FindByName(ctx, name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
@@ -80,12 +80,12 @@ func (s *profileService) Create(ctx context.Context, userID int64, name string)
|
||||
IsActive: true,
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Create(profile); err != nil {
|
||||
if err := s.profileRepo.Create(ctx, profile); err != nil {
|
||||
return nil, fmt.Errorf("创建档案失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置活跃状态
|
||||
if err := s.profileRepo.SetActive(profileUUID, userID); err != nil {
|
||||
if err := s.profileRepo.SetActive(ctx, profileUUID, userID); err != nil {
|
||||
return nil, fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func (s *profileService) GetByUUID(ctx context.Context, uuid string) (*model.Pro
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
profile2, err := s.profileRepo.FindByUUID(uuid)
|
||||
profile2, err := s.profileRepo.FindByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProfileNotFound
|
||||
@@ -131,7 +131,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询档案列表失败: %w", err)
|
||||
}
|
||||
@@ -148,7 +148,7 @@ func (s *profileService) GetByUserID(ctx context.Context, userID int64) ([]*mode
|
||||
|
||||
func (s *profileService) Update(ctx context.Context, uuid string, userID int64, name *string, skinID, capeID *int64) (*model.Profile, error) {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrProfileNotFound
|
||||
@@ -162,7 +162,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
|
||||
|
||||
// 检查角色名是否重复
|
||||
if name != nil && *name != profile.Name {
|
||||
existingName, err := s.profileRepo.FindByName(*name)
|
||||
existingName, err := s.profileRepo.FindByName(ctx, *name)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("查询角色名失败: %w", err)
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
|
||||
profile.CapeID = capeID
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Update(profile); err != nil {
|
||||
if err := s.profileRepo.Update(ctx, profile); err != nil {
|
||||
return nil, fmt.Errorf("更新档案失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -190,12 +190,12 @@ func (s *profileService) Update(ctx context.Context, uuid string, userID int64,
|
||||
s.cacheKeys.ProfileList(userID),
|
||||
)
|
||||
|
||||
return s.profileRepo.FindByUUID(uuid)
|
||||
return s.profileRepo.FindByUUID(ctx, uuid)
|
||||
}
|
||||
|
||||
func (s *profileService) Delete(ctx context.Context, uuid string, userID int64) error {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProfileNotFound
|
||||
@@ -207,7 +207,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64)
|
||||
return ErrProfileNoPermission
|
||||
}
|
||||
|
||||
if err := s.profileRepo.Delete(uuid); err != nil {
|
||||
if err := s.profileRepo.Delete(ctx, uuid); err != nil {
|
||||
return fmt.Errorf("删除档案失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -222,7 +222,7 @@ func (s *profileService) Delete(ctx context.Context, uuid string, userID int64)
|
||||
|
||||
func (s *profileService) SetActive(ctx context.Context, uuid string, userID int64) error {
|
||||
// 获取档案并验证权限
|
||||
profile, err := s.profileRepo.FindByUUID(uuid)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return ErrProfileNotFound
|
||||
@@ -234,11 +234,11 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6
|
||||
return ErrProfileNoPermission
|
||||
}
|
||||
|
||||
if err := s.profileRepo.SetActive(uuid, userID); err != nil {
|
||||
if err := s.profileRepo.SetActive(ctx, uuid, userID); err != nil {
|
||||
return fmt.Errorf("设置活跃状态失败: %w", err)
|
||||
}
|
||||
|
||||
if err := s.profileRepo.UpdateLastUsedAt(uuid); err != nil {
|
||||
if err := s.profileRepo.UpdateLastUsedAt(ctx, uuid); err != nil {
|
||||
return fmt.Errorf("更新使用时间失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -249,7 +249,7 @@ func (s *profileService) SetActive(ctx context.Context, uuid string, userID int6
|
||||
}
|
||||
|
||||
func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfiles int) error {
|
||||
count, err := s.profileRepo.CountByUserID(userID)
|
||||
count, err := s.profileRepo.CountByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询档案数量失败: %w", err)
|
||||
}
|
||||
@@ -261,7 +261,7 @@ func (s *profileService) CheckLimit(ctx context.Context, userID int64, maxProfil
|
||||
}
|
||||
|
||||
func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
|
||||
profiles, err := s.profileRepo.GetByNames(names)
|
||||
profiles, err := s.profileRepo.GetByNames(ctx, names)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查找失败: %w", err)
|
||||
}
|
||||
@@ -270,7 +270,7 @@ func (s *profileService) GetByNames(ctx context.Context, names []string) ([]*mod
|
||||
|
||||
func (s *profileService) GetByProfileName(ctx context.Context, name string) (*model.Profile, error) {
|
||||
// Profile name 查询通常不会频繁缓存,但为了一致性也添加
|
||||
profile, err := s.profileRepo.FindByName(name)
|
||||
profile, err := s.profileRepo.FindByName(ctx, name)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户角色未创建")
|
||||
}
|
||||
|
||||
@@ -426,7 +426,7 @@ func TestProfileServiceImpl_Create(t *testing.T) {
|
||||
Email: "test@example.com",
|
||||
Status: 1,
|
||||
}
|
||||
userRepo.Create(testUser)
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
@@ -459,7 +459,7 @@ func TestProfileServiceImpl_Create(t *testing.T) {
|
||||
wantErr: true,
|
||||
errMsg: "角色名已被使用",
|
||||
setupMocks: func() {
|
||||
profileRepo.Create(&model.Profile{
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{
|
||||
UUID: "existing-uuid",
|
||||
UserID: 2,
|
||||
Name: "ExistingProfile",
|
||||
@@ -516,7 +516,7 @@ func TestProfileServiceImpl_GetByUUID(t *testing.T) {
|
||||
UserID: 1,
|
||||
Name: "TestProfile",
|
||||
}
|
||||
profileRepo.Create(testProfile)
|
||||
_ = profileRepo.Create(context.Background(), testProfile)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
@@ -575,7 +575,7 @@ func TestProfileServiceImpl_Delete(t *testing.T) {
|
||||
UserID: 1,
|
||||
Name: "DeleteTestProfile",
|
||||
}
|
||||
profileRepo.Create(testProfile)
|
||||
_ = profileRepo.Create(context.Background(), testProfile)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
profileService := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
@@ -625,9 +625,9 @@ func TestProfileServiceImpl_GetByUserID(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 为用户 1 和 2 预置不同档案
|
||||
profileRepo.Create(&model.Profile{UUID: "p1", UserID: 1, Name: "P1"})
|
||||
profileRepo.Create(&model.Profile{UUID: "p2", UserID: 1, Name: "P2"})
|
||||
profileRepo.Create(&model.Profile{UUID: "p3", UserID: 2, Name: "P3"})
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p1", UserID: 1, Name: "P1"})
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p2", UserID: 1, Name: "P2"})
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{UUID: "p3", UserID: 2, Name: "P3"})
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
@@ -653,7 +653,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) {
|
||||
UserID: 1,
|
||||
Name: "OldName",
|
||||
}
|
||||
profileRepo.Create(profile)
|
||||
_ = profileRepo.Create(context.Background(), profile)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
@@ -678,7 +678,7 @@ func TestProfileServiceImpl_Update_And_SetActive(t *testing.T) {
|
||||
}
|
||||
|
||||
// 名称重复
|
||||
profileRepo.Create(&model.Profile{
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{
|
||||
UUID: "u2",
|
||||
UserID: 2,
|
||||
Name: "Duplicate",
|
||||
@@ -705,8 +705,8 @@ func TestProfileServiceImpl_CheckLimit_And_GetByNames(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
// 为用户 1 预置 2 个档案
|
||||
profileRepo.Create(&model.Profile{UUID: "a", UserID: 1, Name: "A"})
|
||||
profileRepo.Create(&model.Profile{UUID: "b", UserID: 1, Name: "B"})
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{UUID: "a", UserID: 1, Name: "A"})
|
||||
_ = profileRepo.Create(context.Background(), &model.Profile{UUID: "b", UserID: 1, Name: "B"})
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
svc := NewProfileService(profileRepo, userRepo, cacheManager, logger)
|
||||
|
||||
@@ -32,8 +32,8 @@ const (
|
||||
RedisTTL = 0 // 永不过期,由应用程序管理过期时间
|
||||
)
|
||||
|
||||
// signatureService 签名服务实现
|
||||
type signatureService struct {
|
||||
// SignatureService 签名服务(导出以便依赖注入)
|
||||
type SignatureService struct {
|
||||
profileRepo repository.ProfileRepository
|
||||
redis *redis.Client
|
||||
logger *zap.Logger
|
||||
@@ -44,8 +44,8 @@ func NewSignatureService(
|
||||
profileRepo repository.ProfileRepository,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
) *signatureService {
|
||||
return &signatureService{
|
||||
) *SignatureService {
|
||||
return &SignatureService{
|
||||
profileRepo: profileRepo,
|
||||
redis: redisClient,
|
||||
logger: logger,
|
||||
@@ -53,7 +53,7 @@ func NewSignatureService(
|
||||
}
|
||||
|
||||
// NewKeyPair 生成新的RSA密钥对
|
||||
func (s *signatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
func (s *SignatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, KeySize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA密钥对失败: %w", err)
|
||||
@@ -132,7 +132,7 @@ func (s *signatureService) NewKeyPair() (*model.KeyPair, error) {
|
||||
}
|
||||
|
||||
// GetOrCreateYggdrasilKeyPair 获取或创建Yggdrasil根密钥对
|
||||
func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) {
|
||||
func (s *SignatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKey, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 尝试从Redis获取密钥
|
||||
@@ -201,7 +201,7 @@ func (s *signatureService) GetOrCreateYggdrasilKeyPair() (string, *rsa.PrivateKe
|
||||
}
|
||||
|
||||
// GetPublicKeyFromRedis 从Redis获取公钥
|
||||
func (s *signatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
func (s *SignatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
ctx := context.Background()
|
||||
publicKey, err := s.redis.Get(ctx, PublicKeyRedisKey)
|
||||
if err != nil {
|
||||
@@ -218,7 +218,7 @@ func (s *signatureService) GetPublicKeyFromRedis() (string, error) {
|
||||
}
|
||||
|
||||
// SignStringWithSHA1withRSA 使用SHA1withRSA签名字符串
|
||||
func (s *signatureService) SignStringWithSHA1withRSA(data string) (string, error) {
|
||||
func (s *SignatureService) SignStringWithSHA1withRSA(data string) (string, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 从Redis获取私钥
|
||||
|
||||
@@ -41,13 +41,13 @@ func NewTextureService(
|
||||
|
||||
func (s *textureService) Create(ctx context.Context, uploaderID int64, name, description, textureType, url, hash string, size int, isPublic, isSlim bool) (*model.Texture, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.FindByID(uploaderID)
|
||||
user, err := s.userRepo.FindByID(ctx, uploaderID)
|
||||
if err != nil || user == nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
// 检查Hash是否已存在
|
||||
existingTexture, err := s.textureRepo.FindByHash(hash)
|
||||
existingTexture, err := s.textureRepo.FindByHash(ctx, hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (s *textureService) Create(ctx context.Context, uploaderID int64, name, des
|
||||
FavoriteCount: 0,
|
||||
}
|
||||
|
||||
if err := s.textureRepo.Create(texture); err != nil {
|
||||
if err := s.textureRepo.Create(ctx, texture); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -99,7 +99,7 @@ func (s *textureService) GetByID(ctx context.Context, id int64) (*model.Texture,
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
texture2, err := s.textureRepo.FindByID(id)
|
||||
texture2, err := s.textureRepo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (s *textureService) GetByHash(ctx context.Context, hash string) (*model.Tex
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
texture2, err := s.textureRepo.FindByHash(hash)
|
||||
texture2, err := s.textureRepo.FindByHash(ctx, hash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
textures, total, err := s.textureRepo.FindByUploaderID(uploaderID, page, pageSize)
|
||||
textures, total, err := s.textureRepo.FindByUploaderID(ctx, uploaderID, page, pageSize)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -184,12 +184,12 @@ func (s *textureService) GetByUserID(ctx context.Context, uploaderID int64, page
|
||||
|
||||
func (s *textureService) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
page, pageSize = NormalizePagination(page, pageSize)
|
||||
return s.textureRepo.Search(keyword, textureType, publicOnly, page, pageSize)
|
||||
return s.textureRepo.Search(ctx, keyword, textureType, publicOnly, page, pageSize)
|
||||
}
|
||||
|
||||
func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64, name, description string, isPublic *bool) (*model.Texture, error) {
|
||||
// 获取材质并验证权限
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
texture, err := s.textureRepo.FindByID(ctx, textureID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -213,7 +213,7 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.textureRepo.UpdateFields(textureID, updates); err != nil {
|
||||
if err := s.textureRepo.UpdateFields(ctx, textureID, updates); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -222,12 +222,12 @@ func (s *textureService) Update(ctx context.Context, textureID, uploaderID int64
|
||||
s.cacheInv.OnUpdate(ctx, s.cacheKeys.Texture(textureID))
|
||||
s.cacheInv.BatchInvalidate(ctx, fmt.Sprintf("texture:user:%d:*", uploaderID))
|
||||
|
||||
return s.textureRepo.FindByID(textureID)
|
||||
return s.textureRepo.FindByID(ctx, textureID)
|
||||
}
|
||||
|
||||
func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64) error {
|
||||
// 获取材质并验证权限
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
texture, err := s.textureRepo.FindByID(ctx, textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
|
||||
return ErrTextureNoPermission
|
||||
}
|
||||
|
||||
err = s.textureRepo.Delete(textureID)
|
||||
err = s.textureRepo.Delete(ctx, textureID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func (s *textureService) Delete(ctx context.Context, textureID, uploaderID int64
|
||||
|
||||
func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID int64) (bool, error) {
|
||||
// 确保材质存在
|
||||
texture, err := s.textureRepo.FindByID(textureID)
|
||||
texture, err := s.textureRepo.FindByID(ctx, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -260,27 +260,27 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i
|
||||
return false, ErrTextureNotFound
|
||||
}
|
||||
|
||||
isFavorited, err := s.textureRepo.IsFavorited(userID, textureID)
|
||||
isFavorited, err := s.textureRepo.IsFavorited(ctx, userID, textureID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if isFavorited {
|
||||
// 已收藏 -> 取消收藏
|
||||
if err := s.textureRepo.RemoveFavorite(userID, textureID); err != nil {
|
||||
if err := s.textureRepo.RemoveFavorite(ctx, userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.textureRepo.DecrementFavoriteCount(textureID); err != nil {
|
||||
if err := s.textureRepo.DecrementFavoriteCount(ctx, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 未收藏 -> 添加收藏
|
||||
if err := s.textureRepo.AddFavorite(userID, textureID); err != nil {
|
||||
if err := s.textureRepo.AddFavorite(ctx, userID, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := s.textureRepo.IncrementFavoriteCount(textureID); err != nil {
|
||||
if err := s.textureRepo.IncrementFavoriteCount(ctx, textureID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
@@ -288,11 +288,11 @@ func (s *textureService) ToggleFavorite(ctx context.Context, userID, textureID i
|
||||
|
||||
func (s *textureService) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
page, pageSize = NormalizePagination(page, pageSize)
|
||||
return s.textureRepo.GetUserFavorites(userID, page, pageSize)
|
||||
return s.textureRepo.GetUserFavorites(ctx, userID, page, pageSize)
|
||||
}
|
||||
|
||||
func (s *textureService) CheckUploadLimit(ctx context.Context, uploaderID int64, maxTextures int) error {
|
||||
count, err := s.textureRepo.CountByUploaderID(uploaderID)
|
||||
count, err := s.textureRepo.CountByUploaderID(ctx, uploaderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -491,7 +491,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
|
||||
Email: "test@example.com",
|
||||
Status: 1,
|
||||
}
|
||||
userRepo.Create(testUser)
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
@@ -539,7 +539,7 @@ func TestTextureServiceImpl_Create(t *testing.T) {
|
||||
wantErr: true,
|
||||
errContains: "已存在",
|
||||
setupMocks: func() {
|
||||
textureRepo.Create(&model.Texture{
|
||||
_ = textureRepo.Create(context.Background(), &model.Texture{
|
||||
ID: 100,
|
||||
UploaderID: 1,
|
||||
Name: "ExistingTexture",
|
||||
@@ -614,7 +614,7 @@ func TestTextureServiceImpl_GetByID(t *testing.T) {
|
||||
Name: "TestTexture",
|
||||
Hash: "test-hash",
|
||||
}
|
||||
textureRepo.Create(testTexture)
|
||||
_ = textureRepo.Create(context.Background(), testTexture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
@@ -666,7 +666,7 @@ func TestTextureServiceImpl_GetByUserID_And_Search(t *testing.T) {
|
||||
|
||||
// 预置多条 Texture
|
||||
for i := int64(1); i <= 5; i++ {
|
||||
textureRepo.Create(&model.Texture{
|
||||
_ = textureRepo.Create(context.Background(), &model.Texture{
|
||||
ID: i,
|
||||
UploaderID: 1,
|
||||
Name: "T",
|
||||
@@ -711,7 +711,7 @@ func TestTextureServiceImpl_Update_And_Delete(t *testing.T) {
|
||||
Description: "OldDesc",
|
||||
IsPublic: false,
|
||||
}
|
||||
textureRepo.Create(texture)
|
||||
_ = textureRepo.Create(context.Background(), texture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
@@ -755,12 +755,12 @@ func TestTextureServiceImpl_FavoritesAndLimit(t *testing.T) {
|
||||
|
||||
// 预置若干 Texture 与收藏关系
|
||||
for i := int64(1); i <= 3; i++ {
|
||||
textureRepo.Create(&model.Texture{
|
||||
_ = textureRepo.Create(context.Background(), &model.Texture{
|
||||
ID: i,
|
||||
UploaderID: 1,
|
||||
Name: "T",
|
||||
})
|
||||
_ = textureRepo.AddFavorite(1, i)
|
||||
_ = textureRepo.AddFavorite(context.Background(), 1, i)
|
||||
}
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
@@ -796,7 +796,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
|
||||
|
||||
// 预置用户和Texture
|
||||
testUser := &model.User{ID: 1, Username: "testuser", Status: 1}
|
||||
userRepo.Create(testUser)
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
testTexture := &model.Texture{
|
||||
ID: 1,
|
||||
@@ -804,7 +804,7 @@ func TestTextureServiceImpl_ToggleFavorite(t *testing.T) {
|
||||
Name: "TestTexture",
|
||||
Hash: "test-hash",
|
||||
}
|
||||
textureRepo.Create(testTexture)
|
||||
_ = textureRepo.Create(context.Background(), testTexture)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
textureService := NewTextureService(textureRepo, userRepo, cacheManager, logger)
|
||||
|
||||
@@ -46,12 +46,12 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(UUID)
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
@@ -85,23 +85,27 @@ func (s *tokenService) Create(ctx context.Context, userID int64, UUID string, cl
|
||||
availableProfiles = profiles
|
||||
|
||||
// 插入令牌
|
||||
err = s.tokenRepo.Create(&token)
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Token失败: %w", err)
|
||||
}
|
||||
|
||||
// 清理多余的令牌
|
||||
go s.checkAndCleanupExcessTokens(userID)
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -118,12 +122,16 @@ func (s *tokenService) Validate(ctx context.Context, accessToken, clientToken st
|
||||
}
|
||||
|
||||
func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
|
||||
// 查找旧令牌
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
oldToken, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", "", errors.New("accessToken无效")
|
||||
@@ -134,7 +142,7 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
|
||||
// 验证profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(oldToken.UserID, selectedProfileID)
|
||||
valid, validErr := s.validateProfileByUserID(ctx, oldToken.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(err),
|
||||
@@ -174,13 +182,13 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
}
|
||||
|
||||
// 先插入新令牌,再删除旧令牌
|
||||
err = s.tokenRepo.Create(&newToken)
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Error("创建新Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return "", "", fmt.Errorf("创建新Token失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
err = s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("删除旧Token失败,但新Token已创建",
|
||||
zap.Error(err),
|
||||
@@ -194,11 +202,15 @@ func (s *tokenService) Refresh(ctx context.Context, accessToken, clientToken, se
|
||||
}
|
||||
|
||||
func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByAccessToken(accessToken)
|
||||
err := s.tokenRepo.DeleteByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("删除Token失败", zap.Error(err), zap.String("accessToken", accessToken))
|
||||
return
|
||||
@@ -207,11 +219,15 @@ func (s *tokenService) Invalidate(ctx context.Context, accessToken string) {
|
||||
}
|
||||
|
||||
func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := s.tokenRepo.DeleteByUserID(userID)
|
||||
err := s.tokenRepo.DeleteByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("删除用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
@@ -221,21 +237,33 @@ func (s *tokenService) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
|
||||
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
func (s *tokenService) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
|
||||
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
func (s *tokenService) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -250,7 +278,7 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -261,12 +289,12 @@ func (s *tokenService) checkAndCleanupExcessTokens(userID int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenService) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||
func (s *tokenService) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
|
||||
@@ -55,12 +55,12 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
)
|
||||
|
||||
// 设置超时上下文
|
||||
_, cancel := context.WithTimeout(context.Background(), DefaultTimeout)
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
// 验证用户存在
|
||||
if UUID != "" {
|
||||
_, err := s.profileRepo.FindByUUID(UUID)
|
||||
_, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户信息失败: %w", err)
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
|
||||
// 获取或创建Client
|
||||
var client *model.Client
|
||||
existingClient, err := s.clientRepo.FindByClientToken(clientToken)
|
||||
existingClient, err := s.clientRepo.FindByClientToken(ctx, clientToken)
|
||||
if err != nil {
|
||||
// Client不存在,创建新的
|
||||
clientUUID := uuid.New().String()
|
||||
@@ -90,7 +90,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.ProfileID = UUID
|
||||
}
|
||||
|
||||
if err := s.clientRepo.Create(client); err != nil {
|
||||
if err := s.clientRepo.Create(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("创建Client失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
@@ -103,14 +103,14 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
client.UpdatedAt = time.Now()
|
||||
if UUID != "" {
|
||||
client.ProfileID = UUID
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("更新Client失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取用户配置文件
|
||||
profiles, err := s.profileRepo.FindByUserID(userID)
|
||||
profiles, err := s.profileRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return selectedProfileID, availableProfiles, "", "", fmt.Errorf("获取用户配置文件失败: %w", err)
|
||||
}
|
||||
@@ -122,7 +122,7 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
if profileID == "" {
|
||||
profileID = selectedProfileID.UUID
|
||||
client.ProfileID = profileID
|
||||
s.clientRepo.Update(client)
|
||||
_ = s.clientRepo.Update(ctx, client)
|
||||
}
|
||||
}
|
||||
availableProfiles = profiles
|
||||
@@ -170,20 +170,23 @@ func (s *tokenServiceJWT) Create(ctx context.Context, userID int64, UUID string,
|
||||
StaleAt: &staleAt,
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(&token)
|
||||
err = s.tokenRepo.Create(ctx, &token)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存Token记录失败,但JWT已生成", zap.Error(err))
|
||||
// 不返回错误,因为JWT本身已经生成成功
|
||||
}
|
||||
|
||||
// 清理多余的令牌
|
||||
go s.checkAndCleanupExcessTokens(userID)
|
||||
// 清理多余的令牌(使用独立的后台上下文)
|
||||
go s.checkAndCleanupExcessTokens(context.Background(), userID)
|
||||
|
||||
return selectedProfileID, availableProfiles, accessToken, clientToken, nil
|
||||
}
|
||||
|
||||
// Validate 验证Token(使用JWT验证)
|
||||
func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken string) bool {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
if accessToken == "" {
|
||||
return false
|
||||
}
|
||||
@@ -195,7 +198,7 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -215,6 +218,9 @@ func (s *tokenServiceJWT) Validate(ctx context.Context, accessToken, clientToken
|
||||
|
||||
// Refresh 刷新Token(使用Version机制,无需删除旧Token)
|
||||
func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken, selectedProfileID string) (string, string, error) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
if accessToken == "" {
|
||||
return "", "", errors.New("accessToken不能为空")
|
||||
}
|
||||
@@ -226,7 +232,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
}
|
||||
|
||||
// 查找Client
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return "", "", errors.New("无法找到对应的Client")
|
||||
}
|
||||
@@ -243,7 +249,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
|
||||
// 验证Profile
|
||||
if selectedProfileID != "" {
|
||||
valid, validErr := s.validateProfileByUserID(client.UserID, selectedProfileID)
|
||||
valid, validErr := s.validateProfileByUserID(ctx, client.UserID, selectedProfileID)
|
||||
if validErr != nil {
|
||||
s.logger.Error("验证Profile失败",
|
||||
zap.Error(validErr),
|
||||
@@ -269,7 +275,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
// 增加Version(这是关键:通过Version失效所有旧Token)
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
return "", "", fmt.Errorf("更新Client版本失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -315,7 +321,7 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
StaleAt: &staleAt,
|
||||
}
|
||||
|
||||
err = s.tokenRepo.Create(&newToken)
|
||||
err = s.tokenRepo.Create(ctx, &newToken)
|
||||
if err != nil {
|
||||
s.logger.Warn("保存新Token记录失败,但JWT已生成", zap.Error(err))
|
||||
}
|
||||
@@ -326,6 +332,10 @@ func (s *tokenServiceJWT) Refresh(ctx context.Context, accessToken, clientToken,
|
||||
|
||||
// Invalidate 使Token失效(通过增加Version)
|
||||
func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if accessToken == "" {
|
||||
return
|
||||
}
|
||||
@@ -338,7 +348,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
}
|
||||
|
||||
// 查找Client并增加Version
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
s.logger.Warn("无法找到对应的Client", zap.Error(err))
|
||||
return
|
||||
@@ -347,7 +357,7 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
// 增加Version以失效所有旧Token
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
s.logger.Error("失效Token失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
@@ -357,12 +367,16 @@ func (s *tokenServiceJWT) Invalidate(ctx context.Context, accessToken string) {
|
||||
|
||||
// InvalidateUserTokens 使用户所有Token失效
|
||||
func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64) {
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(ctx, DefaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户所有Client
|
||||
clients, err := s.clientRepo.FindByUserID(userID)
|
||||
clients, err := s.clientRepo.FindByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Client失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
return
|
||||
@@ -372,7 +386,7 @@ func (s *tokenServiceJWT) InvalidateUserTokens(ctx context.Context, userID int64
|
||||
for _, client := range clients {
|
||||
client.Version++
|
||||
client.UpdatedAt = time.Now()
|
||||
if err := s.clientRepo.Update(client); err != nil {
|
||||
if err := s.clientRepo.Update(ctx, client); err != nil {
|
||||
s.logger.Error("失效用户Token失败", zap.Error(err), zap.Int64("userId", userID))
|
||||
}
|
||||
}
|
||||
@@ -385,7 +399,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(accessToken)
|
||||
return s.tokenRepo.GetUUIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
if claims.ProfileID != "" {
|
||||
@@ -393,7 +407,7 @@ func (s *tokenServiceJWT) GetUUIDByAccessToken(ctx context.Context, accessToken
|
||||
}
|
||||
|
||||
// 如果没有ProfileID,从Client获取
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("无法找到对应的Client: %w", err)
|
||||
}
|
||||
@@ -410,11 +424,11 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
claims, err := s.yggdrasilJWT.ParseAccessToken(accessToken, auth.StalePolicyAllow)
|
||||
if err != nil {
|
||||
// 如果JWT解析失败,尝试从数据库查询(向后兼容)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(accessToken)
|
||||
return s.tokenRepo.GetUserIDByAccessToken(ctx, accessToken)
|
||||
}
|
||||
|
||||
// 从Client获取UserID
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("无法找到对应的Client: %w", err)
|
||||
}
|
||||
@@ -429,12 +443,16 @@ func (s *tokenServiceJWT) GetUserIDByAccessToken(ctx context.Context, accessToke
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
func (s *tokenServiceJWT) checkAndCleanupExcessTokens(ctx context.Context, userID int64) {
|
||||
if userID == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(userID)
|
||||
// 为清理操作设置更长的超时时间
|
||||
ctx, cancel := context.WithTimeout(ctx, tokenExtendedTimeout)
|
||||
defer cancel()
|
||||
|
||||
tokens, err := s.tokenRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取用户Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -449,7 +467,7 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
tokensToDelete = append(tokensToDelete, tokens[i].AccessToken)
|
||||
}
|
||||
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(tokensToDelete)
|
||||
deletedCount, err := s.tokenRepo.BatchDelete(ctx, tokensToDelete)
|
||||
if err != nil {
|
||||
s.logger.Error("清理用户多余Token失败", zap.Error(err), zap.String("userId", strconv.FormatInt(userID, 10)))
|
||||
return
|
||||
@@ -460,12 +478,12 @@ func (s *tokenServiceJWT) checkAndCleanupExcessTokens(userID int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(userID int64, UUID string) (bool, error) {
|
||||
func (s *tokenServiceJWT) validateProfileByUserID(ctx context.Context, userID int64, UUID string) (bool, error) {
|
||||
if userID == 0 || UUID == "" {
|
||||
return false, errors.New("用户ID或配置文件ID不能为空")
|
||||
}
|
||||
|
||||
profile, err := s.profileRepo.FindByUUID(UUID)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, UUID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return false, errors.New("配置文件不存在")
|
||||
@@ -482,7 +500,7 @@ func (s *tokenServiceJWT) GetClientFromToken(ctx context.Context, accessToken st
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := s.clientRepo.FindByUUID(claims.Subject)
|
||||
client, err := s.clientRepo.FindByUUID(ctx, claims.Subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ func TestTokenServiceImpl_Create(t *testing.T) {
|
||||
Name: "TestProfile",
|
||||
IsActive: true,
|
||||
}
|
||||
profileRepo.Create(testProfile)
|
||||
_ = profileRepo.Create(context.Background(), testProfile)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
@@ -274,7 +274,7 @@ func TestTokenServiceImpl_Validate(t *testing.T) {
|
||||
ProfileId: "test-profile-uuid",
|
||||
Usable: true,
|
||||
}
|
||||
tokenRepo.Create(testToken)
|
||||
_ = tokenRepo.Create(context.Background(), testToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
@@ -336,7 +336,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) {
|
||||
ProfileId: "test-profile-uuid",
|
||||
Usable: true,
|
||||
}
|
||||
tokenRepo.Create(testToken)
|
||||
_ = tokenRepo.Create(context.Background(), testToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
@@ -352,7 +352,7 @@ func TestTokenServiceImpl_Invalidate(t *testing.T) {
|
||||
tokenService.Invalidate(ctx, "token-to-invalidate")
|
||||
|
||||
// 验证Token已失效(从repo中删除)
|
||||
_, err := tokenRepo.FindByAccessToken("token-to-invalidate")
|
||||
_, err := tokenRepo.FindByAccessToken(context.Background(), "token-to-invalidate")
|
||||
if err == nil {
|
||||
t.Error("Token应该已被删除")
|
||||
}
|
||||
@@ -366,7 +366,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
|
||||
|
||||
// 预置多个Token
|
||||
for i := 1; i <= 3; i++ {
|
||||
tokenRepo.Create(&model.Token{
|
||||
_ = tokenRepo.Create(context.Background(), &model.Token{
|
||||
AccessToken: fmt.Sprintf("user1-token-%d", i),
|
||||
ClientToken: "client-token",
|
||||
UserID: 1,
|
||||
@@ -374,7 +374,7 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
|
||||
Usable: true,
|
||||
})
|
||||
}
|
||||
tokenRepo.Create(&model.Token{
|
||||
_ = tokenRepo.Create(context.Background(), &model.Token{
|
||||
AccessToken: "user2-token-1",
|
||||
ClientToken: "client-token",
|
||||
UserID: 2,
|
||||
@@ -390,13 +390,13 @@ func TestTokenServiceImpl_InvalidateUserTokens(t *testing.T) {
|
||||
tokenService.InvalidateUserTokens(ctx, 1)
|
||||
|
||||
// 验证用户1的Token已失效
|
||||
tokens, _ := tokenRepo.GetByUserID(1)
|
||||
tokens, _ := tokenRepo.GetByUserID(context.Background(), 1)
|
||||
if len(tokens) > 0 {
|
||||
t.Errorf("用户1的Token应该全部被删除,但还剩 %d 个", len(tokens))
|
||||
}
|
||||
|
||||
// 验证用户2的Token仍然存在
|
||||
tokens2, _ := tokenRepo.GetByUserID(2)
|
||||
tokens2, _ := tokenRepo.GetByUserID(context.Background(), 2)
|
||||
if len(tokens2) != 1 {
|
||||
t.Errorf("用户2的Token应该仍然存在,期望1个,实际 %d 个", len(tokens2))
|
||||
}
|
||||
@@ -413,7 +413,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) {
|
||||
UUID: "profile-uuid",
|
||||
UserID: 1,
|
||||
}
|
||||
profileRepo.Create(profile)
|
||||
_ = profileRepo.Create(context.Background(), profile)
|
||||
|
||||
oldToken := &model.Token{
|
||||
AccessToken: "old-token",
|
||||
@@ -422,7 +422,7 @@ func TestTokenServiceImpl_Refresh(t *testing.T) {
|
||||
ProfileId: "",
|
||||
Usable: true,
|
||||
}
|
||||
tokenRepo.Create(oldToken)
|
||||
_ = tokenRepo.Create(context.Background(), oldToken)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
@@ -455,7 +455,7 @@ func TestTokenServiceImpl_GetByAccessToken(t *testing.T) {
|
||||
ProfileId: "profile-42",
|
||||
Usable: true,
|
||||
}
|
||||
tokenRepo.Create(token)
|
||||
_ = tokenRepo.Create(context.Background(), token)
|
||||
|
||||
tokenService := NewTokenService(tokenRepo, profileRepo, logger)
|
||||
|
||||
@@ -489,25 +489,25 @@ func TestTokenServiceImpl_validateProfileByUserID(t *testing.T) {
|
||||
UUID: "p-1",
|
||||
UserID: 1,
|
||||
}
|
||||
profileRepo.Create(profile)
|
||||
_ = profileRepo.Create(context.Background(), profile)
|
||||
|
||||
// 参数非法
|
||||
if ok, err := svc.validateProfileByUserID(0, ""); err == nil || ok {
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 0, ""); err == nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 在参数非法时应返回错误")
|
||||
}
|
||||
|
||||
// Profile 不存在
|
||||
if ok, err := svc.validateProfileByUserID(1, "not-exists"); err == nil || ok {
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "not-exists"); err == nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 在 Profile 不存在时应返回错误")
|
||||
}
|
||||
|
||||
// 用户与 Profile 匹配
|
||||
if ok, err := svc.validateProfileByUserID(1, "p-1"); err != nil || !ok {
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 1, "p-1"); err != nil || !ok {
|
||||
t.Fatalf("validateProfileByUserID 匹配时应返回 true, err=%v", err)
|
||||
}
|
||||
|
||||
// 用户与 Profile 不匹配
|
||||
if ok, err := svc.validateProfileByUserID(2, "p-1"); err != nil || ok {
|
||||
if ok, err := svc.validateProfileByUserID(context.Background(), 2, "p-1"); err != nil || ok {
|
||||
t.Fatalf("validateProfileByUserID 不匹配时应返回 false, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/redis"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -14,6 +8,14 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
apperrors "carrotskin/internal/errors"
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/internal/repository"
|
||||
"carrotskin/pkg/auth"
|
||||
"carrotskin/pkg/config"
|
||||
"carrotskin/pkg/database"
|
||||
"carrotskin/pkg/redis"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -54,21 +56,21 @@ func NewUserService(
|
||||
|
||||
func (s *userService) Register(ctx context.Context, username, password, email, avatar string) (*model.User, string, error) {
|
||||
// 检查用户名是否已存在
|
||||
existingUser, err := s.userRepo.FindByUsername(username)
|
||||
existingUser, err := s.userRepo.FindByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingUser != nil {
|
||||
return nil, "", errors.New("用户名已存在")
|
||||
return nil, "", apperrors.ErrUserAlreadyExists
|
||||
}
|
||||
|
||||
// 检查邮箱是否已存在
|
||||
existingEmail, err := s.userRepo.FindByEmail(email)
|
||||
existingEmail, err := s.userRepo.FindByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if existingEmail != nil {
|
||||
return nil, "", errors.New("邮箱已被注册")
|
||||
return nil, "", apperrors.ErrEmailAlreadyExists
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
@@ -98,7 +100,7 @@ func (s *userService) Register(ctx context.Context, username, password, email, a
|
||||
Points: 0,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -126,9 +128,9 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd
|
||||
var err error
|
||||
|
||||
if strings.Contains(usernameOrEmail, "@") {
|
||||
user, err = s.userRepo.FindByEmail(usernameOrEmail)
|
||||
user, err = s.userRepo.FindByEmail(ctx, usernameOrEmail)
|
||||
} else {
|
||||
user, err = s.userRepo.FindByUsername(usernameOrEmail)
|
||||
user, err = s.userRepo.FindByUsername(ctx, usernameOrEmail)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -166,12 +168,12 @@ func (s *userService) Login(ctx context.Context, usernameOrEmail, password, ipAd
|
||||
// 更新最后登录时间
|
||||
now := time.Now()
|
||||
user.LastLoginAt = &now
|
||||
_ = s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
_ = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
|
||||
"last_login_at": now,
|
||||
})
|
||||
|
||||
// 记录成功登录日志
|
||||
s.logSuccessLogin(user.ID, ipAddress, userAgent)
|
||||
s.logSuccessLogin(ctx, user.ID, ipAddress, userAgent)
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
@@ -180,7 +182,7 @@ func (s *userService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
// 使用 Cached 装饰器自动处理缓存
|
||||
cacheKey := s.cacheKeys.User(id)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByID(id)
|
||||
return s.userRepo.FindByID(ctx, id)
|
||||
}, 5*time.Minute)
|
||||
}
|
||||
|
||||
@@ -188,12 +190,12 @@ func (s *userService) GetByEmail(ctx context.Context, email string) (*model.User
|
||||
// 使用 Cached 装饰器自动处理缓存
|
||||
cacheKey := s.cacheKeys.UserByEmail(email)
|
||||
return database.Cached(ctx, s.cache, cacheKey, func() (*model.User, error) {
|
||||
return s.userRepo.FindByEmail(email)
|
||||
return s.userRepo.FindByEmail(ctx, email)
|
||||
}, 5*time.Minute)
|
||||
}
|
||||
|
||||
func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
|
||||
err := s.userRepo.Update(user)
|
||||
err := s.userRepo.Update(ctx, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -209,7 +211,7 @@ func (s *userService) UpdateInfo(ctx context.Context, user *model.User) error {
|
||||
}
|
||||
|
||||
func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL string) error {
|
||||
err := s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
err := s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||||
"avatar": avatarURL,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -223,7 +225,7 @@ func (s *userService) UpdateAvatar(ctx context.Context, userID int64, avatarURL
|
||||
}
|
||||
|
||||
func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
user, err := s.userRepo.FindByID(ctx, userID)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
@@ -237,7 +239,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
err = s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -251,7 +253,7 @@ func (s *userService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
}
|
||||
|
||||
func (s *userService) ResetPassword(ctx context.Context, email, newPassword string) error {
|
||||
user, err := s.userRepo.FindByEmail(email)
|
||||
user, err := s.userRepo.FindByEmail(ctx, email)
|
||||
if err != nil || user == nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
@@ -261,7 +263,7 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri
|
||||
return errors.New("密码加密失败")
|
||||
}
|
||||
|
||||
err = s.userRepo.UpdateFields(user.ID, map[string]interface{}{
|
||||
err = s.userRepo.UpdateFields(ctx, user.ID, map[string]interface{}{
|
||||
"password": hashedPassword,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -279,17 +281,17 @@ func (s *userService) ResetPassword(ctx context.Context, email, newPassword stri
|
||||
|
||||
func (s *userService) ChangeEmail(ctx context.Context, userID int64, newEmail string) error {
|
||||
// 获取旧邮箱
|
||||
oldUser, _ := s.userRepo.FindByID(userID)
|
||||
oldUser, _ := s.userRepo.FindByID(ctx, userID)
|
||||
|
||||
existingUser, err := s.userRepo.FindByEmail(newEmail)
|
||||
existingUser, err := s.userRepo.FindByEmail(ctx, newEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existingUser != nil && existingUser.ID != userID {
|
||||
return errors.New("邮箱已被其他用户使用")
|
||||
return apperrors.ErrEmailAlreadyExists
|
||||
}
|
||||
|
||||
err = s.userRepo.UpdateFields(userID, map[string]interface{}{
|
||||
err = s.userRepo.UpdateFields(ctx, userID, map[string]interface{}{
|
||||
"email": newEmail,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -346,7 +348,7 @@ func (s *userService) ValidateAvatarURL(ctx context.Context, avatarURL string) e
|
||||
}
|
||||
|
||||
func (s *userService) GetMaxProfilesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_profiles_per_user")
|
||||
config, err := s.configRepo.GetByKey(context.Background(), "max_profiles_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 5
|
||||
}
|
||||
@@ -359,7 +361,7 @@ func (s *userService) GetMaxProfilesPerUser() int {
|
||||
}
|
||||
|
||||
func (s *userService) GetMaxTexturesPerUser() int {
|
||||
config, err := s.configRepo.GetByKey("max_textures_per_user")
|
||||
config, err := s.configRepo.GetByKey(context.Background(), "max_textures_per_user")
|
||||
if err != nil || config == nil {
|
||||
return 50
|
||||
}
|
||||
@@ -374,7 +376,7 @@ func (s *userService) GetMaxTexturesPerUser() int {
|
||||
// 私有辅助方法
|
||||
|
||||
func (s *userService) getDefaultAvatar() string {
|
||||
config, err := s.configRepo.GetByKey("default_avatar")
|
||||
config, err := s.configRepo.GetByKey(context.Background(), "default_avatar")
|
||||
if err != nil || config == nil || config.Value == "" {
|
||||
return ""
|
||||
}
|
||||
@@ -410,14 +412,14 @@ func (s *userService) recordLoginFailure(ctx context.Context, usernameOrEmail, i
|
||||
identifier := usernameOrEmail + ":" + ipAddress
|
||||
count, _ := RecordLoginFailure(ctx, s.redis, identifier)
|
||||
if count >= MaxLoginAttempts {
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason+"-账号已锁定")
|
||||
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason+"-账号已锁定")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.logFailedLogin(userID, ipAddress, userAgent, reason)
|
||||
s.logFailedLogin(ctx, userID, ipAddress, userAgent, reason)
|
||||
}
|
||||
|
||||
func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string) {
|
||||
func (s *userService) logSuccessLogin(ctx context.Context, userID int64, ipAddress, userAgent string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
@@ -425,10 +427,10 @@ func (s *userService) logSuccessLogin(userID int64, ipAddress, userAgent string)
|
||||
LoginMethod: "PASSWORD",
|
||||
IsSuccess: true,
|
||||
}
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
_ = s.userRepo.CreateLoginLog(ctx, log)
|
||||
}
|
||||
|
||||
func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason string) {
|
||||
func (s *userService) logFailedLogin(ctx context.Context, userID int64, ipAddress, userAgent, reason string) {
|
||||
log := &model.UserLoginLog{
|
||||
UserID: userID,
|
||||
IPAddress: ipAddress,
|
||||
@@ -437,5 +439,5 @@ func (s *userService) logFailedLogin(userID int64, ipAddress, userAgent, reason
|
||||
IsSuccess: false,
|
||||
FailureReason: reason,
|
||||
}
|
||||
_ = s.userRepo.CreateLoginLog(log)
|
||||
_ = s.userRepo.CreateLoginLog(ctx, log)
|
||||
}
|
||||
|
||||
@@ -49,9 +49,10 @@ func TestUserServiceImpl_Register(t *testing.T) {
|
||||
email: "new@example.com",
|
||||
avatar: "",
|
||||
wantErr: true,
|
||||
errMsg: "用户名已存在",
|
||||
// 服务实现现已统一使用 apperrors.ErrUserAlreadyExists,错误信息为“用户已存在”
|
||||
errMsg: "用户已存在",
|
||||
setupMocks: func() {
|
||||
userRepo.Create(&model.User{
|
||||
_ = userRepo.Create(context.Background(), &model.User{
|
||||
Username: "existinguser",
|
||||
Email: "old@example.com",
|
||||
})
|
||||
@@ -66,7 +67,7 @@ func TestUserServiceImpl_Register(t *testing.T) {
|
||||
wantErr: true,
|
||||
errMsg: "邮箱已被注册",
|
||||
setupMocks: func() {
|
||||
userRepo.Create(&model.User{
|
||||
_ = userRepo.Create(context.Background(), &model.User{
|
||||
Username: "otheruser",
|
||||
Email: "existing@example.com",
|
||||
})
|
||||
@@ -126,7 +127,7 @@ func TestUserServiceImpl_Login(t *testing.T) {
|
||||
Password: hashedPassword,
|
||||
Status: 1,
|
||||
}
|
||||
userRepo.Create(testUser)
|
||||
_ = userRepo.Create(context.Background(), testUser)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
|
||||
@@ -207,7 +208,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) {
|
||||
Email: "basic@example.com",
|
||||
Avatar: "",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
_ = userRepo.Create(context.Background(), user)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
|
||||
@@ -231,7 +232,7 @@ func TestUserServiceImpl_BasicGettersAndUpdates(t *testing.T) {
|
||||
if err := userService.UpdateInfo(ctx, user); err != nil {
|
||||
t.Fatalf("UpdateInfo 失败: %v", err)
|
||||
}
|
||||
updated, _ := userRepo.FindByID(1)
|
||||
updated, _ := userRepo.FindByID(context.Background(), 1)
|
||||
if updated.Username != "updated" {
|
||||
t.Fatalf("UpdateInfo 未更新用户名, got=%s", updated.Username)
|
||||
}
|
||||
@@ -255,7 +256,7 @@ func TestUserServiceImpl_ChangePassword(t *testing.T) {
|
||||
Username: "changepw",
|
||||
Password: hashed,
|
||||
}
|
||||
userRepo.Create(user)
|
||||
_ = userRepo.Create(context.Background(), user)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
|
||||
@@ -290,7 +291,7 @@ func TestUserServiceImpl_ResetPassword(t *testing.T) {
|
||||
Username: "resetpw",
|
||||
Email: "reset@example.com",
|
||||
}
|
||||
userRepo.Create(user)
|
||||
_ = userRepo.Create(context.Background(), user)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
|
||||
@@ -317,8 +318,8 @@ func TestUserServiceImpl_ChangeEmail(t *testing.T) {
|
||||
|
||||
user1 := &model.User{ID: 1, Email: "user1@example.com"}
|
||||
user2 := &model.User{ID: 2, Email: "user2@example.com"}
|
||||
userRepo.Create(user1)
|
||||
userRepo.Create(user2)
|
||||
_ = userRepo.Create(context.Background(), user1)
|
||||
_ = userRepo.Create(context.Background(), user2)
|
||||
|
||||
cacheManager := NewMockCacheManager()
|
||||
userService := NewUserService(userRepo, configRepo, jwtService, nil, cacheManager, logger)
|
||||
@@ -389,8 +390,8 @@ func TestUserServiceImpl_MaxLimits(t *testing.T) {
|
||||
}
|
||||
|
||||
// 配置有效值
|
||||
configRepo.Update(&model.SystemConfig{Key: "max_profiles_per_user", Value: "10"})
|
||||
configRepo.Update(&model.SystemConfig{Key: "max_textures_per_user", Value: "100"})
|
||||
_ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_profiles_per_user", Value: "10"})
|
||||
_ = configRepo.Update(context.Background(), &model.SystemConfig{Key: "max_textures_per_user", Value: "100"})
|
||||
|
||||
if got := userService.GetMaxProfilesPerUser(); got != 10 {
|
||||
t.Fatalf("GetMaxProfilesPerUser 配置值错误, got=%d", got)
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewYggdrasilAuthService(
|
||||
}
|
||||
|
||||
func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email string) (int64, error) {
|
||||
user, err := s.userRepo.FindByEmail(email)
|
||||
user, err := s.userRepo.FindByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return 0, apperrors.ErrUserNotFound
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func (s *yggdrasilAuthService) GetUserIDByEmail(ctx context.Context, email strin
|
||||
}
|
||||
|
||||
func (s *yggdrasilAuthService) VerifyPassword(ctx context.Context, password string, userID int64) error {
|
||||
passwordStore, err := s.yggdrasilRepo.GetPasswordByID(userID)
|
||||
passwordStore, err := s.yggdrasilRepo.GetPasswordByID(ctx, userID)
|
||||
if err != nil {
|
||||
return apperrors.ErrPasswordNotSet
|
||||
}
|
||||
@@ -68,7 +68,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI
|
||||
}
|
||||
|
||||
// 检查Yggdrasil记录是否存在
|
||||
_, err = s.yggdrasilRepo.GetPasswordByID(userID)
|
||||
_, err = s.yggdrasilRepo.GetPasswordByID(ctx, userID)
|
||||
if err != nil {
|
||||
// 如果不存在,创建新记录
|
||||
yggdrasil := model.Yggdrasil{
|
||||
@@ -82,7 +82,7 @@ func (s *yggdrasilAuthService) ResetYggdrasilPassword(ctx context.Context, userI
|
||||
}
|
||||
|
||||
// 如果存在,更新密码(存储加密后的密码)
|
||||
if err := s.yggdrasilRepo.ResetPassword(userID, hashedPassword); err != nil {
|
||||
if err := s.yggdrasilRepo.ResetPassword(ctx, userID, hashedPassword); err != nil {
|
||||
return "", fmt.Errorf("重置Yggdrasil密码失败: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,14 +21,14 @@ type CertificateService interface {
|
||||
// yggdrasilCertificateService 证书服务实现
|
||||
type yggdrasilCertificateService struct {
|
||||
profileRepo repository.ProfileRepository
|
||||
signatureService *signatureService
|
||||
signatureService *SignatureService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCertificateService 创建证书服务实例
|
||||
func NewCertificateService(
|
||||
profileRepo repository.ProfileRepository,
|
||||
signatureService *signatureService,
|
||||
signatureService *SignatureService,
|
||||
logger *zap.Logger,
|
||||
) CertificateService {
|
||||
return &yggdrasilCertificateService{
|
||||
@@ -49,7 +49,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont
|
||||
)
|
||||
|
||||
// 获取密钥对
|
||||
keyPair, err := s.profileRepo.GetKeyPair(uuid)
|
||||
keyPair, err := s.profileRepo.GetKeyPair(ctx, uuid)
|
||||
if err != nil {
|
||||
s.logger.Info("获取用户密钥对失败,将创建新密钥对",
|
||||
zap.Error(err),
|
||||
@@ -74,7 +74,7 @@ func (s *yggdrasilCertificateService) GeneratePlayerCertificate(ctx context.Cont
|
||||
}
|
||||
|
||||
// 保存密钥对到数据库
|
||||
err = s.profileRepo.UpdateKeyPair(uuid, keyPair)
|
||||
err = s.profileRepo.UpdateKeyPair(ctx, uuid, keyPair)
|
||||
if err != nil {
|
||||
s.logger.Warn("更新用户密钥对失败",
|
||||
zap.Error(err),
|
||||
|
||||
@@ -28,14 +28,14 @@ type Property struct {
|
||||
// yggdrasilSerializationService 序列化服务实现
|
||||
type yggdrasilSerializationService struct {
|
||||
textureRepo repository.TextureRepository
|
||||
signatureService *signatureService
|
||||
signatureService *SignatureService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewSerializationService 创建序列化服务实例
|
||||
func NewSerializationService(
|
||||
textureRepo repository.TextureRepository,
|
||||
signatureService *signatureService,
|
||||
signatureService *SignatureService,
|
||||
logger *zap.Logger,
|
||||
) SerializationService {
|
||||
return &yggdrasilSerializationService{
|
||||
@@ -58,7 +58,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr
|
||||
|
||||
// 处理皮肤
|
||||
if profile.SkinID != nil {
|
||||
skin, err := s.textureRepo.FindByID(*profile.SkinID)
|
||||
skin, err := s.textureRepo.FindByID(ctx, *profile.SkinID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取皮肤失败",
|
||||
zap.Error(err),
|
||||
@@ -74,7 +74,7 @@ func (s *yggdrasilSerializationService) SerializeProfile(ctx context.Context, pr
|
||||
|
||||
// 处理披风
|
||||
if profile.CapeID != nil {
|
||||
cape, err := s.textureRepo.FindByID(*profile.CapeID)
|
||||
cape, err := s.textureRepo.FindByID(ctx, *profile.CapeID)
|
||||
if err != nil {
|
||||
s.logger.Error("获取披风失败",
|
||||
zap.Error(err),
|
||||
|
||||
@@ -33,7 +33,7 @@ func NewYggdrasilServiceComposite(
|
||||
profileRepo repository.ProfileRepository,
|
||||
tokenRepo repository.TokenRepository,
|
||||
yggdrasilRepo repository.YggdrasilRepository,
|
||||
signatureService *signatureService,
|
||||
signatureService *SignatureService,
|
||||
redisClient *redis.Client,
|
||||
logger *zap.Logger,
|
||||
) YggdrasilService {
|
||||
@@ -76,7 +76,7 @@ func (s *yggdrasilServiceComposite) ResetYggdrasilPassword(ctx context.Context,
|
||||
// JoinServer 加入服务器
|
||||
func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, accessToken, selectedProfile, ip string) error {
|
||||
// 验证Token
|
||||
token, err := s.tokenRepo.FindByAccessToken(accessToken)
|
||||
token, err := s.tokenRepo.FindByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
s.logger.Error("验证Token失败",
|
||||
zap.Error(err),
|
||||
@@ -92,7 +92,7 @@ func (s *yggdrasilServiceComposite) JoinServer(ctx context.Context, serverID, ac
|
||||
}
|
||||
|
||||
// 获取Profile以获取用户名
|
||||
profile, err := s.profileRepo.FindByUUID(formattedProfile)
|
||||
profile, err := s.profileRepo.FindByUUID(ctx, formattedProfile)
|
||||
if err != nil {
|
||||
s.logger.Error("获取Profile失败",
|
||||
zap.Error(err),
|
||||
|
||||
Reference in New Issue
Block a user