package net import ( "context" "fmt" "net/url" "sync" "time" "cellbot/internal/engine" "cellbot/internal/protocol" "github.com/bytedance/sonic" "github.com/fasthttp/websocket" "github.com/valyala/fasthttp" "go.uber.org/zap" ) // WebSocketManager WebSocket连接管理器 type WebSocketManager struct { connections map[string]*WebSocketConnection logger *zap.Logger eventBus *engine.EventBus mu sync.RWMutex upgrader *websocket.FastHTTPUpgrader } // NewWebSocketManager 创建WebSocket管理器 func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager { return &WebSocketManager{ connections: make(map[string]*WebSocketConnection), logger: logger.Named("websocket"), eventBus: eventBus, upgrader: &websocket.FastHTTPUpgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { return true // 允许所有来源,生产环境应加强检查 }, }, } } // ConnectionType 连接类型 type ConnectionType string const ( ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受) ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起) ) // WebSocketConnection WebSocket连接 type WebSocketConnection struct { ID string Conn *websocket.Conn BotID string Logger *zap.Logger ctx context.Context cancel context.CancelFunc Type ConnectionType RemoteAddr string reconnectURL string // 用于正向连接重连 maxReconnect int // 最大重连次数 reconnectCount int // 当前重连次数 heartbeatTick time.Duration // 心跳间隔 } // NewWebSocketConnection 创建WebSocket连接 func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection { ctx, cancel := context.WithCancel(context.Background()) connID := generateConnID() return &WebSocketConnection{ ID: connID, Conn: conn, BotID: botID, Logger: logger.With(zap.String("conn_id", connID)), ctx: ctx, cancel: cancel, Type: connType, RemoteAddr: conn.RemoteAddr().String(), maxReconnect: 5, heartbeatTick: 30 * time.Second, } } // UpgradeWebSocket 升级HTTP连接为WebSocket func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) { // 获取查询参数 botID := string(ctx.QueryArgs().Peek("bot_id")) if botID == "" { return nil, fmt.Errorf("missing bot_id parameter") } // 创建通道用于传递连接 connChan := make(chan *websocket.Conn, 1) // 升级连接 err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) { connChan <- conn }) if err != nil { return nil, fmt.Errorf("failed to upgrade connection: %w", err) } // 等待连接在回调中建立 conn := <-connChan // 创建连接对象(反向连接) wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger) // 存储连接 wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket reverse connection established", zap.String("conn_id", wsConn.ID), zap.String("bot_id", botID), zap.String("remote_addr", wsConn.RemoteAddr)) // 启动读取循环和心跳 go wsConn.readLoop(wsm.eventBus) go wsConn.heartbeatLoop() return wsConn, nil } // readLoop 读取循环 func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { defer wsc.close() for { select { case <-wsc.ctx.Done(): return default: messageType, message, err := wsc.Conn.ReadMessage() if err != nil { wsc.Logger.Error("Failed to read message", zap.Error(err)) return } // 只处理文本消息,忽略其他类型 if messageType != websocket.TextMessage { wsc.Logger.Warn("Received non-text message, ignoring", zap.Int("message_type", messageType)) continue } // 处理消息 wsc.handleMessage(message, eventBus) } } } // handleMessage 处理消息 func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { wsc.Logger.Debug("Received message", zap.ByteString("data", data)) // 解析JSON消息为BaseEvent var event protocol.BaseEvent if err := sonic.Unmarshal(data, &event); err != nil { wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data)) return } // 验证必需字段 if event.Type == "" { wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data)) return } // 如果没有时间戳,使用当前时间 if event.Timestamp == 0 { event.Timestamp = time.Now().Unix() } // 如果没有SelfID,使用连接的BotID if event.SelfID == "" { event.SelfID = wsc.BotID } // 确保Data字段不为nil if event.Data == nil { event.Data = make(map[string]interface{}) } wsc.Logger.Info("Event received", zap.String("type", string(event.Type)), zap.String("detail_type", event.DetailType), zap.String("self_id", event.SelfID)) // 发布到事件总线 eventBus.Publish(&event) } // SendMessage 发送消息 func (wsc *WebSocketConnection) SendMessage(data []byte) error { wsc.Logger.Debug("Sending message", zap.ByteString("data", data)) err := wsc.Conn.WriteMessage(websocket.TextMessage, data) if err != nil { return fmt.Errorf("failed to send message: %w", err) } return nil } // heartbeatLoop 心跳循环 func (wsc *WebSocketConnection) heartbeatLoop() { ticker := time.NewTicker(wsc.heartbeatTick) defer ticker.Stop() for { select { case <-ticker.C: // 发送ping消息 if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { wsc.Logger.Warn("Failed to send ping", zap.Error(err)) return } wsc.Logger.Debug("Heartbeat ping sent") case <-wsc.ctx.Done(): return } } } // reconnectLoop 重连循环(仅用于正向连接) func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) { <-wsc.ctx.Done() // 等待连接断开 if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" { return } wsc.Logger.Info("Connection closed, attempting to reconnect", zap.Int("max_reconnect", wsc.maxReconnect)) for wsc.reconnectCount < wsc.maxReconnect { wsc.reconnectCount++ backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second wsc.Logger.Info("Reconnecting", zap.Int("attempt", wsc.reconnectCount), zap.Int("max", wsc.maxReconnect), zap.Duration("backoff", backoff)) time.Sleep(backoff) // 尝试重新连接 conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil) if err != nil { wsc.Logger.Error("Reconnect failed", zap.Error(err)) continue } // 更新连接 wsc.Conn = conn wsc.RemoteAddr = conn.RemoteAddr().String() wsc.ctx, wsc.cancel = context.WithCancel(context.Background()) wsc.reconnectCount = 0 // 重置重连计数 wsc.Logger.Info("Reconnected successfully", zap.String("remote_addr", wsc.RemoteAddr)) // 重新启动读取循环和心跳 go wsc.readLoop(wsm.eventBus) go wsc.heartbeatLoop() go wsc.reconnectLoop(wsm) return } wsc.Logger.Error("Max reconnect attempts reached, giving up", zap.Int("attempts", wsc.reconnectCount)) // 从管理器中移除连接 wsm.RemoveConnection(wsc.ID) } // close 关闭连接 func (wsc *WebSocketConnection) close() { wsc.cancel() if err := wsc.Conn.Close(); err != nil { wsc.Logger.Error("Failed to close connection", zap.Error(err)) } } // RemoveConnection 移除连接 func (wsm *WebSocketManager) RemoveConnection(connID string) { wsm.mu.Lock() defer wsm.mu.Unlock() if conn, ok := wsm.connections[connID]; ok { conn.close() delete(wsm.connections, connID) wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID)) } } // GetConnection 获取连接 func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) { wsm.mu.RLock() defer wsm.mu.RUnlock() conn, ok := wsm.connections[connID] return conn, ok } // GetConnectionByBotID 根据BotID获取连接 func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection { wsm.mu.RLock() defer wsm.mu.RUnlock() connections := make([]*WebSocketConnection, 0) for _, conn := range wsm.connections { if conn.BotID == botID { connections = append(connections, conn) } } return connections } // BroadcastToBot 向指定Bot的所有连接广播消息 func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) { connections := wsm.GetConnectionByBotID(botID) for _, conn := range connections { if err := conn.SendMessage(data); err != nil { wsm.logger.Error("Failed to send message to connection", zap.String("conn_id", conn.ID), zap.Error(err)) } } } // DialConfig WebSocket客户端连接配置 type DialConfig struct { URL string BotID string MaxReconnect int HeartbeatTick time.Duration } // Dial 建立WebSocket客户端连接(正向连接) func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) { return wsm.DialWithConfig(DialConfig{ URL: addr, BotID: botID, MaxReconnect: 5, HeartbeatTick: 30 * time.Second, }) } // DialWithConfig 使用配置建立WebSocket客户端连接 func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) { u, err := url.Parse(config.URL) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } // 验证URL scheme必须是ws或wss if u.Scheme != "ws" && u.Scheme != "wss" { return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme) } conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger) wsConn.reconnectURL = config.URL wsConn.maxReconnect = config.MaxReconnect wsConn.heartbeatTick = config.HeartbeatTick wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket forward connection established", zap.String("conn_id", wsConn.ID), zap.String("bot_id", config.BotID), zap.String("addr", config.URL), zap.String("remote_addr", wsConn.RemoteAddr)) // 启动读取循环和心跳 go wsConn.readLoop(wsm.eventBus) go wsConn.heartbeatLoop() // 如果是正向连接,启动重连监控 if wsConn.Type == ConnectionTypeForward { go wsConn.reconnectLoop(wsm) } return wsConn, nil } // generateConnID 生成连接ID func generateConnID() string { return fmt.Sprintf("conn-%d", time.Now().UnixNano()) }