package milky import ( "cellbot/internal/engine" "cellbot/internal/protocol" "cellbot/pkg/net" "context" "fmt" "strconv" "time" "go.uber.org/zap" ) // Config Milky 适配器配置 type Config struct { // 协议端地址(如 http://localhost:3000) ProtocolURL string `toml:"protocol_url"` // 访问令牌 AccessToken string `toml:"access_token"` // 事件接收方式: sse, websocket, webhook EventMode string `toml:"event_mode"` // Webhook 监听地址(仅当 event_mode = "webhook" 时需要) WebhookListenAddr string `toml:"webhook_listen_addr"` // 超时时间(秒) Timeout int `toml:"timeout"` // 重试次数 RetryCount int `toml:"retry_count"` } // Adapter Milky 协议适配器 type Adapter struct { config *Config selfID string apiClient *APIClient sseClient *net.SSEClient wsManager *net.WebSocketManager wsConn *net.WebSocketConnection webhookServer *WebhookServer eventBus *engine.EventBus eventConverter *EventConverter logger *zap.Logger ctx context.Context cancel context.CancelFunc } // NewAdapter 创建 Milky 适配器 func NewAdapter(config *Config, selfID string, eventBus *engine.EventBus, wsManager *net.WebSocketManager, logger *zap.Logger) *Adapter { ctx, cancel := context.WithCancel(context.Background()) timeout := time.Duration(config.Timeout) * time.Second if timeout == 0 { timeout = 30 * time.Second } retryCount := config.RetryCount if retryCount == 0 { retryCount = 3 } return &Adapter{ config: config, selfID: selfID, apiClient: NewAPIClient(config.ProtocolURL, config.AccessToken, timeout, retryCount, logger), eventBus: eventBus, wsManager: wsManager, eventConverter: NewEventConverter(logger), logger: logger.Named("milky-adapter"), ctx: ctx, cancel: cancel, } } // Connect 连接到协议端 func (a *Adapter) Connect(ctx context.Context) error { a.logger.Info("Connecting to Milky protocol server", zap.String("url", a.config.ProtocolURL), zap.String("event_mode", a.config.EventMode)) // 根据配置选择事件接收方式 switch a.config.EventMode { case "sse": return a.connectSSE(ctx) case "websocket": return a.connectWebSocket(ctx) case "webhook": return a.startWebhook() default: return fmt.Errorf("unknown event mode: %s", a.config.EventMode) } } // connectSSE 连接 SSE func (a *Adapter) connectSSE(ctx context.Context) error { eventURL := a.config.ProtocolURL + "/event" // 创建 SSE 客户端配置 sseConfig := net.SSEClientConfig{ URL: eventURL, AccessToken: a.config.AccessToken, ReconnectDelay: 5 * time.Second, MaxReconnect: -1, // 无限重连 EventFilter: "milky_event", // 只接收 milky_event 类型 BufferSize: 100, } a.sseClient = net.NewSSEClient(sseConfig, a.logger) // 启动 SSE 连接 if err := a.sseClient.Connect(ctx); err != nil { return fmt.Errorf("failed to connect SSE: %w", err) } // 启动事件处理 go a.handleEvents(a.sseClient.Events()) a.logger.Info("SSE connection established") return nil } // connectWebSocket 连接 WebSocket func (a *Adapter) connectWebSocket(ctx context.Context) error { // 构建 WebSocket URL eventURL := a.config.ProtocolURL + "/event" // 替换 http:// 为 ws://,https:// 为 wss:// if len(eventURL) > 7 && eventURL[:7] == "http://" { eventURL = "ws://" + eventURL[7:] } else if len(eventURL) > 8 && eventURL[:8] == "https://" { eventURL = "wss://" + eventURL[8:] } // 添加 access_token 参数 if a.config.AccessToken != "" { eventURL += "?access_token=" + a.config.AccessToken } a.logger.Info("Connecting to WebSocket", zap.String("url", eventURL)) // 使用 WebSocketManager 建立连接 conn, err := a.wsManager.Dial(eventURL, a.selfID) if err != nil { return fmt.Errorf("failed to dial WebSocket: %w", err) } a.wsConn = conn // 启动事件处理 go a.handleWebSocketEvents() a.logger.Info("WebSocket connection established") return nil } // handleWebSocketEvents 处理 WebSocket 事件 func (a *Adapter) handleWebSocketEvents() { for { select { case <-a.ctx.Done(): return default: } // 读取消息 _, message, err := a.wsConn.Conn.ReadMessage() if err != nil { a.logger.Error("Failed to read WebSocket message", zap.Error(err)) return } // 转换事件 event, err := a.eventConverter.Convert(message) if err != nil { a.logger.Error("Failed to convert event", zap.Error(err)) continue } // 发布到事件总线 a.eventBus.Publish(event) } } // startWebhook 启动 Webhook 服务器 func (a *Adapter) startWebhook() error { if a.config.WebhookListenAddr == "" { return fmt.Errorf("webhook_listen_addr is required for webhook mode") } a.webhookServer = NewWebhookServer(a.config.WebhookListenAddr, a.logger) // 启动服务器 if err := a.webhookServer.Start(); err != nil { return fmt.Errorf("failed to start webhook server: %w", err) } // 启动事件处理 go a.handleEvents(a.webhookServer.Events()) a.logger.Info("Webhook server started", zap.String("addr", a.config.WebhookListenAddr)) return nil } // handleEvents 处理事件 func (a *Adapter) handleEvents(eventChan <-chan []byte) { for { select { case <-a.ctx.Done(): return case rawEvent, ok := <-eventChan: if !ok { a.logger.Info("Event channel closed") return } // 转换事件 event, err := a.eventConverter.Convert(rawEvent) if err != nil { a.logger.Error("Failed to convert event", zap.Error(err)) continue } // 发布到事件总线 a.eventBus.Publish(event) } } } // SendAction 发送动作 func (a *Adapter) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) { // 调用 API resp, err := a.apiClient.Call(ctx, string(action.GetType()), action.GetParams()) if err != nil { return nil, fmt.Errorf("failed to call API: %w", err) } return resp.Data, nil } // ParseMessage 解析消息 func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) { return a.eventConverter.Convert(raw) } // Disconnect 断开连接 func (a *Adapter) Disconnect() error { a.logger.Info("Disconnecting from Milky protocol server") a.cancel() // 关闭各种连接 if a.sseClient != nil { if err := a.sseClient.Close(); err != nil { a.logger.Error("Failed to close SSE client", zap.Error(err)) } } if a.wsConn != nil { // WebSocket 连接会在 context 取消时自动关闭 a.logger.Info("WebSocket connection will be closed") } if a.webhookServer != nil { if err := a.webhookServer.Stop(); err != nil { a.logger.Error("Failed to stop webhook server", zap.Error(err)) } } if a.apiClient != nil { if err := a.apiClient.Close(); err != nil { a.logger.Error("Failed to close API client", zap.Error(err)) } } return nil } // GetProtocolName 获取协议名称 func (a *Adapter) GetProtocolName() string { return "milky" } // GetProtocolVersion 获取协议版本 func (a *Adapter) GetProtocolVersion() string { return "1.0" } // GetSelfID 获取机器人自身 ID func (a *Adapter) GetSelfID() string { return a.selfID } // IsConnected 是否已连接 func (a *Adapter) IsConnected() bool { switch a.config.EventMode { case "sse": return a.sseClient != nil case "websocket": return a.wsConn != nil && a.wsConn.Conn != nil case "webhook": return a.webhookServer != nil default: return false } } // GetStats 获取统计信息 func (a *Adapter) GetStats() map[string]interface{} { stats := map[string]interface{}{ "protocol": "milky", "self_id": a.selfID, "event_mode": a.config.EventMode, "connected": a.IsConnected(), } if a.config.EventMode == "websocket" && a.wsConn != nil { stats["remote_addr"] = a.wsConn.RemoteAddr stats["connection_type"] = a.wsConn.Type } return stats } // CallAPI 直接调用 API(提供给 Bot 使用) func (a *Adapter) CallAPI(ctx context.Context, endpoint string, params map[string]interface{}) (*APIResponse, error) { return a.apiClient.Call(ctx, endpoint, params) } // GetConfig 获取配置 func (a *Adapter) GetConfig() *Config { return a.config } // SetSelfID 设置机器人自身 ID func (a *Adapter) SetSelfID(selfID string) { a.selfID = selfID } // GetSelfIDInt64 获取机器人自身 ID(int64) func (a *Adapter) GetSelfIDInt64() (int64, error) { return strconv.ParseInt(a.selfID, 10, 64) }