Files
backend/internal/handler/user_handler.go
lan 4d8f2ec997 Initial backend repository commit.
Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
2026-03-09 21:28:58 +08:00

706 lines
20 KiB
Go

package handler
import (
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// UserHandler 用户处理器
type UserHandler struct {
userService *service.UserService
jwtService *service.JWTService
}
// NewUserHandler 创建用户处理器
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// Register 用户注册
func (h *UserHandler) Register(c *gin.Context) {
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Nickname string `json:"nickname" binding:"required"`
Phone string `json:"phone"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to register")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
type LoginRequest struct {
Username string `json:"username"`
Account string `json:"account"`
Password string `json:"password" binding:"required"`
}
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
account := req.Account
if account == "" {
account = req.Username
}
if account == "" {
response.BadRequest(c, "username or account is required")
return
}
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to login")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SendRegisterCode 发送注册验证码
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send register verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// SendPasswordResetCode 发送找回密码验证码
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send reset verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// ResetPassword 找回密码并重置
func (h *UserHandler) ResetPassword(c *gin.Context) {
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to reset password")
return
}
response.Success(c, gin.H{"success": true})
}
// GetCurrentUser 获取当前用户
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// GetUserByID 获取指定用户
func (h *UserHandler) GetUserByID(c *gin.Context) {
id := c.Param("id")
currentUserID := c.GetString("user_id")
// 获取用户信息,包含双向关注状态
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
response.Success(c, userResponse)
}
// UpdateUser 更新用户
func (h *UserHandler) UpdateUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type UpdateRequest struct {
Nickname string `json:"nickname"`
Bio string `json:"bio"`
Website string `json:"website"`
Location string `json:"location"`
Avatar string `json:"avatar"`
Phone *string `json:"phone"`
Email *string `json:"email"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
if req.Nickname != "" {
user.Nickname = req.Nickname
}
if req.Bio != "" {
user.Bio = req.Bio
}
if req.Website != "" {
user.Website = req.Website
}
if req.Location != "" {
user.Location = req.Location
}
if req.Avatar != "" {
user.Avatar = req.Avatar
}
if req.Phone != nil {
user.Phone = req.Phone
}
if req.Email != nil {
if user.Email == nil || *user.Email != *req.Email {
user.EmailVerified = false
}
user.Email = req.Email
}
err = h.userService.UpdateUser(c.Request.Context(), user)
if err != nil {
response.InternalServerError(c, "failed to update user")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// SendEmailVerifyCode 发送当前用户邮箱验证码
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send email verify code")
return
}
response.Success(c, gin.H{"success": true})
}
// VerifyEmail 验证当前用户邮箱
func (h *UserHandler) VerifyEmail(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type VerifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req VerifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to verify email")
return
}
response.Success(c, gin.H{"success": true})
}
// RefreshToken 刷新Token
func (h *UserHandler) RefreshToken(c *gin.Context) {
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
var req RefreshRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 解析 refresh token
claims, err := h.jwtService.ParseToken(req.RefreshToken)
if err != nil {
response.Unauthorized(c, "invalid refresh token")
return
}
// 生成新 token
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
response.Success(c, gin.H{
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SetJWTService 设置JWT服务
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
h.jwtService = jwtSvc
}
// FollowUser 关注用户
func (h *UserHandler) FollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
if userID == currentUserID {
response.BadRequest(c, "cannot follow yourself")
return
}
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to follow user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnfollowUser 取消关注用户
func (h *UserHandler) UnfollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to unfollow user")
return
}
response.Success(c, gin.H{"success": true})
}
// BlockUser 拉黑用户
func (h *UserHandler) BlockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot block yourself")
return
}
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to block user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnblockUser 取消拉黑
func (h *UserHandler) UnblockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot unblock yourself")
return
}
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to unblock user")
return
}
response.Success(c, gin.H{"success": true})
}
// GetBlockedUsers 获取黑名单列表
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get blocked users")
return
}
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
response.Paginated(c, userResponses, total, page, pageSize)
}
// GetBlockStatus 获取拉黑状态
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
if targetUserID == "" {
response.BadRequest(c, "target user id is required")
return
}
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
response.InternalServerError(c, "failed to get block status")
return
}
response.Success(c, gin.H{"is_blocked": isBlocked})
}
// GetFollowingList 获取关注列表
func (h *UserHandler) GetFollowingList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get following list")
return
}
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// GetFollowersList 获取粉丝列表
func (h *UserHandler) GetFollowersList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get followers list")
return
}
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// CheckUsername 检查用户名是否可用
func (h *UserHandler) CheckUsername(c *gin.Context) {
username := c.Query("username")
if username == "" {
response.BadRequest(c, "username is required")
return
}
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
if err != nil {
response.InternalServerError(c, "failed to check username")
return
}
response.Success(c, gin.H{"available": available})
}
// ChangePassword 修改密码
func (h *UserHandler) ChangePassword(c *gin.Context) {
currentUserID := c.GetString("user_id")
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to change password")
return
}
response.Success(c, gin.H{"success": true})
}
// SendChangePasswordCode 发送修改密码验证码
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send change password code")
return
}
response.Success(c, gin.H{"success": true})
}
// Search 搜索用户
func (h *UserHandler) Search(c *gin.Context) {
keyword := c.Query("keyword")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to search users")
return
}
// 获取实时计算的帖子数量
var userResponses []*dto.UserResponse
if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Paginated(c, userResponses, total, page, pageSize)
}