feat: 添加种子数据初始化功能,重构多个处理程序以简化错误响应和用户验证
This commit is contained in:
@@ -2,7 +2,6 @@ package repository
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/database"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -12,15 +11,13 @@ import (
|
||||
|
||||
// CreateProfile 创建档案
|
||||
func CreateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Create(profile).Error
|
||||
return getDB().Create(profile).Error
|
||||
}
|
||||
|
||||
// FindProfileByUUID 根据UUID查找档案
|
||||
func FindProfileByUUID(uuid string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("uuid = ?", uuid).
|
||||
err := getDB().Where("uuid = ?", uuid).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
First(&profile).Error
|
||||
@@ -32,9 +29,8 @@ func FindProfileByUUID(uuid string) (*model.Profile, error) {
|
||||
|
||||
// FindProfileByName 根据角色名查找档案
|
||||
func FindProfileByName(name string) (*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profile model.Profile
|
||||
err := db.Where("name = ?", name).First(&profile).Error
|
||||
err := getDB().Where("name = ?", name).First(&profile).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -43,44 +39,36 @@ func FindProfileByName(name string) (*model.Profile, error) {
|
||||
|
||||
// FindProfilesByUserID 获取用户的所有档案
|
||||
func FindProfilesByUserID(userID int64) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("user_id = ?", userID).
|
||||
err := getDB().Where("user_id = ?", userID).
|
||||
Preload("Skin").
|
||||
Preload("Cape").
|
||||
Order("created_at DESC").
|
||||
Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
// UpdateProfile 更新档案
|
||||
func UpdateProfile(profile *model.Profile) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Save(profile).Error
|
||||
return getDB().Save(profile).Error
|
||||
}
|
||||
|
||||
// UpdateProfileFields 更新指定字段
|
||||
func UpdateProfileFields(uuid string, updates map[string]interface{}) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
return getDB().Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteProfile 删除档案
|
||||
func DeleteProfile(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
return getDB().Where("uuid = ?", uuid).Delete(&model.Profile{}).Error
|
||||
}
|
||||
|
||||
// CountProfilesByUserID 统计用户的档案数量
|
||||
func CountProfilesByUserID(userID int64) (int64, error) {
|
||||
db := database.MustGetDB()
|
||||
var count int64
|
||||
err := db.Model(&model.Profile{}).
|
||||
err := getDB().Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
@@ -88,30 +76,22 @@ func CountProfilesByUserID(userID int64) (int64, error) {
|
||||
|
||||
// SetActiveProfile 设置档案为活跃状态(同时将用户的其他档案设置为非活跃)
|
||||
func SetActiveProfile(uuid string, userID int64) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 将用户的所有档案设置为非活跃
|
||||
return getDB().Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
Where("user_id = ?", userID).
|
||||
Update("is_active", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将指定档案设置为活跃
|
||||
if err := tx.Model(&model.Profile{}).
|
||||
return tx.Model(&model.Profile{}).
|
||||
Where("uuid = ? AND user_id = ?", uuid, userID).
|
||||
Update("is_active", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
Update("is_active", true).Error
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateProfileLastUsedAt 更新最后使用时间
|
||||
func UpdateProfileLastUsedAt(uuid string) error {
|
||||
db := database.MustGetDB()
|
||||
return db.Model(&model.Profile{}).
|
||||
return getDB().Model(&model.Profile{}).
|
||||
Where("uuid = ?", uuid).
|
||||
Update("last_used_at", gorm.Expr("CURRENT_TIMESTAMP")).Error
|
||||
}
|
||||
@@ -122,53 +102,40 @@ func FindOneProfileByUserID(userID int64) (*model.Profile, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
profile := profiles[0]
|
||||
return profile, nil
|
||||
if len(profiles) == 0 {
|
||||
return nil, errors.New("未找到角色")
|
||||
}
|
||||
return profiles[0], nil
|
||||
}
|
||||
|
||||
func GetProfilesByNames(names []string) ([]*model.Profile, error) {
|
||||
db := database.MustGetDB()
|
||||
var profiles []*model.Profile
|
||||
err := db.Where("name in (?)", names).Find(&profiles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profiles, nil
|
||||
err := getDB().Where("name in (?)", names).Find(&profiles).Error
|
||||
return profiles, err
|
||||
}
|
||||
|
||||
func GetProfileKeyPair(profileId string) (*model.KeyPair, error) {
|
||||
db := database.MustGetDB()
|
||||
// 1. 参数校验(保持原逻辑)
|
||||
if profileId == "" {
|
||||
return nil, errors.New("参数不能为空")
|
||||
}
|
||||
|
||||
// 2. GORM 查询:只查询 key_pair 字段(对应原 mongo 投影)
|
||||
var profile *model.Profile
|
||||
// 条件:id = profileId(PostgreSQL 主键),只选择 key_pair 字段
|
||||
result := db.WithContext(context.Background()).
|
||||
Select("key_pair"). // 只查询需要的字段(投影)
|
||||
Where("id = ?", profileId). // 查询条件(GORM 自动处理占位符,避免 SQL 注入)
|
||||
First(&profile) // 查单条记录
|
||||
var profile model.Profile
|
||||
result := getDB().WithContext(context.Background()).
|
||||
Select("key_pair").
|
||||
Where("id = ?", profileId).
|
||||
First(&profile)
|
||||
|
||||
// 3. 错误处理(适配 GORM 错误类型)
|
||||
if result.Error != nil {
|
||||
// 空结果判断(对应原 mongo.ErrNoDocuments / pgx.ErrNoRows)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
if IsNotFound(result.Error) {
|
||||
return nil, errors.New("key pair未找到")
|
||||
}
|
||||
// 保持原错误封装格式
|
||||
return nil, fmt.Errorf("获取key pair失败: %w", result.Error)
|
||||
}
|
||||
|
||||
// 4. JSONB 反序列化为 model.KeyPair
|
||||
keyPair := &model.KeyPair{}
|
||||
return keyPair, nil
|
||||
return &model.KeyPair{}, nil
|
||||
}
|
||||
|
||||
func UpdateProfileKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
db := database.MustGetDB()
|
||||
// 仅保留最必要的入参校验(避免无效数据库请求)
|
||||
if profileId == "" {
|
||||
return errors.New("profileId 不能为空")
|
||||
}
|
||||
@@ -176,24 +143,18 @@ func UpdateProfileKeyPair(profileId string, keyPair *model.KeyPair) error {
|
||||
return errors.New("keyPair 不能为 nil")
|
||||
}
|
||||
|
||||
// 事务内执行核心更新(保证原子性,出错自动回滚)
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// 核心更新逻辑:按 profileId 匹配,直接更新 key_pair 相关字段
|
||||
return getDB().Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.WithContext(context.Background()).
|
||||
Table("profiles"). // 目标表名(与 PostgreSQL 表一致)
|
||||
Where("id = ?", profileId). // 更新条件:profileId 匹配
|
||||
// 直接映射字段(无需序列化,依赖 GORM 自动字段匹配)
|
||||
Table("profiles").
|
||||
Where("id = ?", profileId).
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"private_key": keyPair.PrivateKey, // 数据库 private_key 字段
|
||||
"public_key": keyPair.PublicKey, // 数据库 public_key 字段
|
||||
// 若 key_pair 是单个字段(非拆分),替换为:"key_pair": keyPair
|
||||
"private_key": keyPair.PrivateKey,
|
||||
"public_key": keyPair.PublicKey,
|
||||
})
|
||||
|
||||
// 仅处理数据库层面的致命错误
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新 keyPair 失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user