feat: 引入依赖注入模式

- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等)
- 创建Repository接口实现
- 创建依赖注入容器(container.Container)
- 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler)
- 创建新的路由注册方式(RegisterRoutesWithDI)
- 提供main.go示例文件展示如何使用依赖注入

同时包含之前的安全修复:
- CORS配置安全加固
- 头像URL验证安全修复
- JWT algorithm confusion漏洞修复
- Recovery中间件增强
- 敏感错误信息泄露修复
- 类型断言安全修复
This commit is contained in:
lan
2025-12-02 17:40:39 +08:00
parent 373c61f625
commit f7589ebbb8
25 changed files with 2029 additions and 139 deletions

View File

@@ -0,0 +1,85 @@
package repository
import (
"carrotskin/internal/model"
)
// UserRepository 用户仓储接口
type UserRepository interface {
Create(user *model.User) error
FindByID(id int64) (*model.User, error)
FindByUsername(username string) (*model.User, error)
FindByEmail(email string) (*model.User, error)
Update(user *model.User) error
UpdateFields(id int64, fields map[string]interface{}) error
Delete(id int64) error
CreateLoginLog(log *model.UserLoginLog) error
CreatePointLog(log *model.UserPointLog) error
UpdatePoints(userID int64, amount int, changeType, reason string) error
}
// ProfileRepository 档案仓储接口
type ProfileRepository interface {
Create(profile *model.Profile) error
FindByUUID(uuid string) (*model.Profile, error)
FindByName(name string) (*model.Profile, error)
FindByUserID(userID int64) ([]*model.Profile, error)
Update(profile *model.Profile) error
UpdateFields(uuid string, updates map[string]interface{}) error
Delete(uuid string) error
CountByUserID(userID int64) (int64, error)
SetActive(uuid string, userID int64) error
UpdateLastUsedAt(uuid string) error
GetByNames(names []string) ([]*model.Profile, error)
GetKeyPair(profileId string) (*model.KeyPair, error)
UpdateKeyPair(profileId string, keyPair *model.KeyPair) error
}
// TextureRepository 材质仓储接口
type TextureRepository interface {
Create(texture *model.Texture) error
FindByID(id int64) (*model.Texture, error)
FindByHash(hash string) (*model.Texture, error)
FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error)
Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error)
Update(texture *model.Texture) error
UpdateFields(id int64, fields map[string]interface{}) error
Delete(id int64) error
IncrementDownloadCount(id int64) error
IncrementFavoriteCount(id int64) error
DecrementFavoriteCount(id int64) error
CreateDownloadLog(log *model.TextureDownloadLog) error
IsFavorited(userID, textureID int64) (bool, error)
AddFavorite(userID, textureID int64) error
RemoveFavorite(userID, textureID int64) error
GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error)
CountByUploaderID(uploaderID int64) (int64, error)
}
// TokenRepository 令牌仓储接口
type TokenRepository interface {
Create(token *model.Token) error
FindByAccessToken(accessToken string) (*model.Token, error)
GetByUserID(userId int64) ([]*model.Token, error)
GetUUIDByAccessToken(accessToken string) (string, error)
GetUserIDByAccessToken(accessToken string) (int64, error)
DeleteByAccessToken(accessToken string) error
DeleteByUserID(userId int64) error
BatchDelete(accessTokens []string) (int64, error)
}
// SystemConfigRepository 系统配置仓储接口
type SystemConfigRepository interface {
GetByKey(key string) (*model.SystemConfig, error)
GetPublic() ([]model.SystemConfig, error)
GetAll() ([]model.SystemConfig, error)
Update(config *model.SystemConfig) error
UpdateValue(key, value string) error
}
// YggdrasilRepository Yggdrasil仓储接口
type YggdrasilRepository interface {
GetPasswordByID(id int64) (string, error)
ResetPassword(id int64, password string) error
}

View File

