package handler import ( "context" "fmt" "strconv" "github.com/gin-gonic/gin" "carrot_bbs/internal/dto" "carrot_bbs/internal/model" "carrot_bbs/internal/pkg/response" "carrot_bbs/internal/service" ) // MessageHandler 消息处理器 type MessageHandler struct { chatService service.ChatService messageService *service.MessageService userService *service.UserService groupService service.GroupService } // NewMessageHandler 创建消息处理器 func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService) *MessageHandler { return &MessageHandler{ chatService: chatService, messageService: messageService, userService: userService, groupService: groupService, } } // GetConversations 获取会话列表 // GET /api/conversations func (h *MessageHandler) GetConversations(c *gin.Context) { userID := c.GetString("user_id") // 添加调试日志 if userID == "" { response.Unauthorized(c, "") return } page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize) if err != nil { response.InternalServerError(c, "failed to get conversations") return } // 过滤掉系统会话(系统通知现在使用独立的表) filteredConvs := make([]*model.Conversation, 0) for _, conv := range convs { if conv.ID != model.SystemConversationID { filteredConvs = append(filteredConvs, conv) } } // 转换为响应格式 result := make([]*dto.ConversationResponse, len(filteredConvs)) for i, conv := range filteredConvs { // 获取未读数 unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID) // 获取最后一条消息 var lastMessage *model.Message messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1) if len(messages) > 0 { lastMessage = messages[0] } // 群聊时返回member_count,私聊时返回participants var resp *dto.ConversationResponse myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) isPinned := myParticipant != nil && myParticipant.IsPinned if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" { // 群聊:实时计算群成员数量 memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID) // 创建响应并设置member_count resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned) resp.MemberCount = memberCount } else { // 私聊:获取参与者信息 participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID) resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned) } result[i] = resp } // 更新 total 为过滤后的数量 response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize) } // CreateConversation 创建私聊会话 // POST /api/conversations func (h *MessageHandler) CreateConversation(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var req dto.CreateConversationRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, err.Error()) return } // 验证目标用户是否存在 targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID) if err != nil { response.BadRequest(c, "target user not found") return } // 不能和自己创建会话 if userID == req.UserID { response.BadRequest(c, "cannot create conversation with yourself") return } conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID) if err != nil { response.InternalServerError(c, "failed to create conversation") return } // 获取参与者信息 participants := []*model.User{targetUser} myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) isPinned := myParticipant != nil && myParticipant.IsPinned response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned)) } // GetConversationByID 获取会话详情 // GET /api/conversations/:id func (h *MessageHandler) GetConversationByID(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationIDStr := c.Param("id") fmt.Printf("[DEBUG] GetConversationByID: conversationIDStr = %s\n", conversationIDStr) conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { fmt.Printf("[DEBUG] GetConversationByID: failed to parse conversation ID: %v\n", err) response.BadRequest(c, "invalid conversation id") return } fmt.Printf("[DEBUG] GetConversationByID: conversationID = %s\n", conversationID) conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID) if err != nil { response.BadRequest(c, err.Error()) return } // 获取未读数 unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) // 获取参与者信息 participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID) // 获取当前用户的已读位置 myLastReadSeq := int64(0) isPinned := false allParticipants, _ := h.messageService.GetConversationParticipants(conversationID) for _, p := range allParticipants { if p.UserID == userID { myLastReadSeq = p.LastReadSeq isPinned = p.IsPinned break } } // 获取对方用户的已读位置 otherLastReadSeq := int64(0) response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned)) } // GetMessages 获取消息列表 // GET /api/conversations/:id/messages func (h *MessageHandler) GetMessages(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationIDStr := c.Param("id") conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { response.BadRequest(c, "invalid conversation id") return } // 检查是否使用增量同步(after_seq参数) afterSeqStr := c.Query("after_seq") if afterSeqStr != "" { // 增量同步模式 afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64) if err != nil { response.BadRequest(c, "invalid after_seq") return } limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Success(c, &dto.MessageSyncResponse{ Messages: result, HasMore: len(messages) == limit, }) return } // 检查是否使用历史消息加载(before_seq参数) beforeSeqStr := c.Query("before_seq") if beforeSeqStr != "" { // 加载更早的消息(下拉加载更多) beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64) if err != nil { response.BadRequest(c, "invalid before_seq") return } limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Success(c, &dto.MessageSyncResponse{ Messages: result, HasMore: len(messages) == limit, }) return } // 分页模式 page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Paginated(c, result, total, page, pageSize) } // SendMessage 发送消息 // POST /api/conversations/:id/messages func (h *MessageHandler) SendMessage(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationIDStr := c.Param("id") fmt.Printf("[DEBUG] SendMessage: conversationIDStr = %s\n", conversationIDStr) conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { fmt.Printf("[DEBUG] SendMessage: failed to parse conversation ID: %v\n", err) response.BadRequest(c, "invalid conversation id") return } fmt.Printf("[DEBUG] SendMessage: conversationID = %s, userID = %s\n", conversationID, userID) var req dto.SendMessageRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, err.Error()) return } // 直接使用 segments msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID) if err != nil { response.BadRequest(c, err.Error()) return } response.Success(c, dto.ConvertMessageToResponse(msg)) } // HandleSendMessage RESTful 风格的发送消息端点 // POST /api/v1/conversations/send_message // 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]} func (h *MessageHandler) HandleSendMessage(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var params dto.SendMessageParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } // 验证参数 if params.ConversationID == "" { response.BadRequest(c, "conversation_id is required") return } if params.DetailType == "" { response.BadRequest(c, "detail_type is required") return } if params.Segments == nil || len(params.Segments) == 0 { response.BadRequest(c, "segments is required") return } // 发送消息 msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID) if err != nil { response.BadRequest(c, err.Error()) return } // 构建 WSEventResponse 格式响应 wsResponse := dto.WSEventResponse{ ID: msg.ID, Time: msg.CreatedAt.UnixMilli(), Type: "message", DetailType: params.DetailType, Seq: strconv.FormatInt(msg.Seq, 10), Segments: params.Segments, SenderID: userID, } response.Success(c, wsResponse) } // HandleDeleteMsg 撤回消息 // POST /api/v1/messages/delete_msg // 请求体格式: {"message_id": "xxx"} func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var params dto.DeleteMsgParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } // 验证参数 if params.MessageID == "" { response.BadRequest(c, "message_id is required") return } // 撤回消息 err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID) if err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "消息已撤回", nil) } // HandleGetConversationList 获取会话列表 // GET /api/v1/conversations/list func (h *MessageHandler) HandleGetConversationList(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize) if err != nil { response.InternalServerError(c, "failed to get conversations") return } // 过滤掉系统会话(系统通知现在使用独立的表) filteredConvs := make([]*model.Conversation, 0) for _, conv := range convs { if conv.ID != model.SystemConversationID { filteredConvs = append(filteredConvs, conv) } } // 转换为响应格式 result := make([]*dto.ConversationResponse, len(filteredConvs)) for i, conv := range filteredConvs { // 获取未读数 unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID) // 获取最后一条消息 var lastMessage *model.Message messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1) if len(messages) > 0 { lastMessage = messages[0] } // 群聊时返回member_count,私聊时返回participants var resp *dto.ConversationResponse myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) isPinned := myParticipant != nil && myParticipant.IsPinned if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" { // 群聊:实时计算群成员数量 memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID) // 创建响应并设置member_count resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned) resp.MemberCount = memberCount } else { // 私聊:获取参与者信息 participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID) resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned) } result[i] = resp } response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize) } // HandleDeleteConversationForSelf 仅自己删除会话 // DELETE /api/v1/conversations/:id/self func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationID := c.Param("id") if conversationID == "" { response.BadRequest(c, "conversation id is required") return } if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "conversation deleted for self", nil) } // MarkAsRead 标记为已读 // POST /api/conversations/:id/read func (h *MessageHandler) MarkAsRead(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationIDStr := c.Param("id") conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { response.BadRequest(c, "invalid conversation id") return } var req dto.MarkReadRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "last_read_seq is required") return } err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq) if err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "marked as read", nil) } // GetUnreadCount 获取未读消息总数 // GET /api/conversations/unread/count func (h *MessageHandler) GetUnreadCount(c *gin.Context) { userID := c.GetString("user_id") // 添加调试日志 fmt.Printf("[DEBUG] GetUnreadCount: user_id from context = %q\n", userID) if userID == "" { fmt.Printf("[DEBUG] GetUnreadCount: user_id is empty, returning 401\n") response.Unauthorized(c, "") return } count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID) if err != nil { response.InternalServerError(c, "failed to get unread count") return } response.Success(c, &dto.UnreadCountResponse{ TotalUnreadCount: count, }) } // GetConversationUnreadCount 获取单个会话的未读数 // GET /api/conversations/:id/unread/count func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationIDStr := c.Param("id") conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { response.BadRequest(c, "invalid conversation id") return } count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) if err != nil { response.BadRequest(c, err.Error()) return } response.Success(c, &dto.ConversationUnreadCountResponse{ ConversationID: conversationID, UnreadCount: count, }) } // RecallMessage 撤回消息 // POST /api/messages/:id/recall func (h *MessageHandler) RecallMessage(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } messageIDStr := c.Param("id") // 直接使用 string 类型的 messageID err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID) if err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "message recalled", nil) } // DeleteMessage 删除消息 // DELETE /api/messages/:id func (h *MessageHandler) DeleteMessage(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } messageIDStr := c.Param("id") // 直接使用 string 类型的 messageID err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID) if err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "message deleted", nil) } // 辅助函数:验证内容类型 func isValidContentType(contentType model.ContentType) bool { switch contentType { case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile: return true default: return false } } // 辅助函数:获取会话参与者信息 func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) { // 从repository获取参与者列表 participants, err := h.messageService.GetConversationParticipants(conversationID) if err != nil { return nil, err } // 获取参与者用户信息 var users []*model.User for _, p := range participants { // 跳过当前用户 if p.UserID == currentUserID { continue } user, err := h.userService.GetUserByID(ctx, p.UserID) if err != nil { continue } users = append(users, user) } return users, nil } // 获取当前用户在会话中的参与者信息 func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) { participants, err := h.messageService.GetConversationParticipants(conversationID) if err != nil { return nil, err } for _, p := range participants { if p.UserID == userID { return p, nil } } return nil, nil } // ==================== RESTful Action 端点 ==================== // HandleCreateConversation 创建会话 // POST /api/v1/conversations/create func (h *MessageHandler) HandleCreateConversation(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var params dto.CreateConversationParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } // 验证目标用户是否存在 targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID) if err != nil { response.BadRequest(c, "target user not found") return } // 不能和自己创建会话 if userID == params.UserID { response.BadRequest(c, "cannot create conversation with yourself") return } conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID) if err != nil { response.InternalServerError(c, "failed to create conversation") return } // 获取参与者信息 participants := []*model.User{targetUser} myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID) isPinned := myParticipant != nil && myParticipant.IsPinned response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned)) } // HandleGetConversation 获取会话详情 // GET /api/v1/conversations/get?conversation_id=xxx func (h *MessageHandler) HandleGetConversation(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationID := c.Query("conversation_id") if conversationID == "" { response.BadRequest(c, "conversation_id is required") return } conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID) if err != nil { response.BadRequest(c, err.Error()) return } // 获取未读数 unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID) // 获取参与者信息 participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID) // 获取当前用户的已读位置 myLastReadSeq := int64(0) isPinned := false allParticipants, _ := h.messageService.GetConversationParticipants(conversationID) for _, p := range allParticipants { if p.UserID == userID { myLastReadSeq = p.LastReadSeq isPinned = p.IsPinned break } } // 获取对方用户的已读位置 otherLastReadSeq := int64(0) response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned)) } // HandleGetMessages 获取会话消息 // GET /api/v1/conversations/get_messages?conversation_id=xxx func (h *MessageHandler) HandleGetMessages(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } conversationID := c.Query("conversation_id") if conversationID == "" { response.BadRequest(c, "conversation_id is required") return } // 检查是否使用增量同步(after_seq参数) afterSeqStr := c.Query("after_seq") if afterSeqStr != "" { // 增量同步模式 afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64) if err != nil { response.BadRequest(c, "invalid after_seq") return } limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100")) messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Success(c, &dto.MessageSyncResponse{ Messages: result, HasMore: len(messages) == limit, }) return } // 检查是否使用历史消息加载(before_seq参数) beforeSeqStr := c.Query("before_seq") if beforeSeqStr != "" { // 加载更早的消息(下拉加载更多) beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64) if err != nil { response.BadRequest(c, "invalid before_seq") return } limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20")) messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Success(c, &dto.MessageSyncResponse{ Messages: result, HasMore: len(messages) == limit, }) return } // 分页模式 page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize) if err != nil { response.BadRequest(c, err.Error()) return } // 转换为响应格式 result := dto.ConvertMessagesToResponse(messages) response.Paginated(c, result, total, page, pageSize) } // HandleMarkRead 标记已读 // POST /api/v1/conversations/mark_read func (h *MessageHandler) HandleMarkRead(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var params dto.MarkReadParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } if params.ConversationID == "" { response.BadRequest(c, "conversation_id is required") return } err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq) if err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "marked as read", nil) } // HandleSetConversationPinned 设置会话置顶 // POST /api/v1/conversations/set_pinned func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) { userID := c.GetString("user_id") if userID == "" { response.Unauthorized(c, "") return } var params dto.SetConversationPinnedParams if err := c.ShouldBindJSON(¶ms); err != nil { response.BadRequest(c, err.Error()) return } if params.ConversationID == "" { response.BadRequest(c, "conversation_id is required") return } if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil { response.BadRequest(c, err.Error()) return } response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{ "conversation_id": params.ConversationID, "is_pinned": params.IsPinned, }) }