package repository import ( "carrotskin/internal/model" "context" "errors" "gorm.io/gorm" ) // userRepository UserRepository的实现 type userRepository struct { db *gorm.DB } // NewUserRepository 创建UserRepository实例 func NewUserRepository(db *gorm.DB) UserRepository { return &userRepository{db: db} } func (r *userRepository) Create(ctx context.Context, user *model.User) error { return r.db.WithContext(ctx).Create(user).Error } func (r *userRepository) FindByID(ctx context.Context, id int64) (*model.User, error) { var user model.User err := r.db.WithContext(ctx).Where("id = ? AND status != -1", id).First(&user).Error return handleNotFoundResult(&user, err) } func (r *userRepository) FindByUsername(ctx context.Context, username string) (*model.User, error) { var user model.User err := r.db.WithContext(ctx).Where("username = ? AND status != -1", username).First(&user).Error return handleNotFoundResult(&user, err) } func (r *userRepository) FindByEmail(ctx context.Context, email string) (*model.User, error) { var user model.User err := r.db.WithContext(ctx).Where("email = ? AND status != -1", email).First(&user).Error return handleNotFoundResult(&user, err) } func (r *userRepository) FindByIDs(ctx context.Context, ids []int64) ([]*model.User, error) { if len(ids) == 0 { return []*model.User{}, nil } var users []*model.User // 使用 IN 查询优化批量查询 err := r.db.WithContext(ctx).Where("id IN ? AND status != -1", ids).Find(&users).Error return users, err } func (r *userRepository) Update(ctx context.Context, user *model.User) error { return r.db.WithContext(ctx).Save(user).Error } func (r *userRepository) UpdateFields(ctx context.Context, id int64, fields map[string]interface{}) error { return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Updates(fields).Error } func (r *userRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Model(&model.User{}).Where("id = ?", id).Update("status", -1).Error } func (r *userRepository) BatchUpdate(ctx context.Context, ids []int64, fields map[string]interface{}) (int64, error) { if len(ids) == 0 { return 0, nil } result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Updates(fields) return result.RowsAffected, result.Error } func (r *userRepository) BatchDelete(ctx context.Context, ids []int64) (int64, error) { if len(ids) == 0 { return 0, nil } result := r.db.WithContext(ctx).Model(&model.User{}).Where("id IN ?", ids).Update("status", -1) return result.RowsAffected, result.Error } func (r *userRepository) CreateLoginLog(ctx context.Context, log *model.UserLoginLog) error { return r.db.WithContext(ctx).Create(log).Error } func (r *userRepository) CreatePointLog(ctx context.Context, log *model.UserPointLog) error { return r.db.WithContext(ctx).Create(log).Error } func (r *userRepository) UpdatePoints(ctx context.Context, userID int64, amount int, changeType, reason string) error { return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var user model.User if err := tx.Where("id = ?", userID).First(&user).Error; err != nil { return err } balanceBefore := user.Points balanceAfter := balanceBefore + amount if balanceAfter < 0 { return errors.New("积分不足") } if err := tx.Model(&user).Update("points", balanceAfter).Error; err != nil { return err } log := &model.UserPointLog{ UserID: userID, ChangeType: changeType, Amount: amount, BalanceBefore: balanceBefore, BalanceAfter: balanceAfter, Reason: reason, } return tx.Create(log).Error }) } // handleNotFoundResult 处理记录未找到的情况 func handleNotFoundResult[T any](result *T, err error) (*T, error) { if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err } return result, nil }