405 lines
12 KiB
Go
405 lines
12 KiB
Go
|
|
package repository
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"carrot_bbs/internal/model"
|
|||
|
|
"fmt"
|
|||
|
|
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
"gorm.io/gorm/clause"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// UserRepository 用户仓储
|
|||
|
|
type UserRepository struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewUserRepository 创建用户仓储
|
|||
|
|
func NewUserRepository(db *gorm.DB) *UserRepository {
|
|||
|
|
return &UserRepository{db: db}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create 创建用户
|
|||
|
|
func (r *UserRepository) Create(user *model.User) error {
|
|||
|
|
return r.db.Create(user).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByID 根据ID获取用户
|
|||
|
|
func (r *UserRepository) GetByID(id string) (*model.User, error) {
|
|||
|
|
var user model.User
|
|||
|
|
err := r.db.First(&user, "id = ?", id).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &user, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByUsername 根据用户名获取用户
|
|||
|
|
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
|
|||
|
|
var user model.User
|
|||
|
|
err := r.db.First(&user, "username = ?", username).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &user, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByEmail 根据邮箱获取用户
|
|||
|
|
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
|
|||
|
|
var user model.User
|
|||
|
|
err := r.db.First(&user, "email = ?", email).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &user, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByPhone 根据手机号获取用户
|
|||
|
|
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
|
|||
|
|
var user model.User
|
|||
|
|
err := r.db.First(&user, "phone = ?", phone).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &user, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update 更新用户
|
|||
|
|
func (r *UserRepository) Update(user *model.User) error {
|
|||
|
|
return r.db.Save(user).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Delete 删除用户
|
|||
|
|
func (r *UserRepository) Delete(id string) error {
|
|||
|
|
return r.db.Delete(&model.User{}, "id = ?", id).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// List 分页获取用户列表
|
|||
|
|
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
|
|||
|
|
var users []*model.User
|
|||
|
|
var total int64
|
|||
|
|
|
|||
|
|
r.db.Model(&model.User{}).Count(&total)
|
|||
|
|
|
|||
|
|
offset := (page - 1) * pageSize
|
|||
|
|
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
|
|||
|
|
|
|||
|
|
return users, total, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetFollowers 获取用户粉丝
|
|||
|
|
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
|||
|
|
var users []*model.User
|
|||
|
|
var total int64
|
|||
|
|
|
|||
|
|
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
|
|||
|
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
|||
|
|
|
|||
|
|
offset := (page - 1) * pageSize
|
|||
|
|
err := r.db.Where("id IN (?)", subQuery).
|
|||
|
|
Order("created_at DESC, id DESC").
|
|||
|
|
Offset(offset).Limit(pageSize).
|
|||
|
|
Find(&users).Error
|
|||
|
|
|
|||
|
|
return users, total, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetFollowing 获取用户关注
|
|||
|
|
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
|||
|
|
var users []*model.User
|
|||
|
|
var total int64
|
|||
|
|
|
|||
|
|
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
|||
|
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
|||
|
|
|
|||
|
|
offset := (page - 1) * pageSize
|
|||
|
|
err := r.db.Where("id IN (?)", subQuery).
|
|||
|
|
Order("created_at DESC, id DESC").
|
|||
|
|
Offset(offset).Limit(pageSize).
|
|||
|
|
Find(&users).Error
|
|||
|
|
|
|||
|
|
return users, total, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CreateFollow 创建关注关系
|
|||
|
|
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
|
|||
|
|
return r.db.Create(follow).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteFollow 删除关注关系
|
|||
|
|
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
|
|||
|
|
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsFollowing 检查是否关注了某用户
|
|||
|
|
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return false, err
|
|||
|
|
}
|
|||
|
|
return count > 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IncrementFollowersCount 增加用户粉丝数
|
|||
|
|
func (r *UserRepository) IncrementFollowersCount(userID string) error {
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
|||
|
|
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DecrementFollowersCount 减少用户粉丝数
|
|||
|
|
func (r *UserRepository) DecrementFollowersCount(userID string) error {
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
|
|||
|
|
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IncrementFollowingCount 增加用户关注数
|
|||
|
|
func (r *UserRepository) IncrementFollowingCount(userID string) error {
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
|||
|
|
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DecrementFollowingCount 减少用户关注数
|
|||
|
|
func (r *UserRepository) DecrementFollowingCount(userID string) error {
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
|
|||
|
|
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
|
|||
|
|
func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
|||
|
|
UpdateColumn("followers_count", count).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetPostsCount 获取用户帖子数(实时计算)
|
|||
|
|
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
|||
|
|
return count, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
|
|||
|
|
// 返回 map[userID]postsCount
|
|||
|
|
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
|
|||
|
|
result := make(map[string]int64)
|
|||
|
|
if len(userIDs) == 0 {
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 初始化所有用户ID的计数为0
|
|||
|
|
for _, userID := range userIDs {
|
|||
|
|
result[userID] = 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用 GROUP BY 一次性查询所有用户的帖子数
|
|||
|
|
type CountResult struct {
|
|||
|
|
UserID string
|
|||
|
|
Count int64
|
|||
|
|
}
|
|||
|
|
var counts []CountResult
|
|||
|
|
err := r.db.Model(&model.Post{}).
|
|||
|
|
Select("user_id, count(*) as count").
|
|||
|
|
Where("user_id IN ?", userIDs).
|
|||
|
|
Group("user_id").
|
|||
|
|
Scan(&counts).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 更新查询结果
|
|||
|
|
for _, c := range counts {
|
|||
|
|
result[c.UserID] = c.Count
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
|
|||
|
|
func (r *UserRepository) RefreshFollowingCount(userID string) error {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
|||
|
|
UpdateColumn("following_count", count).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsBlocked 检查拉黑关系是否存在(blocker -> blocked)
|
|||
|
|
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.UserBlock{}).
|
|||
|
|
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
|||
|
|
Count(&count).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return false, err
|
|||
|
|
}
|
|||
|
|
return count > 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
|
|||
|
|
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
|
|||
|
|
var count int64
|
|||
|
|
err := r.db.Model(&model.UserBlock{}).
|
|||
|
|
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
|
|||
|
|
userA, userB, userB, userA).
|
|||
|
|
Count(&count).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return false, err
|
|||
|
|
}
|
|||
|
|
return count > 0, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
|
|||
|
|
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
|
|||
|
|
return r.db.Transaction(func(tx *gorm.DB) error {
|
|||
|
|
block := &model.UserBlock{
|
|||
|
|
BlockerID: blockerID,
|
|||
|
|
BlockedID: blockedID,
|
|||
|
|
}
|
|||
|
|
if err := tx.Clauses(clause.OnConflict{
|
|||
|
|
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
|
|||
|
|
DoNothing: true,
|
|||
|
|
}).Create(block).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
|
|||
|
|
Delete(&model.Follow{}).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
|
|||
|
|
Delete(&model.Follow{}).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, uid := range []string{blockerID, blockedID} {
|
|||
|
|
var followersCount int64
|
|||
|
|
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
|||
|
|
UpdateColumn("followers_count", followersCount).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var followingCount int64
|
|||
|
|
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
|||
|
|
UpdateColumn("following_count", followingCount).Error; err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UnblockUser 取消拉黑
|
|||
|
|
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
|
|||
|
|
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
|||
|
|
Delete(&model.UserBlock{}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetBlockedUsers 获取用户黑名单列表
|
|||
|
|
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
|||
|
|
var users []*model.User
|
|||
|
|
var total int64
|
|||
|
|
|
|||
|
|
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
|
|||
|
|
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
|||
|
|
|
|||
|
|
offset := (page - 1) * pageSize
|
|||
|
|
err := r.db.Where("id IN (?)", subQuery).
|
|||
|
|
Order("created_at DESC, id DESC").
|
|||
|
|
Offset(offset).
|
|||
|
|
Limit(pageSize).
|
|||
|
|
Find(&users).Error
|
|||
|
|
|
|||
|
|
return users, total, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Search 搜索用户
|
|||
|
|
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
|||
|
|
var users []*model.User
|
|||
|
|
var total int64
|
|||
|
|
|
|||
|
|
query := r.db.Model(&model.User{})
|
|||
|
|
|
|||
|
|
// 搜索用户名、昵称、简介
|
|||
|
|
if keyword != "" {
|
|||
|
|
if r.db.Dialector.Name() == "postgres" {
|
|||
|
|
query = query.Where(
|
|||
|
|
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
|
|||
|
|
keyword,
|
|||
|
|
)
|
|||
|
|
} else {
|
|||
|
|
searchPattern := "%" + keyword + "%"
|
|||
|
|
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
query.Count(&total)
|
|||
|
|
|
|||
|
|
offset := (page - 1) * pageSize
|
|||
|
|
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
|
|||
|
|
|
|||
|
|
return users, total, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetMutualFollowStatus 批量获取双向关注状态
|
|||
|
|
// 返回 map[userID][isFollowing, isFollowingMe]
|
|||
|
|
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
|||
|
|
result := make(map[string][2]bool)
|
|||
|
|
|
|||
|
|
if len(targetUserIDs) == 0 {
|
|||
|
|
return result, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUserID=%s, targetUserIDs=%v\n", currentUserID, targetUserIDs)
|
|||
|
|
|
|||
|
|
// 初始化所有目标用户为未关注状态
|
|||
|
|
for _, userID := range targetUserIDs {
|
|||
|
|
result[userID] = [2]bool{false, false}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 查询当前用户关注了哪些目标用户 (isFollowing)
|
|||
|
|
var followingIDs []string
|
|||
|
|
err := r.db.Model(&model.Follow{}).
|
|||
|
|
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
|
|||
|
|
Pluck("following_id", &followingIDs).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: currentUser follows these targets: %v\n", followingIDs)
|
|||
|
|
for _, id := range followingIDs {
|
|||
|
|
status := result[id]
|
|||
|
|
status[0] = true
|
|||
|
|
result[id] = status
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
|
|||
|
|
var followerIDs []string
|
|||
|
|
err = r.db.Model(&model.Follow{}).
|
|||
|
|
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
|
|||
|
|
Pluck("follower_id", &followerIDs).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: these targets follow currentUser: %v\n", followerIDs)
|
|||
|
|
for _, id := range followerIDs {
|
|||
|
|
status := result[id]
|
|||
|
|
status[1] = true
|
|||
|
|
result[id] = status
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fmt.Printf("[DEBUG] GetMutualFollowStatus: final result=%v\n", result)
|
|||
|
|
return result, nil
|
|||
|
|
}
|