Initial backend repository commit.

Set up project files and add .gitignore to exclude local build/runtime artifacts.

Made-with: Cursor
This commit is contained in:
2026-03-09 21:28:58 +08:00
commit 4d8f2ec997
102 changed files with 25022 additions and 0 deletions

View File

@@ -0,0 +1,253 @@
package handler
import (
"encoding/json"
"errors"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// CommentHandler 评论处理器
type CommentHandler struct {
commentService *service.CommentService
}
// NewCommentHandler 创建评论处理器
func NewCommentHandler(commentService *service.CommentService) *CommentHandler {
return &CommentHandler{
commentService: commentService,
}
}
// Create 创建评论
func (h *CommentHandler) Create(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type CreateRequest struct {
PostID string `json:"post_id" binding:"required"`
Content string `json:"content"` // 内容可选,允许只发图片
ParentID *string `json:"parent_id"`
Images []string `json:"images"` // 图片URL列表
}
var req CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证:评论必须有内容或图片
if req.Content == "" && len(req.Images) == 0 {
response.BadRequest(c, "评论内容或图片不能同时为空")
return
}
// 将图片列表转换为JSON字符串
var imagesJSON string
if len(req.Images) > 0 {
imagesBytes, _ := json.Marshal(req.Images)
imagesJSON = string(imagesBytes)
}
comment, err := h.commentService.Create(c.Request.Context(), req.PostID, userID, req.Content, req.ParentID, imagesJSON, req.Images)
if err != nil {
var moderationErr *service.CommentModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.InternalServerError(c, "failed to create comment")
return
}
response.Success(c, dto.ConvertCommentToResponse(comment, false))
}
// GetByID 获取单条评论详情
func (h *CommentHandler) GetByID(c *gin.Context) {
userID := c.GetString("user_id")
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
resp := dto.ConvertCommentToResponse(comment, h.commentService.IsLiked(c.Request.Context(), id, userID))
response.Success(c, resp)
}
// GetByPostID 获取帖子评论
func (h *CommentHandler) GetByPostID(c *gin.Context) {
userID := c.GetString("user_id")
postID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
comments, total, err := h.commentService.GetByPostID(c.Request.Context(), postID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get comments")
return
}
// 转换为响应结构,检查每个评论的点赞状态
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
response.Paginated(c, commentResponses, total, page, pageSize)
}
// GetReplies 获取回复
func (h *CommentHandler) GetReplies(c *gin.Context) {
userID := c.GetString("user_id")
parentID := c.Param("id")
comments, err := h.commentService.GetReplies(c.Request.Context(), parentID)
if err != nil {
response.InternalServerError(c, "failed to get replies")
return
}
// 转换为响应结构,检查每个回复的点赞状态
commentResponses := dto.ConvertCommentsToResponseWithUser(comments, userID, h.commentService)
response.Success(c, commentResponses)
}
// GetRepliesByRootID 根据根评论ID分页获取回复扁平化
func (h *CommentHandler) GetRepliesByRootID(c *gin.Context) {
userID := c.GetString("user_id")
rootID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
replies, total, err := h.commentService.GetRepliesByRootID(c.Request.Context(), rootID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get replies")
return
}
// 转换为响应结构,检查每个回复的点赞状态
replyResponses := dto.ConvertCommentsToResponseWithUser(replies, userID, h.commentService)
response.Paginated(c, replyResponses, total, page, pageSize)
}
// Update 更新评论
func (h *CommentHandler) Update(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
if comment.UserID != userID {
response.Forbidden(c, "cannot update others' comment")
return
}
type UpdateRequest struct {
Content string `json:"content" binding:"required"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
comment.Content = req.Content
err = h.commentService.Update(c.Request.Context(), comment)
if err != nil {
response.InternalServerError(c, "failed to update comment")
return
}
response.Success(c, dto.ConvertCommentToResponse(comment, false))
}
// Delete 删除评论
func (h *CommentHandler) Delete(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
comment, err := h.commentService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "comment not found")
return
}
if comment.UserID != userID {
response.Forbidden(c, "cannot delete others' comment")
return
}
err = h.commentService.Delete(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to delete comment")
return
}
response.SuccessWithMessage(c, "comment deleted", nil)
}
// Like 点赞评论
func (h *CommentHandler) Like(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.commentService.Like(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to like comment")
return
}
response.SuccessWithMessage(c, "liked", nil)
}
// Unlike 取消点赞评论
func (h *CommentHandler) Unlike(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.commentService.Unlike(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unlike comment")
return
}
response.SuccessWithMessage(c, "unliked", nil)
}

View File

@@ -0,0 +1,234 @@
package handler
import (
"context"
"log"
"strings"
"time"
"carrot_bbs/internal/config"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/gorse"
"carrot_bbs/internal/pkg/response"
gorseio "github.com/gorse-io/gorse-go"
"github.com/gin-gonic/gin"
)
// GorseHandler Gorse推荐处理器
type GorseHandler struct {
importPassword string
gorseConfig config.GorseConfig
}
// NewGorseHandler 创建Gorse处理器
func NewGorseHandler(cfg config.GorseConfig) *GorseHandler {
return &GorseHandler{
importPassword: cfg.ImportPassword,
gorseConfig: cfg,
}
}
// ImportRequest 导入请求
type ImportRequest struct {
Password string `json:"password"`
}
// ImportData 导入数据到Gorse
// POST /api/v1/gorse/import
func (h *GorseHandler) ImportData(c *gin.Context) {
// 验证密码
if h.importPassword == "" {
response.BadRequest(c, "Gorse import is disabled")
return
}
var req ImportRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "invalid request body")
return
}
if req.Password != h.importPassword {
response.Unauthorized(c, "invalid password")
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Minute)
defer cancel()
stats, err := h.importAllData(ctx)
if err != nil {
log.Printf("[ERROR] gorse import failed: %v", err)
response.InternalServerError(c, "gorse import failed: "+err.Error())
return
}
response.Success(c, gin.H{
"message": "import completed",
"status": "done",
"stats": stats,
})
}
// GetStatus 获取Gorse状态
// GET /api/v1/gorse/status
func (h *GorseHandler) GetStatus(c *gin.Context) {
// 返回Gorse连接状态和配置信息
hasPassword := h.importPassword != ""
response.Success(c, gin.H{
"enabled": h.gorseConfig.Enabled,
"has_password": hasPassword,
"address": h.gorseConfig.Address,
"api_key": strings.Repeat("*", 8), // 不返回实际APIKey
})
}
func (h *GorseHandler) importAllData(ctx context.Context) (gin.H, error) {
gorseClient := gorseio.NewGorseClient(h.gorseConfig.Address, h.gorseConfig.APIKey)
gorse.InitEmbeddingWithConfig(h.gorseConfig.EmbeddingAPIKey, h.gorseConfig.EmbeddingURL, h.gorseConfig.EmbeddingModel)
stats := gin.H{
"items": 0,
"users": 0,
"likes": 0,
"favorites": 0,
"comments": 0,
}
// 导入帖子
var posts []model.Post
if err := model.DB.Find(&posts).Error; err != nil {
return nil, err
}
for _, post := range posts {
embedding, err := gorse.GetEmbedding(strings.TrimSpace(post.Title + " " + post.Content))
if err != nil {
log.Printf("[WARN] get embedding failed for post %s: %v", post.ID, err)
embedding = make([]float64, 1024)
}
_, err = gorseClient.InsertItem(ctx, gorseio.Item{
ItemId: post.ID,
IsHidden: post.DeletedAt.Valid,
Categories: buildPostCategories(&post),
Comment: post.Title,
Timestamp: post.CreatedAt.UTC().Truncate(time.Second),
Labels: map[string]any{
"embedding": embedding,
},
})
if err != nil {
log.Printf("[WARN] import item failed (%s): %v", post.ID, err)
continue
}
stats["items"] = stats["items"].(int) + 1
}
// 导入用户
var users []model.User
if err := model.DB.Find(&users).Error; err != nil {
return nil, err
}
for _, user := range users {
_, err := gorseClient.InsertUser(ctx, gorseio.User{
UserId: user.ID,
Labels: map[string]any{
"posts_count": user.PostsCount,
"followers_count": user.FollowersCount,
"following_count": user.FollowingCount,
},
Comment: user.Nickname,
})
if err != nil {
log.Printf("[WARN] import user failed (%s): %v", user.ID, err)
continue
}
stats["users"] = stats["users"].(int) + 1
}
// 导入点赞
var likes []model.PostLike
if err := model.DB.Find(&likes).Error; err != nil {
return nil, err
}
for _, like := range likes {
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeLike),
UserId: like.UserID,
ItemId: like.PostID,
Timestamp: like.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import like failed (%s/%s): %v", like.UserID, like.PostID, err)
continue
}
stats["likes"] = stats["likes"].(int) + 1
}
// 导入收藏
var favorites []model.Favorite
if err := model.DB.Find(&favorites).Error; err != nil {
return nil, err
}
for _, fav := range favorites {
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeStar),
UserId: fav.UserID,
ItemId: fav.PostID,
Timestamp: fav.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import favorite failed (%s/%s): %v", fav.UserID, fav.PostID, err)
continue
}
stats["favorites"] = stats["favorites"].(int) + 1
}
// 导入评论(按用户-帖子去重)
var comments []model.Comment
if err := model.DB.Where("status = ?", model.CommentStatusPublished).Find(&comments).Error; err != nil {
return nil, err
}
seen := make(map[string]struct{})
for _, cm := range comments {
key := cm.UserID + ":" + cm.PostID
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
_, err := gorseClient.InsertFeedback(ctx, []gorseio.Feedback{{
FeedbackType: string(gorse.FeedbackTypeComment),
UserId: cm.UserID,
ItemId: cm.PostID,
Timestamp: cm.CreatedAt.UTC().Truncate(time.Second),
}})
if err != nil {
log.Printf("[WARN] import comment failed (%s/%s): %v", cm.UserID, cm.PostID, err)
continue
}
stats["comments"] = stats["comments"].(int) + 1
}
return stats, nil
}
func buildPostCategories(post *model.Post) []string {
var categories []string
if post.ViewsCount > 1000 {
categories = append(categories, "hot_high")
} else if post.ViewsCount > 100 {
categories = append(categories, "hot_medium")
}
if post.LikesCount > 100 {
categories = append(categories, "likes_100+")
} else if post.LikesCount > 10 {
categories = append(categories, "likes_10+")
}
age := time.Since(post.CreatedAt)
if age < 24*time.Hour {
categories = append(categories, "today")
} else if age < 7*24*time.Hour {
categories = append(categories, "this_week")
}
return categories
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,879 @@
package handler
import (
"context"
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// MessageHandler 消息处理器
type MessageHandler struct {
chatService service.ChatService
messageService *service.MessageService
userService *service.UserService
groupService service.GroupService
}
// NewMessageHandler 创建消息处理器
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler {
return &MessageHandler{
chatService: chatService,
messageService: messageService,
userService: userService,
groupService: groupService,
}
}
// GetConversations 获取会话列表
// GET /api/conversations
func (h *MessageHandler) GetConversations(c *gin.Context) {
userID := c.GetString("user_id")
// 添加调试日志
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get conversations")
return
}
// 过滤掉系统会话(系统通知现在使用独立的表)
filteredConvs := make([]*model.Conversation, 0)
for _, conv := range convs {
if conv.ID != model.SystemConversationID {
filteredConvs = append(filteredConvs, conv)
}
}
// 转换为响应格式
result := make([]*dto.ConversationResponse, len(filteredConvs))
for i, conv := range filteredConvs {
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
// 获取最后一条消息
var lastMessage *model.Message
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
if len(messages) > 0 {
lastMessage = messages[0]
}
// 群聊时返回member_count私聊时返回participants
var resp *dto.ConversationResponse
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
// 群聊:实时计算群成员数量
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
// 创建响应并设置member_count
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
resp.MemberCount = memberCount
} else {
// 私聊:获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
}
result[i] = resp
}
// 更新 total 为过滤后的数量
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
}
// CreateConversation 创建私聊会话
// POST /api/conversations
func (h *MessageHandler) CreateConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req dto.CreateConversationRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证目标用户是否存在
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
if err != nil {
response.BadRequest(c, "target user not found")
return
}
// 不能和自己创建会话
if userID == req.UserID {
response.BadRequest(c, "cannot create conversation with yourself")
return
}
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
if err != nil {
response.InternalServerError(c, "failed to create conversation")
return
}
// 获取参与者信息
participants := []*model.User{targetUser}
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
}
// GetConversationByID 获取会话详情
// GET /api/conversations/:id
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr)
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err)
response.BadRequest(c, "invalid conversation id")
return
}
fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID)
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
// 获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
// 获取当前用户的已读位置
myLastReadSeq := int64(0)
isPinned := false
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
for _, p := range allParticipants {
if p.UserID == userID {
myLastReadSeq = p.LastReadSeq
isPinned = p.IsPinned
break
}
}
// 获取对方用户的已读位置
otherLastReadSeq := int64(0)
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
}
// GetMessages 获取消息列表
// GET /api/conversations/:id/messages
func (h *MessageHandler) GetMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
// 检查是否使用增量同步after_seq参数
afterSeqStr := c.Query("after_seq")
if afterSeqStr != "" {
// 增量同步模式
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid after_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 检查是否使用历史消息加载before_seq参数
beforeSeqStr := c.Query("before_seq")
if beforeSeqStr != "" {
// 加载更早的消息(下拉加载更多)
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid before_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 分页模式
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Paginated(c, result, total, page, pageSize)
}
// SendMessage 发送消息
// POST /api/conversations/:id/messages
func (h *MessageHandler) SendMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr)
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err)
response.BadRequest(c, "invalid conversation id")
return
}
fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID)
var req dto.SendMessageRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 直接使用 segments
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, dto.ConvertMessageToResponse(msg))
}
// HandleSendMessage RESTful 风格的发送消息端点
// POST /api/v1/conversations/send_message
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.SendMessageParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证参数
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if params.DetailType == "" {
response.BadRequest(c, "detail_type is required")
return
}
if params.Segments == nil || len(params.Segments) == 0 {
response.BadRequest(c, "segments is required")
return
}
// 发送消息
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 构建 WSEventResponse 格式响应
wsResponse := dto.WSEventResponse{
ID: msg.ID,
Time: msg.CreatedAt.UnixMilli(),
Type: "message",
DetailType: params.DetailType,
Seq: strconv.FormatInt(msg.Seq, 10),
Segments: params.Segments,
SenderID: userID,
}
response.Success(c, wsResponse)
}
// HandleDeleteMsg 撤回消息
// POST /api/v1/messages/delete_msg
// 请求体格式: {"message_id": "xxx"}
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.DeleteMsgParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证参数
if params.MessageID == "" {
response.BadRequest(c, "message_id is required")
return
}
// 撤回消息
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "消息已撤回", nil)
}
// HandleGetConversationList 获取会话列表
// GET /api/v1/conversations/list
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get conversations")
return
}
// 过滤掉系统会话(系统通知现在使用独立的表)
filteredConvs := make([]*model.Conversation, 0)
for _, conv := range convs {
if conv.ID != model.SystemConversationID {
filteredConvs = append(filteredConvs, conv)
}
}
// 转换为响应格式
result := make([]*dto.ConversationResponse, len(filteredConvs))
for i, conv := range filteredConvs {
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
// 获取最后一条消息
var lastMessage *model.Message
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
if len(messages) > 0 {
lastMessage = messages[0]
}
// 群聊时返回member_count私聊时返回participants
var resp *dto.ConversationResponse
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
// 群聊:实时计算群成员数量
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
// 创建响应并设置member_count
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
resp.MemberCount = memberCount
} else {
// 私聊:获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
}
result[i] = resp
}
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
}
// HandleDeleteConversationForSelf 仅自己删除会话
// DELETE /api/v1/conversations/:id/self
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Param("id")
if conversationID == "" {
response.BadRequest(c, "conversation id is required")
return
}
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "conversation deleted for self", nil)
}
// MarkAsRead 标记为已读
// POST /api/conversations/:id/read
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
var req dto.MarkReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "last_read_seq is required")
return
}
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// GetUnreadCount 获取未读消息总数
// GET /api/conversations/unread/count
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
// 添加调试日志
fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID)
if userID == "" {
fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n")
response.Unauthorized(c, "")
return
}
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, &dto.UnreadCountResponse{
TotalUnreadCount: count,
})
}
// GetConversationUnreadCount 获取单个会话的未读数
// GET /api/conversations/:id/unread/count
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationIDStr := c.Param("id")
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
response.BadRequest(c, "invalid conversation id")
return
}
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.Success(c, &dto.ConversationUnreadCountResponse{
ConversationID: conversationID,
UnreadCount: count,
})
}
// RecallMessage 撤回消息
// POST /api/messages/:id/recall
func (h *MessageHandler) RecallMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
messageIDStr := c.Param("id")
// 直接使用 string 类型的 messageID
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "message recalled", nil)
}
// DeleteMessage 删除消息
// DELETE /api/messages/:id
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
messageIDStr := c.Param("id")
// 直接使用 string 类型的 messageID
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "message deleted", nil)
}
// 辅助函数:验证内容类型
func isValidContentType(contentType model.ContentType) bool {
switch contentType {
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
return true
default:
return false
}
}
// 辅助函数:获取会话参与者信息
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
// 从repository获取参与者列表
participants, err := h.messageService.GetConversationParticipants(conversationID)
if err != nil {
return nil, err
}
// 获取参与者用户信息
var users []*model.User
for _, p := range participants {
// 跳过当前用户
if p.UserID == currentUserID {
continue
}
user, err := h.userService.GetUserByID(ctx, p.UserID)
if err != nil {
continue
}
users = append(users, user)
}
return users, nil
}
// 获取当前用户在会话中的参与者信息
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
participants, err := h.messageService.GetConversationParticipants(conversationID)
if err != nil {
return nil, err
}
for _, p := range participants {
if p.UserID == userID {
return p, nil
}
}
return nil, nil
}
// ==================== RESTful Action 端点 ====================
// HandleCreateConversation 创建会话
// POST /api/v1/conversations/create
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.CreateConversationParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证目标用户是否存在
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
if err != nil {
response.BadRequest(c, "target user not found")
return
}
// 不能和自己创建会话
if userID == params.UserID {
response.BadRequest(c, "cannot create conversation with yourself")
return
}
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
if err != nil {
response.InternalServerError(c, "failed to create conversation")
return
}
// 获取参与者信息
participants := []*model.User{targetUser}
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
isPinned := myParticipant != nil && myParticipant.IsPinned
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
}
// HandleGetConversation 获取会话详情
// GET /api/v1/conversations/get?conversation_id=xxx
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Query("conversation_id")
if conversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取未读数
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
// 获取参与者信息
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
// 获取当前用户的已读位置
myLastReadSeq := int64(0)
isPinned := false
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
for _, p := range allParticipants {
if p.UserID == userID {
myLastReadSeq = p.LastReadSeq
isPinned = p.IsPinned
break
}
}
// 获取对方用户的已读位置
otherLastReadSeq := int64(0)
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
}
// HandleGetMessages 获取会话消息
// GET /api/v1/conversations/get_messages?conversation_id=xxx
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
conversationID := c.Query("conversation_id")
if conversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
// 检查是否使用增量同步after_seq参数
afterSeqStr := c.Query("after_seq")
if afterSeqStr != "" {
// 增量同步模式
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid after_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 检查是否使用历史消息加载before_seq参数
beforeSeqStr := c.Query("before_seq")
if beforeSeqStr != "" {
// 加载更早的消息(下拉加载更多)
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid before_seq")
return
}
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Success(c, &dto.MessageSyncResponse{
Messages: result,
HasMore: len(messages) == limit,
})
return
}
// 分页模式
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
if err != nil {
response.BadRequest(c, err.Error())
return
}
// 转换为响应格式
result := dto.ConvertMessagesToResponse(messages)
response.Paginated(c, result, total, page, pageSize)
}
// HandleMarkRead 标记已读
// POST /api/v1/conversations/mark_read
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.MarkReadParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
if err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// HandleSetConversationPinned 设置会话置顶
// POST /api/v1/conversations/set_pinned
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var params dto.SetConversationPinnedParams
if err := c.ShouldBindJSON(&params); err != nil {
response.BadRequest(c, err.Error())
return
}
if params.ConversationID == "" {
response.BadRequest(c, "conversation_id is required")
return
}
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
response.BadRequest(c, err.Error())
return
}
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
"conversation_id": params.ConversationID,
"is_pinned": params.IsPinned,
})
}

View File

@@ -0,0 +1,132 @@
package handler
import (
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// NotificationHandler 通知处理器
type NotificationHandler struct {
notificationService *service.NotificationService
}
// NewNotificationHandler 创建通知处理器
func NewNotificationHandler(notificationService *service.NotificationService) *NotificationHandler {
return &NotificationHandler{
notificationService: notificationService,
}
}
// GetNotifications 获取通知列表
func (h *NotificationHandler) GetNotifications(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
unreadOnly := c.Query("unread_only") == "true"
notifications, total, err := h.notificationService.GetByUserID(c.Request.Context(), userID, page, pageSize, unreadOnly)
if err != nil {
response.InternalServerError(c, "failed to get notifications")
return
}
response.Paginated(c, notifications, total, page, pageSize)
}
// MarkAsRead 标记为已读
func (h *NotificationHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.notificationService.MarkAsReadWithUserID(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to mark as read")
return
}
response.SuccessWithMessage(c, "marked as read", nil)
}
// MarkAllAsRead 标记所有为已读
func (h *NotificationHandler) MarkAllAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
err := h.notificationService.MarkAllAsRead(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to mark all as read")
return
}
response.SuccessWithMessage(c, "all marked as read", nil)
}
// GetUnreadCount 获取未读数量
func (h *NotificationHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
count, err := h.notificationService.GetUnreadCount(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, gin.H{"count": count})
}
// DeleteNotification 删除通知
func (h *NotificationHandler) DeleteNotification(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
err := h.notificationService.DeleteNotification(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to delete notification")
return
}
response.Success(c, gin.H{"success": true})
}
// ClearAllNotifications 清空所有通知
func (h *NotificationHandler) ClearAllNotifications(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
err := h.notificationService.ClearAllNotifications(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to clear notifications")
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -0,0 +1,511 @@
package handler
import (
"errors"
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// PostHandler 帖子处理器
type PostHandler struct {
postService *service.PostService
userService *service.UserService
}
// NewPostHandler 创建帖子处理器
func NewPostHandler(postService *service.PostService, userService *service.UserService) *PostHandler {
return &PostHandler{
postService: postService,
userService: userService,
}
}
// Create 创建帖子
func (h *PostHandler) Create(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type CreateRequest struct {
Title string `json:"title" binding:"required"`
Content string `json:"content" binding:"required"`
Images []string `json:"images"`
}
var req CreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
post, err := h.postService.Create(c.Request.Context(), userID, req.Title, req.Content, req.Images)
if err != nil {
var moderationErr *service.PostModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.InternalServerError(c, "failed to create post")
return
}
response.Success(c, dto.ConvertPostToResponse(post, false, false))
}
// GetByID 获取帖子(不增加浏览量)
func (h *PostHandler) GetByID(c *gin.Context) {
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
// 非作者不可查看未发布内容
currentUserID := c.GetString("user_id")
if post.Status != model.PostStatusPublished && post.UserID != currentUserID {
response.NotFound(c, "post not found")
return
}
// 注意:不再自动增加浏览量,浏览量通过 RecordView 端点单独记录
// 获取当前用户ID用于判断点赞和收藏状态
fmt.Printf("[DEBUG] GetByID - postID: %s, currentUserID: %s\n", id, currentUserID)
var isLiked, isFavorited bool
if currentUserID != "" {
isLiked = h.postService.IsLiked(c.Request.Context(), id, currentUserID)
isFavorited = h.postService.IsFavorited(c.Request.Context(), id, currentUserID)
fmt.Printf("[DEBUG] GetByID - isLiked: %v, isFavorited: %v\n", isLiked, isFavorited)
} else {
fmt.Printf("[DEBUG] GetByID - user not logged in, isLiked: false, isFavorited: false\n")
}
// 如果有当前用户,检查与帖子作者的相互关注状态
var authorWithFollowStatus *dto.UserResponse
if currentUserID != "" && post.User != nil {
_, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), post.UserID, currentUserID)
if err == nil {
authorWithFollowStatus = dto.ConvertUserToResponseWithMutualFollow(post.User, isFollowing, isFollowingMe)
} else {
// 如果出错使用默认的author
authorWithFollowStatus = dto.ConvertUserToResponse(post.User)
}
}
// 构建响应
responseData := &dto.PostResponse{
ID: post.ID,
UserID: post.UserID,
Title: post.Title,
Content: post.Content,
Images: dto.ConvertPostImagesToResponse(post.Images),
LikesCount: post.LikesCount,
CommentsCount: post.CommentsCount,
FavoritesCount: post.FavoritesCount,
SharesCount: post.SharesCount,
ViewsCount: post.ViewsCount,
IsPinned: post.IsPinned,
IsLocked: post.IsLocked,
IsVote: post.IsVote,
CreatedAt: dto.FormatTime(post.CreatedAt),
Author: authorWithFollowStatus,
IsLiked: isLiked,
IsFavorited: isFavorited,
}
response.Success(c, responseData)
}
// RecordView 记录帖子浏览(增加浏览量)
func (h *PostHandler) RecordView(c *gin.Context) {
id := c.Param("id")
userID := c.GetString("user_id")
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
// 增加浏览量
if err := h.postService.IncrementViews(c.Request.Context(), id, userID); err != nil {
fmt.Printf("[DEBUG] Failed to increment views for post %s: %v\n", id, err)
response.InternalServerError(c, "failed to record view")
return
}
response.Success(c, gin.H{"success": true})
}
// List 获取帖子列表
func (h *PostHandler) List(c *gin.Context) {
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
userID := c.Query("user_id")
tab := c.Query("tab") // recommend, follow, hot, latest
// 获取当前用户ID
currentUserID := c.GetString("user_id")
var posts []*model.Post
var total int64
var err error
// 根据 tab 参数选择不同的获取方式
switch tab {
case "follow":
// 获取关注用户的帖子,需要登录
if currentUserID == "" {
response.Unauthorized(c, "请先登录")
return
}
posts, total, err = h.postService.GetFollowingPosts(c.Request.Context(), currentUserID, page, pageSize)
case "hot":
// 获取热门帖子
posts, total, err = h.postService.GetHotPosts(c.Request.Context(), page, pageSize)
case "recommend":
// 推荐帖子从Gorse获取个性化推荐
posts, total, err = h.postService.GetRecommendedPosts(c.Request.Context(), currentUserID, page, pageSize)
case "latest":
// 最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
default:
// 默认获取最新帖子
posts, total, err = h.postService.GetLatestPosts(c.Request.Context(), page, pageSize, userID)
}
if err != nil {
response.InternalServerError(c, "failed to get posts")
return
}
fmt.Printf("[DEBUG] List - tab: %s, currentUserID: %s, posts count: %d\n", tab, currentUserID, len(posts))
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLiked := h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
isLikedMap[post.ID] = isLiked
isFavoritedMap[post.ID] = isFavorited
fmt.Printf("[DEBUG] List - postID: %s, isLiked: %v, isFavorited: %v\n", post.ID, isLiked, isFavorited)
}
} else {
fmt.Printf("[DEBUG] List - user not logged in\n")
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// Update 更新帖子
func (h *PostHandler) Update(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
if post.UserID != userID {
response.Forbidden(c, "cannot update others' post")
return
}
type UpdateRequest struct {
Title string `json:"title"`
Content string `json:"content"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if req.Title != "" {
post.Title = req.Title
}
if req.Content != "" {
post.Content = req.Content
}
err = h.postService.Update(c.Request.Context(), post)
if err != nil {
response.InternalServerError(c, "failed to update post")
return
}
currentUserID := c.GetString("user_id")
var isLiked, isFavorited bool
if currentUserID != "" {
isLiked = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavorited = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Delete 删除帖子
func (h *PostHandler) Delete(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.NotFound(c, "post not found")
return
}
if post.UserID != userID {
response.Forbidden(c, "cannot delete others' post")
return
}
err = h.postService.Delete(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to delete post")
return
}
response.SuccessWithMessage(c, "post deleted", nil)
}
// Like 点赞帖子
func (h *PostHandler) Like(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Like - postID: %s, userID: %s\n", id, userID)
err := h.postService.Like(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to like post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Like - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Unlike 取消点赞
func (h *PostHandler) Unlike(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Unlike - postID: %s, userID: %s\n", id, userID)
err := h.postService.Unlike(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unlike post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Unlike - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Favorite 收藏帖子
func (h *PostHandler) Favorite(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Favorite - postID: %s, userID: %s\n", id, userID)
err := h.postService.Favorite(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to favorite post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Favorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// Unfavorite 取消收藏
func (h *PostHandler) Unfavorite(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
id := c.Param("id")
fmt.Printf("[DEBUG] Unfavorite - postID: %s, userID: %s\n", id, userID)
err := h.postService.Unfavorite(c.Request.Context(), id, userID)
if err != nil {
response.InternalServerError(c, "failed to unfavorite post")
return
}
// 获取更新后的帖子状态
post, err := h.postService.GetByID(c.Request.Context(), id)
if err != nil {
response.InternalServerError(c, "failed to get post")
return
}
isLiked := h.postService.IsLiked(c.Request.Context(), id, userID)
isFavorited := h.postService.IsFavorited(c.Request.Context(), id, userID)
fmt.Printf("[DEBUG] Unfavorite - postID: %s, isLiked: %v, isFavorited: %v\n", id, isLiked, isFavorited)
response.Success(c, dto.ConvertPostToResponse(post, isLiked, isFavorited))
}
// GetUserPosts 获取用户帖子列表
func (h *PostHandler) GetUserPosts(c *gin.Context) {
userID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.GetUserPosts(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get user posts")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// GetFavorites 获取收藏列表
func (h *PostHandler) GetFavorites(c *gin.Context) {
userID := c.Param("id")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.GetFavorites(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get favorites")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}
// Search 搜索帖子
func (h *PostHandler) Search(c *gin.Context) {
keyword := c.Query("keyword")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
posts, total, err := h.postService.Search(c.Request.Context(), keyword, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to search posts")
return
}
// 获取当前用户ID用于判断点赞和收藏状态
currentUserID := c.GetString("user_id")
isLikedMap := make(map[string]bool)
isFavoritedMap := make(map[string]bool)
if currentUserID != "" {
for _, post := range posts {
isLikedMap[post.ID] = h.postService.IsLiked(c.Request.Context(), post.ID, currentUserID)
isFavoritedMap[post.ID] = h.postService.IsFavorited(c.Request.Context(), post.ID, currentUserID)
}
}
// 转换为响应结构
postResponses := dto.ConvertPostsToResponse(posts, isLikedMap, isFavoritedMap)
response.Paginated(c, postResponses, total, page, pageSize)
}

View File

@@ -0,0 +1,157 @@
package handler
import (
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
"github.com/gin-gonic/gin"
)
// PushHandler 推送处理器
type PushHandler struct {
pushService service.PushService
}
// NewPushHandler 创建推送处理器
func NewPushHandler(pushService service.PushService) *PushHandler {
return &PushHandler{
pushService: pushService,
}
}
// RegisterDevice 注册设备
// POST /api/v1/push/devices
func (h *PushHandler) RegisterDevice(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req dto.RegisterDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 验证设备类型
deviceType := model.DeviceType(req.DeviceType)
if !isValidDeviceType(deviceType) {
response.BadRequest(c, "invalid device type")
return
}
err := h.pushService.RegisterDevice(c.Request.Context(), userID, req.DeviceID, deviceType, req.PushToken)
if err != nil {
response.InternalServerError(c, "failed to register device")
return
}
response.SuccessWithMessage(c, "device registered successfully", nil)
}
// UnregisterDevice 注销设备
// DELETE /api/v1/push/devices/:device_id
func (h *PushHandler) UnregisterDevice(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
deviceID := c.Param("device_id")
if deviceID == "" {
response.BadRequest(c, "device_id is required")
return
}
err := h.pushService.UnregisterDevice(c.Request.Context(), deviceID)
if err != nil {
response.InternalServerError(c, "failed to unregister device")
return
}
response.SuccessWithMessage(c, "device unregistered successfully", nil)
}
// GetMyDevices 获取当前用户的设备列表
// GET /api/v1/push/devices
func (h *PushHandler) GetMyDevices(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 这里需要从DeviceTokenRepository获取设备列表
// 由于PushService接口没有提供获取设备列表的方法我们暂时返回空列表
// TODO: 在PushService接口中添加GetUserDevices方法
_ = userID // 避免未使用变量警告
response.Success(c, []*dto.DeviceTokenResponse{})
}
// GetPushRecords 获取推送记录
// GET /api/v1/push/records
func (h *PushHandler) GetPushRecords(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
records, err := h.pushService.GetPendingPushes(c.Request.Context(), userID)
if err != nil {
response.InternalServerError(c, "failed to get push records")
return
}
response.Success(c, &dto.PushRecordListResponse{
Records: dto.PushRecordsToResponse(records),
Total: int64(len(records)),
})
}
// 辅助函数:验证设备类型
func isValidDeviceType(deviceType model.DeviceType) bool {
switch deviceType {
case model.DeviceTypeIOS, model.DeviceTypeAndroid, model.DeviceTypeWeb:
return true
default:
return false
}
}
// UpdateDeviceToken 更新设备推送Token
// PUT /api/v1/push/devices/:device_id/token
func (h *PushHandler) UpdateDeviceToken(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
deviceID := c.Param("device_id")
if deviceID == "" {
response.BadRequest(c, "device_id is required")
return
}
var req struct {
PushToken string `json:"push_token" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
err := h.pushService.UpdateDeviceToken(c.Request.Context(), deviceID, req.PushToken)
if err != nil {
response.InternalServerError(c, "failed to update device token")
return
}
response.SuccessWithMessage(c, "device token updated successfully", nil)
}

View File

@@ -0,0 +1,164 @@
package handler
import (
"net/http"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// StickerHandler 自定义表情处理器
type StickerHandler struct {
stickerService service.StickerService
}
// NewStickerHandler 创建自定义表情处理器
func NewStickerHandler(stickerService service.StickerService) *StickerHandler {
return &StickerHandler{
stickerService: stickerService,
}
}
// GetStickersRequest 获取表情列表请求
type GetStickersRequest struct {
UserID string `form:"user_id" binding:"required"`
}
// AddStickerRequest 添加表情请求
type AddStickerRequest struct {
URL string `json:"url" binding:"required"`
Width int `json:"width"`
Height int `json:"height"`
}
// DeleteStickerRequest 删除表情请求
type DeleteStickerRequest struct {
StickerID string `json:"sticker_id" binding:"required"`
}
// ReorderStickersRequest 重新排序请求
type ReorderStickersRequest struct {
Orders map[string]int `json:"orders" binding:"required"`
}
// CheckStickerRequest 检查表情是否存在请求
type CheckStickerRequest struct {
URL string `form:"url" binding:"required"`
}
// GetStickers 获取用户的表情列表
func (h *StickerHandler) GetStickers(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
stickers, err := h.stickerService.GetUserStickers(userID)
if err != nil {
response.InternalServerError(c, "failed to get stickers")
return
}
response.Success(c, gin.H{"stickers": stickers})
}
// AddSticker 添加表情
func (h *StickerHandler) AddSticker(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req AddStickerRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
sticker, err := h.stickerService.AddSticker(userID, req.URL, req.Width, req.Height)
if err != nil {
if err == service.ErrStickerAlreadyExists {
response.Error(c, http.StatusConflict, "sticker already exists")
return
}
if err == service.ErrInvalidStickerURL {
response.BadRequest(c, "invalid sticker url, only http/https is allowed")
return
}
response.InternalServerError(c, err.Error())
return
}
response.Success(c, gin.H{"sticker": sticker})
}
// DeleteSticker 删除表情
func (h *StickerHandler) DeleteSticker(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req DeleteStickerRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.stickerService.DeleteSticker(userID, req.StickerID); err != nil {
response.InternalServerError(c, err.Error())
return
}
response.SuccessWithMessage(c, "sticker deleted successfully", nil)
}
// ReorderStickers 重新排序表情
func (h *StickerHandler) ReorderStickers(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
var req ReorderStickersRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.stickerService.ReorderStickers(userID, req.Orders); err != nil {
response.InternalServerError(c, err.Error())
return
}
response.SuccessWithMessage(c, "stickers reordered successfully", nil)
}
// CheckStickerExists 检查表情是否存在
func (h *StickerHandler) CheckStickerExists(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
url := c.Query("url")
if url == "" {
response.BadRequest(c, "url is required")
return
}
exists, err := h.stickerService.CheckExists(userID, url)
if err != nil {
response.InternalServerError(c, err.Error())
return
}
response.Success(c, gin.H{"exists": exists})
}

View File

@@ -0,0 +1,154 @@
package handler
import (
"strconv"
"carrot_bbs/internal/cache"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/repository"
"carrot_bbs/internal/service"
)
// SystemMessageHandler 系统消息处理器
type SystemMessageHandler struct {
systemMsgService service.SystemMessageService
notifyRepo *repository.SystemNotificationRepository
}
// NewSystemMessageHandler 创建系统消息处理器
func NewSystemMessageHandler(
systemMsgService service.SystemMessageService,
notifyRepo *repository.SystemNotificationRepository,
) *SystemMessageHandler {
return &SystemMessageHandler{
systemMsgService: systemMsgService,
notifyRepo: notifyRepo,
}
}
// GetSystemMessages 获取系统消息列表
// GET /api/v1/messages/system
func (h *SystemMessageHandler) GetSystemMessages(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
// 获取当前用户的系统通知(从独立表中获取)
notifications, total, err := h.notifyRepo.GetByReceiverID(userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get system messages")
return
}
// 转换为响应格式
result := make([]*dto.SystemMessageResponse, 0)
for _, n := range notifications {
resp := dto.SystemNotificationToResponse(n)
result = append(result, resp)
}
response.Paginated(c, result, total, page, pageSize)
}
// GetUnreadCount 获取系统消息未读数
// GET /api/v1/messages/system/unread-count
func (h *SystemMessageHandler) GetUnreadCount(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 获取当前用户的未读通知数
unreadCount, err := h.notifyRepo.GetUnreadCount(userID)
if err != nil {
response.InternalServerError(c, "failed to get unread count")
return
}
response.Success(c, &dto.SystemUnreadCountResponse{
UnreadCount: unreadCount,
})
}
// MarkAsRead 标记系统消息为已读
// PUT /api/v1/messages/system/:id/read
func (h *SystemMessageHandler) MarkAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
notificationIDStr := c.Param("id")
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid notification id")
return
}
// 标记为已读
err = h.notifyRepo.MarkAsRead(notificationID, userID)
if err != nil {
response.InternalServerError(c, "failed to mark as read")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "marked as read", nil)
}
// MarkAllAsRead 标记所有系统消息为已读
// PUT /api/v1/messages/system/read-all
func (h *SystemMessageHandler) MarkAllAsRead(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
// 标记当前用户所有通知为已读
err := h.notifyRepo.MarkAllAsRead(userID)
if err != nil {
response.InternalServerError(c, "failed to mark all as read")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "all messages marked as read", nil)
}
// DeleteSystemMessage 删除系统消息
// DELETE /api/v1/messages/system/:id
func (h *SystemMessageHandler) DeleteSystemMessage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
notificationIDStr := c.Param("id")
notificationID, err := strconv.ParseInt(notificationIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "invalid notification id")
return
}
// 删除通知
err = h.notifyRepo.Delete(notificationID, userID)
if err != nil {
response.InternalServerError(c, "failed to delete notification")
return
}
cache.InvalidateUnreadSystem(cache.GetCache(), userID)
response.SuccessWithMessage(c, "notification deleted", nil)
}

View File

@@ -0,0 +1,90 @@
package handler
import (
"github.com/gin-gonic/gin"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// UploadHandler 上传处理器
type UploadHandler struct {
uploadService *service.UploadService
}
// NewUploadHandler 创建上传处理器
func NewUploadHandler(uploadService *service.UploadService) *UploadHandler {
return &UploadHandler{
uploadService: uploadService,
}
}
// UploadImage 上传图片
func (h *UploadHandler) UploadImage(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "image file is required")
return
}
url, err := h.uploadService.UploadImage(c.Request.Context(), file)
if err != nil {
response.InternalServerError(c, "failed to upload image")
return
}
response.Success(c, gin.H{"url": url})
}
// UploadAvatar 上传头像
func (h *UploadHandler) UploadAvatar(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "avatar file is required")
return
}
url, err := h.uploadService.UploadAvatar(c.Request.Context(), userID, file)
if err != nil {
response.InternalServerError(c, "failed to upload avatar")
return
}
response.Success(c, gin.H{"url": url})
}
// UploadCover 上传头图(个人主页封面)
func (h *UploadHandler) UploadCover(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
file, err := c.FormFile("image")
if err != nil {
response.BadRequest(c, "image file is required")
return
}
url, err := h.uploadService.UploadCover(c.Request.Context(), userID, file)
if err != nil {
response.InternalServerError(c, "failed to upload cover")
return
}
response.Success(c, gin.H{"url": url})
}

View File

@@ -0,0 +1,705 @@
package handler
import (
"fmt"
"strconv"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// UserHandler 用户处理器
type UserHandler struct {
userService *service.UserService
jwtService *service.JWTService
}
// NewUserHandler 创建用户处理器
func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{
userService: userService,
}
}
// Register 用户注册
func (h *UserHandler) Register(c *gin.Context) {
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required,min=6"`
Nickname string `json:"nickname" binding:"required"`
Phone string `json:"phone"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.Register(c.Request.Context(), req.Username, req.Email, req.Password, req.Nickname, req.Phone, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to register")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
type LoginRequest struct {
Username string `json:"username"`
Account string `json:"account"`
Password string `json:"password" binding:"required"`
}
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
account := req.Account
if account == "" {
account = req.Username
}
if account == "" {
response.BadRequest(c, "username or account is required")
return
}
user, err := h.userService.Login(c.Request.Context(), account, req.Password)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to login")
return
}
// 生成Token
accessToken, _ := h.jwtService.GenerateAccessToken(user.ID, user.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(user.ID, user.Username)
response.Success(c, gin.H{
"user": dto.ConvertUserToResponse(user),
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SendRegisterCode 发送注册验证码
func (h *UserHandler) SendRegisterCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendRegisterCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send register verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// SendPasswordResetCode 发送找回密码验证码
func (h *UserHandler) SendPasswordResetCode(c *gin.Context) {
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendPasswordResetCode(c.Request.Context(), req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send reset verification code")
return
}
response.Success(c, gin.H{"success": true})
}
// ResetPassword 找回密码并重置
func (h *UserHandler) ResetPassword(c *gin.Context) {
type ResetPasswordRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
var req ResetPasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.ResetPasswordByEmail(c.Request.Context(), req.Email, req.VerificationCode, req.NewPassword); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to reset password")
return
}
response.Success(c, gin.H{"success": true})
}
// GetCurrentUser 获取当前用户
func (h *UserHandler) GetCurrentUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// GetUserByID 获取指定用户
func (h *UserHandler) GetUserByID(c *gin.Context) {
id := c.Param("id")
currentUserID := c.GetString("user_id")
// 获取用户信息,包含双向关注状态
user, isFollowing, isFollowingMe, err := h.userService.GetUserByIDWithMutualFollowStatus(c.Request.Context(), id, currentUserID)
if err != nil {
response.NotFound(c, "user not found")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), id)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
// 转换为响应格式,包含双向关注状态和实时计算的帖子数量
userResponse := dto.ConvertUserToResponseWithMutualFollowAndPostsCount(user, isFollowing, isFollowingMe, int(postsCount))
response.Success(c, userResponse)
}
// UpdateUser 更新用户
func (h *UserHandler) UpdateUser(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type UpdateRequest struct {
Nickname string `json:"nickname"`
Bio string `json:"bio"`
Website string `json:"website"`
Location string `json:"location"`
Avatar string `json:"avatar"`
Phone *string `json:"phone"`
Email *string `json:"email"`
}
var req UpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
response.NotFound(c, "user not found")
return
}
if req.Nickname != "" {
user.Nickname = req.Nickname
}
if req.Bio != "" {
user.Bio = req.Bio
}
if req.Website != "" {
user.Website = req.Website
}
if req.Location != "" {
user.Location = req.Location
}
if req.Avatar != "" {
user.Avatar = req.Avatar
}
if req.Phone != nil {
user.Phone = req.Phone
}
if req.Email != nil {
if user.Email == nil || *user.Email != *req.Email {
user.EmailVerified = false
}
user.Email = req.Email
}
err = h.userService.UpdateUser(c.Request.Context(), user)
if err != nil {
response.InternalServerError(c, "failed to update user")
return
}
// 实时计算帖子数量
postsCount, err := h.userService.GetUserPostCount(c.Request.Context(), userID)
if err != nil {
// 如果获取失败,使用数据库中的值
postsCount = int64(user.PostsCount)
}
response.Success(c, dto.ConvertUserToDetailResponseWithPostsCount(user, int(postsCount)))
}
// SendEmailVerifyCode 发送当前用户邮箱验证码
func (h *UserHandler) SendEmailVerifyCode(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type SendCodeRequest struct {
Email string `json:"email" binding:"required,email"`
}
var req SendCodeRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.SendCurrentUserEmailVerifyCode(c.Request.Context(), userID, req.Email); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send email verify code")
return
}
response.Success(c, gin.H{"success": true})
}
// VerifyEmail 验证当前用户邮箱
func (h *UserHandler) VerifyEmail(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "")
return
}
type VerifyEmailRequest struct {
Email string `json:"email" binding:"required,email"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req VerifyEmailRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.userService.VerifyCurrentUserEmail(c.Request.Context(), userID, req.Email, req.VerificationCode); err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to verify email")
return
}
response.Success(c, gin.H{"success": true})
}
// RefreshToken 刷新Token
func (h *UserHandler) RefreshToken(c *gin.Context) {
type RefreshRequest struct {
RefreshToken string `json:"refresh_token" binding:"required"`
}
var req RefreshRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 解析 refresh token
claims, err := h.jwtService.ParseToken(req.RefreshToken)
if err != nil {
response.Unauthorized(c, "invalid refresh token")
return
}
// 生成新 token
accessToken, _ := h.jwtService.GenerateAccessToken(claims.UserID, claims.Username)
refreshToken, _ := h.jwtService.GenerateRefreshToken(claims.UserID, claims.Username)
response.Success(c, gin.H{
"token": accessToken,
"refresh_token": refreshToken,
})
}
// SetJWTService 设置JWT服务
func (h *UserHandler) SetJWTService(jwtSvc *service.JWTService) {
h.jwtService = jwtSvc
}
// FollowUser 关注用户
func (h *UserHandler) FollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
if userID == currentUserID {
response.BadRequest(c, "cannot follow yourself")
return
}
err := h.userService.FollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to follow user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnfollowUser 取消关注用户
func (h *UserHandler) UnfollowUser(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
err := h.userService.UnfollowUser(c.Request.Context(), currentUserID, userID)
if err != nil {
response.InternalServerError(c, "failed to unfollow user")
return
}
response.Success(c, gin.H{"success": true})
}
// BlockUser 拉黑用户
func (h *UserHandler) BlockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot block yourself")
return
}
err := h.userService.BlockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to block user")
return
}
response.Success(c, gin.H{"success": true})
}
// UnblockUser 取消拉黑
func (h *UserHandler) UnblockUser(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if targetUserID == currentUserID {
response.BadRequest(c, "cannot unblock yourself")
return
}
err := h.userService.UnblockUser(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to unblock user")
return
}
response.Success(c, gin.H{"success": true})
}
// GetBlockedUsers 获取黑名单列表
func (h *UserHandler) GetBlockedUsers(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
users, total, err := h.userService.GetBlockedUsers(c.Request.Context(), currentUserID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get blocked users")
return
}
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses := dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
response.Paginated(c, userResponses, total, page, pageSize)
}
// GetBlockStatus 获取拉黑状态
func (h *UserHandler) GetBlockStatus(c *gin.Context) {
targetUserID := c.Param("id")
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
if targetUserID == "" {
response.BadRequest(c, "target user id is required")
return
}
isBlocked, err := h.userService.IsBlocked(c.Request.Context(), currentUserID, targetUserID)
if err != nil {
response.InternalServerError(c, "failed to get block status")
return
}
response.Success(c, gin.H{"is_blocked": isBlocked})
}
// GetFollowingList 获取关注列表
func (h *UserHandler) GetFollowingList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
users, err := h.userService.GetFollowingList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get following list")
return
}
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// GetFollowersList 获取粉丝列表
func (h *UserHandler) GetFollowersList(c *gin.Context) {
userID := c.Param("id")
currentUserID := c.GetString("user_id")
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
fmt.Printf("[DEBUG] GetFollowersList: userID=%s, currentUserID=%s\n", userID, currentUserID)
users, err := h.userService.GetFollowersList(c.Request.Context(), userID, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to get followers list")
return
}
fmt.Printf("[DEBUG] GetFollowersList: found %d users\n", len(users))
// 如果已登录,获取双向关注状态和实时计算的帖子数量
var userResponses []*dto.UserResponse
if currentUserID != "" && len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
fmt.Printf("[DEBUG] GetFollowersList: checking mutual follow status for userIDs=%v\n", userIDs)
statusMap, _ := h.userService.GetMutualFollowStatus(c.Request.Context(), currentUserID, userIDs)
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, statusMap, postsCountMap)
} else if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Success(c, gin.H{
"list": userResponses,
})
}
// CheckUsername 检查用户名是否可用
func (h *UserHandler) CheckUsername(c *gin.Context) {
username := c.Query("username")
if username == "" {
response.BadRequest(c, "username is required")
return
}
available, err := h.userService.CheckUsernameAvailable(c.Request.Context(), username)
if err != nil {
response.InternalServerError(c, "failed to check username")
return
}
response.Success(c, gin.H{"available": available})
}
// ChangePassword 修改密码
func (h *UserHandler) ChangePassword(c *gin.Context) {
currentUserID := c.GetString("user_id")
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
VerificationCode string `json:"verification_code" binding:"required"`
}
var req ChangePasswordRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
err := h.userService.ChangePassword(c.Request.Context(), currentUserID, req.OldPassword, req.NewPassword, req.VerificationCode)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to change password")
return
}
response.Success(c, gin.H{"success": true})
}
// SendChangePasswordCode 发送修改密码验证码
func (h *UserHandler) SendChangePasswordCode(c *gin.Context) {
currentUserID := c.GetString("user_id")
if currentUserID == "" {
response.Unauthorized(c, "")
return
}
err := h.userService.SendChangePasswordCode(c.Request.Context(), currentUserID)
if err != nil {
if se, ok := err.(*service.ServiceError); ok {
response.Error(c, se.Code, se.Message)
return
}
response.InternalServerError(c, "failed to send change password code")
return
}
response.Success(c, gin.H{"success": true})
}
// Search 搜索用户
func (h *UserHandler) Search(c *gin.Context) {
keyword := c.Query("keyword")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
users, total, err := h.userService.Search(c.Request.Context(), keyword, page, pageSize)
if err != nil {
response.InternalServerError(c, "failed to search users")
return
}
// 获取实时计算的帖子数量
var userResponses []*dto.UserResponse
if len(users) > 0 {
userIDs := make([]string, len(users))
for i, u := range users {
userIDs[i] = u.ID
}
postsCountMap, _ := h.userService.GetUserPostCountBatch(c.Request.Context(), userIDs)
userResponses = dto.ConvertUsersToResponseWithMutualFollowAndPostsCount(users, nil, postsCountMap)
} else {
userResponses = dto.ConvertUsersToResponse(users)
}
response.Paginated(c, userResponses, total, page, pageSize)
}

View File

@@ -0,0 +1,216 @@
package handler
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/pkg/response"
"carrot_bbs/internal/service"
)
// VoteHandler 投票处理器
type VoteHandler struct {
voteService *service.VoteService
postService *service.PostService
}
// NewVoteHandler 创建投票处理器
func NewVoteHandler(voteService *service.VoteService, postService *service.PostService) *VoteHandler {
return &VoteHandler{
voteService: voteService,
postService: postService,
}
}
// CreateVotePost 创建投票帖子
// POST /api/v1/posts/vote
func (h *VoteHandler) CreateVotePost(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
var req dto.CreateVotePostRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
post, err := h.voteService.CreateVotePost(c.Request.Context(), userID, &req)
if err != nil {
var moderationErr *service.PostModerationRejectedError
if errors.As(err, &moderationErr) {
response.BadRequest(c, moderationErr.UserMessage())
return
}
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, post)
}
// GetVoteResult 获取投票结果
// GET /api/v1/posts/:id/vote
func (h *VoteHandler) GetVoteResult(c *gin.Context) {
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
// 获取当前用户ID可选登录
userID := c.GetString("user_id")
// 如果用户未登录返回不带has_voted的结果
if userID == "" {
options, err := h.voteService.GetVoteOptions(postID)
if err != nil {
response.InternalServerError(c, "获取投票选项失败")
return
}
// 计算总票数
totalVotes := 0
for _, option := range options {
totalVotes += option.VotesCount
}
result := &dto.VoteResultDTO{
Options: options,
TotalVotes: totalVotes,
HasVoted: false,
}
response.Success(c, result)
return
}
// 用户已登录,获取完整的投票结果
result, err := h.voteService.GetVoteResult(postID, userID)
if err != nil {
response.InternalServerError(c, "获取投票结果失败")
return
}
response.Success(c, result)
}
// Vote 投票
// POST /api/v1/posts/:id/vote
func (h *VoteHandler) Vote(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
// 解析请求体
var req struct {
OptionID string `json:"option_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := h.voteService.Vote(c.Request.Context(), postID, userID, req.OptionID); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}
// Unvote 取消投票
// DELETE /api/v1/posts/:id/vote
func (h *VoteHandler) Unvote(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
postID := c.Param("id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
// 验证帖子存在
_, err := h.postService.GetByID(c.Request.Context(), postID)
if err != nil {
response.NotFound(c, "帖子不存在")
return
}
if err := h.voteService.Unvote(c.Request.Context(), postID, userID); err != nil {
response.Error(c, http.StatusBadRequest, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}
// UpdateVoteOption 更新投票选项(仅作者)
// PUT /api/v1/vote-options/:id
func (h *VoteHandler) UpdateVoteOption(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
response.Unauthorized(c, "请先登录")
return
}
optionID := c.Param("id")
if optionID == "" {
response.BadRequest(c, "选项ID不能为空")
return
}
// 解析请求体
var req struct {
Content string `json:"content" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
// 获取帖子ID从查询参数或请求体中获取
postID := c.Query("post_id")
if postID == "" {
response.BadRequest(c, "帖子ID不能为空")
return
}
if err := h.voteService.UpdateVoteOption(c.Request.Context(), postID, optionID, userID, req.Content); err != nil {
response.Error(c, http.StatusForbidden, err.Error())
return
}
response.Success(c, gin.H{"success": true})
}

View File

@@ -0,0 +1,866 @@
package handler
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
"strings"
"time"
"carrot_bbs/internal/dto"
"carrot_bbs/internal/model"
ws "carrot_bbs/internal/pkg/websocket"
"carrot_bbs/internal/repository"
"carrot_bbs/internal/service"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源,生产环境应该限制
},
}
// WebSocketHandler WebSocket处理器
type WebSocketHandler struct {
jwtService *service.JWTService
chatService service.ChatService
groupService service.GroupService
groupRepo repository.GroupRepository
userRepo *repository.UserRepository
wsManager *ws.WebSocketManager
}
// NewWebSocketHandler 创建WebSocket处理器
func NewWebSocketHandler(
jwtService *service.JWTService,
chatService service.ChatService,
groupService service.GroupService,
groupRepo repository.GroupRepository,
userRepo *repository.UserRepository,
wsManager *ws.WebSocketManager,
) *WebSocketHandler {
return &WebSocketHandler{
jwtService: jwtService,
chatService: chatService,
groupService: groupService,
groupRepo: groupRepo,
userRepo: userRepo,
wsManager: wsManager,
}
}
// HandleWebSocket 处理WebSocket连接
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// 调试:打印请求头信息
log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path)
log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s",
c.GetHeader("Connection"),
c.GetHeader("Upgrade"))
log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s",
c.GetHeader("Sec-WebSocket-Key"),
c.GetHeader("Sec-WebSocket-Version"))
// 从query参数获取token
token := c.Query("token")
if token == "" {
// 尝试从header获取
authHeader := c.GetHeader("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"})
return
}
// 验证token
claims, err := h.jwtService.ParseToken(token)
if err != nil {
log.Printf("Invalid token: %v", err)
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
userID := claims.UserID
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"})
return
}
// 升级HTTP连接为WebSocket连接
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s",
c.GetHeader("User-Agent"),
c.GetHeader("Content-Type"))
return
}
// 如果用户已在线,先注销旧连接
if h.wsManager.IsUserOnline(userID) {
log.Printf("[DEBUG] 用户 %s 已有在线连接,复用该连接", userID)
} else {
log.Printf("[DEBUG] 用户 %s 当前不在线,创建新连接", userID)
}
// 创建客户端
client := &ws.Client{
ID: userID,
UserID: userID,
Conn: conn,
Send: make(chan []byte, 256),
Manager: h.wsManager,
}
// 注册客户端
h.wsManager.Register(client)
// 启动读写协程
go client.WritePump()
go h.handleMessages(client)
log.Printf("[DEBUG] WebSocket连接建立: userID=%s, 当前在线=%v", userID, h.wsManager.IsUserOnline(userID))
}
// handleMessages 处理客户端消息
// 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳
func (h *WebSocketHandler) handleMessages(client *ws.Client) {
defer func() {
h.wsManager.Unregister(client)
client.Conn.Close()
}()
client.Conn.SetReadLimit(512 * 1024) // 512KB
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
client.Conn.SetPongHandler(func(string) error {
client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒
return nil
})
// 心跳定时器 - 服务端主动 ping 间隔保持 30 秒
pingTicker := time.NewTicker(30 * time.Second)
defer pingTicker.Stop()
for {
select {
case <-pingTicker.C:
// 发送心跳
if err := client.SendPing(); err != nil {
log.Printf("Failed to send ping: %v", err)
return
}
default:
_, message, err := client.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
return
}
var wsMsg ws.WSMessage
if err := json.Unmarshal(message, &wsMsg); err != nil {
log.Printf("Failed to unmarshal message: %v", err)
continue
}
h.processMessage(client, &wsMsg)
}
}
}
// processMessage 处理消息
func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) {
switch msg.Type {
case ws.MessageTypePing:
// 响应心跳
if err := client.SendPong(); err != nil {
log.Printf("Failed to send pong: %v", err)
}
case ws.MessageTypePong:
// 客户端响应心跳
case ws.MessageTypeMessage:
// 处理聊天消息
h.handleChatMessage(client, msg)
case ws.MessageTypeTyping:
// 处理正在输入状态
h.handleTyping(client, msg)
case ws.MessageTypeRead:
// 处理已读回执
h.handleReadReceipt(client, msg)
// 群组消息处理
case ws.MessageTypeGroupMessage:
// 处理群消息
h.handleGroupMessage(client, msg)
case ws.MessageTypeGroupTyping:
// 处理群输入状态
h.handleGroupTyping(client, msg)
case ws.MessageTypeGroupRead:
// 处理群消息已读
h.handleGroupReadReceipt(client, msg)
case ws.MessageTypeGroupRecall:
// 处理群消息撤回
h.handleGroupRecall(client, msg)
default:
log.Printf("Unknown message type: %s", msg.Type)
}
}
// handleChatMessage 处理聊天消息
func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid message data format")
return
}
log.Printf("[DEBUG handleChatMessage] 完整data: %+v", data)
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
log.Printf("Missing conversationId")
return
}
// 解析会话ID
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
log.Printf("Invalid conversation ID: %v", err)
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 从 segments 中提取回复消息ID
replyToID := dto.GetReplyMessageID(segments)
var replyToIDPtr *string
if replyToID != "" {
replyToIDPtr = &replyToID
}
// 发送消息 - 使用 segments
message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr)
if err != nil {
log.Printf("Failed to send message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to send message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 发送确认消息(使用 meta 事件格式,包含完整的消息内容)
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"id": message.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": message.Seq,
"segments": message.Segments,
"created_at": message.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
log.Printf("[DEBUG handleChatMessage] 私聊 ack 消息: %s", string(msgBytes))
log.Printf("[DEBUG handleChatMessage] message.Segments 类型: %T, 值: %+v", message.Segments, message.Segments)
client.Send <- msgBytes
}
}
// handleTyping 处理正在输入状态
func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 直接使用 string 类型的 userID
h.chatService.SendTyping(context.Background(), client.UserID, conversationID)
}
// handleReadReceipt 处理已读回执
func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationIDStr, _ := data["conversationId"].(string)
if conversationIDStr == "" {
return
}
conversationID, err := service.ParseConversationID(conversationIDStr)
if err != nil {
return
}
// 获取lastReadSeq
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 直接使用 string 类型的 userID 和 conversationID
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark as read: %v", err)
}
}
// ==================== 群组消息处理 ====================
// handleGroupMessage 处理群消息
func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) {
// 打印接收到的消息类型和数据,用于调试
log.Printf("[handleGroupMessage] Received message type: %s", msg.Type)
log.Printf("[handleGroupMessage] Message data: %+v", msg.Data)
data, ok := msg.Data.(map[string]interface{})
if !ok {
log.Printf("Invalid group message data format: data is not map[string]interface{}")
return
}
// 解析群组ID支持 camelCase 和 snake_case
var groupIDFloat float64
groupID := "" // 使用 groupID 作为最终变量名
if val, ok := data["groupId"].(float64); ok {
groupIDFloat = val
groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
} else if val, ok := data["group_id"].(string); ok {
groupID = val
}
if groupID == "" {
log.Printf("Missing groupId in group message")
return
}
// 解析会话ID支持 camelCase 和 snake_case
var conversationID string
if val, ok := data["conversationId"].(string); ok {
conversationID = val
} else if val, ok := data["conversation_id"].(string); ok {
conversationID = val
}
if conversationID == "" {
log.Printf("Missing conversationId in group message")
return
}
// 解析 segments
var segments model.MessageSegments
if data["segments"] != nil {
segmentsBytes, err := json.Marshal(data["segments"])
if err == nil {
json.Unmarshal(segmentsBytes, &segments)
}
}
// 解析@用户列表(支持 camelCase 和 snake_case
var mentionUsers []uint64
var mentionUsersInterface []interface{}
if val, ok := data["mentionUsers"].([]interface{}); ok {
mentionUsersInterface = val
} else if val, ok := data["mention_users"].([]interface{}); ok {
mentionUsersInterface = val
}
if len(mentionUsersInterface) > 0 {
for _, uid := range mentionUsersInterface {
if uidFloat, ok := uid.(float64); ok {
mentionUsers = append(mentionUsers, uint64(uidFloat))
} else if uidStr, ok := uid.(string); ok {
// 处理字符串格式的用户ID
if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil {
mentionUsers = append(mentionUsers, uidInt)
}
}
}
}
// 解析@所有人(支持 camelCase 和 snake_case
var mentionAll bool
if val, ok := data["mentionAll"].(bool); ok {
mentionAll = val
} else if val, ok := data["mention_all"].(bool); ok {
mentionAll = val
}
// 检查用户是否可以发送群消息(验证成员身份和禁言状态)
// client.UserID 已经是 string 格式的 UUID
if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil {
log.Printf("User cannot send group message: %v", err)
// 发送错误消息
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Cannot send group message",
"reason": err.Error(),
"type": "group_message_error",
"groupId": groupID,
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 检查@所有人权限(只有群主和管理员可以@所有人)
if mentionAll {
if !h.groupService.IsGroupAdmin(client.UserID, groupID) {
log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID)
mentionAll = false // 取消@所有人标记
}
}
// 创建消息
message := &model.Message{
ConversationID: conversationID,
SenderID: client.UserID,
Segments: segments,
Status: model.MessageStatusNormal,
MentionAll: mentionAll,
}
// 序列化mentionUsers为JSON
if len(mentionUsers) > 0 {
mentionUsersJSON, _ := json.Marshal(mentionUsers)
message.MentionUsers = string(mentionUsersJSON)
}
// 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播)
savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil)
if err != nil {
log.Printf("Failed to save group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to save group message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 更新消息的mention信息
if len(mentionUsers) > 0 || mentionAll {
message.ID = savedMessage.ID
message.Seq = savedMessage.Seq
}
// 构造群消息响应
groupMsg := &ws.GroupMessage{
ID: savedMessage.ID,
ConversationID: conversationID,
GroupID: groupID,
SenderID: client.UserID,
Seq: savedMessage.Seq,
Segments: segments,
MentionUsers: mentionUsers,
MentionAll: mentionAll,
CreatedAt: savedMessage.CreatedAt.UnixMilli(),
}
// 广播消息给群组所有成员(排除发送者)
h.BroadcastGroupMessage(groupID, groupMsg, client.UserID)
// 发送确认消息给发送者作为meta事件
// 使用 meta 事件格式发送 ack
log.Printf("[DEBUG HandleGroupMessageSend] 准备发送 ack 消息, userID=%s, messageID=%s, seq=%d",
client.UserID, savedMessage.ID, savedMessage.Seq)
metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{
"detail_type": ws.MetaDetailTypeAck,
"conversation_id": conversationID,
"group_id": groupID,
"id": savedMessage.ID,
"user_id": client.UserID,
"sender_id": client.UserID,
"seq": savedMessage.Seq,
"segments": segments,
"created_at": savedMessage.CreatedAt.UnixMilli(),
})
if client.Send != nil {
msgBytes, _ := json.Marshal(metaAckMsg)
log.Printf("[DEBUG HandleGroupMessageSend] 发送 ack 消息到 channel, userID=%s, msg=%s",
client.UserID, string(msgBytes))
client.Send <- msgBytes
} else {
log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID)
}
// 处理@提及通知
if len(mentionUsers) > 0 || mentionAll {
// 提取文本正文(不含 @ 部分)
textContent := dto.ExtractTextContentFromModel(segments)
// 在通知内容前拼接被@的真实昵称,通过群成员列表查找
mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent)
h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll)
}
}
// handleGroupTyping 处理群输入状态
func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
isTyping, _ := data["isTyping"].(bool)
// 验证用户是否是群成员
// client.UserID 已经是 string 格式的 UUID
isMember, err := h.groupRepo.IsMember(groupID, client.UserID)
if err != nil || !isMember {
return
}
// 获取用户信息
user, err := h.userRepo.GetByID(client.UserID)
if err != nil {
return
}
// 构造输入状态消息
typingMsg := &ws.GroupTypingMessage{
GroupID: groupID,
UserID: client.UserID,
Username: user.Username,
IsTyping: isTyping,
}
// 广播给群组其他成员
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID)
}
// handleGroupReadReceipt 处理群消息已读回执
func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
conversationID, _ := data["conversationId"].(string)
if conversationID == "" {
return
}
lastReadSeq, _ := data["lastReadSeq"].(float64)
if lastReadSeq == 0 {
return
}
// 标记已读
if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil {
log.Printf("Failed to mark group message as read: %v", err)
}
}
// handleGroupRecall 处理群消息撤回
func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) {
data, ok := msg.Data.(map[string]interface{})
if !ok {
return
}
messageID, _ := data["messageId"].(string)
if messageID == "" {
return
}
groupIDFloat, _ := data["groupId"].(float64)
if groupIDFloat == 0 {
return
}
groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64)
// 撤回消息
if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil {
log.Printf("Failed to recall group message: %v", err)
errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{
"error": "Failed to recall message",
})
if client.Send != nil {
msgBytes, _ := json.Marshal(errorMsg)
client.Send <- msgBytes
}
return
}
// 广播撤回通知给群组所有成员
recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{
"messageId": messageID,
"groupId": groupID,
"userId": client.UserID,
"timestamp": time.Now().UnixMilli(),
})
h.BroadcastGroupNotice(groupID, recallNotice)
}
// handleGroupMention 处理群消息@提及通知
func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) {
// 如果@所有人,获取所有群成员
if mentionAll {
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for mention all: %v", err)
return
}
for _, member := range members {
// 不通知发送者自己
memberIDStr := member.UserID
if memberIDStr == senderID {
continue
}
// 发送@提及通知
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: true,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
return
}
// 处理特定用户的@提及
for _, userID := range mentionUsers {
// userID 是 uint64转换为 string
userIDStr := strconv.FormatUint(userID, 10)
if userIDStr == senderID {
continue // 不通知发送者自己
}
mentionMsg := &ws.GroupMentionMessage{
GroupID: groupID,
MessageID: messageID,
FromUserID: senderID,
Content: truncateContent(content, 50),
MentionAll: false,
CreatedAt: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg)
h.wsManager.SendToUser(userIDStr, wsMsg)
}
}
// buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称
func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string {
var prefix string
if mentionAll {
prefix = "@所有人 "
} else if len(mentionUsers) > 0 {
// 查询群成员列表,找到被@用户的昵称
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err == nil {
memberNickMap := make(map[string]string, len(members))
for _, m := range members {
displayName := m.Nickname
if displayName == "" {
displayName = m.UserID
}
memberNickMap[m.UserID] = displayName
}
for _, uid := range mentionUsers {
uidStr := strconv.FormatUint(uid, 10)
if name, ok := memberNickMap[uidStr]; ok {
prefix += "@" + name + " "
} else {
prefix += "@某人 "
}
}
} else {
for range mentionUsers {
prefix += "@某人 "
}
}
}
return prefix + textBody
}
// BroadcastGroupMessage 向群组所有成员广播消息
func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除发送者
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupNotice 广播群组通知给所有成员
func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户)
func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for notice broadcast: %v", err)
return
}
// 遍历成员,如果在线则发送通知
for _, member := range members {
memberIDStr := member.UserID
if memberIDStr == excludeUserID {
continue
}
h.wsManager.SendToUser(memberIDStr, notice)
}
}
// SendGroupMemberNotice 发送群成员变动通知
func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) {
notice := &ws.GroupNoticeMessage{
NoticeType: noticeType,
GroupID: groupID,
Data: data,
Timestamp: time.Now().UnixMilli(),
}
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice)
h.BroadcastGroupNotice(groupID, wsMsg)
}
// truncateContent 截断内容
func truncateContent(content string, maxLen int) string {
if len(content) <= maxLen {
return content
}
return content[:maxLen] + "..."
}
// BroadcastGroupTyping 向群组所有成员广播输入状态
func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for typing broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}
// BroadcastGroupRead 向群组所有成员广播已读状态
func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) {
// 获取群组所有成员
members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000)
if err != nil {
log.Printf("Failed to get group members for read broadcast: %v", err)
return
}
// 创建WebSocket消息
wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg)
// 遍历成员,如果在线则发送消息
for _, member := range members {
memberIDStr := member.UserID
// 排除指定用户
if memberIDStr == excludeUserID {
continue
}
// 发送消息
h.wsManager.SendToUser(memberIDStr, wsMsg)
}
}