Files
cellbot/internal/adapter/milky/adapter.go

341 lines
8.3 KiB
Go
Raw Permalink Normal View History

package milky
import (
"cellbot/internal/engine"
"cellbot/internal/protocol"
"cellbot/pkg/net"
"context"
"fmt"
"strconv"
"time"
"go.uber.org/zap"
)
// Config Milky 适配器配置
type Config struct {
// 协议端地址(如 http://localhost:3000
ProtocolURL string `toml:"protocol_url"`
// 访问令牌
AccessToken string `toml:"access_token"`
// 事件接收方式: sse, websocket, webhook
EventMode string `toml:"event_mode"`
// Webhook 监听地址(仅当 event_mode = "webhook" 时需要)
WebhookListenAddr string `toml:"webhook_listen_addr"`
// 超时时间(秒)
Timeout int `toml:"timeout"`
// 重试次数
RetryCount int `toml:"retry_count"`
}
// Adapter Milky 协议适配器
type Adapter struct {
config *Config
selfID string
apiClient *APIClient
sseClient *net.SSEClient
wsManager *net.WebSocketManager
wsConn *net.WebSocketConnection
webhookServer *WebhookServer
eventBus *engine.EventBus
eventConverter *EventConverter
logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
}
// NewAdapter 创建 Milky 适配器
func NewAdapter(config *Config, selfID string, eventBus *engine.EventBus, wsManager *net.WebSocketManager, logger *zap.Logger) *Adapter {
ctx, cancel := context.WithCancel(context.Background())
timeout := time.Duration(config.Timeout) * time.Second
if timeout == 0 {
timeout = 30 * time.Second
}
retryCount := config.RetryCount
if retryCount == 0 {
retryCount = 3
}
return &Adapter{
config: config,
selfID: selfID,
apiClient: NewAPIClient(config.ProtocolURL, config.AccessToken, timeout, retryCount, logger),
eventBus: eventBus,
wsManager: wsManager,
eventConverter: NewEventConverter(logger),
logger: logger.Named("milky-adapter"),
ctx: ctx,
cancel: cancel,
}
}
// Connect 连接到协议端
func (a *Adapter) Connect(ctx context.Context) error {
a.logger.Info("Connecting to Milky protocol server",
zap.String("url", a.config.ProtocolURL),
zap.String("event_mode", a.config.EventMode))
// 根据配置选择事件接收方式
switch a.config.EventMode {
case "sse":
return a.connectSSE(ctx)
case "websocket":
return a.connectWebSocket(ctx)
case "webhook":
return a.startWebhook()
default:
return fmt.Errorf("unknown event mode: %s", a.config.EventMode)
}
}
// connectSSE 连接 SSE
func (a *Adapter) connectSSE(ctx context.Context) error {
eventURL := a.config.ProtocolURL + "/event"
// 创建 SSE 客户端配置
sseConfig := net.SSEClientConfig{
URL: eventURL,
AccessToken: a.config.AccessToken,
ReconnectDelay: 5 * time.Second,
MaxReconnect: -1, // 无限重连
EventFilter: "milky_event", // 只接收 milky_event 类型
BufferSize: 100,
}
a.sseClient = net.NewSSEClient(sseConfig, a.logger)
// 启动 SSE 连接
if err := a.sseClient.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect SSE: %w", err)
}
// 启动事件处理
go a.handleEvents(a.sseClient.Events())
a.logger.Info("SSE connection established")
return nil
}
// connectWebSocket 连接 WebSocket
func (a *Adapter) connectWebSocket(ctx context.Context) error {
// 构建 WebSocket URL
eventURL := a.config.ProtocolURL + "/event"
// 替换 http:// 为 ws://https:// 为 wss://
if len(eventURL) > 7 && eventURL[:7] == "http://" {
eventURL = "ws://" + eventURL[7:]
} else if len(eventURL) > 8 && eventURL[:8] == "https://" {
eventURL = "wss://" + eventURL[8:]
}
// 添加 access_token 参数
if a.config.AccessToken != "" {
eventURL += "?access_token=" + a.config.AccessToken
}
a.logger.Info("Connecting to WebSocket", zap.String("url", eventURL))
// 使用 WebSocketManager 建立连接
conn, err := a.wsManager.Dial(eventURL, a.selfID)
if err != nil {
return fmt.Errorf("failed to dial WebSocket: %w", err)
}
a.wsConn = conn
// 启动事件处理
go a.handleWebSocketEvents()
a.logger.Info("WebSocket connection established")
return nil
}
// handleWebSocketEvents 处理 WebSocket 事件
func (a *Adapter) handleWebSocketEvents() {
for {
select {
case <-a.ctx.Done():
return
default:
}
// 读取消息
_, message, err := a.wsConn.Conn.ReadMessage()
if err != nil {
a.logger.Error("Failed to read WebSocket message", zap.Error(err))
return
}
// 转换事件
event, err := a.eventConverter.Convert(message)
if err != nil {
a.logger.Error("Failed to convert event", zap.Error(err))
continue
}
// 发布到事件总线
a.eventBus.Publish(event)
}
}
// startWebhook 启动 Webhook 服务器
func (a *Adapter) startWebhook() error {
if a.config.WebhookListenAddr == "" {
return fmt.Errorf("webhook_listen_addr is required for webhook mode")
}
a.webhookServer = NewWebhookServer(a.config.WebhookListenAddr, a.config.AccessToken, a.logger)
// 启动服务器
if err := a.webhookServer.Start(); err != nil {
return fmt.Errorf("failed to start webhook server: %w", err)
}
// 启动事件处理
go a.handleEvents(a.webhookServer.Events())
a.logger.Info("Webhook server started", zap.String("addr", a.config.WebhookListenAddr))
return nil
}
// handleEvents 处理事件
func (a *Adapter) handleEvents(eventChan <-chan []byte) {
for {
select {
case <-a.ctx.Done():
return
case rawEvent, ok := <-eventChan:
if !ok {
a.logger.Info("Event channel closed")
return
}
// 转换事件
event, err := a.eventConverter.Convert(rawEvent)
if err != nil {
a.logger.Error("Failed to convert event", zap.Error(err))
continue
}
// 发布到事件总线
a.eventBus.Publish(event)
}
}
}
// SendAction 发送动作
func (a *Adapter) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) {
// 调用 API
resp, err := a.apiClient.Call(ctx, string(action.GetType()), action.GetParams())
if err != nil {
return nil, fmt.Errorf("failed to call API: %w", err)
}
return resp.Data, nil
}
// ParseMessage 解析消息
func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) {
return a.eventConverter.Convert(raw)
}
// Disconnect 断开连接
func (a *Adapter) Disconnect() error {
a.logger.Info("Disconnecting from Milky protocol server")
a.cancel()
// 关闭各种连接
if a.sseClient != nil {
if err := a.sseClient.Close(); err != nil {
a.logger.Error("Failed to close SSE client", zap.Error(err))
}
}
if a.wsConn != nil {
// WebSocket 连接会在 context 取消时自动关闭
a.logger.Info("WebSocket connection will be closed")
}
if a.webhookServer != nil {
if err := a.webhookServer.Stop(); err != nil {
a.logger.Error("Failed to stop webhook server", zap.Error(err))
}
}
if a.apiClient != nil {
if err := a.apiClient.Close(); err != nil {
a.logger.Error("Failed to close API client", zap.Error(err))
}
}
return nil
}
// GetProtocolName 获取协议名称
func (a *Adapter) GetProtocolName() string {
return "milky"
}
// GetProtocolVersion 获取协议版本
func (a *Adapter) GetProtocolVersion() string {
return "1.0"
}
// GetSelfID 获取机器人自身 ID
func (a *Adapter) GetSelfID() string {
return a.selfID
}
// IsConnected 是否已连接
func (a *Adapter) IsConnected() bool {
switch a.config.EventMode {
case "sse":
return a.sseClient != nil
case "websocket":
return a.wsConn != nil && a.wsConn.Conn != nil
case "webhook":
return a.webhookServer != nil
default:
return false
}
}
// GetStats 获取统计信息
func (a *Adapter) GetStats() map[string]interface{} {
stats := map[string]interface{}{
"protocol": "milky",
"self_id": a.selfID,
"event_mode": a.config.EventMode,
"connected": a.IsConnected(),
}
if a.config.EventMode == "websocket" && a.wsConn != nil {
stats["remote_addr"] = a.wsConn.RemoteAddr
stats["connection_type"] = a.wsConn.Type
}
return stats
}
// CallAPI 直接调用 API提供给 Bot 使用)
func (a *Adapter) CallAPI(ctx context.Context, endpoint string, params map[string]interface{}) (*APIResponse, error) {
return a.apiClient.Call(ctx, endpoint, params)
}
// GetConfig 获取配置
func (a *Adapter) GetConfig() *Config {
return a.config
}
// SetSelfID 设置机器人自身 ID
func (a *Adapter) SetSelfID(selfID string) {
a.selfID = selfID
}
// GetSelfIDInt64 获取机器人自身 IDint64
func (a *Adapter) GetSelfIDInt64() (int64, error) {
return strconv.ParseInt(a.selfID, 10, 64)
}