@@ -0,0 +1,149 @@
package repository
import (
"carrotskin/internal/model"
"context"
"errors"
"fmt"
"gorm.io/gorm"
)
// profileRepositoryImpl ProfileRepository的实现
type profileRepositoryImpl struct {
db *gorm.DB
}
// NewProfileRepository 创建ProfileRepository实例
func NewProfileRepository(db *gorm.DB) ProfileRepository {
return &profileRepositoryImpl{db: db}
}
func (r *profileRepositoryImpl) Create(profile *model.Profile) error {
return r.db.Create(profile).Error
}
func (r *profileRepositoryImpl) FindByUUID(uuid string) (*model.Profile, error) {
var profile model.Profile
err := r.db.Where("uuid = ?", uuid).
Preload("Skin").
Preload("Cape").
First(&profile).Error
if err != nil {
return nil, err
}
return &profile, nil
}
func (r *profileRepositoryImpl) FindByName(name string) (*model.Profile, error) {
var profile model.Profile
err := r.db.Where("name = ?", name).First(&profile).Error
if err != nil {
return nil, err
}
return &profile, nil
}
func (r *profileRepositoryImpl) FindByUserID(userID int64) ([]*model.Profile, error) {
var profiles []*model.Profile
err := r.db.Where("user_id = ?", userID).
Preload("Skin").
Preload("Cape").
Order("created_at DESC").
Find(&profiles).Error
return profiles, err
}
func (r *profileRepositoryImpl) Update(profile *model.Profile) error {
return r.db.Save(profile).Error
}
func (r *profileRepositoryImpl) UpdateFields(uuid string, updates map[string]interface{}) error {
return r.db.Model(&model.Profile{}).
Where("uuid = ?", uuid).
Updates(updates).Error
}
func (r *profileRepositoryImpl) Delete(uuid string) error {
return r.db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
}
func (r *profileRepositoryImpl) CountByUserID(userID int64) (int64, error) {
var count int64
err := r.db.Model(&model.Profile{}).
Where("user_id = ?", userID).
Count(&count).Error
return count, err
}
func (r *profileRepositoryImpl) SetActive(uuid string, userID int64) error {
return r.db.Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&model.Profile{}).
Where("user_id = ?", userID).
Update("is_active", false).Error; err != nil {
return err
}
return tx.Model(&model.Profile{}).
Where("uuid = ? AND user_id = ?", uuid, userID).
Update("is_active", true).Error
})
}
func (r *profileRepositoryImpl) UpdateLastUsedAt(uuid string) error {
return r.db.Model(&model.Profile{}).
Where("uuid = ?", uuid).
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
}
func (r *profileRepositoryImpl) GetByNames(names []string) ([]*model.Profile, error) {
var profiles []*model.Profile
err := r.db.Where("name in (?)", names).Find(&profiles).Error
return profiles, err
}
func (r *profileRepositoryImpl) GetKeyPair(profileId string) (*model.KeyPair, error) {
if profileId == "" {
return nil, errors.New("参数不能为空")
}
var profile model.Profile
result := r.db.WithContext(context.Background()).
Select("key_pair").
Where("id = ?", profileId).
First(&profile)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, errors.New("key pair未找到")
}
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
}
return &model.KeyPair{}, nil
}
func (r *profileRepositoryImpl) UpdateKeyPair(profileId string, keyPair *model.KeyPair) error {
if profileId == "" {
return errors.New("profileId 不能为空")
}
if keyPair == nil {
return errors.New("keyPair 不能为 nil")
}
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.WithContext(context.Background()).
Table("profiles").
Where("id = ?", profileId).
UpdateColumns(map[string]interface{}{
"private_key": keyPair.PrivateKey,
"public_key": keyPair.PublicKey,
})
if result.Error != nil {
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
}
return nil
})
}

View File

@@ -0,0 +1,44 @@
package repository
import (
"carrotskin/internal/model"
"gorm.io/gorm"
)
// systemConfigRepositoryImpl SystemConfigRepository的实现
type systemConfigRepositoryImpl struct {
db *gorm.DB
}
// NewSystemConfigRepository 创建SystemConfigRepository实例
func NewSystemConfigRepository(db *gorm.DB) SystemConfigRepository {
return &systemConfigRepositoryImpl{db: db}
}
func (r *systemConfigRepositoryImpl) GetByKey(key string) (*model.SystemConfig, error) {
var config model.SystemConfig
err := r.db.Where("key = ?", key).First(&config).Error
return handleNotFoundResult(&config, err)
}
func (r *systemConfigRepositoryImpl) GetPublic() ([]model.SystemConfig, error) {
var configs []model.SystemConfig
err := r.db.Where("is_public = ?", true).Find(&configs).Error
return configs, err
}
func (r *systemConfigRepositoryImpl) GetAll() ([]model.SystemConfig, error) {
var configs []model.SystemConfig
err := r.db.Find(&configs).Error
return configs, err
}
func (r *systemConfigRepositoryImpl) Update(config *model.SystemConfig) error {
return r.db.Save(config).Error
}
func (r *systemConfigRepositoryImpl) UpdateValue(key, value string) error {
return r.db.Model(&model.SystemConfig{}).Where("key = ?", key).Update("value", value).Error
}

