Files
backend/internal/repository/user_repo.go
lan 4c0177149a Clean backend debug logging and standardize error reporting.
This removes verbose trace output in handlers/services and keeps only actionable error-level logs.
2026-03-09 22:20:44 +08:00

399 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"carrot_bbs/internal/model"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// UserRepository 用户仓储
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(user *model.User) error {
return r.db.Create(user).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(id string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "id = ?", id).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(username string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "username = ?", username).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(email string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "email = ?", email).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByPhone 根据手机号获取用户
func (r *UserRepository) GetByPhone(phone string) (*model.User, error) {
var user model.User
err := r.db.First(&user, "phone = ?", phone).Error
if err != nil {
return nil, err
}
return &user, nil
}
// Update 更新用户
func (r *UserRepository) Update(user *model.User) error {
return r.db.Save(user).Error
}
// Delete 删除用户
func (r *UserRepository) Delete(id string) error {
return r.db.Delete(&model.User{}, "id = ?", id).Error
}
// List 分页获取用户列表
func (r *UserRepository) List(page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
r.db.Model(&model.User{}).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Order("created_at DESC, id DESC").Offset(offset).Limit(pageSize).Find(&users).Error
return users, total, err
}
// GetFollowers 获取用户粉丝
func (r *UserRepository) GetFollowers(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Select("follower_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// GetFollowing 获取用户关注
func (r *UserRepository) GetFollowing(userID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Select("following_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).Limit(pageSize).
Find(&users).Error
return users, total, err
}
// CreateFollow 创建关注关系
func (r *UserRepository) CreateFollow(follow *model.Follow) error {
return r.db.Create(follow).Error
}
// DeleteFollow 删除关注关系
func (r *UserRepository) DeleteFollow(followerID, followingID string) error {
return r.db.Where("follower_id = ? AND following_id = ?", followerID, followingID).Delete(&model.Follow{}).Error
}
// IsFollowing 检查是否关注了某用户
func (r *UserRepository) IsFollowing(followerID, followingID string) (bool, error) {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ? AND following_id = ?", followerID, followingID).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IncrementFollowersCount 增加用户粉丝数
func (r *UserRepository) IncrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count + 1")).Error
}
// DecrementFollowersCount 减少用户粉丝数
func (r *UserRepository) DecrementFollowersCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND followers_count > 0", userID).
UpdateColumn("followers_count", gorm.Expr("followers_count - 1")).Error
}
// IncrementFollowingCount 增加用户关注数
func (r *UserRepository) IncrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", gorm.Expr("following_count + 1")).Error
}
// DecrementFollowingCount 减少用户关注数
func (r *UserRepository) DecrementFollowingCount(userID string) error {
return r.db.Model(&model.User{}).Where("id = ? AND following_count > 0", userID).
UpdateColumn("following_count", gorm.Expr("following_count - 1")).Error
}
// RefreshFollowersCount 刷新用户粉丝数(通过实际计数)
func (r *UserRepository) RefreshFollowersCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("following_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("followers_count", count).Error
}
// GetPostsCount 获取用户帖子数(实时计算)
func (r *UserRepository) GetPostsCount(userID string) (int64, error) {
var count int64
err := r.db.Model(&model.Post{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
// GetPostsCountBatch 批量获取用户帖子数(实时计算)
// 返回 map[userID]postsCount
func (r *UserRepository) GetPostsCountBatch(userIDs []string) (map[string]int64, error) {
result := make(map[string]int64)
if len(userIDs) == 0 {
return result, nil
}
// 初始化所有用户ID的计数为0
for _, userID := range userIDs {
result[userID] = 0
}
// 使用 GROUP BY 一次性查询所有用户的帖子数
type CountResult struct {
UserID string
Count int64
}
var counts []CountResult
err := r.db.Model(&model.Post{}).
Select("user_id, count(*) as count").
Where("user_id IN ?", userIDs).
Group("user_id").
Scan(&counts).Error
if err != nil {
return nil, err
}
// 更新查询结果
for _, c := range counts {
result[c.UserID] = c.Count
}
return result, nil
}
// RefreshFollowingCount 刷新用户关注数(通过实际计数)
func (r *UserRepository) RefreshFollowingCount(userID string) error {
var count int64
err := r.db.Model(&model.Follow{}).Where("follower_id = ?", userID).Count(&count).Error
if err != nil {
return err
}
return r.db.Model(&model.User{}).Where("id = ?", userID).
UpdateColumn("following_count", count).Error
}
// IsBlocked 检查拉黑关系是否存在blocker -> blocked
func (r *UserRepository) IsBlocked(blockerID, blockedID string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// IsBlockedEitherDirection 检查是否任一方向存在拉黑
func (r *UserRepository) IsBlockedEitherDirection(userA, userB string) (bool, error) {
var count int64
err := r.db.Model(&model.UserBlock{}).
Where("(blocker_id = ? AND blocked_id = ?) OR (blocker_id = ? AND blocked_id = ?)",
userA, userB, userB, userA).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// BlockUserAndCleanupRelations 拉黑用户并清理双向关注关系(事务)
func (r *UserRepository) BlockUserAndCleanupRelations(blockerID, blockedID string) error {
return r.db.Transaction(func(tx *gorm.DB) error {
block := &model.UserBlock{
BlockerID: blockerID,
BlockedID: blockedID,
}
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "blocker_id"}, {Name: "blocked_id"}},
DoNothing: true,
}).Create(block).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockerID, blockedID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
if err := tx.Where("follower_id = ? AND following_id = ?", blockedID, blockerID).
Delete(&model.Follow{}).Error; err != nil {
return err
}
for _, uid := range []string{blockerID, blockedID} {
var followersCount int64
if err := tx.Model(&model.Follow{}).Where("following_id = ?", uid).Count(&followersCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("followers_count", followersCount).Error; err != nil {
return err
}
var followingCount int64
if err := tx.Model(&model.Follow{}).Where("follower_id = ?", uid).Count(&followingCount).Error; err != nil {
return err
}
if err := tx.Model(&model.User{}).Where("id = ?", uid).
UpdateColumn("following_count", followingCount).Error; err != nil {
return err
}
}
return nil
})
}
// UnblockUser 取消拉黑
func (r *UserRepository) UnblockUser(blockerID, blockedID string) error {
return r.db.Where("blocker_id = ? AND blocked_id = ?", blockerID, blockedID).
Delete(&model.UserBlock{}).Error
}
// GetBlockedUsers 获取用户黑名单列表
func (r *UserRepository) GetBlockedUsers(blockerID string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
subQuery := r.db.Model(&model.UserBlock{}).Where("blocker_id = ?", blockerID).Select("blocked_id")
r.db.Model(&model.User{}).Where("id IN (?)", subQuery).Count(&total)
offset := (page - 1) * pageSize
err := r.db.Where("id IN (?)", subQuery).
Order("created_at DESC, id DESC").
Offset(offset).
Limit(pageSize).
Find(&users).Error
return users, total, err
}
// Search 搜索用户
func (r *UserRepository) Search(keyword string, page, pageSize int) ([]*model.User, int64, error) {
var users []*model.User
var total int64
query := r.db.Model(&model.User{})
// 搜索用户名、昵称、简介
if keyword != "" {
if r.db.Dialector.Name() == "postgres" {
query = query.Where(
"to_tsvector('simple', COALESCE(username, '') || ' ' || COALESCE(nickname, '') || ' ' || COALESCE(bio, '')) @@ plainto_tsquery('simple', ?)",
keyword,
)
} else {
searchPattern := "%" + keyword + "%"
query = query.Where("username LIKE ? OR nickname LIKE ? OR bio LIKE ?", searchPattern, searchPattern, searchPattern)
}
}
query.Count(&total)
offset := (page - 1) * pageSize
err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&users).Error
return users, total, err
}
// GetMutualFollowStatus 批量获取双向关注状态
// 返回 map[userID][isFollowing, isFollowingMe]
func (r *UserRepository) GetMutualFollowStatus(currentUserID string, targetUserIDs []string) (map[string][2]bool, error) {
result := make(map[string][2]bool)
if len(targetUserIDs) == 0 {
return result, nil
}
// 初始化所有目标用户为未关注状态
for _, userID := range targetUserIDs {
result[userID] = [2]bool{false, false}
}
// 查询当前用户关注了哪些目标用户 (isFollowing)
var followingIDs []string
err := r.db.Model(&model.Follow{}).
Where("follower_id = ? AND following_id IN ?", currentUserID, targetUserIDs).
Pluck("following_id", &followingIDs).Error
if err != nil {
return nil, err
}
for _, id := range followingIDs {
status := result[id]
status[0] = true
result[id] = status
}
// 查询哪些目标用户关注了当前用户 (isFollowingMe)
var followerIDs []string
err = r.db.Model(&model.Follow{}).
Where("follower_id IN ? AND following_id = ?", targetUserIDs, currentUserID).
Pluck("follower_id", &followerIDs).Error
if err != nil {
return nil, err
}
for _, id := range followerIDs {
status := result[id]
status[1] = true
result[id] = status
}
return result, nil
}