Files
cellbot/internal/adapter/milky/adapter.go
xiaolan f3a72264af chore: update dependencies and refactor webhook handling
- Added new dependencies for SQLite support and improved HTTP client functionality in go.mod and go.sum.
- Refactored webhook server implementation to utilize a simplified version, enhancing code maintainability.
- Updated API client to leverage a generic request method, streamlining API interactions.
- Modified configuration to include access token for webhook server, improving security.
- Enhanced event handling and request processing in the API client for better performance.
2026-01-05 18:42:45 +08:00

341 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package milky
import (
"cellbot/internal/engine"
"cellbot/internal/protocol"
"cellbot/pkg/net"
"context"
"fmt"
"strconv"
"time"
"go.uber.org/zap"
)
// Config Milky 适配器配置
type Config struct {
// 协议端地址(如 http://localhost:3000
ProtocolURL string `toml:"protocol_url"`
// 访问令牌
AccessToken string `toml:"access_token"`
// 事件接收方式: sse, websocket, webhook
EventMode string `toml:"event_mode"`
// Webhook 监听地址(仅当 event_mode = "webhook" 时需要)
WebhookListenAddr string `toml:"webhook_listen_addr"`
// 超时时间(秒)
Timeout int `toml:"timeout"`
// 重试次数
RetryCount int `toml:"retry_count"`
}
// Adapter Milky 协议适配器
type Adapter struct {
config *Config
selfID string
apiClient *APIClient
sseClient *net.SSEClient
wsManager *net.WebSocketManager
wsConn *net.WebSocketConnection
webhookServer *WebhookServer
eventBus *engine.EventBus
eventConverter *EventConverter
logger *zap.Logger
ctx context.Context
cancel context.CancelFunc
}
// NewAdapter 创建 Milky 适配器
func NewAdapter(config *Config, selfID string, eventBus *engine.EventBus, wsManager *net.WebSocketManager, logger *zap.Logger) *Adapter {
ctx, cancel := context.WithCancel(context.Background())
timeout := time.Duration(config.Timeout) * time.Second
if timeout == 0 {
timeout = 30 * time.Second
}
retryCount := config.RetryCount
if retryCount == 0 {
retryCount = 3
}
return &Adapter{
config: config,
selfID: selfID,
apiClient: NewAPIClient(config.ProtocolURL, config.AccessToken, timeout, retryCount, logger),
eventBus: eventBus,
wsManager: wsManager,
eventConverter: NewEventConverter(logger),
logger: logger.Named("milky-adapter"),
ctx: ctx,
cancel: cancel,
}
}
// Connect 连接到协议端
func (a *Adapter) Connect(ctx context.Context) error {
a.logger.Info("Connecting to Milky protocol server",
zap.String("url", a.config.ProtocolURL),
zap.String("event_mode", a.config.EventMode))
// 根据配置选择事件接收方式
switch a.config.EventMode {
case "sse":
return a.connectSSE(ctx)
case "websocket":
return a.connectWebSocket(ctx)
case "webhook":
return a.startWebhook()
default:
return fmt.Errorf("unknown event mode: %s", a.config.EventMode)
}
}
// connectSSE 连接 SSE
func (a *Adapter) connectSSE(ctx context.Context) error {
eventURL := a.config.ProtocolURL + "/event"
// 创建 SSE 客户端配置
sseConfig := net.SSEClientConfig{
URL: eventURL,
AccessToken: a.config.AccessToken,
ReconnectDelay: 5 * time.Second,
MaxReconnect: -1, // 无限重连
EventFilter: "milky_event", // 只接收 milky_event 类型
BufferSize: 100,
}
a.sseClient = net.NewSSEClient(sseConfig, a.logger)
// 启动 SSE 连接
if err := a.sseClient.Connect(ctx); err != nil {
return fmt.Errorf("failed to connect SSE: %w", err)
}
// 启动事件处理
go a.handleEvents(a.sseClient.Events())
a.logger.Info("SSE connection established")
return nil
}
// connectWebSocket 连接 WebSocket
func (a *Adapter) connectWebSocket(ctx context.Context) error {
// 构建 WebSocket URL
eventURL := a.config.ProtocolURL + "/event"
// 替换 http:// 为 ws://https:// 为 wss://
if len(eventURL) > 7 && eventURL[:7] == "http://" {
eventURL = "ws://" + eventURL[7:]
} else if len(eventURL) > 8 && eventURL[:8] == "https://" {
eventURL = "wss://" + eventURL[8:]
}
// 添加 access_token 参数
if a.config.AccessToken != "" {
eventURL += "?access_token=" + a.config.AccessToken
}
a.logger.Info("Connecting to WebSocket", zap.String("url", eventURL))
// 使用 WebSocketManager 建立连接
conn, err := a.wsManager.Dial(eventURL, a.selfID)
if err != nil {
return fmt.Errorf("failed to dial WebSocket: %w", err)
}
a.wsConn = conn
// 启动事件处理
go a.handleWebSocketEvents()
a.logger.Info("WebSocket connection established")
return nil
}
// handleWebSocketEvents 处理 WebSocket 事件
func (a *Adapter) handleWebSocketEvents() {
for {
select {
case <-a.ctx.Done():
return
default:
}
// 读取消息
_, message, err := a.wsConn.Conn.ReadMessage()
if err != nil {
a.logger.Error("Failed to read WebSocket message", zap.Error(err))
return
}
// 转换事件
event, err := a.eventConverter.Convert(message)
if err != nil {
a.logger.Error("Failed to convert event", zap.Error(err))
continue
}
// 发布到事件总线
a.eventBus.Publish(event)
}
}
// startWebhook 启动 Webhook 服务器
func (a *Adapter) startWebhook() error {
if a.config.WebhookListenAddr == "" {
return fmt.Errorf("webhook_listen_addr is required for webhook mode")
}
a.webhookServer = NewWebhookServer(a.config.WebhookListenAddr, a.config.AccessToken, a.logger)
// 启动服务器
if err := a.webhookServer.Start(); err != nil {
return fmt.Errorf("failed to start webhook server: %w", err)
}
// 启动事件处理
go a.handleEvents(a.webhookServer.Events())
a.logger.Info("Webhook server started", zap.String("addr", a.config.WebhookListenAddr))
return nil
}
// handleEvents 处理事件
func (a *Adapter) handleEvents(eventChan <-chan []byte) {
for {
select {
case <-a.ctx.Done():
return
case rawEvent, ok := <-eventChan:
if !ok {
a.logger.Info("Event channel closed")
return
}
// 转换事件
event, err := a.eventConverter.Convert(rawEvent)
if err != nil {
a.logger.Error("Failed to convert event", zap.Error(err))
continue
}
// 发布到事件总线
a.eventBus.Publish(event)
}
}
}
// SendAction 发送动作
func (a *Adapter) SendAction(ctx context.Context, action protocol.Action) (map[string]interface{}, error) {
// 调用 API
resp, err := a.apiClient.Call(ctx, string(action.GetType()), action.GetParams())
if err != nil {
return nil, fmt.Errorf("failed to call API: %w", err)
}
return resp.Data, nil
}
// ParseMessage 解析消息
func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) {
return a.eventConverter.Convert(raw)
}
// Disconnect 断开连接
func (a *Adapter) Disconnect() error {
a.logger.Info("Disconnecting from Milky protocol server")
a.cancel()
// 关闭各种连接
if a.sseClient != nil {
if err := a.sseClient.Close(); err != nil {
a.logger.Error("Failed to close SSE client", zap.Error(err))
}
}
if a.wsConn != nil {
// WebSocket 连接会在 context 取消时自动关闭
a.logger.Info("WebSocket connection will be closed")
}
if a.webhookServer != nil {
if err := a.webhookServer.Stop(); err != nil {
a.logger.Error("Failed to stop webhook server", zap.Error(err))
}
}
if a.apiClient != nil {
if err := a.apiClient.Close(); err != nil {
a.logger.Error("Failed to close API client", zap.Error(err))
}
}
return nil
}
// GetProtocolName 获取协议名称
func (a *Adapter) GetProtocolName() string {
return "milky"
}
// GetProtocolVersion 获取协议版本
func (a *Adapter) GetProtocolVersion() string {
return "1.0"
}
// GetSelfID 获取机器人自身 ID
func (a *Adapter) GetSelfID() string {
return a.selfID
}
// IsConnected 是否已连接
func (a *Adapter) IsConnected() bool {
switch a.config.EventMode {
case "sse":
return a.sseClient != nil
case "websocket":
return a.wsConn != nil && a.wsConn.Conn != nil
case "webhook":
return a.webhookServer != nil
default:
return false
}
}
// GetStats 获取统计信息
func (a *Adapter) GetStats() map[string]interface{} {
stats := map[string]interface{}{
"protocol": "milky",
"self_id": a.selfID,
"event_mode": a.config.EventMode,
"connected": a.IsConnected(),
}
if a.config.EventMode == "websocket" && a.wsConn != nil {
stats["remote_addr"] = a.wsConn.RemoteAddr
stats["connection_type"] = a.wsConn.Type
}
return stats
}
// CallAPI 直接调用 API提供给 Bot 使用)
func (a *Adapter) CallAPI(ctx context.Context, endpoint string, params map[string]interface{}) (*APIResponse, error) {
return a.apiClient.Call(ctx, endpoint, params)
}
// GetConfig 获取配置
func (a *Adapter) GetConfig() *Config {
return a.config
}
// SetSelfID 设置机器人自身 ID
func (a *Adapter) SetSelfID(selfID string) {
a.selfID = selfID
}
// GetSelfIDInt64 获取机器人自身 IDint64
func (a *Adapter) GetSelfIDInt64() (int64, error) {
return strconv.ParseInt(a.selfID, 10, 64)
}