package repository import ( "carrot_bbs/internal/model" "gorm.io/gorm" ) // CommentRepository 评论仓储 type CommentRepository struct { db *gorm.DB } // NewCommentRepository 创建评论仓储 func NewCommentRepository(db *gorm.DB) *CommentRepository { return &CommentRepository{db: db} } // Create 创建评论 func (r *CommentRepository) Create(comment *model.Comment) error { return r.db.Transaction(func(tx *gorm.DB) error { // 创建评论 err := tx.Create(comment).Error if err != nil { return err } // 增加帖子的评论数并同步热度分 if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). Updates(map[string]interface{}{ "comments_count": gorm.Expr("comments_count + 1"), "hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"), }).Error; err != nil { return err } // 如果是回复,增加父评论的回复数 if comment.ParentID != nil && *comment.ParentID != "" { if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil { return err } } return nil }) } // GetByID 根据ID获取评论 func (r *CommentRepository) GetByID(id string) (*model.Comment, error) { var comment model.Comment err := r.db.Preload("User").First(&comment, "id = ?", id).Error if err != nil { return nil, err } return &comment, nil } // Update 更新评论 func (r *CommentRepository) Update(comment *model.Comment) error { return r.db.Save(comment).Error } // UpdateModerationStatus 更新评论审核状态 func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error { return r.db.Model(&model.Comment{}). Where("id = ?", commentID). Update("status", status).Error } // Delete 删除评论(软删除,同时清理关联数据) func (r *CommentRepository) Delete(id string) error { return r.db.Transaction(func(tx *gorm.DB) error { // 先查询评论获取post_id和parent_id var comment model.Comment if err := tx.First(&comment, "id = ?", id).Error; err != nil { return err } // 删除评论点赞记录 if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil { return err } // 删除评论(软删除) if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil { return err } // 减少帖子的评论数并同步热度分 if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID). Updates(map[string]interface{}{ "comments_count": gorm.Expr("comments_count - 1"), "hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"), }).Error; err != nil { return err } // 如果是回复,减少父评论的回复数 if comment.ParentID != nil && *comment.ParentID != "" { if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID). UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil { return err } } return nil }) } // GetByPostID 获取帖子评论 func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) { var comments []*model.Comment var total int64 r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished). Preload("User"). Offset(offset).Limit(pageSize). Order("created_at ASC"). Find(&comments).Error return comments, total, err } // GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构) // 所有层级的回复都扁平化展示在顶级评论的 replies 中 func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) { var comments []*model.Comment var total int64 r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished). Preload("User"). Offset(offset).Limit(pageSize). Order("created_at ASC"). Find(&comments).Error if err != nil { return nil, 0, err } if len(comments) == 0 { return comments, total, nil } rootIDs := make([]string, 0, len(comments)) commentsByID := make(map[string]*model.Comment, len(comments)) for _, comment := range comments { rootIDs = append(rootIDs, comment.ID) commentsByID[comment.ID] = comment } // 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数 var allReplies []*model.Comment if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished). Preload("User"). Order("created_at ASC"). Find(&allReplies).Error; err != nil { return nil, 0, err } repliesByRoot := make(map[string][]*model.Comment, len(rootIDs)) for _, reply := range allReplies { if reply.RootID == nil { continue } rootID := *reply.RootID if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit { repliesByRoot[rootID] = append(repliesByRoot[rootID], reply) } } type replyCountRow struct { RootID string Total int64 } var replyCountRows []replyCountRow if err := r.db.Model(&model.Comment{}). Select("root_id, COUNT(*) AS total"). Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished). Group("root_id"). Scan(&replyCountRows).Error; err != nil { return nil, 0, err } replyCountMap := make(map[string]int64, len(replyCountRows)) for _, row := range replyCountRows { replyCountMap[row.RootID] = row.Total } for _, rootID := range rootIDs { comment := commentsByID[rootID] comment.Replies = repliesByRoot[rootID] comment.RepliesCount = int(replyCountMap[rootID]) } return comments, total, nil } // loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层) func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) { var allReplies []*model.Comment // 查询所有以该评论为根评论的回复(不包括顶级评论本身) r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished). Preload("User"). Order("created_at ASC"). Limit(limit). Find(&allReplies) rootComment.Replies = allReplies } // GetRepliesByRootID 根据根评论ID分页获取回复(扁平化) func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) { var replies []*model.Comment var total int64 // 统计总数 r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total) // 分页查询 offset := (page - 1) * pageSize err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished). Preload("User"). Order("created_at ASC"). Offset(offset). Limit(pageSize). Find(&replies).Error return replies, total, err } // GetReplies 获取回复 func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) { var comments []*model.Comment err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished). Preload("User"). Order("created_at ASC"). Find(&comments).Error return comments, err } // Like 点赞评论 func (r *CommentRepository) Like(commentID, userID string) error { // 检查是否已经点赞 var existing model.CommentLike err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error if err == nil { // 已经点赞 return nil } if err != gorm.ErrRecordNotFound { return err } // 创建点赞记录 like := &model.CommentLike{ CommentID: commentID, UserID: userID, } err = r.db.Create(like).Error if err != nil { return err } // 增加评论点赞数 return r.db.Model(&model.Comment{}).Where("id = ?", commentID). UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error } // Unlike 取消点赞评论 func (r *CommentRepository) Unlike(commentID, userID string) error { result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{}) if result.Error != nil { return result.Error } if result.RowsAffected > 0 { // 减少评论点赞数 return r.db.Model(&model.Comment{}).Where("id = ?", commentID). UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error } return nil } // IsLiked 检查是否已点赞 func (r *CommentRepository) IsLiked(commentID, userID string) bool { var count int64 r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count) return count > 0 }