package net import ( "context" "fmt" "net/url" "sync" "time" "cellbot/internal/engine" "cellbot/internal/protocol" "github.com/bytedance/sonic" "github.com/fasthttp/websocket" "github.com/valyala/fasthttp" "go.uber.org/zap" ) // WebSocketManager WebSocket连接管理器 type WebSocketManager struct { connections map[string]*WebSocketConnection logger *zap.Logger eventBus *engine.EventBus mu sync.RWMutex upgrader *websocket.FastHTTPUpgrader } // NewWebSocketManager 创建WebSocket管理器 func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager { return &WebSocketManager{ connections: make(map[string]*WebSocketConnection), logger: logger.Named("websocket"), eventBus: eventBus, upgrader: &websocket.FastHTTPUpgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { return true // 允许所有来源,生产环境应加强检查 }, }, } } // ConnectionType 连接类型 type ConnectionType string const ( ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受) ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起) ) // WebSocketConnection WebSocket连接 type WebSocketConnection struct { ID string Conn *websocket.Conn BotID string Logger *zap.Logger ctx context.Context cancel context.CancelFunc Type ConnectionType RemoteAddr string reconnectURL string // 用于正向连接重连 maxReconnect int // 最大重连次数 reconnectCount int // 当前重连次数 heartbeatTick time.Duration // 心跳间隔 } // NewWebSocketConnection 创建WebSocket连接 func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection { ctx, cancel := context.WithCancel(context.Background()) connID := generateConnID() return &WebSocketConnection{ ID: connID, Conn: conn, BotID: botID, Logger: logger.With(zap.String("conn_id", connID)), ctx: ctx, cancel: cancel, Type: connType, RemoteAddr: conn.RemoteAddr().String(), maxReconnect: 5, heartbeatTick: 30 * time.Second, } } // UpgradeWebSocket 升级HTTP连接为WebSocket func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) { // 获取查询参数 botID := string(ctx.QueryArgs().Peek("bot_id")) if botID == "" { return nil, fmt.Errorf("missing bot_id parameter") } // 创建通道用于传递连接 connChan := make(chan *websocket.Conn, 1) // 升级连接 err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) { connChan <- conn }) if err != nil { return nil, fmt.Errorf("failed to upgrade connection: %w", err) } // 等待连接在回调中建立 conn := <-connChan // 创建连接对象(反向连接) wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger) // 存储连接 wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket reverse connection established", zap.String("conn_id", wsConn.ID), zap.String("bot_id", botID), zap.String("remote_addr", wsConn.RemoteAddr)) // 启动读取循环和心跳 go wsConn.readLoop(wsm.eventBus) go wsConn.heartbeatLoop() return wsConn, nil } // readLoop 读取循环 func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { defer wsc.close() for { select { case <-wsc.ctx.Done(): return default: messageType, message, err := wsc.Conn.ReadMessage() if err != nil { wsc.Logger.Error("Failed to read message", zap.Error(err)) return } // 只处理文本消息,忽略其他类型 if messageType != websocket.TextMessage { wsc.Logger.Warn("Received non-text message, ignoring", zap.Int("message_type", messageType)) continue } // 处理消息 wsc.handleMessage(message, eventBus) } } } // handleMessage 处理消息 func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { wsc.Logger.Debug("Received message", zap.ByteString("data", data)) // 先解析为 map 以支持灵活的字段类型(如 self_id 可能是数字或字符串) // 使用 sonic.Config 配置更宽松的解析,允许数字和字符串之间的转换 cfg := sonic.Config{ UseInt64: true, // 使用 int64 而不是 float64 来解析数字 NoValidateJSONSkip: true, // 跳过类型验证,允许更灵活的类型转换 }.Froze() var rawMap map[string]interface{} if err := cfg.Unmarshal(data, &rawMap); err != nil { wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data)) return } // 检查是否是 API 响应(有 echo 字段且没有 post_type) // 如果是响应,不在这里处理,让 adapter 的 handleWebSocketMessages 处理 if echo, hasEcho := rawMap["echo"].(string); hasEcho && echo != "" { if _, hasPostType := rawMap["post_type"]; !hasPostType { // 这是 API 响应,不在这里处理 // 正向 WebSocket 时,adapter 的 handleWebSocketMessages 会处理 // 反向 WebSocket 时,响应应该通过 adapter 处理 wsc.Logger.Debug("Skipping API response in handleMessage, will be handled by adapter", zap.String("echo", echo)) return } } // 构建 BaseEvent event := &protocol.BaseEvent{ Data: make(map[string]interface{}), } // 处理 self_id(可能是数字或字符串) if selfIDVal, ok := rawMap["self_id"]; ok { switch v := selfIDVal.(type) { case string: event.SelfID = v case float64: event.SelfID = fmt.Sprintf("%.0f", v) case int64: event.SelfID = fmt.Sprintf("%d", v) case int: event.SelfID = fmt.Sprintf("%d", v) default: event.SelfID = fmt.Sprintf("%v", v) } } // 如果没有SelfID,使用连接的BotID if event.SelfID == "" { event.SelfID = wsc.BotID } // 处理时间戳 if timeVal, ok := rawMap["time"]; ok { switch v := timeVal.(type) { case float64: event.Timestamp = int64(v) case int64: event.Timestamp = v case int: event.Timestamp = int64(v) } } if event.Timestamp == 0 { event.Timestamp = time.Now().Unix() } // 处理类型字段 if typeVal, ok := rawMap["post_type"]; ok { if typeStr, ok := typeVal.(string); ok { // OneBot11 格式:post_type -> EventType 映射 switch typeStr { case "message": event.Type = protocol.EventTypeMessage case "notice": event.Type = protocol.EventTypeNotice case "request": event.Type = protocol.EventTypeRequest case "meta_event": event.Type = protocol.EventTypeMeta case "message_sent": // 忽略机器人自己发送的消息 wsc.Logger.Debug("Ignoring message_sent event") return default: event.Type = protocol.EventType(typeStr) } } } else if typeVal, ok := rawMap["type"]; ok { if typeStr, ok := typeVal.(string); ok { event.Type = protocol.EventType(typeStr) } } // 验证必需字段 if event.Type == "" { wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data)) return } // 处理 detail_type if detailTypeVal, ok := rawMap["message_type"]; ok { if detailTypeStr, ok := detailTypeVal.(string); ok { event.DetailType = detailTypeStr } } else if detailTypeVal, ok := rawMap["detail_type"]; ok { if detailTypeStr, ok := detailTypeVal.(string); ok { event.DetailType = detailTypeStr } } // 将所有其他字段放入 Data for k, v := range rawMap { if k != "self_id" && k != "time" && k != "post_type" && k != "type" && k != "message_type" && k != "detail_type" { event.Data[k] = v } } wsc.Logger.Info("Event received", zap.String("type", string(event.Type)), zap.String("detail_type", event.DetailType), zap.String("self_id", event.SelfID)) // 发布到事件总线 eventBus.Publish(event) } // SendMessage 发送消息 func (wsc *WebSocketConnection) SendMessage(data []byte) error { wsc.Logger.Debug("Sending message", zap.ByteString("data", data)) err := wsc.Conn.WriteMessage(websocket.TextMessage, data) if err != nil { return fmt.Errorf("failed to send message: %w", err) } return nil } // heartbeatLoop 心跳循环 func (wsc *WebSocketConnection) heartbeatLoop() { ticker := time.NewTicker(wsc.heartbeatTick) defer ticker.Stop() for { select { case <-ticker.C: // 发送ping消息 if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { wsc.Logger.Warn("Failed to send ping", zap.Error(err)) return } wsc.Logger.Debug("Heartbeat ping sent") case <-wsc.ctx.Done(): return } } } // reconnectLoop 重连循环(仅用于正向连接) func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) { <-wsc.ctx.Done() // 等待连接断开 if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" { return } wsc.Logger.Info("Connection closed, attempting to reconnect", zap.Int("max_reconnect", wsc.maxReconnect)) for wsc.reconnectCount < wsc.maxReconnect { wsc.reconnectCount++ backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second wsc.Logger.Info("Reconnecting", zap.Int("attempt", wsc.reconnectCount), zap.Int("max", wsc.maxReconnect), zap.Duration("backoff", backoff)) time.Sleep(backoff) // 尝试重新连接 conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil) if err != nil { wsc.Logger.Error("Reconnect failed", zap.Error(err)) continue } // 更新连接 wsc.Conn = conn wsc.RemoteAddr = conn.RemoteAddr().String() wsc.ctx, wsc.cancel = context.WithCancel(context.Background()) wsc.reconnectCount = 0 // 重置重连计数 wsc.Logger.Info("Reconnected successfully", zap.String("remote_addr", wsc.RemoteAddr)) // 重新启动读取循环和心跳 go wsc.readLoop(wsm.eventBus) go wsc.heartbeatLoop() go wsc.reconnectLoop(wsm) return } wsc.Logger.Error("Max reconnect attempts reached, giving up", zap.Int("attempts", wsc.reconnectCount)) // 从管理器中移除连接 wsm.RemoveConnection(wsc.ID) } // close 关闭连接 func (wsc *WebSocketConnection) close() { wsc.cancel() if err := wsc.Conn.Close(); err != nil { wsc.Logger.Error("Failed to close connection", zap.Error(err)) } } // RemoveConnection 移除连接 func (wsm *WebSocketManager) RemoveConnection(connID string) { wsm.mu.Lock() defer wsm.mu.Unlock() if conn, ok := wsm.connections[connID]; ok { conn.close() delete(wsm.connections, connID) wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID)) } } // GetConnection 获取连接 func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) { wsm.mu.RLock() defer wsm.mu.RUnlock() conn, ok := wsm.connections[connID] return conn, ok } // GetConnectionByBotID 根据BotID获取连接 func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection { wsm.mu.RLock() defer wsm.mu.RUnlock() connections := make([]*WebSocketConnection, 0) for _, conn := range wsm.connections { if conn.BotID == botID { connections = append(connections, conn) } } return connections } // BroadcastToBot 向指定Bot的所有连接广播消息 func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) { connections := wsm.GetConnectionByBotID(botID) for _, conn := range connections { if err := conn.SendMessage(data); err != nil { wsm.logger.Error("Failed to send message to connection", zap.String("conn_id", conn.ID), zap.Error(err)) } } } // DialConfig WebSocket客户端连接配置 type DialConfig struct { URL string BotID string MaxReconnect int HeartbeatTick time.Duration AutoReadLoop bool // 是否自动启动 readLoop(adapter 自己处理消息时设为 false) } // Dial 建立WebSocket客户端连接(正向连接) func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) { return wsm.DialWithConfig(DialConfig{ URL: addr, BotID: botID, MaxReconnect: 5, HeartbeatTick: 30 * time.Second, AutoReadLoop: false, // adapter 自己处理消息,不自动启动 readLoop }) } // DialWithConfig 使用配置建立WebSocket客户端连接 func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) { u, err := url.Parse(config.URL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } // 验证URL scheme必须是ws或wss if u.Scheme != "ws" && u.Scheme != "wss" { return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme) } conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger) wsConn.reconnectURL = config.URL wsConn.maxReconnect = config.MaxReconnect wsConn.heartbeatTick = config.HeartbeatTick wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket forward connection established", zap.String("conn_id", wsConn.ID), zap.String("bot_id", config.BotID), zap.String("addr", config.URL), zap.String("remote_addr", wsConn.RemoteAddr)) // 启动读取循环和心跳(如果启用) if config.AutoReadLoop { go wsConn.readLoop(wsm.eventBus) } go wsConn.heartbeatLoop() // 如果是正向连接,启动重连监控 if wsConn.Type == ConnectionTypeForward { go wsConn.reconnectLoop(wsm) } return wsConn, nil } // generateConnID 生成连接ID func generateConnID() string { return fmt.Sprintf("conn-%d", time.Now().UnixNano()) }