Files
backend/internal/repository/comment_repo.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

297 lines
8.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// CommentRepository 评论仓储
type CommentRepository struct {
db *gorm.DB
}
// NewCommentRepository 创建评论仓储
func NewCommentRepository(db *gorm.DB) *CommentRepository {
return &CommentRepository{db: db}
}
// Create 创建评论
func (r *CommentRepository) Create(comment *model.Comment) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建评论
err := tx.Create(comment).Error
if err != nil {
return err
}
// 增加帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,增加父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByID 根据ID获取评论
func (r *CommentRepository) GetByID(id string) (*model.Comment, error) {
var comment model.Comment
err := r.db.Preload("User").First(&comment, "id = ?", id).Error
if err != nil {
return nil, err
}
return &comment, nil
}
// Update 更新评论
func (r *CommentRepository) Update(comment *model.Comment) error {
return r.db.Save(comment).Error
}
// UpdateModerationStatus 更新评论审核状态
func (r *CommentRepository) UpdateModerationStatus(commentID string, status model.CommentStatus) error {
return r.db.Model(&model.Comment{}).
Where("id = ?", commentID).
Update("status", status).Error
}
// Delete 删除评论(软删除,同时清理关联数据)
func (r *CommentRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 先查询评论获取post_id和parent_id
var comment model.Comment
if err := tx.First(&comment, "id = ?", id).Error; err != nil {
return err
}
// 删除评论点赞记录
if err := tx.Where("comment_id = ?", id).Delete(&model.CommentLike{}).Error; err != nil {
return err
}
// 删除评论(软删除)
if err := tx.Delete(&model.Comment{}, "id = ?", id).Error; err != nil {
return err
}
// 减少帖子的评论数并同步热度分
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
Updates(map[string]interface{}{
"comments_count": gorm.Expr("comments_count - 1"),
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
}).Error; err != nil {
return err
}
// 如果是回复,减少父评论的回复数
if comment.ParentID != nil && *comment.ParentID != "" {
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
return err
}
}
return nil
})
}
// GetByPostID 获取帖子评论
func (r *CommentRepository) GetByPostID(postID string, page, pageSize int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
return comments, total, err
}
// GetByPostIDWithReplies 获取帖子评论(包含回复,扁平化结构)
// 所有层级的回复都扁平化展示在顶级评论的 replies 中
func (r *CommentRepository) GetByPostIDWithReplies(postID string, page, pageSize, replyLimit int) ([]*model.Comment, int64, error) {
var comments []*model.Comment
var total int64
r.db.Model(&model.Comment{}).Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("post_id = ? AND parent_id IS NULL AND status = ?", postID, model.CommentStatusPublished).
Preload("User").
Offset(offset).Limit(pageSize).
Order("created_at ASC").
Find(&comments).Error
if err != nil {
return nil, 0, err
}
if len(comments) == 0 {
return comments, total, nil
}
rootIDs := make([]string, 0, len(comments))
commentsByID := make(map[string]*model.Comment, len(comments))
for _, comment := range comments {
rootIDs = append(rootIDs, comment.ID)
commentsByID[comment.ID] = comment
}
// 批量加载所有回复,内存中按 root_id 分组并裁剪每个根评论的返回条数
var allReplies []*model.Comment
if err := r.db.Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&allReplies).Error; err != nil {
return nil, 0, err
}
repliesByRoot := make(map[string][]*model.Comment, len(rootIDs))
for _, reply := range allReplies {
if reply.RootID == nil {
continue
}
rootID := *reply.RootID
if replyLimit <= 0 || len(repliesByRoot[rootID]) < replyLimit {
repliesByRoot[rootID] = append(repliesByRoot[rootID], reply)
}
}
type replyCountRow struct {
RootID string
Total int64
}
var replyCountRows []replyCountRow
if err := r.db.Model(&model.Comment{}).
Select("root_id, COUNT(*) AS total").
Where("root_id IN ? AND status = ?", rootIDs, model.CommentStatusPublished).
Group("root_id").
Scan(&replyCountRows).Error; err != nil {
return nil, 0, err
}
replyCountMap := make(map[string]int64, len(replyCountRows))
for _, row := range replyCountRows {
replyCountMap[row.RootID] = row.Total
}
for _, rootID := range rootIDs {
comment := commentsByID[rootID]
comment.Replies = repliesByRoot[rootID]
comment.RepliesCount = int(replyCountMap[rootID])
}
return comments, total, nil
}
// loadFlatReplies 加载评论的所有回复(扁平化,所有层级都在同一层)
func (r *CommentRepository) loadFlatReplies(rootComment *model.Comment, limit int) {
var allReplies []*model.Comment
// 查询所有以该评论为根评论的回复(不包括顶级评论本身)
r.db.Where("root_id = ? AND status = ?", rootComment.ID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Limit(limit).
Find(&allReplies)
rootComment.Replies = allReplies
}
// GetRepliesByRootID 根据根评论ID分页获取回复扁平化
func (r *CommentRepository) GetRepliesByRootID(rootID string, page, pageSize int) ([]*model.Comment, int64, error) {
var replies []*model.Comment
var total int64
// 统计总数
r.db.Model(&model.Comment{}).Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).Count(&total)
// 分页查询
offset := (page - 1) * pageSize
err := r.db.Where("root_id = ? AND status = ?", rootID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Offset(offset).
Limit(pageSize).
Find(&replies).Error
return replies, total, err
}
// GetReplies 获取回复
func (r *CommentRepository) GetReplies(parentID string) ([]*model.Comment, error) {
var comments []*model.Comment
err := r.db.Where("parent_id = ? AND status = ?", parentID, model.CommentStatusPublished).
Preload("User").
Order("created_at ASC").
Find(&comments).Error
return comments, err
}
// Like 点赞评论
func (r *CommentRepository) Like(commentID, userID string) error {
// 检查是否已经点赞
var existing model.CommentLike
err := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).First(&existing).Error
if err == nil {
// 已经点赞
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建点赞记录
like := &model.CommentLike{
CommentID: commentID,
UserID: userID,
}
err = r.db.Create(like).Error
if err != nil {
return err
}
// 增加评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count + 1")).Error
}
// Unlike 取消点赞评论
func (r *CommentRepository) Unlike(commentID, userID string) error {
result := r.db.Where("comment_id = ? AND user_id = ?", commentID, userID).Delete(&model.CommentLike{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少评论点赞数
return r.db.Model(&model.Comment{}).Where("id = ?", commentID).
UpdateColumn("likes_count", gorm.Expr("likes_count - 1")).Error
}
return nil
}
// IsLiked 检查是否已点赞
func (r *CommentRepository) IsLiked(commentID, userID string) bool {
var count int64
r.db.Model(&model.CommentLike{}).Where("comment_id = ? AND user_id = ?", commentID, userID).Count(&count)
return count > 0
}