Files
cellbot/pkg/net/websocket.go
lafay d16261e6bd feat: add rate limiting and improve event handling
- Introduced rate limiting configuration in config.toml with options for enabling, requests per second (RPS), and burst capacity.
- Enhanced event handling in the OneBot11 adapter to ignore messages sent by the bot itself.
- Updated the dispatcher to register rate limit middleware based on configuration settings.
- Refactored WebSocket message handling to support flexible JSON parsing and improved event type detection.
- Removed deprecated echo plugin and associated tests to streamline the codebase.
2026-01-05 01:00:38 +08:00

496 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package net
import (
"context"
"fmt"
"net/url"
"sync"
"time"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"github.com/bytedance/sonic"
"github.com/fasthttp/websocket"
"github.com/valyala/fasthttp"
"go.uber.org/zap"
)
// WebSocketManager WebSocket连接管理器
type WebSocketManager struct {
connections map[string]*WebSocketConnection
logger *zap.Logger
eventBus *engine.EventBus
mu sync.RWMutex
upgrader *websocket.FastHTTPUpgrader
}
// NewWebSocketManager 创建WebSocket管理器
func NewWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *WebSocketManager {
return &WebSocketManager{
connections: make(map[string]*WebSocketConnection),
logger: logger.Named("websocket"),
eventBus: eventBus,
upgrader: &websocket.FastHTTPUpgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(ctx *fasthttp.RequestCtx) bool {
return true // 允许所有来源,生产环境应加强检查
},
},
}
}
// ConnectionType 连接类型
type ConnectionType string
const (
ConnectionTypeReverse ConnectionType = "reverse" // 反向连接(被动接受)
ConnectionTypeForward ConnectionType = "forward" // 正向连接(主动发起)
)
// WebSocketConnection WebSocket连接
type WebSocketConnection struct {
ID string
Conn *websocket.Conn
BotID string
Logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
Type ConnectionType
RemoteAddr string
reconnectURL string // 用于正向连接重连
maxReconnect int // 最大重连次数
reconnectCount int // 当前重连次数
heartbeatTick time.Duration // 心跳间隔
}
// NewWebSocketConnection 创建WebSocket连接
func NewWebSocketConnection(conn *websocket.Conn, botID string, connType ConnectionType, logger *zap.Logger) *WebSocketConnection {
ctx, cancel := context.WithCancel(context.Background())
connID := generateConnID()
return &WebSocketConnection{
ID: connID,
Conn: conn,
BotID: botID,
Logger: logger.With(zap.String("conn_id", connID)),
ctx: ctx,
cancel: cancel,
Type: connType,
RemoteAddr: conn.RemoteAddr().String(),
maxReconnect: 5,
heartbeatTick: 30 * time.Second,
}
}
// UpgradeWebSocket 升级HTTP连接为WebSocket
func (wsm *WebSocketManager) UpgradeWebSocket(ctx *fasthttp.RequestCtx) (*WebSocketConnection, error) {
// 获取查询参数
botID := string(ctx.QueryArgs().Peek("bot_id"))
if botID == "" {
return nil, fmt.Errorf("missing bot_id parameter")
}
// 创建通道用于传递连接
connChan := make(chan *websocket.Conn, 1)
// 升级连接
err := wsm.upgrader.Upgrade(ctx, func(conn *websocket.Conn) {
connChan <- conn
})
if err != nil {
return nil, fmt.Errorf("failed to upgrade connection: %w", err)
}
// 等待连接在回调中建立
conn := <-connChan
// 创建连接对象(反向连接)
wsConn := NewWebSocketConnection(conn, botID, ConnectionTypeReverse, wsm.logger)
// 存储连接
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket reverse connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", botID),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳
go wsConn.readLoop(wsm.eventBus)
go wsConn.heartbeatLoop()
return wsConn, nil
}
// readLoop 读取循环
func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) {
defer wsc.close()
for {
select {
case <-wsc.ctx.Done():
return
default:
messageType, message, err := wsc.Conn.ReadMessage()
if err != nil {
wsc.Logger.Error("Failed to read message", zap.Error(err))
return
}
// 只处理文本消息,忽略其他类型
if messageType != websocket.TextMessage {
wsc.Logger.Warn("Received non-text message, ignoring",
zap.Int("message_type", messageType))
continue
}
// 处理消息
wsc.handleMessage(message, eventBus)
}
}
}
// handleMessage 处理消息
func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) {
wsc.Logger.Debug("Received message", zap.ByteString("data", data))
// 先解析为 map 以支持灵活的字段类型(如 self_id 可能是数字或字符串)
// 使用 sonic.Config 配置更宽松的解析,允许数字和字符串之间的转换
cfg := sonic.Config{
UseInt64: true, // 使用 int64 而不是 float64 来解析数字
NoValidateJSONSkip: true, // 跳过类型验证,允许更灵活的类型转换
}.Froze()
var rawMap map[string]interface{}
if err := cfg.Unmarshal(data, &rawMap); err != nil {
wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data))
return
}
// 检查是否是 API 响应(有 echo 字段且没有 post_type
// 如果是响应,不在这里处理,让 adapter 的 handleWebSocketMessages 处理
if echo, hasEcho := rawMap["echo"].(string); hasEcho && echo != "" {
if _, hasPostType := rawMap["post_type"]; !hasPostType {
// 这是 API 响应,不在这里处理
// 正向 WebSocket 时adapter 的 handleWebSocketMessages 会处理
// 反向 WebSocket 时,响应应该通过 adapter 处理
wsc.Logger.Debug("Skipping API response in handleMessage, will be handled by adapter",
zap.String("echo", echo))
return
}
}
// 构建 BaseEvent
event := &protocol.BaseEvent{
Data: make(map[string]interface{}),
}
// 处理 self_id可能是数字或字符串
if selfIDVal, ok := rawMap["self_id"]; ok {
switch v := selfIDVal.(type) {
case string:
event.SelfID = v
case float64:
event.SelfID = fmt.Sprintf("%.0f", v)
case int64:
event.SelfID = fmt.Sprintf("%d", v)
case int:
event.SelfID = fmt.Sprintf("%d", v)
default:
event.SelfID = fmt.Sprintf("%v", v)
}
}
// 如果没有SelfID使用连接的BotID
if event.SelfID == "" {
event.SelfID = wsc.BotID
}
// 处理时间戳
if timeVal, ok := rawMap["time"]; ok {
switch v := timeVal.(type) {
case float64:
event.Timestamp = int64(v)
case int64:
event.Timestamp = v
case int:
event.Timestamp = int64(v)
}
}
if event.Timestamp == 0 {
event.Timestamp = time.Now().Unix()
}
// 处理类型字段
if typeVal, ok := rawMap["post_type"]; ok {
if typeStr, ok := typeVal.(string); ok {
// OneBot11 格式post_type -> EventType 映射
switch typeStr {
case "message":
event.Type = protocol.EventTypeMessage
case "notice":
event.Type = protocol.EventTypeNotice
case "request":
event.Type = protocol.EventTypeRequest
case "meta_event":
event.Type = protocol.EventTypeMeta
case "message_sent":
// 忽略机器人自己发送的消息
wsc.Logger.Debug("Ignoring message_sent event")
return
default:
event.Type = protocol.EventType(typeStr)
}
}
} else if typeVal, ok := rawMap["type"]; ok {
if typeStr, ok := typeVal.(string); ok {
event.Type = protocol.EventType(typeStr)
}
}
// 验证必需字段
if event.Type == "" {
wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data))
return
}
// 处理 detail_type
if detailTypeVal, ok := rawMap["message_type"]; ok {
if detailTypeStr, ok := detailTypeVal.(string); ok {
event.DetailType = detailTypeStr
}
} else if detailTypeVal, ok := rawMap["detail_type"]; ok {
if detailTypeStr, ok := detailTypeVal.(string); ok {
event.DetailType = detailTypeStr
}
}
// 将所有其他字段放入 Data
for k, v := range rawMap {
if k != "self_id" && k != "time" && k != "post_type" && k != "type" && k != "message_type" && k != "detail_type" {
event.Data[k] = v
}
}
wsc.Logger.Info("Event received",
zap.String("type", string(event.Type)),
zap.String("detail_type", event.DetailType),
zap.String("self_id", event.SelfID))
// 发布到事件总线
eventBus.Publish(event)
}
// SendMessage 发送消息
func (wsc *WebSocketConnection) SendMessage(data []byte) error {
wsc.Logger.Debug("Sending message", zap.ByteString("data", data))
err := wsc.Conn.WriteMessage(websocket.TextMessage, data)
if err != nil {
return fmt.Errorf("failed to send message: %w", err)
}
return nil
}
// heartbeatLoop 心跳循环
func (wsc *WebSocketConnection) heartbeatLoop() {
ticker := time.NewTicker(wsc.heartbeatTick)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 发送ping消息
if err := wsc.Conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
wsc.Logger.Warn("Failed to send ping", zap.Error(err))
return
}
wsc.Logger.Debug("Heartbeat ping sent")
case <-wsc.ctx.Done():
return
}
}
}
// reconnectLoop 重连循环(仅用于正向连接)
func (wsc *WebSocketConnection) reconnectLoop(wsm *WebSocketManager) {
<-wsc.ctx.Done() // 等待连接断开
if wsc.Type != ConnectionTypeForward || wsc.reconnectURL == "" {
return
}
wsc.Logger.Info("Connection closed, attempting to reconnect",
zap.Int("max_reconnect", wsc.maxReconnect))
for wsc.reconnectCount < wsc.maxReconnect {
wsc.reconnectCount++
backoff := time.Duration(wsc.reconnectCount) * 5 * time.Second
wsc.Logger.Info("Reconnecting",
zap.Int("attempt", wsc.reconnectCount),
zap.Int("max", wsc.maxReconnect),
zap.Duration("backoff", backoff))
time.Sleep(backoff)
// 尝试重新连接
conn, _, err := websocket.DefaultDialer.Dial(wsc.reconnectURL, nil)
if err != nil {
wsc.Logger.Error("Reconnect failed", zap.Error(err))
continue
}
// 更新连接
wsc.Conn = conn
wsc.RemoteAddr = conn.RemoteAddr().String()
wsc.ctx, wsc.cancel = context.WithCancel(context.Background())
wsc.reconnectCount = 0 // 重置重连计数
wsc.Logger.Info("Reconnected successfully",
zap.String("remote_addr", wsc.RemoteAddr))
// 重新启动读取循环和心跳
go wsc.readLoop(wsm.eventBus)
go wsc.heartbeatLoop()
go wsc.reconnectLoop(wsm)
return
}
wsc.Logger.Error("Max reconnect attempts reached, giving up",
zap.Int("attempts", wsc.reconnectCount))
// 从管理器中移除连接
wsm.RemoveConnection(wsc.ID)
}
// close 关闭连接
func (wsc *WebSocketConnection) close() {
wsc.cancel()
if err := wsc.Conn.Close(); err != nil {
wsc.Logger.Error("Failed to close connection", zap.Error(err))
}
}
// RemoveConnection 移除连接
func (wsm *WebSocketManager) RemoveConnection(connID string) {
wsm.mu.Lock()
defer wsm.mu.Unlock()
if conn, ok := wsm.connections[connID]; ok {
conn.close()
delete(wsm.connections, connID)
wsm.logger.Info("WebSocket connection removed", zap.String("conn_id", connID))
}
}
// GetConnection 获取连接
func (wsm *WebSocketManager) GetConnection(connID string) (*WebSocketConnection, bool) {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
conn, ok := wsm.connections[connID]
return conn, ok
}
// GetConnectionByBotID 根据BotID获取连接
func (wsm *WebSocketManager) GetConnectionByBotID(botID string) []*WebSocketConnection {
wsm.mu.RLock()
defer wsm.mu.RUnlock()
connections := make([]*WebSocketConnection, 0)
for _, conn := range wsm.connections {
if conn.BotID == botID {
connections = append(connections, conn)
}
}
return connections
}
// BroadcastToBot 向指定Bot的所有连接广播消息
func (wsm *WebSocketManager) BroadcastToBot(botID string, data []byte) {
connections := wsm.GetConnectionByBotID(botID)
for _, conn := range connections {
if err := conn.SendMessage(data); err != nil {
wsm.logger.Error("Failed to send message to connection",
zap.String("conn_id", conn.ID),
zap.Error(err))
}
}
}
// DialConfig WebSocket客户端连接配置
type DialConfig struct {
URL string
BotID string
MaxReconnect int
HeartbeatTick time.Duration
AutoReadLoop bool // 是否自动启动 readLoopadapter 自己处理消息时设为 false
}
// Dial 建立WebSocket客户端连接正向连接
func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnection, error) {
return wsm.DialWithConfig(DialConfig{
URL: addr,
BotID: botID,
MaxReconnect: 5,
HeartbeatTick: 30 * time.Second,
AutoReadLoop: false, // adapter 自己处理消息,不自动启动 readLoop
})
}
// DialWithConfig 使用配置建立WebSocket客户端连接
func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnection, error) {
u, err := url.Parse(config.URL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
// 验证URL scheme必须是ws或wss
if u.Scheme != "ws" && u.Scheme != "wss" {
return nil, fmt.Errorf("invalid URL scheme: %s, expected ws or wss", u.Scheme)
}
conn, _, err := websocket.DefaultDialer.Dial(config.URL, nil)
if err != nil {
return nil, fmt.Errorf("failed to dial: %w", err)
}
wsConn := NewWebSocketConnection(conn, config.BotID, ConnectionTypeForward, wsm.logger)
wsConn.reconnectURL = config.URL
wsConn.maxReconnect = config.MaxReconnect
wsConn.heartbeatTick = config.HeartbeatTick
wsm.mu.Lock()
wsm.connections[wsConn.ID] = wsConn
wsm.mu.Unlock()
wsm.logger.Info("WebSocket forward connection established",
zap.String("conn_id", wsConn.ID),
zap.String("bot_id", config.BotID),
zap.String("addr", config.URL),
zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳(如果启用)
if config.AutoReadLoop {
go wsConn.readLoop(wsm.eventBus)
}
go wsConn.heartbeatLoop()
// 如果是正向连接,启动重连监控
if wsConn.Type == ConnectionTypeForward {
go wsConn.reconnectLoop(wsm)
}
return wsConn, nil
}
// generateConnID 生成连接ID
func generateConnID() string {
return fmt.Sprintf("conn-%d", time.Now().UnixNano())
}