Files
cellbot/pkg/net/websocket.go
lafay 44fe05ff62 chore: update dependencies and improve bot configuration
- Upgrade Go version to 1.24.0 and update toolchain.
- Update various dependencies in go.mod and go.sum, including:
  - Upgrade `fasthttp/websocket` to v1.5.12
  - Upgrade `fsnotify/fsnotify` to v1.9.0
  - Upgrade `valyala/fasthttp` to v1.58.0
  - Add new dependencies for `bytedance/sonic` and `google/uuid`.
- Refactor bot configuration in config.toml to support multiple bot protocols, including "milky" and "onebot11".
- Modify internal configuration structures to accommodate new bot settings.
- Enhance event dispatcher with metrics tracking and asynchronous processing capabilities.
- Implement WebSocket connection management with heartbeat and reconnection logic.
- Update server handling for bot management and event publishing.
2026-01-05 00:40:09 +08:00

402 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package net
import (
"context"
"fmt"
"net/url"
"sync"
"time"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"github.com/bytedance/sonic"
"github.com/fasthttp/websocket"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
// WebSocketManager WebSocket连接管理器
type WebSocketManager struct {
connections map[string]*WebSocketConnection
logger *zap.Logger
eventBus *engine.EventBus
mu sync.RWMutex
upgrader *websocket.FastHTTPUpgrader
}
// NewWebSocketManager 创建WebSocket管理器
func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager {
return &WebSocketManager{
connections: make(map[string]*WebSocketConnection),
logger: logger.Named("websocket"),
eventBus: eventBus,
upgrader: &websocket.FastHTTPUpgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
return true // 允许所有来源,生产环境应加强检查
},
},
}
}
// ConnectionType 连接类型
type ConnectionType string
const (
ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受)
ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起)
)
// WebSocketConnection WebSocket连接
type WebSocketConnection struct {
ID string
Conn *websocket.Conn
BotID string
Logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
Type ConnectionType
RemoteAddr string
reconnectURL string // 用于正向连接重连
maxReconnect int // 最大重连次数
reconnectCount int // 当前重连次数
heartbeatTick time.Duration // 心跳间隔
}
// NewWebSocketConnection 创建WebSocket连接
func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection {
ctx, cancel := context.WithCancel(context.Background())
connID := generateConnID()
return &WebSocketConnection{
ID: connID,
Conn: conn,
BotID: botID,
Logger: logger.With(zap.String("conn_id", connID)),
ctx: ctx,
cancel: cancel,
Type: connType,
RemoteAddr: conn.RemoteAddr().String(),
maxReconnect: 5,
heartbeatTick: 30 * time.Second,
}
}
// UpgradeWebSocket 升级HTTP连接为WebSocket
func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) {
// 获取查询参数
botID := string(ctx.QueryArgs().Peek("bot_id"))
if botID == "" {
return nil, fmt.Errorf("missing bot_id parameter")
}
// 创建通道用于传递连接
connChan := make(chan *websocket.Conn, 1)
// 升级连接
err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) {
connChan <- conn
})
if err != nil {
return nil, fmt.Errorf("failed to upgrade connection: %w", err)
}
// 等待连接在回调中建立
conn := <-connChan
// 创建连接对象(反向连接)
wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger)
// 存储连接
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket reverse connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", botID),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳
go wsConn.readLoop(wsm.eventBus)
go wsConn.heartbeatLoop()
return wsConn, nil
}
// readLoop 读取循环
func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) {
defer wsc.close()
for {
select {
case <-wsc.ctx.Done():
return
default:
messageType, message, err := wsc.Conn.ReadMessage()
if err != nil {
wsc.Logger.Error("Failed to read message", zap.Error(err))
return
}
// 只处理文本消息,忽略其他类型
if messageType != websocket.TextMessage {
wsc.Logger.Warn("Received non-text message, ignoring",
zap.Int("message_type", messageType))
continue
}
// 处理消息
wsc.handleMessage(message, eventBus)
}
}
}
// handleMessage 处理消息
func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) {
wsc.Logger.Debug("Received message", zap.ByteString("data", data))
// 解析JSON消息为BaseEvent
var event protocol.BaseEvent
if err := sonic.Unmarshal(data, &event); err != nil {
wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data))
return
}
// 验证必需字段
if event.Type == "" {
wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data))
return
}
// 如果没有时间戳,使用当前时间
if event.Timestamp == 0 {
event.Timestamp = time.Now().Unix()
}
// 如果没有SelfID使用连接的BotID
if event.SelfID == "" {
event.SelfID = wsc.BotID
}
// 确保Data字段不为nil
if event.Data == nil {
event.Data = make(map[string]interface{})
}
wsc.Logger.Info("Event received",
zap.String("type", string(event.Type)),
zap.String("detail_type", event.DetailType),
zap.String("self_id", event.SelfID))
// 发布到事件总线
eventBus.Publish(&event)
}
// SendMessage 发送消息
func (wsc *WebSocketConnection) SendMessage(data []byte) error {
wsc.Logger.Debug("Sending message", zap.ByteString("data", data))
err := wsc.Conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
return nil
}
// heartbeatLoop 心跳循环
func (wsc *WebSocketConnection) heartbeatLoop() {
ticker := time.NewTicker(wsc.heartbeatTick)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 发送ping消息
if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
wsc.Logger.Warn("Failed to send ping", zap.Error(err))
return
}
wsc.Logger.Debug("Heartbeat ping sent")
case <-wsc.ctx.Done():
return
}
}
}
// reconnectLoop 重连循环(仅用于正向连接)
func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) {
<-wsc.ctx.Done() // 等待连接断开
if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" {
return
}
wsc.Logger.Info("Connection closed, attempting to reconnect",
zap.Int("max_reconnect", wsc.maxReconnect))
for wsc.reconnectCount < wsc.maxReconnect {
wsc.reconnectCount++
backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second
wsc.Logger.Info("Reconnecting",
zap.Int("attempt", wsc.reconnectCount),
zap.Int("max", wsc.maxReconnect),
zap.Duration("backoff", backoff))
time.Sleep(backoff)
// 尝试重新连接
conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil)
if err != nil {
wsc.Logger.Error("Reconnect failed", zap.Error(err))
continue
}
// 更新连接
wsc.Conn = conn
wsc.RemoteAddr = conn.RemoteAddr().String()
wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
wsc.reconnectCount = 0 // 重置重连计数
wsc.Logger.Info("Reconnected successfully",
zap.String("remote_addr", wsc.RemoteAddr))
// 重新启动读取循环和心跳
go wsc.readLoop(wsm.eventBus)
go wsc.heartbeatLoop()
go wsc.reconnectLoop(wsm)
return
}
wsc.Logger.Error("Max reconnect attempts reached, giving up",
zap.Int("attempts", wsc.reconnectCount))
// 从管理器中移除连接
wsm.RemoveConnection(wsc.ID)
}
// close 关闭连接
func (wsc *WebSocketConnection) close() {
wsc.cancel()
if err := wsc.Conn.Close(); err != nil {
wsc.Logger.Error("Failed to close connection", zap.Error(err))
}
}
// RemoveConnection 移除连接
func (wsm *WebSocketManager) RemoveConnection(connID string) {
wsm.mu.Lock()
defer wsm.mu.Unlock()
if conn, ok := wsm.connections[connID]; ok {
conn.close()
delete(wsm.connections, connID)
wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID))
}
}
// GetConnection 获取连接
func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
conn, ok := wsm.connections[connID]
return conn, ok
}
// GetConnectionByBotID 根据BotID获取连接
func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
connections := make([]*WebSocketConnection, 0)
for _, conn := range wsm.connections {
if conn.BotID == botID {
connections = append(connections, conn)
}
}
return connections
}
// BroadcastToBot 向指定Bot的所有连接广播消息
func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) {
connections := wsm.GetConnectionByBotID(botID)
for _, conn := range connections {
if err := conn.SendMessage(data); err != nil {
wsm.logger.Error("Failed to send message to connection",
zap.String("conn_id", conn.ID),
zap.Error(err))
}
}
}
// DialConfig WebSocket客户端连接配置
type DialConfig struct {
URL string
BotID string
MaxReconnect int
HeartbeatTick time.Duration
}
// Dial 建立WebSocket客户端连接正向连接
func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) {
return wsm.DialWithConfig(DialConfig{
URL: addr,
BotID: botID,
MaxReconnect: 5,
HeartbeatTick: 30 * time.Second,
})
}
// DialWithConfig 使用配置建立WebSocket客户端连接
func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
// 验证URL scheme必须是ws或wss
if u.Scheme != "ws" && u.Scheme != "wss" {
return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme)
}
conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger)
wsConn.reconnectURL = config.URL
wsConn.maxReconnect = config.MaxReconnect
wsConn.heartbeatTick = config.HeartbeatTick
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket forward connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", config.BotID),
zap.String("addr", config.URL),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳
go wsConn.readLoop(wsm.eventBus)
go wsConn.heartbeatLoop()
// 如果是正向连接,启动重连监控
if wsConn.Type == ConnectionTypeForward {
go wsConn.reconnectLoop(wsm)
}
return wsConn, nil
}
// generateConnID 生成连接ID
func generateConnID() string {
return fmt.Sprintf("conn-%d", time.Now().UnixNano())
}