refactor: Update service and repository methods to use context

- Refactored multiple service and repository methods to accept context as a parameter, enhancing consistency and enabling better control over request lifecycles.
- Updated handlers to utilize context in method calls, improving error handling and performance.
- Cleaned up Dockerfile by removing unnecessary whitespace.
This commit is contained in:
lan
2025-12-03 15:27:12 +08:00
parent 4824a997dd
commit 0bcd9336c4
32 changed files with 833 additions and 497 deletions

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
@@ -16,48 +17,48 @@ func NewClientRepository(db *gorm.DB) ClientRepository {
return &clientRepository{db: db}
}
func (r *clientRepository) Create(client *model.Client) error {
return r.db.Create(client).Error
func (r *clientRepository) Create(ctx context.Context, client *model.Client) error {
return r.db.WithContext(ctx).Create(client).Error
}
func (r *clientRepository) FindByClientToken(clientToken string) (*model.Client, error) {
func (r *clientRepository) FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error) {
var client model.Client
err := r.db.Where("client_token = ?", clientToken).First(&client).Error
err := r.db.WithContext(ctx).Where("client_token = ?", clientToken).First(&client).Error
if err != nil {
return nil, err
}
return &client, nil
}
func (r *clientRepository) FindByUUID(uuid string) (*model.Client, error) {
func (r *clientRepository) FindByUUID(ctx context.Context, uuid string) (*model.Client, error) {
var client model.Client
err := r.db.Where("uuid = ?", uuid).First(&client).Error
err := r.db.WithContext(ctx).Where("uuid = ?", uuid).First(&client).Error
if err != nil {
return nil, err
}
return &client, nil
}
func (r *clientRepository) FindByUserID(userID int64) ([]*model.Client, error) {
func (r *clientRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error) {
var clients []*model.Client
err := r.db.Where("user_id = ?", userID).Find(&clients).Error
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&clients).Error
return clients, err
}
func (r *clientRepository) Update(client *model.Client) error {
return r.db.Save(client).Error
func (r *clientRepository) Update(ctx context.Context, client *model.Client) error {
return r.db.WithContext(ctx).Save(client).Error
}
func (r *clientRepository) IncrementVersion(clientUUID string) error {
return r.db.Model(&model.Client{}).
func (r *clientRepository) IncrementVersion(ctx context.Context, clientUUID string) error {
return r.db.WithContext(ctx).Model(&model.Client{}).
Where("uuid = ?", clientUUID).
Update("version", gorm.Expr("version + 1")).Error
}
func (r *clientRepository) DeleteByClientToken(clientToken string) error {
return r.db.Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
func (r *clientRepository) DeleteByClientToken(ctx context.Context, clientToken string) error {
return r.db.WithContext(ctx).Where("client_token = ?", clientToken).Delete(&model.Client{}).Error
}
func (r *clientRepository) DeleteByUserID(userID int64) error {
return r.db.Where("user_id = ?", userID).Delete(&model.Client{}).Error
func (r *clientRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&model.Client{}).Error
}

View File

@@ -2,95 +2,105 @@ package repository
import (
"carrotskin/internal/model"
"context"
)
// UserRepository 用户仓储接口
type UserRepository interface {
Create(user *model.User) error
FindByID(id int64) (*model.User, error)
FindByUsername(username string) (*model.User, error)
FindByEmail(email string) (*model.User, error)
Update(user *model.User) error
UpdateFields(id int64, fields map[string]interface{}) error
Delete(id int64) error
CreateLoginLog(log *model.UserLoginLog) error
CreatePointLog(log *model.UserPointLog) error
UpdatePoints(userID int64, amount int, changeType, reason string) error
Create(ctx context.Context, user *model.User) error
FindByID(ctx context.Context, id int64) (*model.User, error)
FindByUsername(ctx context.Context, username string) (*model.User, error)
FindByEmail(ctx context.Context, email string) (*model.User, error)
FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) // 批量查询
Update(ctx context.Context, user *model.User) error
UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
Delete(ctx context.Context, id int64) error
BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error
CreatePointLog(ctx context.Context, log *model.UserPointLog) error
UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error
}
// ProfileRepository 档案仓储接口
type ProfileRepository interface {
Create(profile *model.Profile) error
FindByUUID(uuid string) (*model.Profile, error)
FindByName(name string) (*model.Profile, error)
FindByUserID(userID int64) ([]*model.Profile, error)
Update(profile *model.Profile) error
UpdateFields(uuid string, updates map[string]interface{}) error
Delete(uuid string) error
CountByUserID(userID int64) (int64, error)
SetActive(uuid string, userID int64) error
UpdateLastUsedAt(uuid string) error
GetByNames(names []string) ([]*model.Profile, error)
GetKeyPair(profileId string) (*model.KeyPair, error)
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error
Create(ctx context.Context, profile *model.Profile) error
FindByUUID(ctx context.Context, uuid string) (*model.Profile, error)
FindByName(ctx context.Context, name string) (*model.Profile, error)
FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error)
FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) // 批量查询
Update(ctx context.Context, profile *model.Profile) error
UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error
BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) // 批量更新
Delete(ctx context.Context, uuid string) error
BatchDelete(ctx context.Context, uuids []string) (int64, error) // 批量删除
CountByUserID(ctx context.Context, userID int64) (int64, error)
SetActive(ctx context.Context, uuid string, userID int64) error
UpdateLastUsedAt(ctx context.Context, uuid string) error
GetByNames(ctx context.Context, names []string) ([]*model.Profile, error)
GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error)
UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error
}
// TextureRepository 材质仓储接口
type TextureRepository interface {
Create(texture *model.Texture) error
FindByID(id int64) (*model.Texture, error)
FindByHash(hash string) (*model.Texture, error)
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
Update(texture *model.Texture) error
UpdateFields(id int64, fields map[string]interface{}) error
Delete(id int64) error
IncrementDownloadCount(id int64) error
IncrementFavoriteCount(id int64) error
DecrementFavoriteCount(id int64) error
CreateDownloadLog(log *model.TextureDownloadLog) error
IsFavorited(userID, textureID int64) (bool, error)
AddFavorite(userID, textureID int64) error
RemoveFavorite(userID, textureID int64) error
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
CountByUploaderID(uploaderID int64) (int64, error)
Create(ctx context.Context, texture *model.Texture) error
FindByID(ctx context.Context, id int64) (*model.Texture, error)
FindByHash(ctx context.Context, hash string) (*model.Texture, error)
FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) // 批量查询
FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
Update(ctx context.Context, texture *model.Texture) error
UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error
BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) // 批量更新
Delete(ctx context.Context, id int64) error
BatchDelete(ctx context.Context, ids []int64) (int64, error) // 批量删除
IncrementDownloadCount(ctx context.Context, id int64) error
IncrementFavoriteCount(ctx context.Context, id int64) error
DecrementFavoriteCount(ctx context.Context, id int64) error
CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error
IsFavorited(ctx context.Context, userID, textureID int64) (bool, error)
AddFavorite(ctx context.Context, userID, textureID int64) error
RemoveFavorite(ctx context.Context, userID, textureID int64) error
GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error)
CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error)
}
// TokenRepository 令牌仓储接口
type TokenRepository interface {
Create(token *model.Token) error
FindByAccessToken(accessToken string) (*model.Token, error)
GetByUserID(userId int64) ([]*model.Token, error)
GetUUIDByAccessToken(accessToken string) (string, error)
GetUserIDByAccessToken(accessToken string) (int64, error)
DeleteByAccessToken(accessToken string) error
DeleteByUserID(userId int64) error
BatchDelete(accessTokens []string) (int64, error)
Create(ctx context.Context, token *model.Token) error
FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error)
GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error)
GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error)
GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error)
DeleteByAccessToken(ctx context.Context, accessToken string) error
DeleteByUserID(ctx context.Context, userId int64) error
BatchDelete(ctx context.Context, accessTokens []string) (int64, error)
}
// SystemConfigRepository 系统配置仓储接口
type SystemConfigRepository interface {
GetByKey(key string) (*model.SystemConfig, error)
GetPublic() ([]model.SystemConfig, error)
GetAll() ([]model.SystemConfig, error)
Update(config *model.SystemConfig) error
UpdateValue(key, value string) error
GetByKey(ctx context.Context, key string) (*model.SystemConfig, error)
GetPublic(ctx context.Context) ([]model.SystemConfig, error)
GetAll(ctx context.Context) ([]model.SystemConfig, error)
Update(ctx context.Context, config *model.SystemConfig) error
UpdateValue(ctx context.Context, key, value string) error
}
// YggdrasilRepository Yggdrasil仓储接口
type YggdrasilRepository interface {
GetPasswordByID(id int64) (string, error)
ResetPassword(id int64, password string) error
GetPasswordByID(ctx context.Context, id int64) (string, error)
ResetPassword(ctx context.Context, id int64, password string) error
}
// ClientRepository Client仓储接口
type ClientRepository interface {
Create(client *model.Client) error
FindByClientToken(clientToken string) (*model.Client, error)
FindByUUID(uuid string) (*model.Client, error)
FindByUserID(userID int64) ([]*model.Client, error)
Update(client *model.Client) error
IncrementVersion(clientUUID string) error
DeleteByClientToken(clientToken string) error
DeleteByUserID(userID int64) error
Create(ctx context.Context, client *model.Client) error
FindByClientToken(ctx context.Context, clientToken string) (*model.Client, error)
FindByUUID(ctx context.Context, uuid string) (*model.Client, error)
FindByUserID(ctx context.Context, userID int64) ([]*model.Client, error)
Update(ctx context.Context, client *model.Client) error
IncrementVersion(ctx context.Context, clientUUID string) error
DeleteByClientToken(ctx context.Context, clientToken string) error
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -19,13 +19,13 @@ func NewProfileRepository(db *gorm.DB) ProfileRepository {
return &profileRepository{db: db}
}
func (r *profileRepository) Create(profile *model.Profile) error {
return r.db.Create(profile).Error
func (r *profileRepository) Create(ctx context.Context, profile *model.Profile) error {
return r.db.WithContext(ctx).Create(profile).Error
}
func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) {
func (r *profileRepository) FindByUUID(ctx context.Context, uuid string) (*model.Profile, error) {
var profile model.Profile
err := r.db.Where("uuid = ?", uuid).
err := r.db.WithContext(ctx).Where("uuid = ?", uuid).
Preload("Skin").
Preload("Cape").
First(&profile).Error
@@ -35,10 +35,10 @@ func (r *profileRepository) FindByUUID(uuid string) (*model.Profile, error) {
return &profile, nil
}
func (r *profileRepository) FindByName(name string) (*model.Profile, error) {
func (r *profileRepository) FindByName(ctx context.Context, name string) (*model.Profile, error) {
var profile model.Profile
// 使用 LOWER 函数进行不区分大小写的查询,并预加载 Skin 和 Cape
err := r.db.Where("LOWER(name) = LOWER(?)", name).
err := r.db.WithContext(ctx).Where("LOWER(name) = LOWER(?)", name).
Preload("Skin").
Preload("Cape").
First(&profile).Error
@@ -48,9 +48,9 @@ func (r *profileRepository) FindByName(name string) (*model.Profile, error) {
return &profile, nil
}
func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error) {
func (r *profileRepository) FindByUserID(ctx context.Context, userID int64) ([]*model.Profile, error) {
var profiles []*model.Profile
err := r.db.Where("user_id = ?", userID).
err := r.db.WithContext(ctx).Where("user_id = ?", userID).
Preload("Skin").
Preload("Cape").
Order("created_at DESC").
@@ -58,30 +58,59 @@ func (r *profileRepository) FindByUserID(userID int64) ([]*model.Profile, error)
return profiles, err
}
func (r *profileRepository) Update(profile *model.Profile) error {
return r.db.Save(profile).Error
func (r *profileRepository) FindByUUIDs(ctx context.Context, uuids []string) ([]*model.Profile, error) {
if len(uuids) == 0 {
return []*model.Profile{}, nil
}
var profiles []*model.Profile
// 使用 IN 查询优化批量查询,并预加载关联
err := r.db.WithContext(ctx).Where("uuid IN ?", uuids).
Preload("Skin").
Preload("Cape").
Find(&profiles).Error
return profiles, err
}
func (r *profileRepository) UpdateFields(uuid string, updates map[string]interface{}) error {
return r.db.Model(&model.Profile{}).
func (r *profileRepository) Update(ctx context.Context, profile *model.Profile) error {
return r.db.WithContext(ctx).Save(profile).Error
}
func (r *profileRepository) UpdateFields(ctx context.Context, uuid string, updates map[string]interface{}) error {
return r.db.WithContext(ctx).Model(&model.Profile{}).
Where("uuid = ?", uuid).
Updates(updates).Error
}
func (r *profileRepository) Delete(uuid string) error {
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
func (r *profileRepository) Delete(ctx context.Context, uuid string) error {
return r.db.WithContext(ctx).Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
}
func (r *profileRepository) CountByUserID(userID int64) (int64, error) {
func (r *profileRepository) BatchUpdate(ctx context.Context, uuids []string, updates map[string]interface{}) (int64, error) {
if len(uuids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Profile{}).Where("uuid IN ?", uuids).Updates(updates)
return result.RowsAffected, result.Error
}
func (r *profileRepository) BatchDelete(ctx context.Context, uuids []string) (int64, error) {
if len(uuids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Where("uuid IN ?", uuids).Delete(&model.Profile{})
return result.RowsAffected, result.Error
}
func (r *profileRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
var count int64
err := r.db.Model(&model.Profile{}).
err := r.db.WithContext(ctx).Model(&model.Profile{}).
Where("user_id = ?", userID).
Count(&count).Error
return count, err
}
func (r *profileRepository) SetActive(uuid string, userID int64) error {
return r.db.Transaction(func(tx *gorm.DB) error {
func (r *profileRepository) SetActive(ctx context.Context, uuid string, userID int64) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&model.Profile{}).
Where("user_id = ?", userID).
Update("is_active", false).Error; err != nil {
@@ -94,28 +123,28 @@ func (r *profileRepository) SetActive(uuid string, userID int64) error {
})
}
func (r *profileRepository) UpdateLastUsedAt(uuid string) error {
return r.db.Model(&model.Profile{}).
func (r *profileRepository) UpdateLastUsedAt(ctx context.Context, uuid string) error {
return r.db.WithContext(ctx).Model(&model.Profile{}).
Where("uuid = ?", uuid).
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
}
func (r *profileRepository) GetByNames(names []string) ([]*model.Profile, error) {
func (r *profileRepository) GetByNames(ctx context.Context, names []string) ([]*model.Profile, error) {
var profiles []*model.Profile
err := r.db.Where("name in (?)", names).
err := r.db.WithContext(ctx).Where("name in (?)", names).
Preload("Skin").
Preload("Cape").
Find(&profiles).Error
return profiles, err
}
func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error) {
func (r *profileRepository) GetKeyPair(ctx context.Context, profileId string) (*model.KeyPair, error) {
if profileId == "" {
return nil, errors.New("参数不能为空")
}
var profile model.Profile
result := r.db.WithContext(context.Background()).
result := r.db.WithContext(ctx).
Select("key_pair").
Where("id = ?", profileId).
First(&profile)
@@ -130,7 +159,7 @@ func (r *profileRepository) GetKeyPair(profileId string) (*model.KeyPair, error)
return &model.KeyPair{}, nil
}
func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
func (r *profileRepository) UpdateKeyPair(ctx context.Context, profileId string, keyPair *model.KeyPair) error {
if profileId == "" {
return errors.New("profileId 不能为空")
}
@@ -138,9 +167,8 @@ func (r *profileRepository) UpdateKeyPair(profileId string, keyPair *model.KeyPa
return errors.New("keyPair 不能为 nil")
}
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.WithContext(context.Background()).
Table("profiles").
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
result := tx.Table("profiles").
Where("id = ?", profileId).
UpdateColumns(map[string]interface{}{
"private_key": keyPair.PrivateKey,

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
@@ -16,28 +17,28 @@ func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
return &systemConfigRepository{db: db}
}
func (r *systemConfigRepository) GetByKey(key string) (*model.SystemConfig, error) {
func (r *systemConfigRepository) GetByKey(ctx context.Context, key string) (*model.SystemConfig, error) {
var config model.SystemConfig
err := r.db.Where("key = ?", key).First(&config).Error
err := r.db.WithContext(ctx).Where("key = ?", key).First(&config).Error
return handleNotFoundResult(&config, err)
}
func (r *systemConfigRepository) GetPublic() ([]model.SystemConfig, error) {
func (r *systemConfigRepository) GetPublic(ctx context.Context) ([]model.SystemConfig, error) {
var configs []model.SystemConfig
err := r.db.Where("is_public = ?", true).Find(&configs).Error
err := r.db.WithContext(ctx).Where("is_public = ?", true).Find(&configs).Error
return configs, err
}
func (r *systemConfigRepository) GetAll() ([]model.SystemConfig, error) {
func (r *systemConfigRepository) GetAll(ctx context.Context) ([]model.SystemConfig, error) {
var configs []model.SystemConfig
err := r.db.Find(&configs).Error
err := r.db.WithContext(ctx).Find(&configs).Error
return configs, err
}
func (r *systemConfigRepository) Update(config *model.SystemConfig) error {
return r.db.Save(config).Error
func (r *systemConfigRepository) Update(ctx context.Context, config *model.SystemConfig) error {
return r.db.WithContext(ctx).Save(config).Error
}
func (r *systemConfigRepository) UpdateValue(key, value string) error {
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
func (r *systemConfigRepository) UpdateValue(ctx context.Context, key, value string) error {
return r.db.WithContext(ctx).Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
@@ -16,27 +17,39 @@ func NewTextureRepository(db *gorm.DB) TextureRepository {
return &textureRepository{db: db}
}
func (r *textureRepository) Create(texture *model.Texture) error {
return r.db.Create(texture).Error
func (r *textureRepository) Create(ctx context.Context, texture *model.Texture) error {
return r.db.WithContext(ctx).Create(texture).Error
}
func (r *textureRepository) FindByID(id int64) (*model.Texture, error) {
func (r *textureRepository) FindByID(ctx context.Context, id int64) (*model.Texture, error) {
var texture model.Texture
err := r.db.Preload("Uploader").First(&texture, id).Error
err := r.db.WithContext(ctx).Preload("Uploader").First(&texture, id).Error
return handleNotFoundResult(&texture, err)
}
func (r *textureRepository) FindByHash(hash string) (*model.Texture, error) {
func (r *textureRepository) FindByHash(ctx context.Context, hash string) (*model.Texture, error) {
var texture model.Texture
err := r.db.Where("hash = ?", hash).First(&texture).Error
err := r.db.WithContext(ctx).Where("hash = ?", hash).First(&texture).Error
return handleNotFoundResult(&texture, err)
}
func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
func (r *textureRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.Texture, error) {
if len(ids) == 0 {
return []*model.Texture{}, nil
}
var textures []*model.Texture
// 使用 IN 查询优化批量查询,并预加载关联
err := r.db.WithContext(ctx).Where("id IN ?", ids).
Preload("Uploader").
Find(&textures).Error
return textures, err
}
func (r *textureRepository) FindByUploaderID(ctx context.Context, uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
@@ -54,11 +67,11 @@ func (r *textureRepository) FindByUploaderID(uploaderID int64, page, pageSize in
return textures, total, nil
}
func (r *textureRepository) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
func (r *textureRepository) Search(ctx context.Context, keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
query := r.db.Model(&model.Texture{}).Where("status = 1")
query := r.db.WithContext(ctx).Model(&model.Texture{}).Where("status = 1")
if publicOnly {
query = query.Where("is_public = ?", true)
@@ -86,67 +99,86 @@ func (r *textureRepository) Search(keyword string, textureType model.TextureType
return textures, total, nil
}
func (r *textureRepository) Update(texture *model.Texture) error {
return r.db.Save(texture).Error
func (r *textureRepository) Update(ctx context.Context, texture *model.Texture) error {
return r.db.WithContext(ctx).Save(texture).Error
}
func (r *textureRepository) UpdateFields(id int64, fields map[string]interface{}) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
func (r *textureRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
}
func (r *textureRepository) Delete(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
func (r *textureRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
}
func (r *textureRepository) IncrementDownloadCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
func (r *textureRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Updates(fields)
return result.RowsAffected, result.Error
}
func (r *textureRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.Texture{}).Where("id IN ?", ids).Update("status", -1)
return result.RowsAffected, result.Error
}
func (r *textureRepository) IncrementDownloadCount(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
}
func (r *textureRepository) IncrementFavoriteCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
func (r *textureRepository) IncrementFavoriteCount(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
}
func (r *textureRepository) DecrementFavoriteCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
func (r *textureRepository) DecrementFavoriteCount(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
}
func (r *textureRepository) CreateDownloadLog(log *model.TextureDownloadLog) error {
return r.db.Create(log).Error
func (r *textureRepository) CreateDownloadLog(ctx context.Context, log *model.TextureDownloadLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *textureRepository) IsFavorited(userID, textureID int64) (bool, error) {
func (r *textureRepository) IsFavorited(ctx context.Context, userID, textureID int64) (bool, error) {
var count int64
err := r.db.Model(&model.UserTextureFavorite{}).
// 使用 Select("1") 优化,只查询是否存在,不需要查询所有字段
err := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
Select("1").
Where("user_id = ? AND texture_id = ?", userID, textureID).
Limit(1).
Count(&count).Error
return count > 0, err
}
func (r *textureRepository) AddFavorite(userID, textureID int64) error {
func (r *textureRepository) AddFavorite(ctx context.Context, userID, textureID int64) error {
favorite := &model.UserTextureFavorite{
UserID: userID,
TextureID: textureID,
}
return r.db.Create(favorite).Error
return r.db.WithContext(ctx).Create(favorite).Error
}
func (r *textureRepository) RemoveFavorite(userID, textureID int64) error {
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID).
func (r *textureRepository) RemoveFavorite(ctx context.Context, userID, textureID int64) error {
return r.db.WithContext(ctx).Where("user_id = ? AND texture_id = ?", userID, textureID).
Delete(&model.UserTextureFavorite{}).Error
}
func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
func (r *textureRepository) GetUserFavorites(ctx context.Context, userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
subQuery := r.db.Model(&model.UserTextureFavorite{}).
subQuery := r.db.WithContext(ctx).Model(&model.UserTextureFavorite{}).
Select("texture_id").
Where("user_id = ?", userID)
query := r.db.Model(&model.Texture{}).
query := r.db.WithContext(ctx).Model(&model.Texture{}).
Where("id IN (?) AND status = 1", subQuery)
if err := query.Count(&total).Error; err != nil {
@@ -165,9 +197,9 @@ func (r *textureRepository) GetUserFavorites(userID int64, page, pageSize int) (
return textures, total, nil
}
func (r *textureRepository) CountByUploaderID(uploaderID int64) (int64, error) {
func (r *textureRepository) CountByUploaderID(ctx context.Context, uploaderID int64) (int64, error) {
var count int64
err := r.db.Model(&model.Texture{}).
err := r.db.WithContext(ctx).Model(&model.Texture{}).
Where("uploader_id = ? AND status != -1", uploaderID).
Count(&count).Error
return count, err

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
@@ -16,55 +17,55 @@ func NewTokenRepository(db *gorm.DB) TokenRepository {
return &tokenRepository{db: db}
}
func (r *tokenRepository) Create(token *model.Token) error {
return r.db.Create(token).Error
func (r *tokenRepository) Create(ctx context.Context, token *model.Token) error {
return r.db.WithContext(ctx).Create(token).Error
}
func (r *tokenRepository) FindByAccessToken(accessToken string) (*model.Token, error) {
func (r *tokenRepository) FindByAccessToken(ctx context.Context, accessToken string) (*model.Token, error) {
var token model.Token
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
err := r.db.WithContext(ctx).Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
func (r *tokenRepository) GetByUserID(userId int64) ([]*model.Token, error) {
func (r *tokenRepository) GetByUserID(ctx context.Context, userId int64) ([]*model.Token, error) {
var tokens []*model.Token
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error
err := r.db.WithContext(ctx).Where("user_id = ?", userId).Find(&tokens).Error
return tokens, err
}
func (r *tokenRepository) GetUUIDByAccessToken(accessToken string) (string, error) {
func (r *tokenRepository) GetUUIDByAccessToken(ctx context.Context, accessToken string) (string, error) {
var token model.Token
err := r.db.Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
err := r.db.WithContext(ctx).Select("profile_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return "", err
}
return token.ProfileId, nil
}
func (r *tokenRepository) GetUserIDByAccessToken(accessToken string) (int64, error) {
func (r *tokenRepository) GetUserIDByAccessToken(ctx context.Context, accessToken string) (int64, error) {
var token model.Token
err := r.db.Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
err := r.db.WithContext(ctx).Select("user_id").Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return 0, err
}
return token.UserID, nil
}
func (r *tokenRepository) DeleteByAccessToken(accessToken string) error {
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
func (r *tokenRepository) DeleteByAccessToken(ctx context.Context, accessToken string) error {
return r.db.WithContext(ctx).Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
}
func (r *tokenRepository) DeleteByUserID(userId int64) error {
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
func (r *tokenRepository) DeleteByUserID(ctx context.Context, userId int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userId).Delete(&model.Token{}).Error
}
func (r *tokenRepository) BatchDelete(accessTokens []string) (int64, error) {
func (r *tokenRepository) BatchDelete(ctx context.Context, accessTokens []string) (int64, error) {
if len(accessTokens) == 0 {
return 0, nil
}
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{})
result := r.db.WithContext(ctx).Where("access_token IN ?", accessTokens).Delete(&model.Token{})
return result.RowsAffected, result.Error
}

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"errors"
"gorm.io/gorm"
@@ -17,50 +18,76 @@ func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepository{db: db}
}
func (r *userRepository) Create(user *model.User) error {
return r.db.Create(user).Error
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
func (r *userRepository) FindByID(id int64) (*model.User, error) {
func (r *userRepository) FindByID(ctx context.Context, id int64) (*model.User, error) {
var user model.User
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error
err := r.db.WithContext(ctx).Where("id = ? AND status != -1", id).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepository) FindByUsername(username string) (*model.User, error) {
func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) {
var user model.User
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error
err := r.db.WithContext(ctx).Where("username = ? AND status != -1", username).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepository) FindByEmail(email string) (*model.User, error) {
func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) {
var user model.User
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error
err := r.db.WithContext(ctx).Where("email = ? AND status != -1", email).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepository) Update(user *model.User) error {
return r.db.Save(user).Error
func (r *userRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) {
if len(ids) == 0 {
return []*model.User{}, nil
}
var users []*model.User
// 使用 IN 查询优化批量查询
err := r.db.WithContext(ctx).Where("id IN ? AND status != -1", ids).Find(&users).Error
return users, err
}
func (r *userRepository) UpdateFields(id int64, fields map[string]interface{}) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
func (r *userRepository) Update(ctx context.Context, user *model.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
func (r *userRepository) Delete(id int64) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
func (r *userRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
}
func (r *userRepository) CreateLoginLog(log *model.UserLoginLog) error {
return r.db.Create(log).Error
func (r *userRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
}
func (r *userRepository) CreatePointLog(log *model.UserPointLog) error {
return r.db.Create(log).Error
func (r *userRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Updates(fields)
return result.RowsAffected, result.Error
}
func (r *userRepository) UpdatePoints(userID int64, amount int, changeType, reason string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
func (r *userRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) {
if len(ids) == 0 {
return 0, nil
}
result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Update("status", -1)
return result.RowsAffected, result.Error
}
func (r *userRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *userRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
func (r *userRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var user model.User
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
return err

View File

@@ -2,6 +2,7 @@ package repository
import (
"carrotskin/internal/model"
"context"
"gorm.io/gorm"
)
@@ -16,15 +17,15 @@ func NewYggdrasilRepository(db *gorm.DB) YggdrasilRepository {
return &yggdrasilRepository{db: db}
}
func (r *yggdrasilRepository) GetPasswordByID(id int64) (string, error) {
func (r *yggdrasilRepository) GetPasswordByID(ctx context.Context, id int64) (string, error) {
var yggdrasil model.Yggdrasil
err := r.db.Select("password").Where("id = ?", id).First(&yggdrasil).Error
err := r.db.WithContext(ctx).Select("password").Where("id = ?", id).First(&yggdrasil).Error
if err != nil {
return "", err
}
return yggdrasil.Password, nil
}
func (r *yggdrasilRepository) ResetPassword(id int64, password string) error {
return r.db.Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error
func (r *yggdrasilRepository) ResetPassword(ctx context.Context, id int64, password string) error {
return r.db.WithContext(ctx).Model(&model.Yggdrasil{}).Where("id = ?", id).Update("password", password).Error
}