2026-03-09 21:28:58 +08:00
package repository
import (
"carrot_bbs/internal/model"
2026-03-10 12:58:23 +08:00
"time"
2026-03-09 21:28:58 +08:00
"gorm.io/gorm"
)
// PostRepository 帖子仓储
type PostRepository struct {
db * gorm . DB
}
// NewPostRepository 创建帖子仓储
func NewPostRepository ( db * gorm . DB ) * PostRepository {
return & PostRepository { db : db }
}
// Create 创建帖子
func ( r * PostRepository ) Create ( post * model . Post , images [ ] string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 创建帖子
if err := tx . Create ( post ) . Error ; err != nil {
return err
}
// 创建图片记录
for i , url := range images {
image := & model . PostImage {
PostID : post . ID ,
URL : url ,
SortOrder : i ,
}
if err := tx . Create ( image ) . Error ; err != nil {
return err
}
}
return nil
} )
}
// GetByID 根据ID获取帖子
func ( r * PostRepository ) GetByID ( id string ) ( * model . Post , error ) {
var post model . Post
err := r . db . Preload ( "User" ) . Preload ( "Images" ) . First ( & post , "id = ?" , id ) . Error
if err != nil {
return nil , err
}
return & post , nil
}
// Update 更新帖子
func ( r * PostRepository ) Update ( post * model . Post ) error {
2026-03-10 12:58:23 +08:00
post . UpdatedAt = time . Now ( )
2026-03-09 21:28:58 +08:00
return r . db . Save ( post ) . Error
}
2026-03-10 12:58:23 +08:00
// UpdateWithImages 更新帖子及其图片( images=nil 表示不更新图片)
func ( r * PostRepository ) UpdateWithImages ( post * model . Post , images * [ ] string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
post . UpdatedAt = time . Now ( )
if err := tx . Save ( post ) . Error ; err != nil {
return err
}
if images == nil {
return nil
}
if err := tx . Where ( "post_id = ?" , post . ID ) . Delete ( & model . PostImage { } ) . Error ; err != nil {
return err
}
for i , url := range * images {
image := & model . PostImage {
PostID : post . ID ,
URL : url ,
SortOrder : i ,
}
if err := tx . Create ( image ) . Error ; err != nil {
return err
}
}
return nil
} )
}
2026-03-09 21:28:58 +08:00
// UpdateModerationStatus 更新帖子审核状态
func ( r * PostRepository ) UpdateModerationStatus ( postID string , status model . PostStatus , rejectReason string , reviewedBy string ) error {
updates := map [ string ] interface { } {
"status" : status ,
"reviewed_at" : gorm . Expr ( "CURRENT_TIMESTAMP" ) ,
"reviewed_by" : reviewedBy ,
"reject_reason" : rejectReason ,
}
return r . db . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) . Updates ( updates ) . Error
}
// Delete 删除帖子(软删除,同时清理关联数据)
func ( r * PostRepository ) Delete ( id string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 删除帖子图片
if err := tx . Where ( "post_id = ?" , id ) . Delete ( & model . PostImage { } ) . Error ; err != nil {
return err
}
// 删除帖子点赞记录
if err := tx . Where ( "post_id = ?" , id ) . Delete ( & model . PostLike { } ) . Error ; err != nil {
return err
}
// 删除帖子收藏记录
if err := tx . Where ( "post_id = ?" , id ) . Delete ( & model . Favorite { } ) . Error ; err != nil {
return err
}
// 删除评论点赞记录( 子查询获取该帖子所有评论ID)
if err := tx . Where ( "comment_id IN (SELECT id FROM comments WHERE post_id = ?)" , id ) . Delete ( & model . CommentLike { } ) . Error ; err != nil {
return err
}
// 删除帖子评论
if err := tx . Where ( "post_id = ?" , id ) . Delete ( & model . Comment { } ) . Error ; err != nil {
return err
}
// 最后删除帖子本身(软删除)
return tx . Delete ( & model . Post { } , "id = ?" , id ) . Error
} )
}
// List 分页获取帖子列表
2026-03-10 12:58:23 +08:00
// includePending=true 时,仅在指定 userID 下额外返回 pending( 用于作者查看自己待审核帖子)
func ( r * PostRepository ) List ( page , pageSize int , userID string , includePending bool ) ( [ ] * model . Post , int64 , error ) {
2026-03-09 21:28:58 +08:00
var posts [ ] * model . Post
var total int64
2026-03-10 12:58:23 +08:00
query := r . db . Model ( & model . Post { } )
2026-03-09 21:28:58 +08:00
if userID != "" {
query = query . Where ( "user_id = ?" , userID )
}
2026-03-10 12:58:23 +08:00
if includePending && userID != "" {
query = query . Where ( "status IN ?" , [ ] model . PostStatus {
model . PostStatusPublished ,
model . PostStatusPending ,
} )
} else {
query = query . Where ( "status = ?" , model . PostStatusPublished )
}
2026-03-09 21:28:58 +08:00
query . Count ( & total )
offset := ( page - 1 ) * pageSize
err := query . Preload ( "User" ) . Preload ( "Images" ) . Offset ( offset ) . Limit ( pageSize ) . Order ( "created_at DESC" ) . Find ( & posts ) . Error
return posts , total , err
}
// GetUserPosts 获取用户帖子
2026-03-10 12:58:23 +08:00
func ( r * PostRepository ) GetUserPosts ( userID string , page , pageSize int , includePending bool ) ( [ ] * model . Post , int64 , error ) {
2026-03-09 21:28:58 +08:00
var posts [ ] * model . Post
var total int64
2026-03-10 12:58:23 +08:00
statusQuery := r . db . Model ( & model . Post { } ) . Where ( "user_id = ?" , userID )
if includePending {
statusQuery = statusQuery . Where ( "status IN ?" , [ ] model . PostStatus {
model . PostStatusPublished ,
model . PostStatusPending ,
} )
} else {
statusQuery = statusQuery . Where ( "status = ?" , model . PostStatusPublished )
}
statusQuery . Count ( & total )
2026-03-09 21:28:58 +08:00
offset := ( page - 1 ) * pageSize
2026-03-10 12:58:23 +08:00
listQuery := r . db . Where ( "user_id = ?" , userID )
if includePending {
listQuery = listQuery . Where ( "status IN ?" , [ ] model . PostStatus {
model . PostStatusPublished ,
model . PostStatusPending ,
} )
} else {
listQuery = listQuery . Where ( "status = ?" , model . PostStatusPublished )
}
err := listQuery . Preload ( "User" ) . Preload ( "Images" ) . Offset ( offset ) . Limit ( pageSize ) . Order ( "created_at DESC" ) . Find ( & posts ) . Error
2026-03-09 21:28:58 +08:00
return posts , total , err
}
// GetFavorites 获取用户收藏
func ( r * PostRepository ) GetFavorites ( userID string , page , pageSize int ) ( [ ] * model . Post , int64 , error ) {
var posts [ ] * model . Post
var total int64
subQuery := r . db . Model ( & model . Favorite { } ) . Where ( "user_id = ?" , userID ) . Select ( "post_id" )
r . db . Model ( & model . Post { } ) . Where ( "id IN (?) AND status = ?" , subQuery , model . PostStatusPublished ) . Count ( & total )
offset := ( page - 1 ) * pageSize
err := r . db . Where ( "id IN (?) AND status = ?" , subQuery , model . PostStatusPublished ) . Preload ( "User" ) . Preload ( "Images" ) . Offset ( offset ) . Limit ( pageSize ) . Order ( "created_at DESC" ) . Find ( & posts ) . Error
return posts , total , err
}
// Like 点赞帖子
func ( r * PostRepository ) Like ( postID , userID string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 检查是否已经点赞
var existing model . PostLike
err := tx . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . First ( & existing ) . Error
if err == nil {
// 已经点赞,直接返回
return nil
}
if err != gorm . ErrRecordNotFound {
return err
}
// 创建点赞记录
if err := tx . Create ( & model . PostLike {
PostID : postID ,
UserID : userID ,
} ) . Error ; err != nil {
return err
}
// 增加帖子点赞数并同步热度分
return tx . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) .
Updates ( map [ string ] interface { } {
"likes_count" : gorm . Expr ( "likes_count + 1" ) ,
"hot_score" : gorm . Expr ( "(likes_count + 1) * 2 + comments_count * 3 + views_count * 0.1" ) ,
} ) . Error
} )
}
// Unlike 取消点赞
func ( r * PostRepository ) Unlike ( postID , userID string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
result := tx . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . Delete ( & model . PostLike { } )
if result . Error != nil {
return result . Error
}
if result . RowsAffected > 0 {
// 减少帖子点赞数并同步热度分
return tx . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) .
Updates ( map [ string ] interface { } {
"likes_count" : gorm . Expr ( "likes_count - 1" ) ,
"hot_score" : gorm . Expr ( "(likes_count - 1) * 2 + comments_count * 3 + views_count * 0.1" ) ,
} ) . Error
}
return nil
} )
}
// IsLiked 检查是否点赞
func ( r * PostRepository ) IsLiked ( postID , userID string ) bool {
var count int64
r . db . Model ( & model . PostLike { } ) . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . Count ( & count )
return count > 0
}
// Favorite 收藏帖子
func ( r * PostRepository ) Favorite ( postID , userID string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
// 检查是否已经收藏
var existing model . Favorite
err := tx . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . First ( & existing ) . Error
if err == nil {
// 已经收藏,直接返回
return nil
}
if err != gorm . ErrRecordNotFound {
return err
}
// 创建收藏记录
if err := tx . Create ( & model . Favorite {
PostID : postID ,
UserID : userID ,
} ) . Error ; err != nil {
return err
}
// 增加帖子收藏数
return tx . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) .
UpdateColumn ( "favorites_count" , gorm . Expr ( "favorites_count + 1" ) ) . Error
} )
}
// Unfavorite 取消收藏
func ( r * PostRepository ) Unfavorite ( postID , userID string ) error {
return r . db . Transaction ( func ( tx * gorm . DB ) error {
result := tx . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . Delete ( & model . Favorite { } )
if result . Error != nil {
return result . Error
}
if result . RowsAffected > 0 {
// 减少帖子收藏数
return tx . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) .
UpdateColumn ( "favorites_count" , gorm . Expr ( "favorites_count - 1" ) ) . Error
}
return nil
} )
}
// IsFavorited 检查是否收藏
func ( r * PostRepository ) IsFavorited ( postID , userID string ) bool {
var count int64
r . db . Model ( & model . Favorite { } ) . Where ( "post_id = ? AND user_id = ?" , postID , userID ) . Count ( & count )
return count > 0
}
// IncrementViews 增加帖子观看量
func ( r * PostRepository ) IncrementViews ( postID string ) error {
return r . db . Model ( & model . Post { } ) . Where ( "id = ?" , postID ) .
2026-03-10 12:58:23 +08:00
// 浏览量属于统计字段, 不应影响帖子内容更新时间( updated_at)
UpdateColumns ( map [ string ] interface { } {
2026-03-09 21:28:58 +08:00
"views_count" : gorm . Expr ( "views_count + 1" ) ,
"hot_score" : gorm . Expr ( "likes_count * 2 + comments_count * 3 + (views_count + 1) * 0.1" ) ,
} ) . Error
}
// Search 搜索帖子
func ( r * PostRepository ) Search ( keyword string , page , pageSize int ) ( [ ] * model . Post , int64 , error ) {
var posts [ ] * model . Post
var total int64
query := r . db . Model ( & model . Post { } ) . Where ( "status = ?" , model . PostStatusPublished )
// 搜索标题和内容
if keyword != "" {
if r . db . Dialector . Name ( ) == "postgres" {
// PostgreSQL 使用全文检索表达式,为 pg_trgm/GIN 索引升级预留路径
query = query . Where (
"to_tsvector('simple', COALESCE(title, '') || ' ' || COALESCE(content, '')) @@ plainto_tsquery('simple', ?)" ,
keyword ,
)
} else {
searchPattern := "%" + keyword + "%"
query = query . Where ( "title LIKE ? OR content LIKE ?" , searchPattern , searchPattern )
}
}
query . Count ( & total )
offset := ( page - 1 ) * pageSize
err := query . Preload ( "User" ) . Preload ( "Images" ) . Offset ( offset ) . Limit ( pageSize ) . Order ( "created_at DESC" ) . Find ( & posts ) . Error
return posts , total , err
}
// GetFollowingPosts 获取关注用户的帖子
func ( r * PostRepository ) GetFollowingPosts ( userID string , page , pageSize int ) ( [ ] * model . Post , int64 , error ) {
var posts [ ] * model . Post
var total int64
// 子查询: 获取当前用户关注的所有用户ID
subQuery := r . db . Model ( & model . Follow { } ) . Where ( "follower_id = ?" , userID ) . Select ( "following_id" )
// 统计总数
r . db . Model ( & model . Post { } ) . Where ( "user_id IN (?) AND status = ?" , subQuery , model . PostStatusPublished ) . Count ( & total )
offset := ( page - 1 ) * pageSize
err := r . db . Where ( "user_id IN (?) AND status = ?" , subQuery , model . PostStatusPublished ) .
Preload ( "User" ) . Preload ( "Images" ) .
Offset ( offset ) . Limit ( pageSize ) .
Order ( "created_at DESC" ) .
Find ( & posts ) . Error
return posts , total , err
}
// GetHotPosts 获取热门帖子(按点赞数和评论数排序)
func ( r * PostRepository ) GetHotPosts ( page , pageSize int ) ( [ ] * model . Post , int64 , error ) {
var posts [ ] * model . Post
var total int64
r . db . Model ( & model . Post { } ) . Where ( "status = ?" , model . PostStatusPublished ) . Count ( & total )
offset := ( page - 1 ) * pageSize
// 热门排序使用预计算热度分,避免每次请求进行表达式排序计算
err := r . db . Where ( "status = ?" , model . PostStatusPublished ) . Preload ( "User" ) . Preload ( "Images" ) .
Offset ( offset ) . Limit ( pageSize ) .
Order ( "hot_score DESC, created_at DESC" ) .
Find ( & posts ) . Error
return posts , total , err
}
// GetByIDs 根据ID列表获取帖子( 保持传入顺序)
func ( r * PostRepository ) GetByIDs ( ids [ ] string ) ( [ ] * model . Post , error ) {
if len ( ids ) == 0 {
return [ ] * model . Post { } , nil
}
var posts [ ] * model . Post
err := r . db . Preload ( "User" ) . Preload ( "Images" ) .
Where ( "id IN ? AND status = ?" , ids , model . PostStatusPublished ) .
Find ( & posts ) . Error
if err != nil {
return nil , err
}
// 按传入ID顺序排序
postMap := make ( map [ string ] * model . Post )
for _ , post := range posts {
postMap [ post . ID ] = post
}
ordered := make ( [ ] * model . Post , 0 , len ( ids ) )
for _ , id := range ids {
if post , ok := postMap [ id ] ; ok {
ordered = append ( ordered , post )
}
}
return ordered , nil
}