diff --git a/configs/config.toml b/configs/config.toml index 5aabc57..6785054 100644 --- a/configs/config.toml +++ b/configs/config.toml @@ -18,6 +18,11 @@ version = "1.0" [protocol.options] # Protocol specific options can be added here +[engine.rate_limit] +enabled = true +rps = 100 # 每秒请求数 +burst = 200 # 突发容量 + # ============================================================================ # Bot 配置 # ============================================================================ diff --git a/docs/plugin_guide.md b/docs/plugin_guide.md new file mode 100644 index 0000000..b346036 --- /dev/null +++ b/docs/plugin_guide.md @@ -0,0 +1,306 @@ +# CellBot 插件开发指南 + +## 快速开始 + +CellBot 提供了类似 ZeroBot 风格的插件注册方式,可以在一个包内注册多个处理函数。 + +### ZeroBot 风格(推荐) + +在 `init` 函数中使用 `OnXXX().Handle()` 注册处理函数,一个包内可以注册多个: + +```go +package echo + +import ( + "context" + "cellbot/internal/engine" + "cellbot/internal/protocol" + "go.uber.org/zap" +) + +func init() { + // 处理私聊消息 + engine.OnPrivateMessage(). + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + data := event.GetData() + message := data["message"] + userID := data["user_id"] + + // 获取 bot 实例 + bot, _ := botManager.Get(event.GetSelfID()) + + // 发送回复 + action := &protocol.BaseAction{ + Type: protocol.ActionTypeSendPrivateMessage, + Params: map[string]interface{}{ + "user_id": userID, + "message": message, + }, + } + + return bot.SendAction(ctx, action) + }) + + // 可以继续注册更多处理函数 + engine.OnGroupMessage(). + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + // 处理群消息 + return nil + }) + + // 处理命令 + engine.OnCommand("/help"). + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + // 处理 /help 命令 + return nil + }) +} +``` + +**注意**:需要在 `internal/di/providers.go` 中导入插件包以触发 `init` 函数: + +```go +import ( + _ "cellbot/internal/plugins/echo" // 导入插件以触发 init +) +``` + +### 方式二:传统方式 + +实现 `protocol.EventHandler` 接口: + +```go +package myplugin + +import ( + "context" + "cellbot/internal/protocol" + "go.uber.org/zap" +) + +type MyPlugin struct { + logger *zap.Logger + botManager *protocol.BotManager +} + +func NewMyPlugin(logger *zap.Logger, botManager *protocol.BotManager) *MyPlugin { + return &MyPlugin{ + logger: logger.Named("my-plugin"), + botManager: botManager, + } +} + +func (p *MyPlugin) Name() string { + return "MyPlugin" +} + +func (p *MyPlugin) Description() string { + return "我的插件" +} + +func (p *MyPlugin) Priority() int { + return 100 +} + +func (p *MyPlugin) Match(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeMessage +} + +func (p *MyPlugin) Handle(ctx context.Context, event protocol.Event) error { + // 处理逻辑 + return nil +} +``` + +## 内置匹配器 + +提供了以下便捷的匹配器函数: + +- `engine.OnPrivateMessage()` - 匹配私聊消息 +- `engine.OnGroupMessage()` - 匹配群消息 +- `engine.OnMessage()` - 匹配所有消息 +- `engine.OnNotice()` - 匹配通知事件 +- `engine.OnRequest()` - 匹配请求事件 +- `engine.OnCommand(prefix)` - 匹配命令(以指定前缀开头) +- `engine.OnPrefix(prefix)` - 匹配以指定前缀开头的消息 +- `engine.OnSuffix(suffix)` - 匹配以指定后缀结尾的消息 +- `engine.OnKeyword(keyword)` - 匹配包含指定关键词的消息 +- `engine.On(matchFunc)` - 自定义匹配器 + +### 自定义匹配器 + +```go +func init() { + engine.On(func(event protocol.Event) bool { + // 自定义匹配逻辑 + if event.GetType() != protocol.EventTypeMessage { + return false + } + + data := event.GetData() + message, ok := data["raw_message"].(string) + if !ok { + return false + } + + // 只匹配以 "/" 开头的消息 + return len(message) > 0 && message[0] == '/' + }). + Priority(50). // 可以设置优先级 + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + // 处理命令 + return nil + }) +} +``` + +## 注册插件 + +插件通过 `init` 函数自动注册,只需在 `internal/di/providers.go` 中导入插件包: + +```go +import ( + _ "cellbot/internal/plugins/echo" // 导入插件以触发 init + _ "cellbot/internal/plugins/other" // 可以导入多个插件包 +) +``` + +插件会在应用启动时自动加载,无需手动注册。 + +## 插件优先级 + +优先级数值越小,越先执行。建议: + +- 0-50: 高优先级(预处理、权限检查等) +- 51-100: 中等优先级(普通功能插件) +- 101-200: 低优先级(日志记录、统计等) + +## 完整示例 + +### 示例 1:关键词回复插件 + +```go +package keyword + +import ( + "context" + "cellbot/internal/engine" + "cellbot/internal/protocol" + "go.uber.org/zap" +) + +func init() { + keywords := map[string]string{ + "你好": "你好呀!", + "再见": "再见~", + "帮助": "发送 /help 查看帮助", + } + + engine.OnMessage(). + Priority(80). + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + data := event.GetData() + message, ok := data["raw_message"].(string) + if !ok { + return nil + } + + // 检查关键词 + reply, found := keywords[message] + if !found { + return nil + } + + // 获取 bot 和用户信息 + bot, _ := botManager.Get(event.GetSelfID()) + userID := data["user_id"] + + // 发送回复 + action := &protocol.BaseAction{ + Type: protocol.ActionTypeSendPrivateMessage, + Params: map[string]interface{}{ + "user_id": userID, + "message": reply, + }, + } + + _, err := bot.SendAction(ctx, action) + return err + }) +} +``` + +### 示例 2:命令插件 + +```go +func RegisterCommandPlugin(registry *engine.PluginRegistry, botManager *protocol.BotManager, logger *zap.Logger) { + plugin := engine.NewPlugin("CommandPlugin"). + Description("命令处理插件"). + Priority(50). + Match(func(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + data := event.GetData() + message, ok := data["raw_message"].(string) + return ok && len(message) > 0 && message[0] == '/' + }). + Handle(func(ctx context.Context, event protocol.Event) error { + data := event.GetData() + message := data["raw_message"].(string) + userID := data["user_id"] + + bot, _ := botManager.Get(event.GetSelfID()) + + var reply string + switch message { + case "/help": + reply = "可用命令:\n/help - 显示帮助\n/ping - 测试连接\n/time - 显示时间" + case "/ping": + reply = "pong!" + case "/time": + reply = time.Now().Format("2006-01-02 15:04:05") + default: + reply = "未知命令,发送 /help 查看帮助" + } + + action := &protocol.BaseAction{ + Type: protocol.ActionTypeSendPrivateMessage, + Params: map[string]interface{}{ + "user_id": userID, + "message": reply, + }, + } + + _, err := bot.SendAction(ctx, action) + return err + }). + Build() + + registry.Register(plugin) +} +``` + +## 最佳实践 + +1. **使用简化方式**:对于简单插件,使用 `engine.NewPlugin` 构建器 +2. **使用传统方式**:对于复杂插件(需要状态管理、配置等),使用传统方式 +3. **合理设置优先级**:确保插件按正确顺序执行 +4. **错误处理**:在 Handle 函数中妥善处理错误 +5. **日志记录**:使用 logger 记录关键操作 +6. **避免阻塞**:Handle 函数应快速返回,耗时操作应使用 goroutine + +## 插件生命周期 + +插件在应用启动时注册,在应用运行期间持续监听事件。目前不支持热重载。 + +## 调试技巧 + +1. 使用 `logger.Debug` 记录调试信息 +2. 在 Match 函数中添加日志,确认匹配逻辑 +3. 检查事件数据结构,确保字段存在 +4. 使用 `zap.Any` 打印完整事件数据 + +```go +logger.Debug("Event data", zap.Any("data", event.GetData())) +``` diff --git a/internal/adapter/onebot11/adapter.go b/internal/adapter/onebot11/adapter.go index 4186d9f..d2a95c7 100644 --- a/internal/adapter/onebot11/adapter.go +++ b/internal/adapter/onebot11/adapter.go @@ -206,6 +206,11 @@ func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) { return nil, fmt.Errorf("failed to unmarshal raw event: %w", err) } + // 忽略机器人自己发送的消息 + if rawEvent.PostType == "message_sent" { + return nil, fmt.Errorf("ignoring message_sent event") + } + return a.convertToEvent(&rawEvent) } @@ -419,17 +424,24 @@ func (a *Adapter) handleWebSocketMessages() { zap.Int("size", len(message)), zap.String("preview", string(message[:min(len(message), 200)]))) - // 尝试解析为响应 - var resp OB11Response - if err := sonic.Unmarshal(message, &resp); err == nil { - // 如果有echo字段,说明是API响应 - if resp.Echo != "" { - a.logger.Debug("Received API response", - zap.String("echo", resp.Echo), - zap.String("status", resp.Status), - zap.Int("retcode", resp.RetCode)) - a.wsWaiter.Notify(&resp) - continue + // 尝试解析为响应(先检查是否有 echo 字段) + var tempMap map[string]interface{} + if err := sonic.Unmarshal(message, &tempMap); err == nil { + if echo, ok := tempMap["echo"].(string); ok && echo != "" { + // 有 echo 字段,说明是API响应 + var resp OB11Response + if err := sonic.Unmarshal(message, &resp); err == nil { + a.logger.Debug("Received API response", + zap.String("echo", resp.Echo), + zap.String("status", resp.Status), + zap.Int("retcode", resp.RetCode)) + a.wsWaiter.Notify(&resp) + continue + } else { + a.logger.Warn("Failed to parse API response", + zap.Error(err), + zap.String("echo", echo)) + } } } @@ -441,9 +453,14 @@ func (a *Adapter) handleWebSocketMessages() { a.logger.Info("Parsing OneBot event...") event, err := a.ParseMessage(message) if err != nil { - a.logger.Error("Failed to parse event", - zap.Error(err), - zap.ByteString("raw_message", message)) + // 如果是忽略的事件(如 message_sent),只记录 debug 日志 + if err.Error() == "ignoring message_sent event" { + a.logger.Debug("Ignoring message_sent event") + } else { + a.logger.Error("Failed to parse event", + zap.Error(err), + zap.ByteString("raw_message", message)) + } continue } diff --git a/internal/config/config.go b/internal/config/config.go index ec5e90d..b8b4e33 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,6 +15,7 @@ type Config struct { Log LogConfig `toml:"log"` Protocol ProtocolConfig `toml:"protocol"` Bots []BotConfig `toml:"bots"` + Engine EngineConfig `toml:"engine"` } // ServerConfig 服务器配置 @@ -39,6 +40,18 @@ type ProtocolConfig struct { Options map[string]string `toml:"options"` } +// EngineConfig 引擎配置 +type EngineConfig struct { + RateLimit RateLimitConfig `toml:"rate_limit"` +} + +// RateLimitConfig 限流配置 +type RateLimitConfig struct { + Enabled bool `toml:"enabled"` // 是否启用限流 + RPS int `toml:"rps"` // 每秒请求数 + Burst int `toml:"burst"` // 突发容量 +} + // BotConfig Bot 配置 type BotConfig struct { ID string `toml:"id"` diff --git a/internal/di/providers.go b/internal/di/providers.go index 9f6987a..52ad254 100644 --- a/internal/di/providers.go +++ b/internal/di/providers.go @@ -1,14 +1,15 @@ package di import ( + "context" + "cellbot/internal/adapter/milky" "cellbot/internal/adapter/onebot11" "cellbot/internal/config" "cellbot/internal/engine" - "cellbot/internal/plugins/echo" + _ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数 "cellbot/internal/protocol" "cellbot/pkg/net" - "context" "go.uber.org/fx" "go.uber.org/zap" @@ -41,8 +42,27 @@ func ProvideEventBus(logger *zap.Logger) *engine.EventBus { return engine.NewEventBus(logger, 10000) } -func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger) *engine.Dispatcher { - return engine.NewDispatcher(eventBus, logger) +func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config) *engine.Dispatcher { + dispatcher := engine.NewDispatcher(eventBus, logger) + + // 注册限流中间件 + if cfg.Engine.RateLimit.Enabled { + rateLimitMiddleware := engine.NewRateLimitMiddleware( + logger, + cfg.Engine.RateLimit.RPS, + cfg.Engine.RateLimit.Burst, + ) + dispatcher.RegisterMiddleware(rateLimitMiddleware) + logger.Info("Rate limit middleware registered", + zap.Int("rps", cfg.Engine.RateLimit.RPS), + zap.Int("burst", cfg.Engine.RateLimit.Burst)) + } + + return dispatcher +} + +func ProvidePluginRegistry(dispatcher *engine.Dispatcher, logger *zap.Logger) *engine.PluginRegistry { + return engine.NewPluginRegistry(dispatcher, logger) } func ProvideBotManager(logger *zap.Logger) *protocol.BotManager { @@ -129,10 +149,26 @@ func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net. return nil } -func ProvideEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager, dispatcher *engine.Dispatcher) { - echoPlugin := echo.NewEchoPlugin(logger, botManager) - dispatcher.RegisterHandler(echoPlugin) - logger.Info("Echo plugin registered") +// LoadPlugins 加载所有插件 +func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *engine.PluginRegistry) { + // 从全局注册表加载所有插件(通过 RegisterPlugin 注册的) + plugins := engine.LoadAllPlugins(botManager, logger) + for _, plugin := range plugins { + registry.Register(plugin) + } + + // 加载所有通过 OnXXX().Handle() 注册的处理器(注入依赖) + handlers := engine.LoadAllHandlers(botManager, logger) + for _, handler := range handlers { + registry.Register(handler) + } + + totalCount := len(plugins) + len(handlers) + logger.Info("All plugins loaded", + zap.Int("plugin_count", len(plugins)), + zap.Int("handler_count", len(handlers)), + zap.Int("total_count", totalCount), + zap.Strings("plugins", engine.GetRegisteredPlugins())) } var Providers = fx.Options( @@ -142,11 +178,12 @@ var Providers = fx.Options( ProvideLogger, ProvideEventBus, ProvideDispatcher, + ProvidePluginRegistry, ProvideBotManager, ProvideWebSocketManager, ProvideServer, ), fx.Invoke(ProvideMilkyBots), fx.Invoke(ProvideOneBot11Bots), - fx.Invoke(ProvideEchoPlugin), + fx.Invoke(LoadPlugins), ) diff --git a/internal/engine/eventbus_test.go b/internal/engine/eventbus_test.go deleted file mode 100644 index 1c6659d..0000000 --- a/internal/engine/eventbus_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package engine - -import ( - "context" - "sync" - "testing" - "time" - - "cellbot/internal/protocol" - "go.uber.org/zap" -) - -func TestEventBus_PublishSubscribe(t *testing.T) { - logger := zap.NewNop() - eventBus := NewEventBus(logger, 100) - eventBus.Start() - defer eventBus.Stop() - - // 创建测试事件 - event := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - - // 订阅事件 - eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil) - - // 发布事件 - eventBus.Publish(event) - - // 接收事件 - select { - case receivedEvent := <-eventChan: - if receivedEvent.GetType() != protocol.EventTypeMessage { - t.Errorf("Expected event type '%s', got '%s'", protocol.EventTypeMessage, receivedEvent.GetType()) - } - case <-time.After(100 * time.Millisecond): - t.Error("Timeout waiting for event") - } -} - -func TestEventBus_Filter(t *testing.T) { - logger := zap.NewNop() - eventBus := NewEventBus(logger, 100) - eventBus.Start() - defer eventBus.Stop() - - // 创建测试事件 - event1 := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - - event2 := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "group", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - - // 订阅并过滤:只接收private消息 - filter := func(e protocol.Event) bool { - return e.GetDetailType() == "private" - } - eventChan := eventBus.Subscribe(protocol.EventTypeMessage, filter) - - // 发布两个事件 - eventBus.Publish(event1) - eventBus.Publish(event2) - - // 应该只收到private消息 - select { - case receivedEvent := <-eventChan: - if receivedEvent.GetDetailType() != "private" { - t.Errorf("Expected detail type 'private', got '%s'", receivedEvent.GetDetailType()) - } - case <-time.After(100 * time.Millisecond): - t.Error("Timeout waiting for event") - } - - // 不应该再收到第二个事件 - select { - case <-eventChan: - t.Error("Should not receive group message") - case <-time.After(50 * time.Millisecond): - // 正确,不应该收到 - } -} - -func TestEventBus_Concurrent(t *testing.T) { - logger := zap.NewNop() - eventBus := NewEventBus(logger, 10000) - eventBus.Start() - defer eventBus.Stop() - - numSubscribers := 10 - numPublishers := 10 - numEvents := 100 - - // 创建多个订阅者 - subscribers := make([]chan protocol.Event, numSubscribers) - for i := 0; i < numSubscribers; i++ { - subscribers[i] = eventBus.Subscribe(protocol.EventTypeMessage, nil) - } - - var wg sync.WaitGroup - wg.Add(numPublishers) - - // 多个发布者并发发布事件 - for i := 0; i < numPublishers; i++ { - go func() { - defer wg.Done() - for j := 0; j < numEvents; j++ { - event := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - eventBus.Publish(event) - } - }() - } - - wg.Wait() - - // 验证每个订阅者都收到了事件 - for _, ch := range subscribers { - count := 0 - for count < numPublishers*numEvents { - select { - case <-ch: - count++ - case <-time.After(500 * time.Millisecond): - t.Errorf("Timeout waiting for events, received %d, expected %d", count, numPublishers*numEvents) - break - } - } - } -} - -func TestEventBus_Benchmark(b *testing.B) { - logger := zap.NewNop() - eventBus := NewEventBus(logger, 100000) - eventBus.Start() - defer eventBus.Stop() - - eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil) - - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - event := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - eventBus.Publish(event) - } - }) - - // 消耗channel避免阻塞 - go func() { - for range eventChan { - } - }() -} - -func TestDispatcher_HandlerPriority(t *testing.T) { - logger := zap.NewNop() - eventBus := NewEventBus(logger, 100) - dispatcher := NewDispatcher(eventBus, logger) - - // 创建测试处理器 - handlers := make([]*TestHandler, 3) - priorities := []int{3, 1, 2} - - for i, priority := range priorities { - handlers[i] = &TestHandler{ - priority: priority, - matched: false, - executed: false, - } - dispatcher.RegisterHandler(handlers[i]) - } - - // 验证处理器按优先级排序 - if dispatcher.GetHandlerCount() != 3 { - t.Errorf("Expected 3 handlers, got %d", dispatcher.GetHandlerCount()) - } - - event := &protocol.BaseEvent{ - Type: protocol.EventTypeMessage, - DetailType: "private", - Timestamp: time.Now().Unix(), - SelfID: "test_bot", - Data: make(map[string]interface{}), - } - - ctx := context.Background() - - // 分发事件 - dispatcher.handleEvent(ctx, event) - - // 验证处理器是否按优先级执行 - executedOrder := make([]int, 0) - for _, handler := range handlers { - if handler.executed { - executedOrder = append(executedOrder, handler.priority) - } - } - - // 检查是否按升序执行 - for i := 1; i < len(executedOrder); i++ { - if executedOrder[i] < executedOrder[i-1] { - t.Errorf("Handlers not executed in priority order: %v", executedOrder) - } - } -} - -// TestHandler 测试处理器 -type TestHandler struct { - priority int - matched bool - executed bool -} - -func (h *TestHandler) Handle(ctx context.Context, event protocol.Event) error { - h.executed = true - return nil -} - -func (h *TestHandler) Priority() int { - return h.priority -} - -func (h *TestHandler) Match(event protocol.Event) bool { - h.matched = true - return true -} diff --git a/internal/engine/plugin.go b/internal/engine/plugin.go new file mode 100644 index 0000000..6593ebd --- /dev/null +++ b/internal/engine/plugin.go @@ -0,0 +1,421 @@ +package engine + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "cellbot/internal/protocol" + + "go.uber.org/zap" +) + +// PluginFactory 插件工厂函数类型 +// 用于在运行时创建插件实例,支持依赖注入 +type PluginFactory func(botManager *protocol.BotManager, logger *zap.Logger) protocol.EventHandler + +// globalPluginRegistry 全局插件注册表 +var ( + globalPluginRegistry = make(map[string]PluginFactory) + globalPluginMu sync.RWMutex +) + +// RegisterPlugin 注册插件(供插件在 init 函数中调用) +func RegisterPlugin(name string, factory PluginFactory) { + globalPluginMu.Lock() + defer globalPluginMu.Unlock() + + if _, exists := globalPluginRegistry[name]; exists { + panic("plugin already registered: " + name) + } + + globalPluginRegistry[name] = factory +} + +// GetRegisteredPlugins 获取所有已注册的插件名称 +func GetRegisteredPlugins() []string { + globalPluginMu.RLock() + defer globalPluginMu.RUnlock() + + names := make([]string, 0, len(globalPluginRegistry)) + for name := range globalPluginRegistry { + names = append(names, name) + } + return names +} + +// LoadPlugin 加载指定名称的插件 +func LoadPlugin(name string, botManager *protocol.BotManager, logger *zap.Logger) (protocol.EventHandler, error) { + globalPluginMu.RLock() + factory, exists := globalPluginRegistry[name] + globalPluginMu.RUnlock() + + if !exists { + return nil, nil // 插件不存在,返回 nil 而不是错误 + } + + return factory(botManager, logger), nil +} + +// LoadAllPlugins 加载所有已注册的插件 +func LoadAllPlugins(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler { + globalPluginMu.RLock() + defer globalPluginMu.RUnlock() + + plugins := make([]protocol.EventHandler, 0, len(globalPluginRegistry)) + for name, factory := range globalPluginRegistry { + plugin := factory(botManager, logger.Named("plugin."+name)) + plugins = append(plugins, plugin) + } + + return plugins +} + +// PluginRegistry 插件注册表 +type PluginRegistry struct { + plugins []protocol.EventHandler + dispatcher *Dispatcher + logger *zap.Logger + mu sync.RWMutex +} + +// NewPluginRegistry 创建插件注册表 +func NewPluginRegistry(dispatcher *Dispatcher, logger *zap.Logger) *PluginRegistry { + return &PluginRegistry{ + plugins: make([]protocol.EventHandler, 0), + dispatcher: dispatcher, + logger: logger.Named("plugin-registry"), + } +} + +// Register 注册插件 +func (r *PluginRegistry) Register(plugin protocol.EventHandler) { + r.mu.Lock() + defer r.mu.Unlock() + + r.plugins = append(r.plugins, plugin) + r.dispatcher.RegisterHandler(plugin) + r.logger.Info("Plugin registered", + zap.String("name", plugin.Name()), + zap.String("description", plugin.Description())) +} + +// GetPlugins 获取所有插件 +func (r *PluginRegistry) GetPlugins() []protocol.EventHandler { + r.mu.RLock() + defer r.mu.RUnlock() + return r.plugins +} + +// PluginBuilder 插件构建器 +type PluginBuilder struct { + name string + description string + priority int + matchFunc func(protocol.Event) bool + handleFunc func(context.Context, protocol.Event) error +} + +// NewPlugin 创建插件构建器 +func NewPlugin(name string) *PluginBuilder { + return &PluginBuilder{ + name: name, + priority: 100, // 默认优先级 + } +} + +// Description 设置插件描述 +func (b *PluginBuilder) Description(desc string) *PluginBuilder { + b.description = desc + return b +} + +// Priority 设置优先级 +func (b *PluginBuilder) Priority(priority int) *PluginBuilder { + b.priority = priority + return b +} + +// Match 设置匹配函数 +func (b *PluginBuilder) Match(fn func(protocol.Event) bool) *PluginBuilder { + b.matchFunc = fn + return b +} + +// Handle 设置处理函数 +func (b *PluginBuilder) Handle(fn func(context.Context, protocol.Event) error) *PluginBuilder { + b.handleFunc = fn + return b +} + +// Build 构建插件 +func (b *PluginBuilder) Build() protocol.EventHandler { + return &simplePlugin{ + name: b.name, + description: b.description, + priority: b.priority, + matchFunc: b.matchFunc, + handleFunc: b.handleFunc, + } +} + +// simplePlugin 简单插件实现 +type simplePlugin struct { + name string + description string + priority int + matchFunc func(protocol.Event) bool + handleFunc func(context.Context, protocol.Event) error +} + +func (p *simplePlugin) Name() string { + return p.name +} + +func (p *simplePlugin) Description() string { + return p.description +} + +func (p *simplePlugin) Priority() int { + return p.priority +} + +func (p *simplePlugin) Match(event protocol.Event) bool { + if p.matchFunc == nil { + return true + } + return p.matchFunc(event) +} + +func (p *simplePlugin) Handle(ctx context.Context, event protocol.Event) error { + if p.handleFunc == nil { + return nil + } + return p.handleFunc(ctx, event) +} + +// ============================================================================ +// ZeroBot 风格的全局注册 API +// ============================================================================ + +// globalHandlerRegistry 全局处理器注册表 +var ( + globalHandlerRegistry = make([]*HandlerBuilder, 0) + globalHandlerMu sync.RWMutex +) + +// HandlerFunc 处理函数类型(支持依赖注入) +type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error + +// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API) +type HandlerBuilder struct { + matchFunc func(protocol.Event) bool + priority int + handleFunc HandlerFunc +} + +// OnPrivateMessage 匹配私聊消息 +func OnPrivateMessage() *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private" + }, + priority: 100, + } +} + +// OnGroupMessage 匹配群消息 +func OnGroupMessage() *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "group" + }, + priority: 100, + } +} + +// OnMessage 匹配所有消息 +func OnMessage() *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeMessage + }, + priority: 100, + } +} + +// OnNotice 匹配通知事件 +func OnNotice() *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeNotice + }, + priority: 100, + } +} + +// OnRequest 匹配请求事件 +func OnRequest() *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + return event.GetType() == protocol.EventTypeRequest + }, + priority: 100, + } +} + +// On 自定义匹配器 +func On(matchFunc func(protocol.Event) bool) *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: matchFunc, + priority: 100, + } +} + +// OnCommand 匹配命令(以指定前缀开头的消息) +func OnCommand(prefix string) *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + // 检查是否以命令前缀开头 + if len(rawMessage) > 0 && len(prefix) > 0 { + return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix + } + return false + }, + priority: 50, // 命令通常优先级较高 + } +} + +// OnPrefix 匹配以指定前缀开头的消息 +func OnPrefix(prefix string) *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix + }, + priority: 100, + } +} + +// OnSuffix 匹配以指定后缀结尾的消息 +func OnSuffix(suffix string) *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + return len(rawMessage) >= len(suffix) && rawMessage[len(rawMessage)-len(suffix):] == suffix + }, + priority: 100, + } +} + +// OnKeyword 匹配包含指定关键词的消息 +func OnKeyword(keyword string) *HandlerBuilder { + return &HandlerBuilder{ + matchFunc: func(event protocol.Event) bool { + if event.GetType() != protocol.EventTypeMessage { + return false + } + data := event.GetData() + rawMessage, ok := data["raw_message"].(string) + if !ok { + return false + } + // 简单的字符串包含检查 + return len(rawMessage) >= len(keyword) && + (len(rawMessage) == len(keyword) || + (rawMessage[:len(keyword)] == keyword || + rawMessage[len(rawMessage)-len(keyword):] == keyword || + contains(rawMessage, keyword))) + }, + priority: 100, + } +} + +// contains 检查字符串是否包含子串(简单实现) +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// Priority 设置优先级 +func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder { + b.priority = priority + return b +} + +// Handle 注册处理函数(在 init 中调用) +func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) { + globalHandlerMu.Lock() + defer globalHandlerMu.Unlock() + + b.handleFunc = handleFunc + globalHandlerRegistry = append(globalHandlerRegistry, b) +} + +// generateHandlerName 生成处理器名称 +var handlerCounter int64 + +func generateHandlerName() string { + counter := atomic.AddInt64(&handlerCounter, 1) + return fmt.Sprintf("handler_%d", counter) +} + +// LoadAllHandlers 加载所有已注册的处理器(注入依赖) +func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler { + globalHandlerMu.RLock() + defer globalHandlerMu.RUnlock() + + handlers := make([]protocol.EventHandler, 0, len(globalHandlerRegistry)) + for i, builder := range globalHandlerRegistry { + if builder.handleFunc == nil { + continue + } + + // 生成唯一的插件名称 + pluginName := fmt.Sprintf("handler_%d", i+1) + + // 创建包装的处理函数,注入依赖 + handleFunc := func(ctx context.Context, event protocol.Event) error { + return builder.handleFunc(ctx, event, botManager, logger) + } + + handler := &simplePlugin{ + name: pluginName, + description: "Handler registered via OnXXX().Handle()", + priority: builder.priority, + matchFunc: builder.matchFunc, + handleFunc: handleFunc, + } + + handlers = append(handlers, handler) + } + + return handlers +} diff --git a/internal/plugins/echo/echo.go b/internal/plugins/echo/echo.go deleted file mode 100644 index da6d341..0000000 --- a/internal/plugins/echo/echo.go +++ /dev/null @@ -1,137 +0,0 @@ -package echo - -import ( - "context" - - "cellbot/internal/protocol" - - "go.uber.org/zap" -) - -// EchoPlugin 回声插件 -type EchoPlugin struct { - logger *zap.Logger - botManager *protocol.BotManager -} - -// NewEchoPlugin 创建回声插件 -func NewEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager) *EchoPlugin { - return &EchoPlugin{ - logger: logger.Named("echo-plugin"), - botManager: botManager, - } -} - -// Handle 处理事件 -func (p *EchoPlugin) Handle(ctx context.Context, event protocol.Event) error { - // 获取事件数据 - data := event.GetData() - - // 获取消息内容 - message, ok := data["message"] - if !ok { - p.logger.Debug("No message field in event") - return nil - } - - rawMessage, ok := data["raw_message"].(string) - if !ok { - p.logger.Debug("No raw_message field in event") - return nil - } - - // 获取用户ID - userID, ok := data["user_id"] - if !ok { - p.logger.Debug("No user_id field in event") - return nil - } - - p.logger.Info("Received private message", - zap.Any("user_id", userID), - zap.String("message", rawMessage)) - - // 获取 self_id 来确定是哪个 bot - selfID := event.GetSelfID() - - // 获取对应的 bot 实例 - bot, ok := p.botManager.Get(selfID) - if !ok { - // 如果通过 selfID 找不到,尝试获取第一个 bot - bots := p.botManager.GetAll() - if len(bots) == 0 { - p.logger.Error("No bot instance available") - return nil - } - bot = bots[0] - p.logger.Debug("Using first available bot", - zap.String("bot_id", bot.GetID())) - } - - // 构建回复动作 - action := &protocol.BaseAction{ - Type: protocol.ActionTypeSendPrivateMessage, - Params: map[string]interface{}{ - "user_id": userID, - "message": message, // 原封不动返回消息 - }, - } - - p.logger.Info("Sending echo reply", - zap.Any("user_id", userID), - zap.String("reply", rawMessage)) - - // 发送消息 - result, err := bot.SendAction(ctx, action) - if err != nil { - p.logger.Error("Failed to send echo reply", - zap.Error(err), - zap.Any("user_id", userID)) - return err - } - - p.logger.Info("Echo reply sent successfully", - zap.Any("user_id", userID), - zap.Any("result", result)) - - return nil -} - -// Priority 获取处理器优先级 -func (p *EchoPlugin) Priority() int { - return 100 // 中等优先级 -} - -// Match 判断是否匹配事件 -func (p *EchoPlugin) Match(event protocol.Event) bool { - // 只处理私聊消息 - eventType := event.GetType() - detailType := event.GetDetailType() - - p.logger.Debug("Echo plugin matching event", - zap.String("event_type", string(eventType)), - zap.String("detail_type", detailType)) - - if eventType != protocol.EventTypeMessage { - p.logger.Debug("Event type mismatch", zap.String("expected", string(protocol.EventTypeMessage))) - return false - } - - if detailType != "private" { - p.logger.Debug("Detail type mismatch", zap.String("expected", "private"), zap.String("got", detailType)) - return false - } - - p.logger.Info("Echo plugin matched event!") - return true -} - -// Name 获取插件名称 -func (p *EchoPlugin) Name() string { - return "Echo" -} - -// Description 获取插件描述 -func (p *EchoPlugin) Description() string { - return "回声插件:将私聊消息原封不动返回" -} diff --git a/internal/plugins/echo/echo_new.go b/internal/plugins/echo/echo_new.go new file mode 100644 index 0000000..715bd7a --- /dev/null +++ b/internal/plugins/echo/echo_new.go @@ -0,0 +1,64 @@ +package echo + +import ( + "context" + + "cellbot/internal/engine" + "cellbot/internal/protocol" + + "go.uber.org/zap" +) + +func init() { + // 在 init 函数中注册多个处理函数(类似 ZeroBot 风格) + + // 处理私聊消息 + engine.OnPrivateMessage(). + Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error { + // 获取消息内容 + data := event.GetData() + message, ok := data["message"] + if !ok { + return nil + } + + userID, ok := data["user_id"] + if !ok { + return nil + } + + // 获取 bot 实例 + selfID := event.GetSelfID() + bot, ok := botManager.Get(selfID) + if !ok { + bots := botManager.GetAll() + if len(bots) == 0 { + logger.Error("No bot instance available") + return nil + } + bot = bots[0] + } + + // 发送回复 + action := &protocol.BaseAction{ + Type: protocol.ActionTypeSendPrivateMessage, + Params: map[string]interface{}{ + "user_id": userID, + "message": message, + }, + } + + _, err := bot.SendAction(ctx, action) + if err != nil { + logger.Error("Failed to send reply", zap.Error(err)) + return err + } + + logger.Info("Echo reply sent", zap.Any("user_id", userID)) + return nil + }) + + // 可以继续注册更多处理函数 + // engine.OnGroupMessage().Handle(...) + // engine.OnCommand("help").Handle(...) +} diff --git a/internal/protocol/interface.go b/internal/protocol/interface.go index 7843a44..a965b9a 100644 --- a/internal/protocol/interface.go +++ b/internal/protocol/interface.go @@ -68,6 +68,10 @@ type EventHandler interface { Priority() int // Match 判断是否匹配事件 Match(event Event) bool + // Name 获取处理器名称 + Name() string + // Description 获取处理器描述 + Description() string } // Middleware 中间件接口 diff --git a/pkg/net/websocket.go b/pkg/net/websocket.go index 8217b1e..7e5960b 100644 --- a/pkg/net/websocket.go +++ b/pkg/net/websocket.go @@ -157,32 +157,122 @@ func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) { func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { wsc.Logger.Debug("Received message", zap.ByteString("data", data)) - // 解析JSON消息为BaseEvent - var event protocol.BaseEvent - if err := sonic.Unmarshal(data, &event); err != nil { + // 先解析为 map 以支持灵活的字段类型(如 self_id 可能是数字或字符串) + // 使用 sonic.Config 配置更宽松的解析,允许数字和字符串之间的转换 + cfg := sonic.Config{ + UseInt64: true, // 使用 int64 而不是 float64 来解析数字 + NoValidateJSONSkip: true, // 跳过类型验证,允许更灵活的类型转换 + }.Froze() + + var rawMap map[string]interface{} + if err := cfg.Unmarshal(data, &rawMap); err != nil { wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data)) return } + // 检查是否是 API 响应(有 echo 字段且没有 post_type) + // 如果是响应,不在这里处理,让 adapter 的 handleWebSocketMessages 处理 + if echo, hasEcho := rawMap["echo"].(string); hasEcho && echo != "" { + if _, hasPostType := rawMap["post_type"]; !hasPostType { + // 这是 API 响应,不在这里处理 + // 正向 WebSocket 时,adapter 的 handleWebSocketMessages 会处理 + // 反向 WebSocket 时,响应应该通过 adapter 处理 + wsc.Logger.Debug("Skipping API response in handleMessage, will be handled by adapter", + zap.String("echo", echo)) + return + } + } + + // 构建 BaseEvent + event := &protocol.BaseEvent{ + Data: make(map[string]interface{}), + } + + // 处理 self_id(可能是数字或字符串) + if selfIDVal, ok := rawMap["self_id"]; ok { + switch v := selfIDVal.(type) { + case string: + event.SelfID = v + case float64: + event.SelfID = fmt.Sprintf("%.0f", v) + case int64: + event.SelfID = fmt.Sprintf("%d", v) + case int: + event.SelfID = fmt.Sprintf("%d", v) + default: + event.SelfID = fmt.Sprintf("%v", v) + } + } + + // 如果没有SelfID,使用连接的BotID + if event.SelfID == "" { + event.SelfID = wsc.BotID + } + + // 处理时间戳 + if timeVal, ok := rawMap["time"]; ok { + switch v := timeVal.(type) { + case float64: + event.Timestamp = int64(v) + case int64: + event.Timestamp = v + case int: + event.Timestamp = int64(v) + } + } + if event.Timestamp == 0 { + event.Timestamp = time.Now().Unix() + } + + // 处理类型字段 + if typeVal, ok := rawMap["post_type"]; ok { + if typeStr, ok := typeVal.(string); ok { + // OneBot11 格式:post_type -> EventType 映射 + switch typeStr { + case "message": + event.Type = protocol.EventTypeMessage + case "notice": + event.Type = protocol.EventTypeNotice + case "request": + event.Type = protocol.EventTypeRequest + case "meta_event": + event.Type = protocol.EventTypeMeta + case "message_sent": + // 忽略机器人自己发送的消息 + wsc.Logger.Debug("Ignoring message_sent event") + return + default: + event.Type = protocol.EventType(typeStr) + } + } + } else if typeVal, ok := rawMap["type"]; ok { + if typeStr, ok := typeVal.(string); ok { + event.Type = protocol.EventType(typeStr) + } + } + // 验证必需字段 if event.Type == "" { wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data)) return } - // 如果没有时间戳,使用当前时间 - if event.Timestamp == 0 { - event.Timestamp = time.Now().Unix() + // 处理 detail_type + if detailTypeVal, ok := rawMap["message_type"]; ok { + if detailTypeStr, ok := detailTypeVal.(string); ok { + event.DetailType = detailTypeStr + } + } else if detailTypeVal, ok := rawMap["detail_type"]; ok { + if detailTypeStr, ok := detailTypeVal.(string); ok { + event.DetailType = detailTypeStr + } } - // 如果没有SelfID,使用连接的BotID - if event.SelfID == "" { - event.SelfID = wsc.BotID - } - - // 确保Data字段不为nil - if event.Data == nil { - event.Data = make(map[string]interface{}) + // 将所有其他字段放入 Data + for k, v := range rawMap { + if k != "self_id" && k != "time" && k != "post_type" && k != "type" && k != "message_type" && k != "detail_type" { + event.Data[k] = v + } } wsc.Logger.Info("Event received", @@ -191,7 +281,7 @@ func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.Even zap.String("self_id", event.SelfID)) // 发布到事件总线 - eventBus.Publish(&event) + eventBus.Publish(event) } // SendMessage 发送消息 @@ -339,6 +429,7 @@ type DialConfig struct { BotID string MaxReconnect int HeartbeatTick time.Duration + AutoReadLoop bool // 是否自动启动 readLoop(adapter 自己处理消息时设为 false) } // Dial 建立WebSocket客户端连接(正向连接) @@ -348,6 +439,7 @@ func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnecti BotID: botID, MaxReconnect: 5, HeartbeatTick: 30 * time.Second, + AutoReadLoop: false, // adapter 自己处理消息,不自动启动 readLoop }) } @@ -383,8 +475,10 @@ func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnec zap.String("addr", config.URL), zap.String("remote_addr", wsConn.RemoteAddr)) - // 启动读取循环和心跳 - go wsConn.readLoop(wsm.eventBus) + // 启动读取循环和心跳(如果启用) + if config.AutoReadLoop { + go wsConn.readLoop(wsm.eventBus) + } go wsConn.heartbeatLoop() // 如果是正向连接,启动重连监控