Align group and conversation handlers/services with path-based endpoints, and unify response/service error handling for related modules. Made-with: Cursor
974 lines
26 KiB
Go
974 lines
26 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
|
||
"carrot_bbs/internal/dto"
|
||
"carrot_bbs/internal/model"
|
||
"carrot_bbs/internal/pkg/response"
|
||
"carrot_bbs/internal/pkg/sse"
|
||
"carrot_bbs/internal/service"
|
||
)
|
||
|
||
// MessageHandler 消息处理器
|
||
type MessageHandler struct {
|
||
chatService service.ChatService
|
||
messageService *service.MessageService
|
||
userService *service.UserService
|
||
groupService service.GroupService
|
||
sseHub *sse.Hub
|
||
}
|
||
|
||
// NewMessageHandler 创建消息处理器
|
||
func NewMessageHandler(chatService service.ChatService, messageService *service.MessageService, userService *service.UserService, groupService service.GroupService, sseHub *sse.Hub) *MessageHandler {
|
||
return &MessageHandler{
|
||
chatService: chatService,
|
||
messageService: messageService,
|
||
userService: userService,
|
||
groupService: groupService,
|
||
sseHub: sseHub,
|
||
}
|
||
}
|
||
|
||
// HandleSSE 实时消息订阅(SSE)
|
||
// GET /api/v1/realtime/sse
|
||
func (h *MessageHandler) HandleSSE(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
if h.sseHub == nil {
|
||
response.InternalServerError(c, "sse hub not available")
|
||
return
|
||
}
|
||
|
||
lastID := sse.ParseEventID(c.GetHeader("Last-Event-ID"))
|
||
if lastID == 0 {
|
||
lastID = sse.ParseEventID(c.Query("last_event_id"))
|
||
}
|
||
ch, cancel, replay := h.sseHub.Subscribe(userID, lastID)
|
||
defer cancel()
|
||
|
||
w := c.Writer
|
||
flusher, ok := w.(http.Flusher)
|
||
if !ok {
|
||
response.InternalServerError(c, "streaming unsupported")
|
||
return
|
||
}
|
||
|
||
w.Header().Set("Content-Type", "text/event-stream")
|
||
w.Header().Set("Cache-Control", "no-cache")
|
||
w.Header().Set("Connection", "keep-alive")
|
||
w.Header().Set("X-Accel-Buffering", "no")
|
||
c.Status(http.StatusOK)
|
||
flusher.Flush()
|
||
|
||
writeEvent := func(ev sse.Event) bool {
|
||
data, err := sse.EncodeData(ev)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Event, data); err != nil {
|
||
return false
|
||
}
|
||
flusher.Flush()
|
||
return true
|
||
}
|
||
|
||
for _, ev := range replay {
|
||
if !writeEvent(ev) {
|
||
return
|
||
}
|
||
}
|
||
|
||
heartbeat := time.NewTicker(25 * time.Second)
|
||
defer heartbeat.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-c.Request.Context().Done():
|
||
return
|
||
case ev, ok := <-ch:
|
||
if !ok || !writeEvent(ev) {
|
||
return
|
||
}
|
||
case <-heartbeat.C:
|
||
if _, err := fmt.Fprint(w, "event: heartbeat\ndata: {}\n\n"); err != nil {
|
||
return
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
|
||
// HandleTyping 输入状态上报
|
||
// POST /api/v1/conversations/typing
|
||
func (h *MessageHandler) HandleTyping(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
var params struct {
|
||
ConversationID string `json:"conversation_id" binding:"required"`
|
||
}
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
h.chatService.SendTyping(c.Request.Context(), userID, params.ConversationID)
|
||
response.SuccessWithMessage(c, "typing sent", nil)
|
||
}
|
||
|
||
// GetConversations 获取会话列表
|
||
// GET /api/conversations
|
||
func (h *MessageHandler) GetConversations(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
// 添加调试日志
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
|
||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||
if err != nil {
|
||
response.InternalServerError(c, "failed to get conversations")
|
||
return
|
||
}
|
||
|
||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||
filteredConvs := make([]*model.Conversation, 0)
|
||
for _, conv := range convs {
|
||
if conv.ID != model.SystemConversationID {
|
||
filteredConvs = append(filteredConvs, conv)
|
||
}
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||
for i, conv := range filteredConvs {
|
||
// 获取未读数
|
||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||
|
||
// 获取最后一条消息
|
||
var lastMessage *model.Message
|
||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||
if len(messages) > 0 {
|
||
lastMessage = messages[0]
|
||
}
|
||
|
||
// 群聊时返回member_count,私聊时返回participants
|
||
var resp *dto.ConversationResponse
|
||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||
// 群聊:实时计算群成员数量
|
||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||
// 创建响应并设置member_count
|
||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||
resp.MemberCount = memberCount
|
||
} else {
|
||
// 私聊:获取参与者信息
|
||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||
}
|
||
result[i] = resp
|
||
}
|
||
|
||
// 更新 total 为过滤后的数量
|
||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||
}
|
||
|
||
// CreateConversation 创建私聊会话
|
||
// POST /api/conversations
|
||
func (h *MessageHandler) CreateConversation(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var req dto.CreateConversationRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 验证目标用户是否存在
|
||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), req.UserID)
|
||
if err != nil {
|
||
response.BadRequest(c, "target user not found")
|
||
return
|
||
}
|
||
|
||
// 不能和自己创建会话
|
||
if userID == req.UserID {
|
||
response.BadRequest(c, "cannot create conversation with yourself")
|
||
return
|
||
}
|
||
|
||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, req.UserID)
|
||
if err != nil {
|
||
response.InternalServerError(c, "failed to create conversation")
|
||
return
|
||
}
|
||
|
||
// 获取参与者信息
|
||
participants := []*model.User{targetUser}
|
||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||
|
||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||
}
|
||
|
||
// GetConversationByID 获取会话详情
|
||
// GET /api/conversations/:id
|
||
func (h *MessageHandler) GetConversationByID(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationIDStr := c.Param("id")
|
||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid conversation id")
|
||
return
|
||
}
|
||
|
||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 获取未读数
|
||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||
|
||
// 获取参与者信息
|
||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||
|
||
// 获取当前用户的已读位置
|
||
myLastReadSeq := int64(0)
|
||
isPinned := false
|
||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||
for _, p := range allParticipants {
|
||
if p.UserID == userID {
|
||
myLastReadSeq = p.LastReadSeq
|
||
isPinned = p.IsPinned
|
||
break
|
||
}
|
||
}
|
||
|
||
// 获取对方用户的已读位置
|
||
otherLastReadSeq := int64(0)
|
||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||
}
|
||
|
||
// GetMessages 获取消息列表
|
||
// GET /api/conversations/:id/messages
|
||
func (h *MessageHandler) GetMessages(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationIDStr := c.Param("id")
|
||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid conversation id")
|
||
return
|
||
}
|
||
|
||
// 检查是否使用增量同步(after_seq参数)
|
||
afterSeqStr := c.Query("after_seq")
|
||
if afterSeqStr != "" {
|
||
// 增量同步模式
|
||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid after_seq")
|
||
return
|
||
}
|
||
|
||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||
|
||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Success(c, &dto.MessageSyncResponse{
|
||
Messages: result,
|
||
HasMore: len(messages) == limit,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否使用历史消息加载(before_seq参数)
|
||
beforeSeqStr := c.Query("before_seq")
|
||
if beforeSeqStr != "" {
|
||
// 加载更早的消息(下拉加载更多)
|
||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid before_seq")
|
||
return
|
||
}
|
||
|
||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||
|
||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Success(c, &dto.MessageSyncResponse{
|
||
Messages: result,
|
||
HasMore: len(messages) == limit,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 分页模式
|
||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
|
||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Paginated(c, result, total, page, pageSize)
|
||
}
|
||
|
||
// SendMessage 发送消息
|
||
// POST /api/conversations/:id/messages
|
||
func (h *MessageHandler) SendMessage(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationIDStr := c.Param("id")
|
||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid conversation id")
|
||
return
|
||
}
|
||
|
||
var req dto.SendMessageRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 直接使用 segments
|
||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, conversationID, req.Segments, req.ReplyToID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.Success(c, dto.ConvertMessageToResponse(msg))
|
||
}
|
||
|
||
// HandleSendMessage RESTful 风格的发送消息端点
|
||
// POST /api/v1/conversations/send_message
|
||
// 请求体格式: {"detail_type": "private", "conversation_id": "123445667", "segments": [{"type": "text", "data": {"text": "嗨~"}}]}
|
||
func (h *MessageHandler) HandleSendMessage(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var params dto.SendMessageParams
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 验证参数
|
||
if params.ConversationID == "" {
|
||
response.BadRequest(c, "conversation_id is required")
|
||
return
|
||
}
|
||
if params.DetailType == "" {
|
||
response.BadRequest(c, "detail_type is required")
|
||
return
|
||
}
|
||
if params.Segments == nil || len(params.Segments) == 0 {
|
||
response.BadRequest(c, "segments is required")
|
||
return
|
||
}
|
||
|
||
// 发送消息
|
||
msg, err := h.chatService.SendMessage(c.Request.Context(), userID, params.ConversationID, params.Segments, params.ReplyToID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 构建 WSEventResponse 格式响应
|
||
wsResponse := dto.WSEventResponse{
|
||
ID: msg.ID,
|
||
Time: msg.CreatedAt.UnixMilli(),
|
||
Type: "message",
|
||
DetailType: params.DetailType,
|
||
Seq: strconv.FormatInt(msg.Seq, 10),
|
||
Segments: params.Segments,
|
||
SenderID: userID,
|
||
}
|
||
|
||
response.Success(c, wsResponse)
|
||
}
|
||
|
||
// HandleDeleteMsg 撤回消息
|
||
// POST /api/v1/messages/delete_msg
|
||
// 请求体格式: {"message_id": "xxx"}
|
||
func (h *MessageHandler) HandleDeleteMsg(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var params dto.DeleteMsgParams
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 验证参数
|
||
if params.MessageID == "" {
|
||
response.BadRequest(c, "message_id is required")
|
||
return
|
||
}
|
||
|
||
// 撤回消息
|
||
err := h.chatService.RecallMessage(c.Request.Context(), params.MessageID, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "消息已撤回", nil)
|
||
}
|
||
|
||
// HandleGetConversationList 获取会话列表
|
||
// GET /api/v1/conversations/list
|
||
func (h *MessageHandler) HandleGetConversationList(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
|
||
convs, _, err := h.chatService.GetConversationList(c.Request.Context(), userID, page, pageSize)
|
||
if err != nil {
|
||
response.InternalServerError(c, "failed to get conversations")
|
||
return
|
||
}
|
||
|
||
// 过滤掉系统会话(系统通知现在使用独立的表)
|
||
filteredConvs := make([]*model.Conversation, 0)
|
||
for _, conv := range convs {
|
||
if conv.ID != model.SystemConversationID {
|
||
filteredConvs = append(filteredConvs, conv)
|
||
}
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := make([]*dto.ConversationResponse, len(filteredConvs))
|
||
for i, conv := range filteredConvs {
|
||
// 获取未读数
|
||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conv.ID, userID)
|
||
|
||
// 获取最后一条消息
|
||
var lastMessage *model.Message
|
||
messages, _, _ := h.chatService.GetMessages(c.Request.Context(), conv.ID, userID, 1, 1)
|
||
if len(messages) > 0 {
|
||
lastMessage = messages[0]
|
||
}
|
||
|
||
// 群聊时返回member_count,私聊时返回participants
|
||
var resp *dto.ConversationResponse
|
||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||
if conv.Type == model.ConversationTypeGroup && conv.GroupID != nil && *conv.GroupID != "" {
|
||
// 群聊:实时计算群成员数量
|
||
memberCount, _ := h.groupService.GetMemberCount(*conv.GroupID)
|
||
// 创建响应并设置member_count
|
||
resp = dto.ConvertConversationToResponse(conv, nil, int(unreadCount), lastMessage, isPinned)
|
||
resp.MemberCount = memberCount
|
||
} else {
|
||
// 私聊:获取参与者信息
|
||
participants, _ := h.getConversationParticipants(c.Request.Context(), conv.ID, userID)
|
||
resp = dto.ConvertConversationToResponse(conv, participants, int(unreadCount), lastMessage, isPinned)
|
||
}
|
||
result[i] = resp
|
||
}
|
||
|
||
response.Paginated(c, result, int64(len(filteredConvs)), page, pageSize)
|
||
}
|
||
|
||
// HandleDeleteConversationForSelf 仅自己删除会话
|
||
// DELETE /api/v1/conversations/:id/self
|
||
func (h *MessageHandler) HandleDeleteConversationForSelf(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationID := getIDParam(c, "id")
|
||
if conversationID == "" {
|
||
response.BadRequest(c, "conversation id is required")
|
||
return
|
||
}
|
||
|
||
if err := h.chatService.DeleteConversationForSelf(c.Request.Context(), conversationID, userID); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "conversation deleted for self", nil)
|
||
}
|
||
|
||
// MarkAsRead 标记为已读
|
||
// POST /api/conversations/:id/read
|
||
func (h *MessageHandler) MarkAsRead(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationIDStr := c.Param("id")
|
||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid conversation id")
|
||
return
|
||
}
|
||
|
||
var req dto.MarkReadRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
response.BadRequest(c, "last_read_seq is required")
|
||
return
|
||
}
|
||
|
||
err = h.chatService.MarkAsRead(c.Request.Context(), conversationID, userID, req.LastReadSeq)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "marked as read", nil)
|
||
}
|
||
|
||
// GetUnreadCount 获取未读消息总数
|
||
// GET /api/conversations/unread/count
|
||
func (h *MessageHandler) GetUnreadCount(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
count, err := h.chatService.GetAllUnreadCount(c.Request.Context(), userID)
|
||
if err != nil {
|
||
response.InternalServerError(c, "failed to get unread count")
|
||
return
|
||
}
|
||
|
||
response.Success(c, &dto.UnreadCountResponse{
|
||
TotalUnreadCount: count,
|
||
})
|
||
}
|
||
|
||
// GetConversationUnreadCount 获取单个会话的未读数
|
||
// GET /api/conversations/:id/unread/count
|
||
func (h *MessageHandler) GetConversationUnreadCount(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationIDStr := c.Param("id")
|
||
conversationID, err := service.ParseConversationID(conversationIDStr)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid conversation id")
|
||
return
|
||
}
|
||
|
||
count, err := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.Success(c, &dto.ConversationUnreadCountResponse{
|
||
ConversationID: conversationID,
|
||
UnreadCount: count,
|
||
})
|
||
}
|
||
|
||
// RecallMessage 撤回消息
|
||
// POST /api/messages/:id/recall
|
||
func (h *MessageHandler) RecallMessage(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
messageIDStr := c.Param("id")
|
||
// 直接使用 string 类型的 messageID
|
||
err := h.chatService.RecallMessage(c.Request.Context(), messageIDStr, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "message recalled", nil)
|
||
}
|
||
|
||
// DeleteMessage 删除消息
|
||
// DELETE /api/messages/:id
|
||
func (h *MessageHandler) DeleteMessage(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
messageIDStr := c.Param("id")
|
||
// 直接使用 string 类型的 messageID
|
||
err := h.chatService.DeleteMessage(c.Request.Context(), messageIDStr, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "message deleted", nil)
|
||
}
|
||
|
||
// 辅助函数:验证内容类型
|
||
func isValidContentType(contentType model.ContentType) bool {
|
||
switch contentType {
|
||
case model.ContentTypeText, model.ContentTypeImage, model.ContentTypeVideo, model.ContentTypeAudio, model.ContentTypeFile:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 辅助函数:获取会话参与者信息
|
||
func (h *MessageHandler) getConversationParticipants(ctx context.Context, conversationID string, currentUserID string) ([]*model.User, error) {
|
||
// 从repository获取参与者列表
|
||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取参与者用户信息
|
||
var users []*model.User
|
||
for _, p := range participants {
|
||
// 跳过当前用户
|
||
if p.UserID == currentUserID {
|
||
continue
|
||
}
|
||
user, err := h.userService.GetUserByID(ctx, p.UserID)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
users = append(users, user)
|
||
}
|
||
return users, nil
|
||
}
|
||
|
||
// 获取当前用户在会话中的参与者信息
|
||
func (h *MessageHandler) getMyConversationParticipant(conversationID string, userID string) (*model.ConversationParticipant, error) {
|
||
participants, err := h.messageService.GetConversationParticipants(conversationID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, p := range participants {
|
||
if p.UserID == userID {
|
||
return p, nil
|
||
}
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
// getIDParam 从路径参数获取 ID
|
||
func getIDParam(c *gin.Context, paramName string) string {
|
||
return c.Param(paramName)
|
||
}
|
||
|
||
// ==================== RESTful Action 端点 ====================
|
||
|
||
// HandleCreateConversation 创建会话
|
||
// POST /api/v1/conversations/create
|
||
func (h *MessageHandler) HandleCreateConversation(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var params dto.CreateConversationParams
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 验证目标用户是否存在
|
||
targetUser, err := h.userService.GetUserByID(c.Request.Context(), params.UserID)
|
||
if err != nil {
|
||
response.BadRequest(c, "target user not found")
|
||
return
|
||
}
|
||
|
||
// 不能和自己创建会话
|
||
if userID == params.UserID {
|
||
response.BadRequest(c, "cannot create conversation with yourself")
|
||
return
|
||
}
|
||
|
||
conv, err := h.chatService.GetOrCreateConversation(c.Request.Context(), userID, params.UserID)
|
||
if err != nil {
|
||
response.InternalServerError(c, "failed to create conversation")
|
||
return
|
||
}
|
||
|
||
// 获取参与者信息
|
||
participants := []*model.User{targetUser}
|
||
myParticipant, _ := h.getMyConversationParticipant(conv.ID, userID)
|
||
isPinned := myParticipant != nil && myParticipant.IsPinned
|
||
|
||
response.Success(c, dto.ConvertConversationToResponse(conv, participants, 0, nil, isPinned))
|
||
}
|
||
|
||
// HandleGetConversation 获取会话详情
|
||
// GET /api/v1/conversations/get?conversation_id=xxx
|
||
// GET /api/v1/conversations/:id
|
||
func (h *MessageHandler) HandleGetConversation(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationID := getIDParam(c, "id")
|
||
if conversationID == "" {
|
||
response.BadRequest(c, "conversation_id is required")
|
||
return
|
||
}
|
||
|
||
conv, err := h.chatService.GetConversationByID(c.Request.Context(), conversationID, userID)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 获取未读数
|
||
unreadCount, _ := h.chatService.GetUnreadCount(c.Request.Context(), conversationID, userID)
|
||
|
||
// 获取参与者信息
|
||
participants, _ := h.getConversationParticipants(c.Request.Context(), conversationID, userID)
|
||
|
||
// 获取当前用户的已读位置
|
||
myLastReadSeq := int64(0)
|
||
isPinned := false
|
||
allParticipants, _ := h.messageService.GetConversationParticipants(conversationID)
|
||
for _, p := range allParticipants {
|
||
if p.UserID == userID {
|
||
myLastReadSeq = p.LastReadSeq
|
||
isPinned = p.IsPinned
|
||
break
|
||
}
|
||
}
|
||
|
||
// 获取对方用户的已读位置
|
||
otherLastReadSeq := int64(0)
|
||
response.Success(c, dto.ConvertConversationToDetailResponse(conv, participants, unreadCount, nil, myLastReadSeq, otherLastReadSeq, isPinned))
|
||
}
|
||
|
||
// HandleGetMessages 获取会话消息
|
||
// GET /api/v1/conversations/get_messages?conversation_id=xxx
|
||
// GET /api/v1/conversations/:id/messages
|
||
func (h *MessageHandler) HandleGetMessages(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
conversationID := getIDParam(c, "id")
|
||
if conversationID == "" {
|
||
response.BadRequest(c, "conversation_id is required")
|
||
return
|
||
}
|
||
|
||
// 检查是否使用增量同步(after_seq参数)
|
||
afterSeqStr := c.Query("after_seq")
|
||
if afterSeqStr != "" {
|
||
// 增量同步模式
|
||
afterSeq, err := strconv.ParseInt(afterSeqStr, 10, 64)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid after_seq")
|
||
return
|
||
}
|
||
|
||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||
|
||
messages, err := h.chatService.GetMessagesAfterSeq(c.Request.Context(), conversationID, userID, afterSeq, limit)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Success(c, &dto.MessageSyncResponse{
|
||
Messages: result,
|
||
HasMore: len(messages) == limit,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 检查是否使用历史消息加载(before_seq参数)
|
||
beforeSeqStr := c.Query("before_seq")
|
||
if beforeSeqStr != "" {
|
||
// 加载更早的消息(下拉加载更多)
|
||
beforeSeq, err := strconv.ParseInt(beforeSeqStr, 10, 64)
|
||
if err != nil {
|
||
response.BadRequest(c, "invalid before_seq")
|
||
return
|
||
}
|
||
|
||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||
|
||
messages, err := h.chatService.GetMessagesBeforeSeq(c.Request.Context(), conversationID, userID, beforeSeq, limit)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Success(c, &dto.MessageSyncResponse{
|
||
Messages: result,
|
||
HasMore: len(messages) == limit,
|
||
})
|
||
return
|
||
}
|
||
|
||
// 分页模式
|
||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||
|
||
messages, total, err := h.chatService.GetMessages(c.Request.Context(), conversationID, userID, page, pageSize)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
// 转换为响应格式
|
||
result := dto.ConvertMessagesToResponse(messages)
|
||
|
||
response.Paginated(c, result, total, page, pageSize)
|
||
}
|
||
|
||
// HandleMarkRead 标记已读
|
||
// POST /api/v1/conversations/mark_read
|
||
func (h *MessageHandler) HandleMarkRead(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var params dto.MarkReadParams
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
if params.ConversationID == "" {
|
||
response.BadRequest(c, "conversation_id is required")
|
||
return
|
||
}
|
||
|
||
err := h.chatService.MarkAsRead(c.Request.Context(), params.ConversationID, userID, params.LastReadSeq)
|
||
if err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "marked as read", nil)
|
||
}
|
||
|
||
// HandleSetConversationPinned 设置会话置顶
|
||
// POST /api/v1/conversations/set_pinned
|
||
func (h *MessageHandler) HandleSetConversationPinned(c *gin.Context) {
|
||
userID := c.GetString("user_id")
|
||
if userID == "" {
|
||
response.Unauthorized(c, "")
|
||
return
|
||
}
|
||
|
||
var params dto.SetConversationPinnedParams
|
||
if err := c.ShouldBindJSON(¶ms); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
if params.ConversationID == "" {
|
||
response.BadRequest(c, "conversation_id is required")
|
||
return
|
||
}
|
||
|
||
if err := h.chatService.SetConversationPinned(c.Request.Context(), params.ConversationID, userID, params.IsPinned); err != nil {
|
||
response.BadRequest(c, err.Error())
|
||
return
|
||
}
|
||
|
||
response.SuccessWithMessage(c, "conversation pinned status updated", gin.H{
|
||
"conversation_id": params.ConversationID,
|
||
"is_pinned": params.IsPinned,
|
||
})
|
||
}
|