package repository import ( "carrot_bbs/internal/model" "fmt" "gorm.io/gorm" "gorm.io/gorm/clause" ) // UserRepository 用户仓储 type UserRepository struct { db *gorm.DB } // NewUserRepository 创建用户仓储 func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } // Create 创建用户 func (r *UserRepository) Create(user *model.User) error { return r.db.Create(user).Error } // GetByID 根据ID获取用户 func (r *UserRepository) GetByID(id string) (*model.User, error) { var user model.User err := r.db.First(&user, "id = ?", id).Error if err != nil { return nil, err } return &user, nil } // GetByUsername 根据用户名获取用户 func (r *UserRepository) GetByUsername(username string) (*model.User, error) { var user model.User err := r.db.First(&user, "username = ?", username).Error if err != nil { return nil, err } return &user, nil } // GetByEmail 根据邮箱获取用户 func (r *UserRepository) GetByEmail(email string) (*model.User, error) { var user model.User err := r.db.First(&user, "email = ?", email).Error if err != nil { return nil, err } return &user, nil } // GetByPhone 根据手机号获取用户 func (r *UserRepository) GetByPhone(phone string) (*model.User, error) { var user model.User err := r.db.First(&user, "phone = ?", phone).Error if err != nil { return nil, err } return &user, nil } // Update 更新用户 func (r *UserRepository) Update(user *model.User) error { return r.db.Save(user).Error } // Delete 删除用户 func (r *UserRepository) Delete(id string) error { return r.db.Delete(&model.User{}, "id = ?", id).Error } // List 分页获取用户列表 func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) { var users []*model.User var total int64 r.db.Model(&model.User{}).Count(&total) offset := (page - 1) * pageSize err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error return users, total, err } // GetFollowers 获取用户粉丝 func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) { var users []*model.User var total int64 subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id") r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("id IN (?)", subQuery). Order("created_at DESC, id DESC"). Offset(offset).Limit(pageSize). Find(&users).Error return users, total, err } // GetFollowing 获取用户关注 func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) { var users []*model.User var total int64 subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id") r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("id IN (?)", subQuery). Order("created_at DESC, id DESC"). Offset(offset).Limit(pageSize). Find(&users).Error return users, total, err } // CreateFollow 创建关注关系 func (r *UserRepository) CreateFollow(follow *model.Follow) error { return r.db.Create(follow).Error } // DeleteFollow 删除关注关系 func (r *UserRepository) DeleteFollow(followerID, followingID string) error { return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error } // IsFollowing 检查是否关注了某用户 func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) { var count int64 err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error if err != nil { return false, err } return count > 0, nil } // IncrementFollowersCount 增加用户粉丝数 func (r *UserRepository) IncrementFollowersCount(userID string) error { return r.db.Model(&model.User{}).Where("id = ?", userID). UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error } // DecrementFollowersCount 减少用户粉丝数 func (r *UserRepository) DecrementFollowersCount(userID string) error { return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID). UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error } // IncrementFollowingCount 增加用户关注数 func (r *UserRepository) IncrementFollowingCount(userID string) error { return r.db.Model(&model.User{}).Where("id = ?", userID). UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error } // DecrementFollowingCount 减少用户关注数 func (r *UserRepository) DecrementFollowingCount(userID string) error { return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID). UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error } // RefreshFollowersCount 刷新用户粉丝数(通过实际计数) func (r *UserRepository) RefreshFollowersCount(userID string) error { var count int64 err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error if err != nil { return err } return r.db.Model(&model.User{}).Where("id = ?", userID). UpdateColumn("followers_count", count).Error } // GetPostsCount 获取用户帖子数(实时计算) func (r *UserRepository) GetPostsCount(userID string) (int64, error) { var count int64 err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error return count, err } // GetPostsCountBatch 批量获取用户帖子数(实时计算) // 返回 map[userID]postsCount func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) { result := make(map[string]int64) if len(userIDs) == 0 { return result, nil } // 初始化所有用户ID的计数为0 for _, userID := range userIDs { result[userID] = 0 } // 使用 GROUP BY 一次性查询所有用户的帖子数 type CountResult struct { UserID string Count int64 } var counts []CountResult err := r.db.Model(&model.Post{}). Select("user_id, count(*) as count"). Where("user_id IN ?", userIDs). Group("user_id"). Scan(&counts).Error if err != nil { return nil, err } // 更新查询结果 for _, c := range counts { result[c.UserID] = c.Count } return result, nil } // RefreshFollowingCount 刷新用户关注数(通过实际计数) func (r *UserRepository) RefreshFollowingCount(userID string) error { var count int64 err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error if err != nil { return err } return r.db.Model(&model.User{}).Where("id = ?", userID). UpdateColumn("following_count", count).Error } // IsBlocked 检查拉黑关系是否存在(blocker -> blocked) func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) { var count int64 err := r.db.Model(&model.UserBlock{}). Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID). Count(&count).Error if err != nil { return false, err } return count > 0, nil } // IsBlockedEitherDirection 检查是否任一方向存在拉黑 func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) { var count int64 err := r.db.Model(&model.UserBlock{}). Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)", userA, userB, userB, userA). Count(&count).Error if err != nil { return false, err } return count > 0, nil } // BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务) func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error { return r.db.Transaction(func(tx *gorm.DB) error { block := &model.UserBlock{ BlockerID: blockerID, BlockedID: blockedID, } if err := tx.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}}, DoNothing: true, }).Create(block).Error; err != nil { return err } if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID). Delete(&model.Follow{}).Error; err != nil { return err } if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID). Delete(&model.Follow{}).Error; err != nil { return err } for _, uid := range []string{blockerID, blockedID} { var followersCount int64 if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil { return err } if err := tx.Model(&model.User{}).Where("id = ?", uid). UpdateColumn("followers_count", followersCount).Error; err != nil { return err } var followingCount int64 if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil { return err } if err := tx.Model(&model.User{}).Where("id = ?", uid). UpdateColumn("following_count", followingCount).Error; err != nil { return err } } return nil }) } // UnblockUser 取消拉黑 func (r *UserRepository) UnblockUser(blockerID, blockedID string) error { return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID). Delete(&model.UserBlock{}).Error } // GetBlockedUsers 获取用户黑名单列表 func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) { var users []*model.User var total int64 subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id") r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total) offset := (page - 1) * pageSize err := r.db.Where("id IN (?)", subQuery). Order("created_at DESC, id DESC"). Offset(offset). Limit(pageSize). Find(&users).Error return users, total, err } // Search 搜索用户 func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) { var users []*model.User var total int64 query := r.db.Model(&model.User{}) // 搜索用户名、昵称、简介 if keyword != "" { if r.db.Dialector.Name() == "postgres" { query = query.Where( "to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)", keyword, ) } else { searchPattern := "%" + keyword + "%" query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern) } } query.Count(&total) offset := (page - 1) * pageSize err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error return users, total, err } // GetMutualFollowStatus 批量获取双向关注状态 // 返回 map[userID][isFollowing, isFollowingMe] func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) { result := make(map[string][2]bool) if len(targetUserIDs) == 0 { return result, nil } fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs) // 初始化所有目标用户为未关注状态 for _, userID := range targetUserIDs { result[userID] = [2]bool{false, false} } // 查询当前用户关注了哪些目标用户 (isFollowing) var followingIDs []string err := r.db.Model(&model.Follow{}). Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs). Pluck("following_id", &followingIDs).Error if err != nil { return nil, err } fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs) for _, id := range followingIDs { status := result[id] status[0] = true result[id] = status } // 查询哪些目标用户关注了当前用户 (isFollowingMe) var followerIDs []string err = r.db.Model(&model.Follow{}). Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID). Pluck("follower_id", &followerIDs).Error if err != nil { return nil, err } fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs) for _, id := range followerIDs { status := result[id] status[1] = true result[id] = status } fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result) return result, nil }