package repository import ( "carrot_bbs/internal/model" "errors" "gorm.io/gorm" ) // VoteRepository 投票仓储 type VoteRepository struct { db *gorm.DB } // NewVoteRepository 创建投票仓储 func NewVoteRepository(db *gorm.DB) *VoteRepository { return &VoteRepository{db: db} } // CreateOptions 批量创建投票选项 func (r *VoteRepository) CreateOptions(postID string, options []string) error { return r.db.Transaction(func(tx *gorm.DB) error { for i, content := range options { option := &model.VoteOption{ PostID: postID, Content: content, SortOrder: i, } if err := tx.Create(option).Error; err != nil { return err } } return nil }) } // GetOptionsByPostID 获取帖子的所有投票选项 func (r *VoteRepository) GetOptionsByPostID(postID string) ([]model.VoteOption, error) { var options []model.VoteOption err := r.db.Where("post_id = ?", postID).Order("sort_order ASC").Find(&options).Error return options, err } // Vote 用户投票 func (r *VoteRepository) Vote(postID, userID, optionID string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 检查用户是否已投票 var existing model.UserVote err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error if err == nil { // 已经投票,返回错误 return errors.New("user already voted") } if err != gorm.ErrRecordNotFound { return err } // 验证选项是否属于该帖子 var option model.VoteOption if err := tx.Where("id = ? AND post_id = ?", optionID, postID).First(&option).Error; err != nil { if err == gorm.ErrRecordNotFound { return errors.New("invalid option") } return err } // 创建投票记录 if err := tx.Create(&model.UserVote{ PostID: postID, UserID: userID, OptionID: optionID, }).Error; err != nil { return err } // 原子增加选项投票数 return tx.Model(&model.VoteOption{}).Where("id = ?", optionID). UpdateColumn("votes_count", gorm.Expr("votes_count + 1")).Error }) } // Unvote 取消投票 func (r *VoteRepository) Unvote(postID, userID string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 获取用户的投票记录 var userVote model.UserVote err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil // 没有投票记录,直接返回 } return err } // 删除投票记录 result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.UserVote{}) if result.Error != nil { return result.Error } if result.RowsAffected > 0 { // 原子减少选项投票数 return tx.Model(&model.VoteOption{}).Where("id = ?", userVote.OptionID). UpdateColumn("votes_count", gorm.Expr("votes_count - 1")).Error } return nil }) } // GetUserVote 获取用户在指定帖子的投票 func (r *VoteRepository) GetUserVote(postID, userID string) (*model.UserVote, error) { var userVote model.UserVote err := r.db.Where("post_id = ? AND user_id = ?", postID, userID).First(&userVote).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } return &userVote, nil } // UpdateOption 更新选项内容 func (r *VoteRepository) UpdateOption(optionID, content string) error { return r.db.Model(&model.VoteOption{}).Where("id = ?", optionID). Update("content", content).Error } // DeleteOptionsByPostID 删除帖子的所有投票选项 func (r *VoteRepository) DeleteOptionsByPostID(postID string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 删除关联的用户投票记录 if err := tx.Where("post_id = ?", postID).Delete(&model.UserVote{}).Error; err != nil { return err } // 删除投票选项 return tx.Where("post_id = ?", postID).Delete(&model.VoteOption{}).Error }) }