Files
cellbot/pkg/net/websocket.go

402 lines
10 KiB
Go
Raw Normal View History

package net
import (
"context"
"fmt"
"net/url"
"sync"
"time"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"github.com/bytedance/sonic"
"github.com/fasthttp/websocket"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
// WebSocketManager WebSocket连接管理器
type WebSocketManager struct {
connections map[string]*WebSocketConnection
logger *zap.Logger
eventBus *engine.EventBus
mu sync.RWMutex
upgrader *websocket.FastHTTPUpgrader
}
// NewWebSocketManager 创建WebSocket管理器
func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager {
return &WebSocketManager{
connections: make(map[string]*WebSocketConnection),
logger: logger.Named("websocket"),
eventBus: eventBus,
upgrader: &websocket.FastHTTPUpgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
return true // 允许所有来源,生产环境应加强检查
},
},
}
}
// ConnectionType 连接类型
type ConnectionType string
const (
ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受)
ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起)
)
// WebSocketConnection WebSocket连接
type WebSocketConnection struct {
ID string
Conn *websocket.Conn
BotID string
Logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
Type ConnectionType
RemoteAddr string
reconnectURL string // 用于正向连接重连
maxReconnect int // 最大重连次数
reconnectCount int // 当前重连次数
heartbeatTick time.Duration // 心跳间隔
}
// NewWebSocketConnection 创建WebSocket连接
func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection {
ctx, cancel := context.WithCancel(context.Background())
connID := generateConnID()
return &WebSocketConnection{
ID: connID,
Conn: conn,
BotID: botID,
Logger: logger.With(zap.String("conn_id", connID)),
ctx: ctx,
cancel: cancel,
Type: connType,
RemoteAddr: conn.RemoteAddr().String(),
maxReconnect: 5,
heartbeatTick: 30 * time.Second,
}
}
// UpgradeWebSocket 升级HTTP连接为WebSocket
func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) {
// 获取查询参数
botID := string(ctx.QueryArgs().Peek("bot_id"))
if botID == "" {
return nil, fmt.Errorf("missing bot_id parameter")
}
// 创建通道用于传递连接
connChan := make(chan *websocket.Conn, 1)
// 升级连接
err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) {
connChan <- conn
})
if err != nil {
return nil, fmt.Errorf("failed to upgrade connection: %w", err)
}
// 等待连接在回调中建立
conn := <-connChan
// 创建连接对象(反向连接)
wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger)
// 存储连接
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket reverse connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", botID),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳
go wsConn.readLoop(wsm.eventBus)
go wsConn.heartbeatLoop()
return wsConn, nil
}
// readLoop 读取循环
func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) {
defer wsc.close()
for {
select {
case <-wsc.ctx.Done():
return
default:
messageType, message, err := wsc.Conn.ReadMessage()
if err != nil {
wsc.Logger.Error("Failed to read message", zap.Error(err))
return
}
// 只处理文本消息,忽略其他类型
if messageType != websocket.TextMessage {
wsc.Logger.Warn("Received non-text message, ignoring",
zap.Int("message_type", messageType))
continue
}
// 处理消息
wsc.handleMessage(message, eventBus)
}
}
}
// handleMessage 处理消息
func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) {
wsc.Logger.Debug("Received message", zap.ByteString("data", data))
// 解析JSON消息为BaseEvent
var event protocol.BaseEvent
if err := sonic.Unmarshal(data, &event); err != nil {
wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data))
return
}
// 验证必需字段
if event.Type == "" {
wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data))
return
}
// 如果没有时间戳,使用当前时间
if event.Timestamp == 0 {
event.Timestamp = time.Now().Unix()
}
// 如果没有SelfID使用连接的BotID
if event.SelfID == "" {
event.SelfID = wsc.BotID
}
// 确保Data字段不为nil
if event.Data == nil {
event.Data = make(map[string]interface{})
}
wsc.Logger.Info("Event received",
zap.String("type", string(event.Type)),
zap.String("detail_type", event.DetailType),
zap.String("self_id", event.SelfID))
// 发布到事件总线
eventBus.Publish(&event)
}
// SendMessage 发送消息
func (wsc *WebSocketConnection) SendMessage(data []byte) error {
wsc.Logger.Debug("Sending message", zap.ByteString("data", data))
err := wsc.Conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
return nil
}
// heartbeatLoop 心跳循环
func (wsc *WebSocketConnection) heartbeatLoop() {
ticker := time.NewTicker(wsc.heartbeatTick)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 发送ping消息
if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
wsc.Logger.Warn("Failed to send ping", zap.Error(err))
return
}
wsc.Logger.Debug("Heartbeat ping sent")
case <-wsc.ctx.Done():
return
}
}
}
// reconnectLoop 重连循环(仅用于正向连接)
func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) {
<-wsc.ctx.Done() // 等待连接断开
if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" {
return
}
wsc.Logger.Info("Connection closed, attempting to reconnect",
zap.Int("max_reconnect", wsc.maxReconnect))
for wsc.reconnectCount < wsc.maxReconnect {
wsc.reconnectCount++
backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second
wsc.Logger.Info("Reconnecting",
zap.Int("attempt", wsc.reconnectCount),
zap.Int("max", wsc.maxReconnect),
zap.Duration("backoff", backoff))
time.Sleep(backoff)
// 尝试重新连接
conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil)
if err != nil {
wsc.Logger.Error("Reconnect failed", zap.Error(err))
continue
}
// 更新连接
wsc.Conn = conn
wsc.RemoteAddr = conn.RemoteAddr().String()
wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
wsc.reconnectCount = 0 // 重置重连计数
wsc.Logger.Info("Reconnected successfully",
zap.String("remote_addr", wsc.RemoteAddr))
// 重新启动读取循环和心跳
go wsc.readLoop(wsm.eventBus)
go wsc.heartbeatLoop()
go wsc.reconnectLoop(wsm)
return
}
wsc.Logger.Error("Max reconnect attempts reached, giving up",
zap.Int("attempts", wsc.reconnectCount))
// 从管理器中移除连接
wsm.RemoveConnection(wsc.ID)
}
// close 关闭连接
func (wsc *WebSocketConnection) close() {
wsc.cancel()
if err := wsc.Conn.Close(); err != nil {
wsc.Logger.Error("Failed to close connection", zap.Error(err))
}
}
// RemoveConnection 移除连接
func (wsm *WebSocketManager) RemoveConnection(connID string) {
wsm.mu.Lock()
defer wsm.mu.Unlock()
if conn, ok := wsm.connections[connID]; ok {
conn.close()
delete(wsm.connections, connID)
wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID))
}
}
// GetConnection 获取连接
func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
conn, ok := wsm.connections[connID]
return conn, ok
}
// GetConnectionByBotID 根据BotID获取连接
func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
connections := make([]*WebSocketConnection, 0)
for _, conn := range wsm.connections {
if conn.BotID == botID {
connections = append(connections, conn)
}
}
return connections
}
// BroadcastToBot 向指定Bot的所有连接广播消息
func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) {
connections := wsm.GetConnectionByBotID(botID)
for _, conn := range connections {
if err := conn.SendMessage(data); err != nil {
wsm.logger.Error("Failed to send message to connection",
zap.String("conn_id", conn.ID),
zap.Error(err))
}
}
}
// DialConfig WebSocket客户端连接配置
type DialConfig struct {
URL string
BotID string
MaxReconnect int
HeartbeatTick time.Duration
}
// Dial 建立WebSocket客户端连接正向连接
func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) {
return wsm.DialWithConfig(DialConfig{
URL: addr,
BotID: botID,
MaxReconnect: 5,
HeartbeatTick: 30 * time.Second,
})
}
// DialWithConfig 使用配置建立WebSocket客户端连接
func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
// 验证URL scheme必须是ws或wss
if u.Scheme != "ws" && u.Scheme != "wss" {
return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme)
}
conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger)
wsConn.reconnectURL = config.URL
wsConn.maxReconnect = config.MaxReconnect
wsConn.heartbeatTick = config.HeartbeatTick
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket forward connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", config.BotID),
zap.String("addr", config.URL),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳
go wsConn.readLoop(wsm.eventBus)
go wsConn.heartbeatLoop()
// 如果是正向连接,启动重连监控
if wsConn.Type == ConnectionTypeForward {
go wsConn.reconnectLoop(wsm)
}
return wsConn, nil
}
// generateConnID 生成连接ID
func generateConnID() string {
return fmt.Sprintf("conn-%d", time.Now().UnixNano())
}