This removes verbose trace output in handlers/services and keeps only actionable error-level logs.
399 lines
12 KiB
Go
399 lines
12 KiB
Go
package repository
|
||
|
||
import (
|
||
"carrot_bbs/internal/model"
|
||
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/clause"
|
||
)
|
||
|
||
// UserRepository 用户仓储
|
||
type UserRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewUserRepository 创建用户仓储
|
||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||
return &UserRepository{db: db}
|
||
}
|
||
|
||
// Create 创建用户
|
||
func (r *UserRepository) Create(user *model.User) error {
|
||
return r.db.Create(user).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取用户
|
||
func (r *UserRepository) GetByID(id string) (*model.User, error) {
|
||
var user model.User
|
||
err := r.db.First(&user, "id = ?", id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByUsername 根据用户名获取用户
|
||
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
|
||
var user model.User
|
||
err := r.db.First(&user, "username = ?", username).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByEmail 根据邮箱获取用户
|
||
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
|
||
var user model.User
|
||
err := r.db.First(&user, "email = ?", email).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByPhone 根据手机号获取用户
|
||
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
|
||
var user model.User
|
||
err := r.db.First(&user, "phone = ?", phone).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// Update 更新用户
|
||
func (r *UserRepository) Update(user *model.User) error {
|
||
return r.db.Save(user).Error
|
||
}
|
||
|
||
// Delete 删除用户
|
||
func (r *UserRepository) Delete(id string) error {
|
||
return r.db.Delete(&model.User{}, "id = ?", id).Error
|
||
}
|
||
|
||
// List 分页获取用户列表
|
||
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
|
||
var users []*model.User
|
||
var total int64
|
||
|
||
r.db.Model(&model.User{}).Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
|
||
|
||
return users, total, err
|
||
}
|
||
|
||
// GetFollowers 获取用户粉丝
|
||
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||
var users []*model.User
|
||
var total int64
|
||
|
||
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
|
||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := r.db.Where("id IN (?)", subQuery).
|
||
Order("created_at DESC, id DESC").
|
||
Offset(offset).Limit(pageSize).
|
||
Find(&users).Error
|
||
|
||
return users, total, err
|
||
}
|
||
|
||
// GetFollowing 获取用户关注
|
||
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
|
||
var users []*model.User
|
||
var total int64
|
||
|
||
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
|
||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := r.db.Where("id IN (?)", subQuery).
|
||
Order("created_at DESC, id DESC").
|
||
Offset(offset).Limit(pageSize).
|
||
Find(&users).Error
|
||
|
||
return users, total, err
|
||
}
|
||
|
||
// CreateFollow 创建关注关系
|
||
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
|
||
return r.db.Create(follow).Error
|
||
}
|
||
|
||
// DeleteFollow 删除关注关系
|
||
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
|
||
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
|
||
}
|
||
|
||
// IsFollowing 检查是否关注了某用户
|
||
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
|
||
var count int64
|
||
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
// IncrementFollowersCount 增加用户粉丝数
|
||
func (r *UserRepository) IncrementFollowersCount(userID string) error {
|
||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
|
||
}
|
||
|
||
// DecrementFollowersCount 减少用户粉丝数
|
||
func (r *UserRepository) DecrementFollowersCount(userID string) error {
|
||
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
|
||
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
|
||
}
|
||
|
||
// IncrementFollowingCount 增加用户关注数
|
||
func (r *UserRepository) IncrementFollowingCount(userID string) error {
|
||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
|
||
}
|
||
|
||
// DecrementFollowingCount 减少用户关注数
|
||
func (r *UserRepository) DecrementFollowingCount(userID string) error {
|
||
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
|
||
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
|
||
}
|
||
|
||
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
|
||
func (r *UserRepository) RefreshFollowersCount(userID string) error {
|
||
var count int64
|
||
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||
UpdateColumn("followers_count", count).Error
|
||
}
|
||
|
||
// GetPostsCount 获取用户帖子数(实时计算)
|
||
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
|
||
var count int64
|
||
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
|
||
return count, err
|
||
}
|
||
|
||
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
|
||
// 返回 map[userID]postsCount
|
||
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
|
||
result := make(map[string]int64)
|
||
if len(userIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
// 初始化所有用户ID的计数为0
|
||
for _, userID := range userIDs {
|
||
result[userID] = 0
|
||
}
|
||
|
||
// 使用 GROUP BY 一次性查询所有用户的帖子数
|
||
type CountResult struct {
|
||
UserID string
|
||
Count int64
|
||
}
|
||
var counts []CountResult
|
||
err := r.db.Model(&model.Post{}).
|
||
Select("user_id, count(*) as count").
|
||
Where("user_id IN ?", userIDs).
|
||
Group("user_id").
|
||
Scan(&counts).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 更新查询结果
|
||
for _, c := range counts {
|
||
result[c.UserID] = c.Count
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
|
||
func (r *UserRepository) RefreshFollowingCount(userID string) error {
|
||
var count int64
|
||
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return r.db.Model(&model.User{}).Where("id = ?", userID).
|
||
UpdateColumn("following_count", count).Error
|
||
}
|
||
|
||
// IsBlocked 检查拉黑关系是否存在(blocker -> blocked)
|
||
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
|
||
var count int64
|
||
err := r.db.Model(&model.UserBlock{}).
|
||
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
|
||
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
|
||
var count int64
|
||
err := r.db.Model(&model.UserBlock{}).
|
||
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
|
||
userA, userB, userB, userA).
|
||
Count(&count).Error
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return count > 0, nil
|
||
}
|
||
|
||
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
|
||
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
|
||
return r.db.Transaction(func(tx *gorm.DB) error {
|
||
block := &model.UserBlock{
|
||
BlockerID: blockerID,
|
||
BlockedID: blockedID,
|
||
}
|
||
if err := tx.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
|
||
DoNothing: true,
|
||
}).Create(block).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
|
||
Delete(&model.Follow{}).Error; err != nil {
|
||
return err
|
||
}
|
||
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
|
||
Delete(&model.Follow{}).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
for _, uid := range []string{blockerID, blockedID} {
|
||
var followersCount int64
|
||
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
|
||
return err
|
||
}
|
||
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||
UpdateColumn("followers_count", followersCount).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
var followingCount int64
|
||
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
|
||
return err
|
||
}
|
||
if err := tx.Model(&model.User{}).Where("id = ?", uid).
|
||
UpdateColumn("following_count", followingCount).Error; err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// UnblockUser 取消拉黑
|
||
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
|
||
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
|
||
Delete(&model.UserBlock{}).Error
|
||
}
|
||
|
||
// GetBlockedUsers 获取用户黑名单列表
|
||
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
|
||
var users []*model.User
|
||
var total int64
|
||
|
||
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
|
||
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := r.db.Where("id IN (?)", subQuery).
|
||
Order("created_at DESC, id DESC").
|
||
Offset(offset).
|
||
Limit(pageSize).
|
||
Find(&users).Error
|
||
|
||
return users, total, err
|
||
}
|
||
|
||
// Search 搜索用户
|
||
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
|
||
var users []*model.User
|
||
var total int64
|
||
|
||
query := r.db.Model(&model.User{})
|
||
|
||
// 搜索用户名、昵称、简介
|
||
if keyword != "" {
|
||
if r.db.Dialector.Name() == "postgres" {
|
||
query = query.Where(
|
||
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
|
||
keyword,
|
||
)
|
||
} else {
|
||
searchPattern := "%" + keyword + "%"
|
||
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
|
||
}
|
||
}
|
||
|
||
query.Count(&total)
|
||
|
||
offset := (page - 1) * pageSize
|
||
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
|
||
|
||
return users, total, err
|
||
}
|
||
|
||
// GetMutualFollowStatus 批量获取双向关注状态
|
||
// 返回 map[userID][isFollowing, isFollowingMe]
|
||
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
|
||
result := make(map[string][2]bool)
|
||
|
||
if len(targetUserIDs) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
// 初始化所有目标用户为未关注状态
|
||
for _, userID := range targetUserIDs {
|
||
result[userID] = [2]bool{false, false}
|
||
}
|
||
|
||
// 查询当前用户关注了哪些目标用户 (isFollowing)
|
||
var followingIDs []string
|
||
err := r.db.Model(&model.Follow{}).
|
||
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
|
||
Pluck("following_id", &followingIDs).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, id := range followingIDs {
|
||
status := result[id]
|
||
status[0] = true
|
||
result[id] = status
|
||
}
|
||
|
||
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
|
||
var followerIDs []string
|
||
err = r.db.Model(&model.Follow{}).
|
||
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
|
||
Pluck("follower_id", &followerIDs).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, id := range followerIDs {
|
||
status := result[id]
|
||
status[1] = true
|
||
result[id] = status
|
||
}
|
||
|
||
return result, nil
|
||
}
|