Files

401 lines
12 KiB
Go
Raw Permalink Normal View History

package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UserRepository 用户仓储
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(user *model.User) error {
return r.db.Create(user).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(id string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "id = ?", id).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "username = ?", username).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "email = ?", email).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByPhone 根据手机号获取用户
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "phone = ?", phone).Error
if err != nil {
return nil, err
}
return &user, nil
}
// Update 更新用户
func (r *UserRepository) Update(user *model.User) error {
return r.db.Save(user).Error
}
// Delete 删除用户
func (r *UserRepository) Delete(id string) error {
return r.db.Delete(&model.User{}, "id = ?", id).Error
}
// List 分页获取用户列表
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
r.db.Model(&model.User{}).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
return users, total, err
}
// GetFollowers 获取用户粉丝
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// GetFollowing 获取用户关注
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// CreateFollow 创建关注关系
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
return r.db.Create(follow).Error
}
// DeleteFollow 删除关注关系
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
}
// IsFollowing 检查是否关注了某用户
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IncrementFollowersCount 增加用户粉丝数
func (r *UserRepository) IncrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
}
// DecrementFollowersCount 减少用户粉丝数
func (r *UserRepository) DecrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
}
// IncrementFollowingCount 增加用户关注数
func (r *UserRepository) IncrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
}
// DecrementFollowingCount 减少用户关注数
func (r *UserRepository) DecrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
}
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
func (r *UserRepository) RefreshFollowersCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", count).Error
}
// GetPostsCount 获取用户帖子数(实时计算)
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.Post{}).
Where("user_id = ? AND status = ?", userID, model.PostStatusPublished).
Count(&count).Error
return count, err
}
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
// 返回 map[userID]postsCount
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
result := make(map[string]int64)
if len(userIDs) == 0 {
return result, nil
}
// 初始化所有用户ID的计数为0
for _, userID := range userIDs {
result[userID] = 0
}
// 使用 GROUP BY 一次性查询所有用户的帖子数
type CountResult struct {
UserID string
Count int64
}
var counts []CountResult
err := r.db.Model(&model.Post{}).
Select("user_id, count(*) as count").
Where("user_id IN ? AND status = ?", userIDs, model.PostStatusPublished).
Group("user_id").
Scan(&counts).Error
if err != nil {
return nil, err
}
// 更新查询结果
for _, c := range counts {
result[c.UserID] = c.Count
}
return result, nil
}
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
func (r *UserRepository) RefreshFollowingCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", count).Error
}
// IsBlocked 检查拉黑关系是否存在blocker -> blocked
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
userA, userB, userB, userA).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
block := &model.UserBlock{
BlockerID: blockerID,
BlockedID: blockedID,
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
DoNothing: true,
}).Create(block).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
for _, uid := range []string{blockerID, blockedID} {
var followersCount int64
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("followers_count", followersCount).Error; err != nil {
return err
}
var followingCount int64
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("following_count", followingCount).Error; err != nil {
return err
}
}
return nil
})
}
// UnblockUser 取消拉黑
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Delete(&model.UserBlock{}).Error
}
// GetBlockedUsers 获取用户黑名单列表
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).
Limit(pageSize).
Find(&users).Error
return users, total, err
}
// Search 搜索用户
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
query := r.db.Model(&model.User{})
// 搜索用户名、昵称、简介
if keyword != "" {
if r.db.Dialector.Name() == "postgres" {
query = query.Where(
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
keyword,
)
} else {
searchPattern := "%" + keyword + "%"
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
}
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
return users, total, err
}
// GetMutualFollowStatus 批量获取双向关注状态
// 返回 map[userID][isFollowing, isFollowingMe]
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
result := make(map[string][2]bool)
if len(targetUserIDs) == 0 {
return result, nil
}
// 初始化所有目标用户为未关注状态
for _, userID := range targetUserIDs {
result[userID] = [2]bool{false, false}
}
// 查询当前用户关注了哪些目标用户 (isFollowing)
var followingIDs []string
err := r.db.Model(&model.Follow{}).
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
Pluck("following_id", &followingIDs).Error
if err != nil {
return nil, err
}
for _, id := range followingIDs {
status := result[id]
status[0] = true
result[id] = status
}
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
var followerIDs []string
err = r.db.Model(&model.Follow{}).
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
Pluck("follower_id", &followerIDs).Error
if err != nil {
return nil, err
}
for _, id := range followerIDs {
status := result[id]
status[1] = true
result[id] = status
}
return result, nil
}