View File

@@ -0,0 +1,175 @@
package repository
import (
"carrotskin/internal/model"
"gorm.io/gorm"
)
// textureRepositoryImpl TextureRepository的实现
type textureRepositoryImpl struct {
db *gorm.DB
}
// NewTextureRepository 创建TextureRepository实例
func NewTextureRepository(db *gorm.DB) TextureRepository {
return &textureRepositoryImpl{db: db}
}
func (r *textureRepositoryImpl) Create(texture *model.Texture) error {
return r.db.Create(texture).Error
}
func (r *textureRepositoryImpl) FindByID(id int64) (*model.Texture, error) {
var texture model.Texture
err := r.db.Preload("Uploader").First(&texture, id).Error
return handleNotFoundResult(&texture, err)
}
func (r *textureRepositoryImpl) FindByHash(hash string) (*model.Texture, error) {
var texture model.Texture
err := r.db.Where("hash = ?", hash).First(&texture).Error
return handleNotFoundResult(&texture, err)
}
func (r *textureRepositoryImpl) FindByUploaderID(uploaderID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
query := r.db.Model(&model.Texture{}).Where("uploader_id = ? AND status != -1", uploaderID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
err := query.Scopes(Paginate(page, pageSize)).
Preload("Uploader").
Order("created_at DESC").
Find(&textures).Error
if err != nil {
return nil, 0, err
}
return textures, total, nil
}
func (r *textureRepositoryImpl) Search(keyword string, textureType model.TextureType, publicOnly bool, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
query := r.db.Model(&model.Texture{}).Where("status = 1")
if publicOnly {
query = query.Where("is_public = ?", true)
}
if textureType != "" {
query = query.Where("type = ?", textureType)
}
if keyword != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%")
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
err := query.Scopes(Paginate(page, pageSize)).
Preload("Uploader").
Order("created_at DESC").
Find(&textures).Error
if err != nil {
return nil, 0, err
}
return textures, total, nil
}
func (r *textureRepositoryImpl) Update(texture *model.Texture) error {
return r.db.Save(texture).Error
}
func (r *textureRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Updates(fields).Error
}
func (r *textureRepositoryImpl) Delete(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).Update("status", -1).Error
}
func (r *textureRepositoryImpl) IncrementDownloadCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
}
func (r *textureRepositoryImpl) IncrementFavoriteCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count + ?", 1)).Error
}
func (r *textureRepositoryImpl) DecrementFavoriteCount(id int64) error {
return r.db.Model(&model.Texture{}).Where("id = ?", id).
UpdateColumn("favorite_count", gorm.Expr("favorite_count - ?", 1)).Error
}
func (r *textureRepositoryImpl) CreateDownloadLog(log *model.TextureDownloadLog) error {
return r.db.Create(log).Error
}
func (r *textureRepositoryImpl) IsFavorited(userID, textureID int64) (bool, error) {
var count int64
err := r.db.Model(&model.UserTextureFavorite{}).
Where("user_id = ? AND texture_id = ?", userID, textureID).
Count(&count).Error
return count > 0, err
}
func (r *textureRepositoryImpl) AddFavorite(userID, textureID int64) error {
favorite := &model.UserTextureFavorite{
UserID: userID,
TextureID: textureID,
}
return r.db.Create(favorite).Error
}
func (r *textureRepositoryImpl) RemoveFavorite(userID, textureID int64) error {
return r.db.Where("user_id = ? AND texture_id = ?", userID, textureID).
Delete(&model.UserTextureFavorite{}).Error
}
func (r *textureRepositoryImpl) GetUserFavorites(userID int64, page, pageSize int) ([]*model.Texture, int64, error) {
var textures []*model.Texture
var total int64
subQuery := r.db.Model(&model.UserTextureFavorite{}).
Select("texture_id").
Where("user_id = ?", userID)
query := r.db.Model(&model.Texture{}).
Where("id IN (?) AND status = 1", subQuery)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
err := query.Scopes(Paginate(page, pageSize)).
Preload("Uploader").
Order("created_at DESC").
Find(&textures).Error
if err != nil {
return nil, 0, err
}
return textures, total, nil
}
func (r *textureRepositoryImpl) CountByUploaderID(uploaderID int64) (int64, error) {
var count int64
err := r.db.Model(&model.Texture{}).
Where("uploader_id = ? AND status != -1", uploaderID).
Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,71 @@
package repository
import (
"carrotskin/internal/model"
"gorm.io/gorm"
)
// tokenRepositoryImpl TokenRepository的实现
type tokenRepositoryImpl struct {
db *gorm.DB
}
// NewTokenRepository 创建TokenRepository实例
func NewTokenRepository(db *gorm.DB) TokenRepository {
return &tokenRepositoryImpl{db: db}
}
func (r *tokenRepositoryImpl) Create(token *model.Token) error {
return r.db.Create(token).Error
}
func (r *tokenRepositoryImpl) FindByAccessToken(accessToken string) (*model.Token, error) {
var token model.Token
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return nil, err
}
return &token, nil
}
func (r *tokenRepositoryImpl) GetByUserID(userId int64) ([]*model.Token, error) {
var tokens []*model.Token
err := r.db.Where("user_id = ?", userId).Find(&tokens).Error
return tokens, err
}
func (r *tokenRepositoryImpl) GetUUIDByAccessToken(accessToken string) (string, error) {
var token model.Token
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return "", err
}
return token.ProfileId, nil
}
func (r *tokenRepositoryImpl) GetUserIDByAccessToken(accessToken string) (int64, error) {
var token model.Token
err := r.db.Where("access_token = ?", accessToken).First(&token).Error
if err != nil {
return 0, err
}
return token.UserID, nil
}
func (r *tokenRepositoryImpl) DeleteByAccessToken(accessToken string) error {
return r.db.Where("access_token = ?", accessToken).Delete(&model.Token{}).Error
}
func (r *tokenRepositoryImpl) DeleteByUserID(userId int64) error {
return r.db.Where("user_id = ?", userId).Delete(&model.Token{}).Error
}
func (r *tokenRepositoryImpl) BatchDelete(accessTokens []string) (int64, error) {
if len(accessTokens) == 0 {
return 0, nil
}
result := r.db.Where("access_token IN ?", accessTokens).Delete(&model.Token{})
return result.RowsAffected, result.Error
}

