feat: 引入依赖注入模式
- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等) - 创建Repository接口实现 - 创建依赖注入容器(container.Container) - 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler) - 创建新的路由注册方式(RegisterRoutesWithDI) - 提供main.go示例文件展示如何使用依赖注入 同时包含之前的安全修复: - CORS配置安全加固 - 头像URL验证安全修复 - JWT algorithm confusion漏洞修复 - Recovery中间件增强 - 敏感错误信息泄露修复 - 类型断言安全修复
This commit is contained in:
85
internal/repository/interfaces.go
Normal file
85
internal/repository/interfaces.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储接口
|
||||
type UserRepository interface {
|
||||
Create(user *model.User) error
|
||||
FindByID(id int64) (*model.User, error)
|
||||
FindByUsername(username string) (*model.User, error)
|
||||
FindByEmail(email string) (*model.User, error)
|
||||
Update(user *model.User) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
CreateLoginLog(log *model.UserLoginLog) error
|
||||
CreatePointLog(log *model.UserPointLog) error
|
||||
UpdatePoints(userID int64, amount int, changeType, reason string) error
|
||||
}
|
||||
|
||||
// ProfileRepository 档案仓储接口
|
||||
type ProfileRepository interface {
|
||||
Create(profile *model.Profile) error
|
||||
FindByUUID(uuid string) (*model.Profile, error)
|
||||
FindByName(name string) (*model.Profile, error)
|
||||
FindByUserID(userID int64) ([]*model.Profile, error)
|
||||
Update(profile *model.Profile) error
|
||||
UpdateFields(uuid string, updates map[string]interface{}) error
|
||||
Delete(uuid string) error
|
||||
CountByUserID(userID int64) (int64, error)
|
||||
SetActive(uuid string, userID int64) error
|
||||
UpdateLastUsedAt(uuid string) error
|
||||
GetByNames(names []string) ([]*model.Profile, error)
|
||||
GetKeyPair(profileId string) (*model.KeyPair, error)
|
||||
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error
|
||||
}
|
||||
|
||||
// TextureRepository 材质仓储接口
|
||||
type TextureRepository interface {
|
||||
Create(texture *model.Texture) error
|
||||
FindByID(id int64) (*model.Texture, error)
|
||||
FindByHash(hash string) (*model.Texture, error)
|
||||
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
Update(texture *model.Texture) error
|
||||
UpdateFields(id int64, fields map[string]interface{}) error
|
||||
Delete(id int64) error
|
||||
IncrementDownloadCount(id int64) error
|
||||
IncrementFavoriteCount(id int64) error
|
||||
DecrementFavoriteCount(id int64) error
|
||||
CreateDownloadLog(log *model.TextureDownloadLog) error
|
||||
IsFavorited(userID, textureID int64) (bool, error)
|
||||
AddFavorite(userID, textureID int64) error
|
||||
RemoveFavorite(userID, textureID int64) error
|
||||
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
|
||||
CountByUploaderID(uploaderID int64) (int64, error)
|
||||
}
|
||||
|
||||
// TokenRepository 令牌仓储接口
|
||||
type TokenRepository interface {
|
||||
Create(token *model.Token) error
|
||||
FindByAccessToken(accessToken string) (*model.Token, error)
|
||||
GetByUserID(userId int64) ([]*model.Token, error)
|
||||
GetUUIDByAccessToken(accessToken string) (string, error)
|
||||
GetUserIDByAccessToken(accessToken string) (int64, error)
|
||||
DeleteByAccessToken(accessToken string) error
|
||||
DeleteByUserID(userId int64) error
|
||||
BatchDelete(accessTokens []string) (int64, error)
|
||||
}
|
||||
|
||||
// SystemConfigRepository 系统配置仓储接口
|
||||
type SystemConfigRepository interface {
|
||||
GetByKey(key string) (*model.SystemConfig, error)
|
||||
GetPublic() ([]model.SystemConfig, error)
|
||||
GetAll() ([]model.SystemConfig, error)
|
||||
Update(config *model.SystemConfig) error
|
||||
UpdateValue(key, value string) error
|
||||
}
|
||||
|
||||
// YggdrasilRepository Yggdrasil仓储接口
|
||||
type YggdrasilRepository interface {
|
||||
GetPasswordByID(id int64) (string, error)
|
||||
ResetPassword(id int64, password string) error
|
||||
}
|
||||
|
||||
149
internal/repository/profile_repository_impl.go
Normal file
149
internal/repository/profile_repository_impl.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// profileRepositoryImpl ProfileRepository的实现
|
||||
type profileRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewProfileRepository 创建ProfileRepository实例
|
||||
func NewProfileRepository(db *gorm.DB) ProfileRepository {
|
||||
return &profileRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Create(profile *model.Profile) error {
|
||||
return r.db.Create(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByUUID(uuid string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
err := r.db.Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByName(name string) (*model.Profile, error) {
|
||||
var profile model.Profile
|
||||
err := r.db.Where("name = ?", name).First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) FindByUserID(userID int64) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Update(profile *model.Profile) error {
|
||||
return r.db.Save(profile).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateFields(uuid string, updates map[string]interface{}) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) Delete(uuid string) error {
|
||||
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) CountByUserID(userID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) SetActive(uuid string, userID int64) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Model(&model.Profile{}).
|
||||
Where("uuid = ? AND user_id = ?", uuid, userID).
|
||||
Update("is_active", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateLastUsedAt(uuid string) error {
|
||||
return r.db.Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) GetByNames(names []string) ([]*model.Profile, error) {
|
||||
var profiles []*model.Profile
|
||||
err := r.db.Where("name in (?)", names).Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) GetKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
var profile model.Profile
|
||||
result := r.db.WithContext(context.Background()).
|
||||
Select("key_pair").
|
||||
Where("id = ?", profileId).
|
||||
First(&profile)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("key pair未找到")
|
||||
}
|
||||
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &model.KeyPair{}, nil
|
||||
}
|
||||
|
||||
func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
if keyPair == nil {
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles").
|
||||
Where("id = ?", profileId).
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey,
|
||||
"public_key": keyPair.PublicKey,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
44
internal/repository/system_config_repository_impl.go
Normal file
44
internal/repository/system_config_repository_impl.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// systemConfigRepositoryImpl SystemConfigRepository的实现
|
||||
type systemConfigRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSystemConfigRepository 创建SystemConfigRepository实例
|
||||
func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
|
||||
return &systemConfigRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetByKey(key string) (*model.SystemConfig, error) {
|
||||
var config model.SystemConfig
|
||||
err := r.db.Where("key = ?", key).First(&config).Error
|
||||
return handleNotFoundResult(&config, err)
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetPublic() ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Where("is_public = ?", true).Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) GetAll() ([]model.SystemConfig, error) {
|
||||
var configs []model.SystemConfig
|
||||
err := r.db.Find(&configs).Error
|
||||
return configs, err
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) Update(config *model.SystemConfig) error {
|
||||
return r.db.Save(config).Error
|
||||
}
|
||||
|
||||
func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error {
|
||||
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
|
||||
}
|
||||
|
||||
175
internal/repository/texture_repository_impl.go
Normal file
175
internal/repository/texture_repository_impl.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// textureRepositoryImpl TextureRepository的实现
|
||||
type textureRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTextureRepository 创建TextureRepository实例
|
||||
func NewTextureRepository(db *gorm.DB) TextureRepository {
|
||||
return &textureRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Create(texture *model.Texture) error {
|
||||
return r.db.Create(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByID(id int64) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Preload("Uploader").First(&texture, id).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByHash(hash string) (*model.Texture, error) {
|
||||
var texture model.Texture
|
||||
err := r.db.Where("hash = ?", hash).First(&texture).Error
|
||||
return handleNotFoundResult(&texture, err)
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Texture{}).Where("status = 1")
|
||||
|
||||
if publicOnly {
|
||||
query = query.Where("is_public = ?", true)
|
||||
}
|
||||
if textureType != "" {
|
||||
query = query.Where("type = ?", textureType)
|
||||
}
|
||||
if keyword != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Update(texture *model.Texture) error {
|
||||
return r.db.Save(texture).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) Delete(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IncrementDownloadCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IncrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) DecrementFavoriteCount(id int64) error {
|
||||
return r.db.Model(&model.Texture{}).Where("id = ?", id).
|
||||
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) CreateDownloadLog(log *model.TextureDownloadLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) IsFavorited(userID, textureID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.UserTextureFavorite{}).
|
||||
Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) AddFavorite(userID, textureID int64) error {
|
||||
favorite := &model.UserTextureFavorite{
|
||||
UserID: userID,
|
||||
TextureID: textureID,
|
||||
}
|
||||
return r.db.Create(favorite).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) RemoveFavorite(userID, textureID int64) error {
|
||||
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID).
|
||||
Delete(&model.UserTextureFavorite{}).Error
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
|
||||
var textures []*model.Texture
|
||||
var total int64
|
||||
|
||||
subQuery := r.db.Model(&model.UserTextureFavorite{}).
|
||||
Select("texture_id").
|
||||
Where("user_id = ?", userID)
|
||||
|
||||
query := r.db.Model(&model.Texture{}).
|
||||
Where("id IN (?) AND status = 1", subQuery)
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err := query.Scopes(Paginate(page, pageSize)).
|
||||
Preload("Uploader").
|
||||
Order("created_at DESC").
|
||||
Find(&textures).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return textures, total, nil
|
||||
}
|
||||
|
||||
func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Texture{}).
|
||||
Where("uploader_id = ? AND status != -1", uploaderID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
71
internal/repository/token_repository_impl.go
Normal file
71
internal/repository/token_repository_impl.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// tokenRepositoryImpl TokenRepository的实现
|
||||
type tokenRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewTokenRepository 创建TokenRepository实例
|
||||
func NewTokenRepository(db *gorm.DB) TokenRepository {
|
||||
return &tokenRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) Create(token *model.Token) error {
|
||||
return r.db.Create(token).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) FindByAccessToken(accessToken string) (*model.Token, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetByUserID(userId int64) ([]*model.Token, error) {
|
||||
var tokens []*model.Token
|
||||
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetUUIDByAccessToken(accessToken string) (string, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token.ProfileId, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) GetUserIDByAccessToken(accessToken string) (int64, error) {
|
||||
var token model.Token
|
||||
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return token.UserID, nil
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) DeleteByAccessToken(accessToken string) error {
|
||||
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) DeleteByUserID(userId int64) error {
|
||||
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
|
||||
}
|
||||
|
||||
func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) {
|
||||
if len(accessTokens) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
103
internal/repository/user_repository_impl.go
Normal file
103
internal/repository/user_repository_impl.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// userRepositoryImpl UserRepository的实现
|
||||
type userRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建UserRepository实例
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepositoryImpl{db: db}
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Create(user *model.User) error {
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByID(id int64) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByUsername(username string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) FindByEmail(email string) (*model.User, error) {
|
||||
var user model.User
|
||||
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error
|
||||
return handleNotFoundResult(&user, err)
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Update(user *model.User) error {
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) Delete(id int64) error {
|
||||
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) CreateLoginLog(log *model.UserLoginLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) CreatePointLog(log *model.UserPointLog) error {
|
||||
return r.db.Create(log).Error
|
||||
}
|
||||
|
||||
func (r *userRepositoryImpl) UpdatePoints(userID int64, amount int, changeType, reason string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
var user model.User
|
||||
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
balanceBefore := user.Points
|
||||
balanceAfter := balanceBefore + amount
|
||||
|
||||
if balanceAfter < 0 {
|
||||
return errors.New("积分不足")
|
||||
}
|
||||
|
||||
if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log := &model.UserPointLog{
|
||||
UserID: userID,
|
||||
ChangeType: changeType,
|
||||
Amount: amount,
|
||||
BalanceBefore: balanceBefore,
|
||||
BalanceAfter: balanceAfter,
|
||||
Reason: reason,
|
||||
}
|
||||
|
||||
return tx.Create(log).Error
|
||||
})
|
||||
}
|
||||
|
||||
// handleNotFoundResult 处理记录未找到的情况
|
||||
func handleNotFoundResult[T any](result *T, err error) (*T, error) {
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user