package handler import ( "context" "encoding/json" "log" "net/http" "strconv" "strings" "time" "carrot_bbs/internal/dto" "carrot_bbs/internal/model" ws "carrot_bbs/internal/pkg/websocket" "carrot_bbs/internal/repository" "carrot_bbs/internal/service" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源,生产环境应该限制 }, } // WebSocketHandler WebSocket处理器 type WebSocketHandler struct { jwtService *service.JWTService chatService service.ChatService groupService service.GroupService groupRepo repository.GroupRepository userRepo *repository.UserRepository wsManager *ws.WebSocketManager } // NewWebSocketHandler 创建WebSocket处理器 func NewWebSocketHandler( jwtService *service.JWTService, chatService service.ChatService, groupService service.GroupService, groupRepo repository.GroupRepository, userRepo *repository.UserRepository, wsManager *ws.WebSocketManager, ) *WebSocketHandler { return &WebSocketHandler{ jwtService: jwtService, chatService: chatService, groupService: groupService, groupRepo: groupRepo, userRepo: userRepo, wsManager: wsManager, } } // HandleWebSocket 处理WebSocket连接 func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) { // 调试:打印请求头信息 log.Printf("[WebSocket] 收到请求: Method=%s, Path=%s", c.Request.Method, c.Request.URL.Path) log.Printf("[WebSocket] 请求头: Connection=%s, Upgrade=%s", c.GetHeader("Connection"), c.GetHeader("Upgrade")) log.Printf("[WebSocket] Sec-WebSocket-Key=%s, Sec-WebSocket-Version=%s", c.GetHeader("Sec-WebSocket-Key"), c.GetHeader("Sec-WebSocket-Version")) // 从query参数获取token token := c.Query("token") if token == "" { // 尝试从header获取 authHeader := c.GetHeader("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { token = strings.TrimPrefix(authHeader, "Bearer ") } } if token == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "missing token"}) return } // 验证token claims, err := h.jwtService.ParseToken(token) if err != nil { log.Printf("Invalid token: %v", err) c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) return } userID := claims.UserID if userID == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token claims"}) return } // 升级HTTP连接为WebSocket连接 conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("Failed to upgrade connection: %v", err) log.Printf("[WebSocket] 请求详情 - User-Agent: %s, Content-Type: %s", c.GetHeader("User-Agent"), c.GetHeader("Content-Type")) return } // 创建客户端 client := &ws.Client{ ID: userID, UserID: userID, Conn: conn, Send: make(chan []byte, 256), Manager: h.wsManager, } // 注册客户端 h.wsManager.Register(client) // 启动读写协程 go client.WritePump() go h.handleMessages(client) } // handleMessages 处理客户端消息 // 针对移动端优化:增加超时时间到 120 秒,配合客户端 55 秒心跳 func (h *WebSocketHandler) handleMessages(client *ws.Client) { defer func() { h.wsManager.Unregister(client) client.Conn.Close() }() client.Conn.SetReadLimit(512 * 1024) // 512KB client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 client.Conn.SetPongHandler(func(string) error { client.Conn.SetReadDeadline(time.Now().Add(120 * time.Second)) // 增加到 120 秒 return nil }) // 心跳定时器 - 服务端主动 ping 间隔保持 30 秒 pingTicker := time.NewTicker(30 * time.Second) defer pingTicker.Stop() for { select { case <-pingTicker.C: // 发送心跳 if err := client.SendPing(); err != nil { log.Printf("Failed to send ping: %v", err) return } default: _, message, err := client.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } return } var wsMsg ws.WSMessage if err := json.Unmarshal(message, &wsMsg); err != nil { log.Printf("Failed to unmarshal message: %v", err) continue } h.processMessage(client, &wsMsg) } } } // processMessage 处理消息 func (h *WebSocketHandler) processMessage(client *ws.Client, msg *ws.WSMessage) { switch msg.Type { case ws.MessageTypePing: // 响应心跳 if err := client.SendPong(); err != nil { log.Printf("Failed to send pong: %v", err) } case ws.MessageTypePong: // 客户端响应心跳 case ws.MessageTypeMessage: // 处理聊天消息 h.handleChatMessage(client, msg) case ws.MessageTypeTyping: // 处理正在输入状态 h.handleTyping(client, msg) case ws.MessageTypeRead: // 处理已读回执 h.handleReadReceipt(client, msg) // 群组消息处理 case ws.MessageTypeGroupMessage: // 处理群消息 h.handleGroupMessage(client, msg) case ws.MessageTypeGroupTyping: // 处理群输入状态 h.handleGroupTyping(client, msg) case ws.MessageTypeGroupRead: // 处理群消息已读 h.handleGroupReadReceipt(client, msg) case ws.MessageTypeGroupRecall: // 处理群消息撤回 h.handleGroupRecall(client, msg) default: log.Printf("Unknown message type: %s", msg.Type) } } // handleChatMessage 处理聊天消息 func (h *WebSocketHandler) handleChatMessage(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { log.Printf("Invalid message data format") return } conversationIDStr, _ := data["conversationId"].(string) if conversationIDStr == "" { log.Printf("Missing conversationId") return } // 解析会话ID conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { log.Printf("Invalid conversation ID: %v", err) return } // 解析 segments var segments model.MessageSegments if data["segments"] != nil { segmentsBytes, err := json.Marshal(data["segments"]) if err == nil { json.Unmarshal(segmentsBytes, &segments) } } // 从 segments 中提取回复消息ID replyToID := dto.GetReplyMessageID(segments) var replyToIDPtr *string if replyToID != "" { replyToIDPtr = &replyToID } // 发送消息 - 使用 segments message, err := h.chatService.SendMessage(context.Background(), client.UserID, conversationID, segments, replyToIDPtr) if err != nil { log.Printf("Failed to send message: %v", err) // 发送错误消息 errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ "error": "Failed to send message", }) if client.Send != nil { msgBytes, _ := json.Marshal(errorMsg) client.Send <- msgBytes } return } // 发送确认消息(使用 meta 事件格式,包含完整的消息内容) metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ "detail_type": ws.MetaDetailTypeAck, "conversation_id": conversationID, "id": message.ID, "user_id": client.UserID, "sender_id": client.UserID, "seq": message.Seq, "segments": message.Segments, "created_at": message.CreatedAt.UnixMilli(), }) if client.Send != nil { msgBytes, _ := json.Marshal(metaAckMsg) client.Send <- msgBytes } } // handleTyping 处理正在输入状态 func (h *WebSocketHandler) handleTyping(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { return } conversationIDStr, _ := data["conversationId"].(string) if conversationIDStr == "" { return } conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { return } // 直接使用 string 类型的 userID h.chatService.SendTyping(context.Background(), client.UserID, conversationID) } // handleReadReceipt 处理已读回执 func (h *WebSocketHandler) handleReadReceipt(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { return } conversationIDStr, _ := data["conversationId"].(string) if conversationIDStr == "" { return } conversationID, err := service.ParseConversationID(conversationIDStr) if err != nil { return } // 获取lastReadSeq lastReadSeq, _ := data["lastReadSeq"].(float64) if lastReadSeq == 0 { return } // 直接使用 string 类型的 userID 和 conversationID if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { log.Printf("Failed to mark as read: %v", err) } } // ==================== 群组消息处理 ==================== // handleGroupMessage 处理群消息 func (h *WebSocketHandler) handleGroupMessage(client *ws.Client, msg *ws.WSMessage) { // 打印接收到的消息类型和数据,用于调试 log.Printf("[handleGroupMessage] Received message type: %s", msg.Type) log.Printf("[handleGroupMessage] Message data: %+v", msg.Data) data, ok := msg.Data.(map[string]interface{}) if !ok { log.Printf("Invalid group message data format: data is not map[string]interface{}") return } // 解析群组ID(支持 camelCase 和 snake_case) var groupIDFloat float64 groupID := "" // 使用 groupID 作为最终变量名 if val, ok := data["groupId"].(float64); ok { groupIDFloat = val groupID = strconv.FormatFloat(groupIDFloat, 'f', 0, 64) } else if val, ok := data["group_id"].(string); ok { groupID = val } if groupID == "" { log.Printf("Missing groupId in group message") return } // 解析会话ID(支持 camelCase 和 snake_case) var conversationID string if val, ok := data["conversationId"].(string); ok { conversationID = val } else if val, ok := data["conversation_id"].(string); ok { conversationID = val } if conversationID == "" { log.Printf("Missing conversationId in group message") return } // 解析 segments var segments model.MessageSegments if data["segments"] != nil { segmentsBytes, err := json.Marshal(data["segments"]) if err == nil { json.Unmarshal(segmentsBytes, &segments) } } // 解析@用户列表(支持 camelCase 和 snake_case) var mentionUsers []uint64 var mentionUsersInterface []interface{} if val, ok := data["mentionUsers"].([]interface{}); ok { mentionUsersInterface = val } else if val, ok := data["mention_users"].([]interface{}); ok { mentionUsersInterface = val } if len(mentionUsersInterface) > 0 { for _, uid := range mentionUsersInterface { if uidFloat, ok := uid.(float64); ok { mentionUsers = append(mentionUsers, uint64(uidFloat)) } else if uidStr, ok := uid.(string); ok { // 处理字符串格式的用户ID if uidInt, err := strconv.ParseUint(uidStr, 10, 64); err == nil { mentionUsers = append(mentionUsers, uidInt) } } } } // 解析@所有人(支持 camelCase 和 snake_case) var mentionAll bool if val, ok := data["mentionAll"].(bool); ok { mentionAll = val } else if val, ok := data["mention_all"].(bool); ok { mentionAll = val } // 检查用户是否可以发送群消息(验证成员身份和禁言状态) // client.UserID 已经是 string 格式的 UUID if err := h.groupService.CanSendGroupMessage(client.UserID, groupID); err != nil { log.Printf("User cannot send group message: %v", err) // 发送错误消息 errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ "error": "Cannot send group message", "reason": err.Error(), "type": "group_message_error", "groupId": groupID, }) if client.Send != nil { msgBytes, _ := json.Marshal(errorMsg) client.Send <- msgBytes } return } // 检查@所有人权限(只有群主和管理员可以@所有人) if mentionAll { if !h.groupService.IsGroupAdmin(client.UserID, groupID) { log.Printf("User %s has no permission to mention all in group %s", client.UserID, groupID) mentionAll = false // 取消@所有人标记 } } // 创建消息 message := &model.Message{ ConversationID: conversationID, SenderID: client.UserID, Segments: segments, Status: model.MessageStatusNormal, MentionAll: mentionAll, } // 序列化mentionUsers为JSON if len(mentionUsers) > 0 { mentionUsersJSON, _ := json.Marshal(mentionUsers) message.MentionUsers = string(mentionUsersJSON) } // 保存消息到数据库(只存库,不发私聊 WebSocket 帧,群消息通过 BroadcastGroupMessage 单独广播) savedMessage, err := h.chatService.SaveMessage(context.Background(), client.UserID, conversationID, segments, nil) if err != nil { log.Printf("Failed to save group message: %v", err) errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ "error": "Failed to save group message", }) if client.Send != nil { msgBytes, _ := json.Marshal(errorMsg) client.Send <- msgBytes } return } // 更新消息的mention信息 if len(mentionUsers) > 0 || mentionAll { message.ID = savedMessage.ID message.Seq = savedMessage.Seq } // 构造群消息响应 groupMsg := &ws.GroupMessage{ ID: savedMessage.ID, ConversationID: conversationID, GroupID: groupID, SenderID: client.UserID, Seq: savedMessage.Seq, Segments: segments, MentionUsers: mentionUsers, MentionAll: mentionAll, CreatedAt: savedMessage.CreatedAt.UnixMilli(), } // 广播消息给群组所有成员(排除发送者) h.BroadcastGroupMessage(groupID, groupMsg, client.UserID) // 发送确认消息给发送者(作为meta事件) // 使用 meta 事件格式发送 ack metaAckMsg := ws.CreateWSMessage("meta", map[string]interface{}{ "detail_type": ws.MetaDetailTypeAck, "conversation_id": conversationID, "group_id": groupID, "id": savedMessage.ID, "user_id": client.UserID, "sender_id": client.UserID, "seq": savedMessage.Seq, "segments": segments, "created_at": savedMessage.CreatedAt.UnixMilli(), }) if client.Send != nil { msgBytes, _ := json.Marshal(metaAckMsg) client.Send <- msgBytes } else { log.Printf("[ERROR HandleGroupMessageSend] client.Send 为 nil, userID=%s", client.UserID) } // 处理@提及通知 if len(mentionUsers) > 0 || mentionAll { // 提取文本正文(不含 @ 部分) textContent := dto.ExtractTextContentFromModel(segments) // 在通知内容前拼接被@的真实昵称,通过群成员列表查找 mentionContent := h.buildMentionContent(groupID, mentionUsers, mentionAll, textContent) h.handleGroupMention(groupID, savedMessage.ID, client.UserID, mentionContent, mentionUsers, mentionAll) } } // handleGroupTyping 处理群输入状态 func (h *WebSocketHandler) handleGroupTyping(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { return } groupIDFloat, _ := data["groupId"].(float64) if groupIDFloat == 0 { return } groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) isTyping, _ := data["isTyping"].(bool) // 验证用户是否是群成员 // client.UserID 已经是 string 格式的 UUID isMember, err := h.groupRepo.IsMember(groupID, client.UserID) if err != nil || !isMember { return } // 获取用户信息 user, err := h.userRepo.GetByID(client.UserID) if err != nil { return } // 构造输入状态消息 typingMsg := &ws.GroupTypingMessage{ GroupID: groupID, UserID: client.UserID, Username: user.Username, IsTyping: isTyping, } // 广播给群组其他成员 wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) h.BroadcastGroupNoticeExclude(groupID, wsMsg, client.UserID) } // handleGroupReadReceipt 处理群消息已读回执 func (h *WebSocketHandler) handleGroupReadReceipt(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { return } conversationID, _ := data["conversationId"].(string) if conversationID == "" { return } lastReadSeq, _ := data["lastReadSeq"].(float64) if lastReadSeq == 0 { return } // 标记已读 if err := h.chatService.MarkAsRead(context.Background(), conversationID, client.UserID, int64(lastReadSeq)); err != nil { log.Printf("Failed to mark group message as read: %v", err) } } // handleGroupRecall 处理群消息撤回 func (h *WebSocketHandler) handleGroupRecall(client *ws.Client, msg *ws.WSMessage) { data, ok := msg.Data.(map[string]interface{}) if !ok { return } messageID, _ := data["messageId"].(string) if messageID == "" { return } groupIDFloat, _ := data["groupId"].(float64) if groupIDFloat == 0 { return } groupID := strconv.FormatFloat(groupIDFloat, 'f', 0, 64) // 撤回消息 if err := h.chatService.RecallMessage(context.Background(), messageID, client.UserID); err != nil { log.Printf("Failed to recall group message: %v", err) errorMsg := ws.CreateWSMessage(ws.MessageTypeError, map[string]string{ "error": "Failed to recall message", }) if client.Send != nil { msgBytes, _ := json.Marshal(errorMsg) client.Send <- msgBytes } return } // 广播撤回通知给群组所有成员 recallNotice := ws.CreateWSMessage(ws.MessageTypeGroupRecall, map[string]interface{}{ "messageId": messageID, "groupId": groupID, "userId": client.UserID, "timestamp": time.Now().UnixMilli(), }) h.BroadcastGroupNotice(groupID, recallNotice) } // handleGroupMention 处理群消息@提及通知 func (h *WebSocketHandler) handleGroupMention(groupID string, messageID, senderID, content string, mentionUsers []uint64, mentionAll bool) { // 如果@所有人,获取所有群成员 if mentionAll { members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for mention all: %v", err) return } for _, member := range members { // 不通知发送者自己 memberIDStr := member.UserID if memberIDStr == senderID { continue } // 发送@提及通知 mentionMsg := &ws.GroupMentionMessage{ GroupID: groupID, MessageID: messageID, FromUserID: senderID, Content: truncateContent(content, 50), MentionAll: true, CreatedAt: time.Now().UnixMilli(), } wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) h.wsManager.SendToUser(memberIDStr, wsMsg) } return } // 处理特定用户的@提及 for _, userID := range mentionUsers { // userID 是 uint64,转换为 string userIDStr := strconv.FormatUint(userID, 10) if userIDStr == senderID { continue // 不通知发送者自己 } mentionMsg := &ws.GroupMentionMessage{ GroupID: groupID, MessageID: messageID, FromUserID: senderID, Content: truncateContent(content, 50), MentionAll: false, CreatedAt: time.Now().UnixMilli(), } wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMention, mentionMsg) h.wsManager.SendToUser(userIDStr, wsMsg) } } // buildMentionContent 构建@提及通知的内容,通过群成员列表查找被@用户的真实昵称 func (h *WebSocketHandler) buildMentionContent(groupID string, mentionUsers []uint64, mentionAll bool, textBody string) string { var prefix string if mentionAll { prefix = "@所有人 " } else if len(mentionUsers) > 0 { // 查询群成员列表,找到被@用户的昵称 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err == nil { memberNickMap := make(map[string]string, len(members)) for _, m := range members { displayName := m.Nickname if displayName == "" { displayName = m.UserID } memberNickMap[m.UserID] = displayName } for _, uid := range mentionUsers { uidStr := strconv.FormatUint(uid, 10) if name, ok := memberNickMap[uidStr]; ok { prefix += "@" + name + " " } else { prefix += "@某人 " } } } else { for range mentionUsers { prefix += "@某人 " } } } return prefix + textBody } // BroadcastGroupMessage 向群组所有成员广播消息 func (h *WebSocketHandler) BroadcastGroupMessage(groupID string, message *ws.GroupMessage, excludeUserID string) { // 获取群组所有成员 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for broadcast: %v", err) return } // 创建WebSocket消息 wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupMessage, message) // 遍历成员,如果在线则发送消息 for _, member := range members { memberIDStr := member.UserID // 排除发送者 if memberIDStr == excludeUserID { continue } // 发送消息 h.wsManager.SendToUser(memberIDStr, wsMsg) } } // BroadcastGroupNotice 广播群组通知给所有成员 func (h *WebSocketHandler) BroadcastGroupNotice(groupID string, notice *ws.WSMessage) { // 获取群组所有成员 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for notice broadcast: %v", err) return } // 遍历成员,如果在线则发送通知 for _, member := range members { memberIDStr := member.UserID h.wsManager.SendToUser(memberIDStr, notice) } } // BroadcastGroupNoticeExclude 广播群组通知给所有成员(排除指定用户) func (h *WebSocketHandler) BroadcastGroupNoticeExclude(groupID string, notice *ws.WSMessage, excludeUserID string) { // 获取群组所有成员 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for notice broadcast: %v", err) return } // 遍历成员,如果在线则发送通知 for _, member := range members { memberIDStr := member.UserID if memberIDStr == excludeUserID { continue } h.wsManager.SendToUser(memberIDStr, notice) } } // SendGroupMemberNotice 发送群成员变动通知 func (h *WebSocketHandler) SendGroupMemberNotice(noticeType string, groupID string, data *ws.GroupNoticeData) { notice := &ws.GroupNoticeMessage{ NoticeType: noticeType, GroupID: groupID, Data: data, Timestamp: time.Now().UnixMilli(), } wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupNotice, notice) h.BroadcastGroupNotice(groupID, wsMsg) } // truncateContent 截断内容 func truncateContent(content string, maxLen int) string { if len(content) <= maxLen { return content } return content[:maxLen] + "..." } // BroadcastGroupTyping 向群组所有成员广播输入状态 func (h *WebSocketHandler) BroadcastGroupTyping(groupID string, typingMsg *ws.GroupTypingMessage, excludeUserID string) { // 获取群组所有成员 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for typing broadcast: %v", err) return } // 创建WebSocket消息 wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupTyping, typingMsg) // 遍历成员,如果在线则发送消息 for _, member := range members { memberIDStr := member.UserID // 排除指定用户 if memberIDStr == excludeUserID { continue } // 发送消息 h.wsManager.SendToUser(memberIDStr, wsMsg) } } // BroadcastGroupRead 向群组所有成员广播已读状态 func (h *WebSocketHandler) BroadcastGroupRead(groupID string, readMsg map[string]interface{}, excludeUserID string) { // 获取群组所有成员 members, _, err := h.groupRepo.GetMembers(groupID, 1, 1000) if err != nil { log.Printf("Failed to get group members for read broadcast: %v", err) return } // 创建WebSocket消息 wsMsg := ws.CreateWSMessage(ws.MessageTypeGroupRead, readMsg) // 遍历成员,如果在线则发送消息 for _, member := range members { memberIDStr := member.UserID // 排除指定用户 if memberIDStr == excludeUserID { continue } // 发送消息 h.wsManager.SendToUser(memberIDStr, wsMsg) } }