View File

@@ -0,0 +1,103 @@
package repository
import (
"carrotskin/internal/model"
"errors"
"gorm.io/gorm"
)
// userRepositoryImpl UserRepository的实现
type userRepositoryImpl struct {
db *gorm.DB
}
// NewUserRepository 创建UserRepository实例
func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepositoryImpl{db: db}
}
func (r *userRepositoryImpl) Create(user *model.User) error {
return r.db.Create(user).Error
}
func (r *userRepositoryImpl) FindByID(id int64) (*model.User, error) {
var user model.User
err := r.db.Where("id = ? AND status != -1", id).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepositoryImpl) FindByUsername(username string) (*model.User, error) {
var user model.User
err := r.db.Where("username = ? AND status != -1", username).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepositoryImpl) FindByEmail(email string) (*model.User, error) {
var user model.User
err := r.db.Where("email = ? AND status != -1", email).First(&user).Error
return handleNotFoundResult(&user, err)
}
func (r *userRepositoryImpl) Update(user *model.User) error {
return r.db.Save(user).Error
}
func (r *userRepositoryImpl) UpdateFields(id int64, fields map[string]interface{}) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Updates(fields).Error
}
func (r *userRepositoryImpl) Delete(id int64) error {
return r.db.Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error
}
func (r *userRepositoryImpl) CreateLoginLog(log *model.UserLoginLog) error {
return r.db.Create(log).Error
}
func (r *userRepositoryImpl) CreatePointLog(log *model.UserPointLog) error {
return r.db.Create(log).Error
}
func (r *userRepositoryImpl) UpdatePoints(userID int64, amount int, changeType, reason string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
var user model.User
if err := tx.Where("id = ?", userID).First(&user).Error; err != nil {
return err
}
balanceBefore := user.Points
balanceAfter := balanceBefore + amount
if balanceAfter < 0 {
return errors.New("积分不足")
}
if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil {
return err
}
log := &model.UserPointLog{
UserID: userID,
ChangeType: changeType,
Amount: amount,
BalanceBefore: balanceBefore,
BalanceAfter: balanceAfter,
Reason: reason,
}
return tx.Create(log).Error
})
}
// handleNotFoundResult 处理记录未找到的情况
func handleNotFoundResult[T any](result *T, err error) (*T, error) {
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return result, nil
}