Files
backend/internal/repository/comment_repo.go

297 lines
8.8 KiB
Go
Raw Permalink Normal View History

package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// CommentRepository 评论仓储
type CommentRepository struct {
db *gorm.DB
}
// NewCommentRepository 创建评论仓储
func NewCommentRepository(db *gorm.DB) *CommentRepository {
return &CommentRepository{db: db}
}
// Create 创建评论
func (r *CommentRepository) Create(comment *model.Comment) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建评论
err := tx.Create(comment).Error
if err != nil {
return err
}
// 增加帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,增加父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByID 根据ID获取评论
func (r *CommentRepository) GetByID(id string) (*model.Comment, error) {
var comment model.Comment
err := r.db.Preload("User").First(&comment, "id = ?", id).Error
if err != nil {
return nil, err
}
return &comment, nil
}
// Update 更新评论
func (r *CommentRepository) Update(comment *model.Comment) error {
return r.db.Save(comment).Error
}
// UpdateModerationStatus 更新评论审核状态
func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error {
return r.db.Model(&model.Comment{}).
Where("id = ?", commentID).
Update("status", status).Error
}
// Delete 删除评论(软删除,同时清理关联数据)
func (r *CommentRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 先查询评论获取post_id和parent_id
var comment model.Comment
if err := tx.First(&comment, "id = ?", id).Error; err != nil {
return err
}
// 删除评论点赞记录
if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil {
return err
}
// 删除评论(软删除)
if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil {
return err
}
// 减少帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count - 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,减少父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByPostID 获取帖子评论
func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
return comments, total, err
}
// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构)
// 所有层级的回复都扁平化展示在顶级评论的 replies 中
func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
if err != nil {
return nil, 0, err
}
if len(comments) == 0 {
return comments, total, nil
}
rootIDs := make([]string, 0, len(comments))
commentsByID := make(map[string]*model.Comment, len(comments))
for _, comment := range comments {
rootIDs = append(rootIDs, comment.ID)
commentsByID[comment.ID] = comment
}
// 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数
var allReplies []*model.Comment
if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&allReplies).Error; err != nil {
return nil, 0, err
}
repliesByRoot := make(map[string][]*model.Comment, len(rootIDs))
for _, reply := range allReplies {
if reply.RootID == nil {
continue
}
rootID := *reply.RootID
if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit {
repliesByRoot[rootID] = append(repliesByRoot[rootID], reply)
}
}
type replyCountRow struct {
RootID string
Total int64
}
var replyCountRows []replyCountRow
if err := r.db.Model(&model.Comment{}).
Select("root_id, COUNT(*) AS total").
Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Group("root_id").
Scan(&replyCountRows).Error; err != nil {
return nil, 0, err
}
replyCountMap := make(map[string]int64, len(replyCountRows))
for _, row := range replyCountRows {
replyCountMap[row.RootID] = row.Total
}
for _, rootID := range rootIDs {
comment := commentsByID[rootID]
comment.Replies = repliesByRoot[rootID]
comment.RepliesCount = int(replyCountMap[rootID])
}
return comments, total, nil
}
// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层)
func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) {
var allReplies []*model.Comment
// 查询所有以该评论为根评论的回复(不包括顶级评论本身)
r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Limit(limit).
Find(&allReplies)
rootComment.Replies = allReplies
}
// GetRepliesByRootID 根据根评论ID分页获取回复扁平化
func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
var replies []*model.Comment
var total int64
// 统计总数
r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total)
// 分页查询
offset := (page - 1) * pageSize
err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Offset(offset).
Limit(pageSize).
Find(&replies).Error
return replies, total, err
}
// GetReplies 获取回复
func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) {
var comments []*model.Comment
err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&comments).Error
return comments, err
}
// Like 点赞评论
func (r *CommentRepository) Like(commentID, userID string) error {
// 检查是否已经点赞
var existing model.CommentLike
err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error
if err == nil {
// 已经点赞
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建点赞记录
like := &model.CommentLike{
CommentID: commentID,
UserID: userID,
}
err = r.db.Create(like).Error
if err != nil {
return err
}
// 增加评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error
}
// Unlike 取消点赞评论
func (r *CommentRepository) Unlike(commentID, userID string) error {
result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error
}
return nil
}
// IsLiked 检查是否已点赞
func (r *CommentRepository) IsLiked(commentID, userID string) bool {
var count int64
r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count)
return count > 0
}