Replace websocket flow with SSE support in backend.
Update handlers, services, router, and data conversion logic to support server-sent events and related message pipeline changes. Made-with: Cursor
This commit is contained in:
@@ -18,32 +18,7 @@ func NewCommentRepository(db *gorm.DB) *CommentRepository {
|
||||
|
||||
// Create 创建评论
|
||||
func (r *CommentRepository) Create(comment *model.Comment) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 创建评论
|
||||
err := tx.Create(comment).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 增加帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,增加父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return r.db.Create(comment).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取评论
|
||||
@@ -87,23 +62,52 @@ func (r *CommentRepository) Delete(id string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 减少帖子的评论数并同步热度分
|
||||
// 仅已发布评论才参与统计,避免 pending/rejected 影响计数
|
||||
if comment.Status == model.CommentStatusPublished {
|
||||
// 减少帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count - 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,减少父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ApplyPublishedStats 在评论审核通过后更新帖子评论数/回复数
|
||||
func (r *CommentRepository) ApplyPublishedStats(comment *model.Comment) error {
|
||||
if comment == nil {
|
||||
return nil
|
||||
}
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 增加帖子的评论数并同步热度分
|
||||
if err := tx.Model(&model.Post{}).Where("id = ?", comment.PostID).
|
||||
Updates(map[string]interface{}{
|
||||
"comments_count": gorm.Expr("comments_count - 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count - 1) * 3 + views_count * 0.1"),
|
||||
"comments_count": gorm.Expr("comments_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + (comments_count + 1) * 3 + views_count * 0.1"),
|
||||
}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是回复,减少父评论的回复数
|
||||
// 如果是回复,增加父评论的回复数
|
||||
if comment.ParentID != nil && *comment.ParentID != "" {
|
||||
if err := tx.Model(&model.Comment{}).Where("id = ?", *comment.ParentID).
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count - 1")).Error; err != nil {
|
||||
UpdateColumn("replies_count", gorm.Expr("replies_count + 1")).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"carrot_bbs/internal/model"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -52,9 +53,41 @@ func (r *PostRepository) GetByID(id string) (*model.Post, error) {
|
||||
|
||||
// Update 更新帖子
|
||||
func (r *PostRepository) Update(post *model.Post) error {
|
||||
post.UpdatedAt = time.Now()
|
||||
return r.db.Save(post).Error
|
||||
}
|
||||
|
||||
// UpdateWithImages 更新帖子及其图片(images=nil 表示不更新图片)
|
||||
func (r *PostRepository) UpdateWithImages(post *model.Post, images *[]string) error {
|
||||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||||
post.UpdatedAt = time.Now()
|
||||
if err := tx.Save(post).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if images == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := tx.Where("post_id = ?", post.ID).Delete(&model.PostImage{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, url := range *images {
|
||||
image := &model.PostImage{
|
||||
PostID: post.ID,
|
||||
URL: url,
|
||||
SortOrder: i,
|
||||
}
|
||||
if err := tx.Create(image).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateModerationStatus 更新帖子审核状态
|
||||
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
|
||||
updates := map[string]interface{}{
|
||||
@@ -100,15 +133,24 @@ func (r *PostRepository) Delete(id string) error {
|
||||
}
|
||||
|
||||
// List 分页获取帖子列表
|
||||
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
|
||||
// includePending=true 时,仅在指定 userID 下额外返回 pending(用于作者查看自己待审核帖子)
|
||||
func (r *PostRepository) List(page, pageSize int, userID string, includePending bool) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
|
||||
query := r.db.Model(&model.Post{})
|
||||
|
||||
if userID != "" {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if includePending && userID != "" {
|
||||
query = query.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
query = query.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
|
||||
query.Count(&total)
|
||||
|
||||
@@ -119,14 +161,32 @@ func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post,
|
||||
}
|
||||
|
||||
// GetUserPosts 获取用户帖子
|
||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
|
||||
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int, includePending bool) ([]*model.Post, int64, error) {
|
||||
var posts []*model.Post
|
||||
var total int64
|
||||
|
||||
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
|
||||
statusQuery := r.db.Model(&model.Post{}).Where("user_id = ?", userID)
|
||||
if includePending {
|
||||
statusQuery = statusQuery.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
statusQuery = statusQuery.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
statusQuery.Count(&total)
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
listQuery := r.db.Where("user_id = ?", userID)
|
||||
if includePending {
|
||||
listQuery = listQuery.Where("status IN ?", []model.PostStatus{
|
||||
model.PostStatusPublished,
|
||||
model.PostStatusPending,
|
||||
})
|
||||
} else {
|
||||
listQuery = listQuery.Where("status = ?", model.PostStatusPublished)
|
||||
}
|
||||
err := listQuery.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
|
||||
|
||||
return posts, total, err
|
||||
}
|
||||
@@ -256,7 +316,8 @@ func (r *PostRepository) IsFavorited(postID, userID string) bool {
|
||||
// IncrementViews 增加帖子观看量
|
||||
func (r *PostRepository) IncrementViews(postID string) error {
|
||||
return r.db.Model(&model.Post{}).Where("id = ?", postID).
|
||||
Updates(map[string]interface{}{
|
||||
// 浏览量属于统计字段,不应影响帖子内容更新时间(updated_at)
|
||||
UpdateColumns(map[string]interface{}{
|
||||
"views_count": gorm.Expr("views_count + 1"),
|
||||
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
|
||||
}).Error
|
||||
|
||||
@@ -177,7 +177,9 @@ func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
||||
// GetPostsCount 获取用户帖子数(实时计算)
|
||||
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
err := r.db.Model(&model.Post{}).
|
||||
Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -202,7 +204,7 @@ func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64,
|
||||
var counts []CountResult
|
||||
err := r.db.Model(&model.Post{}).
|
||||
Select("user_id, count(*) as count").
|
||||
Where("user_id IN ?", userIDs).
|
||||
Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished).
|
||||
Group("user_id").
|
||||
Scan(&counts).Error
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user