package net import ( "context" "fmt" "net/url" "sync" "time" "cellbot/internal/engine" "cellbot/internal/protocol" "github.com/fasthttp/websocket" "github.com/valyala/fasthttp" "go.uber.org/zap" ) // WebSocketManager WebSocket连接管理器 type WebSocketManager struct { connections map[string]*WebSocketConnection logger *zap.Logger eventBus *engine.EventBus mu sync.RWMutex upgrader *websocket.FastHTTPUpgrader } // NewWebSocketManager 创建WebSocket管理器 func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager { return &WebSocketManager{ connections: make(map[string]*WebSocketConnection), logger: logger.Named("websocket"), eventBus: eventBus, upgrader: &websocket.FastHTTPUpgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { return true // 允许所有来源,生产环境应加强检查 }, }, } } // WebSocketConnection WebSocket连接 type WebSocketConnection struct { ID string Conn *websocket.Conn BotID string Logger *zap.Logger ctx context.Context cancel context.CancelFunc } // NewWebSocketConnection 创建WebSocket连接 func NewWebSocketConnection(conn *websocket.Conn, botID string, logger *zap.Logger) *WebSocketConnection { ctx, cancel := context.WithCancel(context.Background()) connID := generateConnID() return &WebSocketConnection{ ID: connID, Conn: conn, BotID: botID, Logger: logger.With(zap.String("conn_id", connID)), ctx: ctx, cancel: cancel, } } // UpgradeWebSocket 升级HTTP连接为WebSocket func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) { // 获取查询参数 botID := string(ctx.QueryArgs().Peek("bot_id")) if botID == "" { return nil, fmt.Errorf("missing bot_id parameter") } // 创建通道用于传递连接 connChan := make(chan *websocket.Conn, 1) // 升级连接 err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) { connChan <- conn }) if err != nil { return nil, fmt.Errorf("failed to upgrade connection: %w", err) } // 等待连接在回调中建立 conn := <-connChan // 创建连接对象 wsConn := NewWebSocketConnection(conn, botID, wsm.logger) // 存储连接 wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket connection established", zap.String("conn_id", wsConn.ID), zap.String("bot_id", botID)) // 启动读取循环 go wsConn.readLoop(wsm.eventBus) return wsConn, nil } // readLoop 读取循环 func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { defer wsc.close() for { select { case <-wsc.ctx.Done(): return default: messageType, message, err := wsc.Conn.ReadMessage() if err != nil { wsc.Logger.Error("Failed to read message", zap.Error(err)) return } // 处理消息 wsc.handleMessage(message, eventBus) // messageType 可用于区分文本或二进制消息 } } } // handleMessage 处理消息 func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { wsc.Logger.Debug("Received message", zap.ByteString("data", data)) // TODO: 解析消息为Event对象 // 这里简化实现,实际应该根据协议解析 event := &protocol.BaseEvent{ Type: protocol.EventTypeMessage, DetailType: "private", Timestamp: time.Now().Unix(), SelfID: wsc.BotID, Data: make(map[string]interface{}), } // 发布到事件总线 eventBus.Publish(event) } // SendMessage 发送消息 func (wsc *WebSocketConnection) SendMessage(data []byte) error { wsc.Logger.Debug("Sending message", zap.ByteString("data", data)) err := wsc.Conn.WriteMessage(websocket.TextMessage, data) if err != nil { return fmt.Errorf("failed to send message: %w", err) } return nil } // close 关闭连接 func (wsc *WebSocketConnection) close() { wsc.cancel() if err := wsc.Conn.Close(); err != nil { wsc.Logger.Error("Failed to close connection", zap.Error(err)) } } // RemoveConnection 移除连接 func (wsm *WebSocketManager) RemoveConnection(connID string) { wsm.mu.Lock() defer wsm.mu.Unlock() if conn, ok := wsm.connections[connID]; ok { conn.close() delete(wsm.connections, connID) wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID)) } } // GetConnection 获取连接 func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) { wsm.mu.RLock() defer wsm.mu.RUnlock() conn, ok := wsm.connections[connID] return conn, ok } // GetConnectionByBotID 根据BotID获取连接 func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection { wsm.mu.RLock() defer wsm.mu.RUnlock() connections := make([]*WebSocketConnection, 0) for _, conn := range wsm.connections { if conn.BotID == botID { connections = append(connections, conn) } } return connections } // BroadcastToBot 向指定Bot的所有连接广播消息 func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) { connections := wsm.GetConnectionByBotID(botID) for _, conn := range connections { if err := conn.SendMessage(data); err != nil { wsm.logger.Error("Failed to send message to connection", zap.String("conn_id", conn.ID), zap.Error(err)) } } } // Dial 建立WebSocket客户端连接 func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) { u, err := url.Parse(addr) if err != nil { return nil, fmt.Errorf("invalid URL: %w", err) } conn, _, err := websocket.DefaultDialer.Dial(addr, nil) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } wsConn := NewWebSocketConnection(conn, botID, wsm.logger) wsm.mu.Lock() wsm.connections[wsConn.ID] = wsConn wsm.mu.Unlock() wsm.logger.Info("WebSocket client connected", zap.String("conn_id", wsConn.ID), zap.String("bot_id", botID), zap.String("addr", addr)) // 启动读取循环 go wsConn.readLoop(wsm.eventBus) return wsConn, nil } // generateConnID 生成连接ID func generateConnID() string { return fmt.Sprintf("conn-%d", time.Now().UnixNano()) }