Set up project files and add .gitignore to exclude local build/runtime artifacts. Made-with: Cursor
142 lines
3.8 KiB
Go
142 lines
3.8 KiB
Go
package repository
|
|
|
|
import (
|
|
"carrot_bbs/internal/model"
|
|
"errors"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// VoteRepository 投票仓储
|
|
type VoteRepository struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// NewVoteRepository 创建投票仓储
|
|
func NewVoteRepository(db *gorm.DB) *VoteRepository {
|
|
return &VoteRepository{db: db}
|
|
}
|
|
|
|
// CreateOptions 批量创建投票选项
|
|
func (r *VoteRepository) CreateOptions(postID string, options []string) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
for i, content := range options {
|
|
option := &model.VoteOption{
|
|
PostID: postID,
|
|
Content: content,
|
|
SortOrder: i,
|
|
}
|
|
if err := tx.Create(option).Error; err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// GetOptionsByPostID 获取帖子的所有投票选项
|
|
func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) {
|
|
var options []model.VoteOption
|
|
err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error
|
|
return options, err
|
|
}
|
|
|
|
// Vote 用户投票
|
|
func (r *VoteRepository) Vote(postID, userID, optionID string) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// 检查用户是否已投票
|
|
var existing model.UserVote
|
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
|
|
if err == nil {
|
|
// 已经投票,返回错误
|
|
return errors.New("user already voted")
|
|
}
|
|
if err != gorm.ErrRecordNotFound {
|
|
return err
|
|
}
|
|
|
|
// 验证选项是否属于该帖子
|
|
var option model.VoteOption
|
|
if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return errors.New("invalid option")
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 创建投票记录
|
|
if err := tx.Create(&model.UserVote{
|
|
PostID: postID,
|
|
UserID: userID,
|
|
OptionID: optionID,
|
|
}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// 原子增加选项投票数
|
|
return tx.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
|
UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error
|
|
})
|
|
}
|
|
|
|
// Unvote 取消投票
|
|
func (r *VoteRepository) Unvote(postID, userID string) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// 获取用户的投票记录
|
|
var userVote model.UserVote
|
|
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil // 没有投票记录,直接返回
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 删除投票记录
|
|
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{})
|
|
if result.Error != nil {
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected > 0 {
|
|
// 原子减少选项投票数
|
|
return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID).
|
|
UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// GetUserVote 获取用户在指定帖子的投票
|
|
func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) {
|
|
var userVote model.UserVote
|
|
err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error
|
|
if err != nil {
|
|
if err == gorm.ErrRecordNotFound {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &userVote, nil
|
|
}
|
|
|
|
// UpdateOption 更新选项内容
|
|
func (r *VoteRepository) UpdateOption(optionID, content string) error {
|
|
return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID).
|
|
Update("content", content).Error
|
|
}
|
|
|
|
// DeleteOptionsByPostID 删除帖子的所有投票选项
|
|
func (r *VoteRepository) DeleteOptionsByPostID(postID string) error {
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|
// 删除关联的用户投票记录
|
|
if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// 删除投票选项
|
|
return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error
|
|
})
|
|
}
|