Files
backend/internal/repository/post_repo.go

361 lines
11 KiB
Go
Raw Normal View History

package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
)
// PostRepository 帖子仓储
type PostRepository struct {
db *gorm.DB
}
// NewPostRepository 创建帖子仓储
func NewPostRepository(db *gorm.DB) *PostRepository {
return &PostRepository{db: db}
}
// Create 创建帖子
func (r *PostRepository) Create(post *model.Post, images []string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 创建帖子
if err := tx.Create(post).Error; err != nil {
return err
}
// 创建图片记录
for i, url := range images {
image := &model.PostImage{
PostID: post.ID,
URL: url,
SortOrder: i,
}
if err := tx.Create(image).Error; err != nil {
return err
}
}
return nil
})
}
// GetByID 根据ID获取帖子
func (r *PostRepository) GetByID(id string) (*model.Post, error) {
var post model.Post
err := r.db.Preload("User").Preload("Images").First(&post, "id = ?", id).Error
if err != nil {
return nil, err
}
return &post, nil
}
// Update 更新帖子
func (r *PostRepository) Update(post *model.Post) error {
return r.db.Save(post).Error
}
// UpdateModerationStatus 更新帖子审核状态
func (r *PostRepository) UpdateModerationStatus(postID string, status model.PostStatus, rejectReason string, reviewedBy string) error {
updates := map[string]interface{}{
"status": status,
"reviewed_at": gorm.Expr("CURRENT_TIMESTAMP"),
"reviewed_by": reviewedBy,
"reject_reason": rejectReason,
}
return r.db.Model(&model.Post{}).Where("id = ?", postID).Updates(updates).Error
}
// Delete 删除帖子(软删除,同时清理关联数据)
func (r *PostRepository) Delete(id string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 删除帖子图片
if err := tx.Where("post_id = ?", id).Delete(&model.PostImage{}).Error; err != nil {
return err
}
// 删除帖子点赞记录
if err := tx.Where("post_id = ?", id).Delete(&model.PostLike{}).Error; err != nil {
return err
}
// 删除帖子收藏记录
if err := tx.Where("post_id = ?", id).Delete(&model.Favorite{}).Error; err != nil {
return err
}
// 删除评论点赞记录子查询获取该帖子所有评论ID
if err := tx.Where("comment_id IN (SELECT id FROM comments WHERE post_id = ?)", id).Delete(&model.CommentLike{}).Error; err != nil {
return err
}
// 删除帖子评论
if err := tx.Where("post_id = ?", id).Delete(&model.Comment{}).Error; err != nil {
return err
}
// 最后删除帖子本身(软删除)
return tx.Delete(&model.Post{}, "id = ?", id).Error
})
}
// List 分页获取帖子列表
func (r *PostRepository) List(page, pageSize int, userID string) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
if userID != "" {
query = query.Where("user_id = ?", userID)
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetUserPosts 获取用户帖子
func (r *PostRepository) GetUserPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
r.db.Model(&model.Post{}).Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetFavorites 获取用户收藏
func (r *PostRepository) GetFavorites(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
subQuery := r.db.Model(&model.Favorite{}).Where("user_id = ?", userID).Select("post_id")
r.db.Model(&model.Post{}).Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// Like 点赞帖子
func (r *PostRepository) Like(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 检查是否已经点赞
var existing model.PostLike
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
if err == nil {
// 已经点赞,直接返回
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建点赞记录
if err := tx.Create(&model.PostLike{
PostID: postID,
UserID: userID,
}).Error; err != nil {
return err
}
// 增加帖子点赞数并同步热度分
return tx.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"likes_count": gorm.Expr("likes_count + 1"),
"hot_score": gorm.Expr("(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1"),
}).Error
})
}
// Unlike 取消点赞
func (r *PostRepository) Unlike(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.PostLike{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少帖子点赞数并同步热度分
return tx.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"likes_count": gorm.Expr("likes_count - 1"),
"hot_score": gorm.Expr("(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1"),
}).Error
}
return nil
})
}
// IsLiked 检查是否点赞
func (r *PostRepository) IsLiked(postID, userID string) bool {
var count int64
r.db.Model(&model.PostLike{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
return count > 0
}
// Favorite 收藏帖子
func (r *PostRepository) Favorite(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
// 检查是否已经收藏
var existing model.Favorite
err := tx.Where("post_id = ? AND user_id = ?", postID, userID).First(&existing).Error
if err == nil {
// 已经收藏,直接返回
return nil
}
if err != gorm.ErrRecordNotFound {
return err
}
// 创建收藏记录
if err := tx.Create(&model.Favorite{
PostID: postID,
UserID: userID,
}).Error; err != nil {
return err
}
// 增加帖子收藏数
return tx.Model(&model.Post{}).Where("id = ?", postID).
UpdateColumn("favorites_count", gorm.Expr("favorites_count + 1")).Error
})
}
// Unfavorite 取消收藏
func (r *PostRepository) Unfavorite(postID, userID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
result := tx.Where("post_id = ? AND user_id = ?", postID, userID).Delete(&model.Favorite{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected > 0 {
// 减少帖子收藏数
return tx.Model(&model.Post{}).Where("id = ?", postID).
UpdateColumn("favorites_count", gorm.Expr("favorites_count - 1")).Error
}
return nil
})
}
// IsFavorited 检查是否收藏
func (r *PostRepository) IsFavorited(postID, userID string) bool {
var count int64
r.db.Model(&model.Favorite{}).Where("post_id = ? AND user_id = ?", postID, userID).Count(&count)
return count > 0
}
// IncrementViews 增加帖子观看量
func (r *PostRepository) IncrementViews(postID string) error {
return r.db.Model(&model.Post{}).Where("id = ?", postID).
Updates(map[string]interface{}{
"views_count": gorm.Expr("views_count + 1"),
"hot_score": gorm.Expr("likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1"),
}).Error
}
// Search 搜索帖子
func (r *PostRepository) Search(keyword string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
query := r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished)
// 搜索标题和内容
if keyword != "" {
if r.db.Dialector.Name() == "postgres" {
// PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径
query = query.Where(
"to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)",
keyword,
)
} else {
searchPattern := "%" + keyword + "%"
query = query.Where("title LIKE ? OR content LIKE ?", searchPattern, searchPattern)
}
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Preload("User").Preload("Images").Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&posts).Error
return posts, total, err
}
// GetFollowingPosts 获取关注用户的帖子
func (r *PostRepository) GetFollowingPosts(userID string, page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
// 子查询获取当前用户关注的所有用户ID
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
// 统计总数
r.db.Model(&model.Post{}).Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("user_id IN (?) AND status = ?", subQuery, model.PostStatusPublished).
Preload("User").Preload("Images").
Offset(offset).Limit(pageSize).
Order("created_at DESC").
Find(&posts).Error
return posts, total, err
}
// GetHotPosts 获取热门帖子(按点赞数和评论数排序)
func (r *PostRepository) GetHotPosts(page, pageSize int) ([]*model.Post, int64, error) {
var posts []*model.Post
var total int64
r.db.Model(&model.Post{}).Where("status = ?", model.PostStatusPublished).Count(&total)
offset := (page - 1) * pageSize
// 热门排序使用预计算热度分,避免每次请求进行表达式排序计算
err := r.db.Where("status = ?", model.PostStatusPublished).Preload("User").Preload("Images").
Offset(offset).Limit(pageSize).
Order("hot_score DESC, created_at DESC").
Find(&posts).Error
return posts, total, err
}
// GetByIDs 根据ID列表获取帖子保持传入顺序
func (r *PostRepository) GetByIDs(ids []string) ([]*model.Post, error) {
if len(ids) == 0 {
return []*model.Post{}, nil
}
var posts []*model.Post
err := r.db.Preload("User").Preload("Images").
Where("id IN ? AND status = ?", ids, model.PostStatusPublished).
Find(&posts).Error
if err != nil {
return nil, err
}
// 按传入ID顺序排序
postMap := make(map[string]*model.Post)
for _, post := range posts {
postMap[post.ID] = post
}
ordered := make([]*model.Post, 0, len(ids))
for _, id := range ids {
if post, ok := postMap[id]; ok {
ordered = append(ordered, post)
}
}
return ordered, nil
}