feat: add rate limiting and improve event handling

- Introduced rate limiting configuration in config.toml with options for enabling, requests per second (RPS), and burst capacity.
- Enhanced event handling in the OneBot11 adapter to ignore messages sent by the bot itself.
- Updated the dispatcher to register rate limit middleware based on configuration settings.
- Refactored WebSocket message handling to support flexible JSON parsing and improved event type detection.
- Removed deprecated echo plugin and associated tests to streamline the codebase.
This commit is contained in:
lafay
2026-01-05 01:00:38 +08:00
parent 44fe05ff62
commit d16261e6bd
11 changed files with 1001 additions and 427 deletions

View File

@@ -18,6 +18,11 @@ version = "1.0"
[protocol.options] [protocol.options]
# Protocol specific options can be added here # Protocol specific options can be added here
[engine.rate_limit]
enabled = true
rps = 100 # 每秒请求数
burst = 200 # 突发容量
# ============================================================================ # ============================================================================
# Bot 配置 # Bot 配置
# ============================================================================ # ============================================================================

306
docs/plugin_guide.md Normal file
View File

@@ -0,0 +1,306 @@
# CellBot 插件开发指南
## 快速开始
CellBot 提供了类似 ZeroBot 风格的插件注册方式,可以在一个包内注册多个处理函数。
### ZeroBot 风格(推荐)
`init` 函数中使用 `OnXXX().Handle()` 注册处理函数,一个包内可以注册多个:
```go
package echo
import (
"context"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func init() {
// 处理私聊消息
engine.OnPrivateMessage().
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
data := event.GetData()
message := data["message"]
userID := data["user_id"]
// 获取 bot 实例
bot, _ := botManager.Get(event.GetSelfID())
// 发送回复
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendPrivateMessage,
Params: map[string]interface{}{
"user_id": userID,
"message": message,
},
}
return bot.SendAction(ctx, action)
})
// 可以继续注册更多处理函数
engine.OnGroupMessage().
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
// 处理群消息
return nil
})
// 处理命令
engine.OnCommand("/help").
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
// 处理 /help 命令
return nil
})
}
```
**注意**:需要在 `internal/di/providers.go` 中导入插件包以触发 `init` 函数:
```go
import (
_ "cellbot/internal/plugins/echo" // 导入插件以触发 init
)
```
### 方式二:传统方式
实现 `protocol.EventHandler` 接口:
```go
package myplugin
import (
"context"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
type MyPlugin struct {
logger *zap.Logger
botManager *protocol.BotManager
}
func NewMyPlugin(logger *zap.Logger, botManager *protocol.BotManager) *MyPlugin {
return &MyPlugin{
logger: logger.Named("my-plugin"),
botManager: botManager,
}
}
func (p *MyPlugin) Name() string {
return "MyPlugin"
}
func (p *MyPlugin) Description() string {
return "我的插件"
}
func (p *MyPlugin) Priority() int {
return 100
}
func (p *MyPlugin) Match(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage
}
func (p *MyPlugin) Handle(ctx context.Context, event protocol.Event) error {
// 处理逻辑
return nil
}
```
## 内置匹配器
提供了以下便捷的匹配器函数:
- `engine.OnPrivateMessage()` - 匹配私聊消息
- `engine.OnGroupMessage()` - 匹配群消息
- `engine.OnMessage()` - 匹配所有消息
- `engine.OnNotice()` - 匹配通知事件
- `engine.OnRequest()` - 匹配请求事件
- `engine.OnCommand(prefix)` - 匹配命令(以指定前缀开头)
- `engine.OnPrefix(prefix)` - 匹配以指定前缀开头的消息
- `engine.OnSuffix(suffix)` - 匹配以指定后缀结尾的消息
- `engine.OnKeyword(keyword)` - 匹配包含指定关键词的消息
- `engine.On(matchFunc)` - 自定义匹配器
### 自定义匹配器
```go
func init() {
engine.On(func(event protocol.Event) bool {
// 自定义匹配逻辑
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
message, ok := data["raw_message"].(string)
if !ok {
return false
}
// 只匹配以 "/" 开头的消息
return len(message) > 0 && message[0] == '/'
}).
Priority(50). // 可以设置优先级
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
// 处理命令
return nil
})
}
```
## 注册插件
插件通过 `init` 函数自动注册,只需在 `internal/di/providers.go` 中导入插件包:
```go
import (
_ "cellbot/internal/plugins/echo" // 导入插件以触发 init
_ "cellbot/internal/plugins/other" // 可以导入多个插件包
)
```
插件会在应用启动时自动加载,无需手动注册。
## 插件优先级
优先级数值越小,越先执行。建议:
- 0-50: 高优先级(预处理、权限检查等)
- 51-100: 中等优先级(普通功能插件)
- 101-200: 低优先级(日志记录、统计等)
## 完整示例
### 示例 1关键词回复插件
```go
package keyword
import (
"context"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func init() {
keywords := map[string]string{
"你好": "你好呀!",
"再见": "再见~",
"帮助": "发送 /help 查看帮助",
}
engine.OnMessage().
Priority(80).
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
data := event.GetData()
message, ok := data["raw_message"].(string)
if !ok {
return nil
}
// 检查关键词
reply, found := keywords[message]
if !found {
return nil
}
// 获取 bot 和用户信息
bot, _ := botManager.Get(event.GetSelfID())
userID := data["user_id"]
// 发送回复
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendPrivateMessage,
Params: map[string]interface{}{
"user_id": userID,
"message": reply,
},
}
_, err := bot.SendAction(ctx, action)
return err
})
}
```
### 示例 2命令插件
```go
func RegisterCommandPlugin(registry *engine.PluginRegistry, botManager *protocol.BotManager, logger *zap.Logger) {
plugin := engine.NewPlugin("CommandPlugin").
Description("命令处理插件").
Priority(50).
Match(func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
message, ok := data["raw_message"].(string)
return ok && len(message) > 0 && message[0] == '/'
}).
Handle(func(ctx context.Context, event protocol.Event) error {
data := event.GetData()
message := data["raw_message"].(string)
userID := data["user_id"]
bot, _ := botManager.Get(event.GetSelfID())
var reply string
switch message {
case "/help":
reply = "可用命令:\n/help - 显示帮助\n/ping - 测试连接\n/time - 显示时间"
case "/ping":
reply = "pong!"
case "/time":
reply = time.Now().Format("2006-01-02 15:04:05")
default:
reply = "未知命令,发送 /help 查看帮助"
}
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendPrivateMessage,
Params: map[string]interface{}{
"user_id": userID,
"message": reply,
},
}
_, err := bot.SendAction(ctx, action)
return err
}).
Build()
registry.Register(plugin)
}
```
## 最佳实践
1. **使用简化方式**:对于简单插件,使用 `engine.NewPlugin` 构建器
2. **使用传统方式**:对于复杂插件(需要状态管理、配置等),使用传统方式
3. **合理设置优先级**:确保插件按正确顺序执行
4. **错误处理**:在 Handle 函数中妥善处理错误
5. **日志记录**:使用 logger 记录关键操作
6. **避免阻塞**Handle 函数应快速返回,耗时操作应使用 goroutine
## 插件生命周期
插件在应用启动时注册,在应用运行期间持续监听事件。目前不支持热重载。
## 调试技巧
1. 使用 `logger.Debug` 记录调试信息
2. 在 Match 函数中添加日志,确认匹配逻辑
3. 检查事件数据结构,确保字段存在
4. 使用 `zap.Any` 打印完整事件数据
```go
logger.Debug("Event data", zap.Any("data", event.GetData()))
```

View File

@@ -206,6 +206,11 @@ func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) {
return nil, fmt.Errorf("failed to unmarshal raw event: %w", err) return nil, fmt.Errorf("failed to unmarshal raw event: %w", err)
} }
// 忽略机器人自己发送的消息
if rawEvent.PostType == "message_sent" {
return nil, fmt.Errorf("ignoring message_sent event")
}
return a.convertToEvent(&rawEvent) return a.convertToEvent(&rawEvent)
} }
@@ -419,17 +424,24 @@ func (a *Adapter) handleWebSocketMessages() {
zap.Int("size", len(message)), zap.Int("size", len(message)),
zap.String("preview", string(message[:min(len(message), 200)]))) zap.String("preview", string(message[:min(len(message), 200)])))
// 尝试解析为响应 // 尝试解析为响应(先检查是否有 echo 字段)
var resp OB11Response var tempMap map[string]interface{}
if err := sonic.Unmarshal(message, &resp); err == nil { if err := sonic.Unmarshal(message, &tempMap); err == nil {
// 如果有echo字段说明是API响应 if echo, ok := tempMap["echo"].(string); ok && echo != "" {
if resp.Echo != "" { // 有 echo 字段说明是API响应
a.logger.Debug("Received API response", var resp OB11Response
zap.String("echo", resp.Echo), if err := sonic.Unmarshal(message, &resp); err == nil {
zap.String("status", resp.Status), a.logger.Debug("Received API response",
zap.Int("retcode", resp.RetCode)) zap.String("echo", resp.Echo),
a.wsWaiter.Notify(&resp) zap.String("status", resp.Status),
continue zap.Int("retcode", resp.RetCode))
a.wsWaiter.Notify(&resp)
continue
} else {
a.logger.Warn("Failed to parse API response",
zap.Error(err),
zap.String("echo", echo))
}
} }
} }
@@ -441,9 +453,14 @@ func (a *Adapter) handleWebSocketMessages() {
a.logger.Info("Parsing OneBot event...") a.logger.Info("Parsing OneBot event...")
event, err := a.ParseMessage(message) event, err := a.ParseMessage(message)
if err != nil { if err != nil {
a.logger.Error("Failed to parse event", // 如果是忽略的事件(如 message_sent只记录 debug 日志
zap.Error(err), if err.Error() == "ignoring message_sent event" {
zap.ByteString("raw_message", message)) a.logger.Debug("Ignoring message_sent event")
} else {
a.logger.Error("Failed to parse event",
zap.Error(err),
zap.ByteString("raw_message", message))
}
continue continue
} }

View File

@@ -15,6 +15,7 @@ type Config struct {
Log LogConfig `toml:"log"` Log LogConfig `toml:"log"`
Protocol ProtocolConfig `toml:"protocol"` Protocol ProtocolConfig `toml:"protocol"`
Bots []BotConfig `toml:"bots"` Bots []BotConfig `toml:"bots"`
Engine EngineConfig `toml:"engine"`
} }
// ServerConfig 服务器配置 // ServerConfig 服务器配置
@@ -39,6 +40,18 @@ type ProtocolConfig struct {
Options map[string]string `toml:"options"` Options map[string]string `toml:"options"`
} }
// EngineConfig 引擎配置
type EngineConfig struct {
RateLimit RateLimitConfig `toml:"rate_limit"`
}
// RateLimitConfig 限流配置
type RateLimitConfig struct {
Enabled bool `toml:"enabled"` // 是否启用限流
RPS int `toml:"rps"` // 每秒请求数
Burst int `toml:"burst"` // 突发容量
}
// BotConfig Bot 配置 // BotConfig Bot 配置
type BotConfig struct { type BotConfig struct {
ID string `toml:"id"` ID string `toml:"id"`

View File

@@ -1,14 +1,15 @@
package di package di
import ( import (
"context"
"cellbot/internal/adapter/milky" "cellbot/internal/adapter/milky"
"cellbot/internal/adapter/onebot11" "cellbot/internal/adapter/onebot11"
"cellbot/internal/config" "cellbot/internal/config"
"cellbot/internal/engine" "cellbot/internal/engine"
"cellbot/internal/plugins/echo" _ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数
"cellbot/internal/protocol" "cellbot/internal/protocol"
"cellbot/pkg/net" "cellbot/pkg/net"
"context"
"go.uber.org/fx" "go.uber.org/fx"
"go.uber.org/zap" "go.uber.org/zap"
@@ -41,8 +42,27 @@ func ProvideEventBus(logger *zap.Logger) *engine.EventBus {
return engine.NewEventBus(logger, 10000) return engine.NewEventBus(logger, 10000)
} }
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger) *engine.Dispatcher { func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config) *engine.Dispatcher {
return engine.NewDispatcher(eventBus, logger) dispatcher := engine.NewDispatcher(eventBus, logger)
// 注册限流中间件
if cfg.Engine.RateLimit.Enabled {
rateLimitMiddleware := engine.NewRateLimitMiddleware(
logger,
cfg.Engine.RateLimit.RPS,
cfg.Engine.RateLimit.Burst,
)
dispatcher.RegisterMiddleware(rateLimitMiddleware)
logger.Info("Rate limit middleware registered",
zap.Int("rps", cfg.Engine.RateLimit.RPS),
zap.Int("burst", cfg.Engine.RateLimit.Burst))
}
return dispatcher
}
func ProvidePluginRegistry(dispatcher *engine.Dispatcher, logger *zap.Logger) *engine.PluginRegistry {
return engine.NewPluginRegistry(dispatcher, logger)
} }
func ProvideBotManager(logger *zap.Logger) *protocol.BotManager { func ProvideBotManager(logger *zap.Logger) *protocol.BotManager {
@@ -129,10 +149,26 @@ func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.
return nil return nil
} }
func ProvideEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager, dispatcher *engine.Dispatcher) { // LoadPlugins 加载所有插件
echoPlugin := echo.NewEchoPlugin(logger, botManager) func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *engine.PluginRegistry) {
dispatcher.RegisterHandler(echoPlugin) // 从全局注册表加载所有插件(通过 RegisterPlugin 注册的)
logger.Info("Echo plugin registered") plugins := engine.LoadAllPlugins(botManager, logger)
for _, plugin := range plugins {
registry.Register(plugin)
}
// 加载所有通过 OnXXX().Handle() 注册的处理器(注入依赖)
handlers := engine.LoadAllHandlers(botManager, logger)
for _, handler := range handlers {
registry.Register(handler)
}
totalCount := len(plugins) + len(handlers)
logger.Info("All plugins loaded",
zap.Int("plugin_count", len(plugins)),
zap.Int("handler_count", len(handlers)),
zap.Int("total_count", totalCount),
zap.Strings("plugins", engine.GetRegisteredPlugins()))
} }
var Providers = fx.Options( var Providers = fx.Options(
@@ -142,11 +178,12 @@ var Providers = fx.Options(
ProvideLogger, ProvideLogger,
ProvideEventBus, ProvideEventBus,
ProvideDispatcher, ProvideDispatcher,
ProvidePluginRegistry,
ProvideBotManager, ProvideBotManager,
ProvideWebSocketManager, ProvideWebSocketManager,
ProvideServer, ProvideServer,
), ),
fx.Invoke(ProvideMilkyBots), fx.Invoke(ProvideMilkyBots),
fx.Invoke(ProvideOneBot11Bots), fx.Invoke(ProvideOneBot11Bots),
fx.Invoke(ProvideEchoPlugin), fx.Invoke(LoadPlugins),
) )

View File

@@ -1,250 +0,0 @@
package engine
import (
"context"
"sync"
"testing"
"time"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func TestEventBus_PublishSubscribe(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅事件
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
// 发布事件
eventBus.Publish(event)
// 接收事件
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetType() != protocol.EventTypeMessage {
t.Errorf("Expected event type '%s', got '%s'", protocol.EventTypeMessage, receivedEvent.GetType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
}
func TestEventBus_Filter(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event1 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
event2 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "group",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅并过滤:只接收private消息
filter := func(e protocol.Event) bool {
return e.GetDetailType() == "private"
}
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, filter)
// 发布两个事件
eventBus.Publish(event1)
eventBus.Publish(event2)
// 应该只收到private消息
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetDetailType() != "private" {
t.Errorf("Expected detail type 'private', got '%s'", receivedEvent.GetDetailType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
// 不应该再收到第二个事件
select {
case <-eventChan:
t.Error("Should not receive group message")
case <-time.After(50 * time.Millisecond):
// 正确,不应该收到
}
}
func TestEventBus_Concurrent(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 10000)
eventBus.Start()
defer eventBus.Stop()
numSubscribers := 10
numPublishers := 10
numEvents := 100
// 创建多个订阅者
subscribers := make([]chan protocol.Event, numSubscribers)
for i := 0; i < numSubscribers; i++ {
subscribers[i] = eventBus.Subscribe(protocol.EventTypeMessage, nil)
}
var wg sync.WaitGroup
wg.Add(numPublishers)
// 多个发布者并发发布事件
for i := 0; i < numPublishers; i++ {
go func() {
defer wg.Done()
for j := 0; j < numEvents; j++ {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
}()
}
wg.Wait()
// 验证每个订阅者都收到了事件
for _, ch := range subscribers {
count := 0
for count < numPublishers*numEvents {
select {
case <-ch:
count++
case <-time.After(500 * time.Millisecond):
t.Errorf("Timeout waiting for events, received %d, expected %d", count, numPublishers*numEvents)
break
}
}
}
}
func TestEventBus_Benchmark(b *testing.B) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100000)
eventBus.Start()
defer eventBus.Stop()
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
})
// 消耗channel避免阻塞
go func() {
for range eventChan {
}
}()
}
func TestDispatcher_HandlerPriority(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
dispatcher := NewDispatcher(eventBus, logger)
// 创建测试处理器
handlers := make([]*TestHandler, 3)
priorities := []int{3, 1, 2}
for i, priority := range priorities {
handlers[i] = &TestHandler{
priority: priority,
matched: false,
executed: false,
}
dispatcher.RegisterHandler(handlers[i])
}
// 验证处理器按优先级排序
if dispatcher.GetHandlerCount() != 3 {
t.Errorf("Expected 3 handlers, got %d", dispatcher.GetHandlerCount())
}
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
ctx := context.Background()
// 分发事件
dispatcher.handleEvent(ctx, event)
// 验证处理器是否按优先级执行
executedOrder := make([]int, 0)
for _, handler := range handlers {
if handler.executed {
executedOrder = append(executedOrder, handler.priority)
}
}
// 检查是否按升序执行
for i := 1; i < len(executedOrder); i++ {
if executedOrder[i] < executedOrder[i-1] {
t.Errorf("Handlers not executed in priority order: %v", executedOrder)
}
}
}
// TestHandler 测试处理器
type TestHandler struct {
priority int
matched bool
executed bool
}
func (h *TestHandler) Handle(ctx context.Context, event protocol.Event) error {
h.executed = true
return nil
}
func (h *TestHandler) Priority() int {
return h.priority
}
func (h *TestHandler) Match(event protocol.Event) bool {
h.matched = true
return true
}

421
internal/engine/plugin.go Normal file
View File

@@ -0,0 +1,421 @@
package engine
import (
"context"
"fmt"
"sync"
"sync/atomic"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
// PluginFactory 插件工厂函数类型
// 用于在运行时创建插件实例,支持依赖注入
type PluginFactory func(botManager *protocol.BotManager, logger *zap.Logger) protocol.EventHandler
// globalPluginRegistry 全局插件注册表
var (
globalPluginRegistry = make(map[string]PluginFactory)
globalPluginMu sync.RWMutex
)
// RegisterPlugin 注册插件(供插件在 init 函数中调用)
func RegisterPlugin(name string, factory PluginFactory) {
globalPluginMu.Lock()
defer globalPluginMu.Unlock()
if _, exists := globalPluginRegistry[name]; exists {
panic("plugin already registered: " + name)
}
globalPluginRegistry[name] = factory
}
// GetRegisteredPlugins 获取所有已注册的插件名称
func GetRegisteredPlugins() []string {
globalPluginMu.RLock()
defer globalPluginMu.RUnlock()
names := make([]string, 0, len(globalPluginRegistry))
for name := range globalPluginRegistry {
names = append(names, name)
}
return names
}
// LoadPlugin 加载指定名称的插件
func LoadPlugin(name string, botManager *protocol.BotManager, logger *zap.Logger) (protocol.EventHandler, error) {
globalPluginMu.RLock()
factory, exists := globalPluginRegistry[name]
globalPluginMu.RUnlock()
if !exists {
return nil, nil // 插件不存在,返回 nil 而不是错误
}
return factory(botManager, logger), nil
}
// LoadAllPlugins 加载所有已注册的插件
func LoadAllPlugins(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
globalPluginMu.RLock()
defer globalPluginMu.RUnlock()
plugins := make([]protocol.EventHandler, 0, len(globalPluginRegistry))
for name, factory := range globalPluginRegistry {
plugin := factory(botManager, logger.Named("plugin."+name))
plugins = append(plugins, plugin)
}
return plugins
}
// PluginRegistry 插件注册表
type PluginRegistry struct {
plugins []protocol.EventHandler
dispatcher *Dispatcher
logger *zap.Logger
mu sync.RWMutex
}
// NewPluginRegistry 创建插件注册表
func NewPluginRegistry(dispatcher *Dispatcher, logger *zap.Logger) *PluginRegistry {
return &PluginRegistry{
plugins: make([]protocol.EventHandler, 0),
dispatcher: dispatcher,
logger: logger.Named("plugin-registry"),
}
}
// Register 注册插件
func (r *PluginRegistry) Register(plugin protocol.EventHandler) {
r.mu.Lock()
defer r.mu.Unlock()
r.plugins = append(r.plugins, plugin)
r.dispatcher.RegisterHandler(plugin)
r.logger.Info("Plugin registered",
zap.String("name", plugin.Name()),
zap.String("description", plugin.Description()))
}
// GetPlugins 获取所有插件
func (r *PluginRegistry) GetPlugins() []protocol.EventHandler {
r.mu.RLock()
defer r.mu.RUnlock()
return r.plugins
}
// PluginBuilder 插件构建器
type PluginBuilder struct {
name string
description string
priority int
matchFunc func(protocol.Event) bool
handleFunc func(context.Context, protocol.Event) error
}
// NewPlugin 创建插件构建器
func NewPlugin(name string) *PluginBuilder {
return &PluginBuilder{
name: name,
priority: 100, // 默认优先级
}
}
// Description 设置插件描述
func (b *PluginBuilder) Description(desc string) *PluginBuilder {
b.description = desc
return b
}
// Priority 设置优先级
func (b *PluginBuilder) Priority(priority int) *PluginBuilder {
b.priority = priority
return b
}
// Match 设置匹配函数
func (b *PluginBuilder) Match(fn func(protocol.Event) bool) *PluginBuilder {
b.matchFunc = fn
return b
}
// Handle 设置处理函数
func (b *PluginBuilder) Handle(fn func(context.Context, protocol.Event) error) *PluginBuilder {
b.handleFunc = fn
return b
}
// Build 构建插件
func (b *PluginBuilder) Build() protocol.EventHandler {
return &simplePlugin{
name: b.name,
description: b.description,
priority: b.priority,
matchFunc: b.matchFunc,
handleFunc: b.handleFunc,
}
}
// simplePlugin 简单插件实现
type simplePlugin struct {
name string
description string
priority int
matchFunc func(protocol.Event) bool
handleFunc func(context.Context, protocol.Event) error
}
func (p *simplePlugin) Name() string {
return p.name
}
func (p *simplePlugin) Description() string {
return p.description
}
func (p *simplePlugin) Priority() int {
return p.priority
}
func (p *simplePlugin) Match(event protocol.Event) bool {
if p.matchFunc == nil {
return true
}
return p.matchFunc(event)
}
func (p *simplePlugin) Handle(ctx context.Context, event protocol.Event) error {
if p.handleFunc == nil {
return nil
}
return p.handleFunc(ctx, event)
}
// ============================================================================
// ZeroBot 风格的全局注册 API
// ============================================================================
// globalHandlerRegistry 全局处理器注册表
var (
globalHandlerRegistry = make([]*HandlerBuilder, 0)
globalHandlerMu sync.RWMutex
)
// HandlerFunc 处理函数类型(支持依赖注入)
type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error
// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API
type HandlerBuilder struct {
matchFunc func(protocol.Event) bool
priority int
handleFunc HandlerFunc
}
// OnPrivateMessage 匹配私聊消息
func OnPrivateMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private"
},
priority: 100,
}
}
// OnGroupMessage 匹配群消息
func OnGroupMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "group"
},
priority: 100,
}
}
// OnMessage 匹配所有消息
func OnMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage
},
priority: 100,
}
}
// OnNotice 匹配通知事件
func OnNotice() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeNotice
},
priority: 100,
}
}
// OnRequest 匹配请求事件
func OnRequest() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeRequest
},
priority: 100,
}
}
// On 自定义匹配器
func On(matchFunc func(protocol.Event) bool) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: matchFunc,
priority: 100,
}
}
// OnCommand 匹配命令(以指定前缀开头的消息)
func OnCommand(prefix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
// 检查是否以命令前缀开头
if len(rawMessage) > 0 && len(prefix) > 0 {
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
}
return false
},
priority: 50, // 命令通常优先级较高
}
}
// OnPrefix 匹配以指定前缀开头的消息
func OnPrefix(prefix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
},
priority: 100,
}
}
// OnSuffix 匹配以指定后缀结尾的消息
func OnSuffix(suffix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
return len(rawMessage) >= len(suffix) && rawMessage[len(rawMessage)-len(suffix):] == suffix
},
priority: 100,
}
}
// OnKeyword 匹配包含指定关键词的消息
func OnKeyword(keyword string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
// 简单的字符串包含检查
return len(rawMessage) >= len(keyword) &&
(len(rawMessage) == len(keyword) ||
(rawMessage[:len(keyword)] == keyword ||
rawMessage[len(rawMessage)-len(keyword):] == keyword ||
contains(rawMessage, keyword)))
},
priority: 100,
}
}
// contains 检查字符串是否包含子串(简单实现)
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// Priority 设置优先级
func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder {
b.priority = priority
return b
}
// Handle 注册处理函数(在 init 中调用)
func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
globalHandlerMu.Lock()
defer globalHandlerMu.Unlock()
b.handleFunc = handleFunc
globalHandlerRegistry = append(globalHandlerRegistry, b)
}
// generateHandlerName 生成处理器名称
var handlerCounter int64
func generateHandlerName() string {
counter := atomic.AddInt64(&handlerCounter, 1)
return fmt.Sprintf("handler_%d", counter)
}
// LoadAllHandlers 加载所有已注册的处理器(注入依赖)
func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
globalHandlerMu.RLock()
defer globalHandlerMu.RUnlock()
handlers := make([]protocol.EventHandler, 0, len(globalHandlerRegistry))
for i, builder := range globalHandlerRegistry {
if builder.handleFunc == nil {
continue
}
// 生成唯一的插件名称
pluginName := fmt.Sprintf("handler_%d", i+1)
// 创建包装的处理函数,注入依赖
handleFunc := func(ctx context.Context, event protocol.Event) error {
return builder.handleFunc(ctx, event, botManager, logger)
}
handler := &simplePlugin{
name: pluginName,
description: "Handler registered via OnXXX().Handle()",
priority: builder.priority,
matchFunc: builder.matchFunc,
handleFunc: handleFunc,
}
handlers = append(handlers, handler)
}
return handlers
}

View File

@@ -1,137 +0,0 @@
package echo
import (
"context"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
// EchoPlugin 回声插件
type EchoPlugin struct {
logger *zap.Logger
botManager *protocol.BotManager
}
// NewEchoPlugin 创建回声插件
func NewEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager) *EchoPlugin {
return &EchoPlugin{
logger: logger.Named("echo-plugin"),
botManager: botManager,
}
}
// Handle 处理事件
func (p *EchoPlugin) Handle(ctx context.Context, event protocol.Event) error {
// 获取事件数据
data := event.GetData()
// 获取消息内容
message, ok := data["message"]
if !ok {
p.logger.Debug("No message field in event")
return nil
}
rawMessage, ok := data["raw_message"].(string)
if !ok {
p.logger.Debug("No raw_message field in event")
return nil
}
// 获取用户ID
userID, ok := data["user_id"]
if !ok {
p.logger.Debug("No user_id field in event")
return nil
}
p.logger.Info("Received private message",
zap.Any("user_id", userID),
zap.String("message", rawMessage))
// 获取 self_id 来确定是哪个 bot
selfID := event.GetSelfID()
// 获取对应的 bot 实例
bot, ok := p.botManager.Get(selfID)
if !ok {
// 如果通过 selfID 找不到,尝试获取第一个 bot
bots := p.botManager.GetAll()
if len(bots) == 0 {
p.logger.Error("No bot instance available")
return nil
}
bot = bots[0]
p.logger.Debug("Using first available bot",
zap.String("bot_id", bot.GetID()))
}
// 构建回复动作
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendPrivateMessage,
Params: map[string]interface{}{
"user_id": userID,
"message": message, // 原封不动返回消息
},
}
p.logger.Info("Sending echo reply",
zap.Any("user_id", userID),
zap.String("reply", rawMessage))
// 发送消息
result, err := bot.SendAction(ctx, action)
if err != nil {
p.logger.Error("Failed to send echo reply",
zap.Error(err),
zap.Any("user_id", userID))
return err
}
p.logger.Info("Echo reply sent successfully",
zap.Any("user_id", userID),
zap.Any("result", result))
return nil
}
// Priority 获取处理器优先级
func (p *EchoPlugin) Priority() int {
return 100 // 中等优先级
}
// Match 判断是否匹配事件
func (p *EchoPlugin) Match(event protocol.Event) bool {
// 只处理私聊消息
eventType := event.GetType()
detailType := event.GetDetailType()
p.logger.Debug("Echo plugin matching event",
zap.String("event_type", string(eventType)),
zap.String("detail_type", detailType))
if eventType != protocol.EventTypeMessage {
p.logger.Debug("Event type mismatch", zap.String("expected", string(protocol.EventTypeMessage)))
return false
}
if detailType != "private" {
p.logger.Debug("Detail type mismatch", zap.String("expected", "private"), zap.String("got", detailType))
return false
}
p.logger.Info("Echo plugin matched event!")
return true
}
// Name 获取插件名称
func (p *EchoPlugin) Name() string {
return "Echo"
}
// Description 获取插件描述
func (p *EchoPlugin) Description() string {
return "回声插件:将私聊消息原封不动返回"
}

View File

@@ -0,0 +1,64 @@
package echo
import (
"context"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func init() {
// 在 init 函数中注册多个处理函数(类似 ZeroBot 风格)
// 处理私聊消息
engine.OnPrivateMessage().
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
// 获取消息内容
data := event.GetData()
message, ok := data["message"]
if !ok {
return nil
}
userID, ok := data["user_id"]
if !ok {
return nil
}
// 获取 bot 实例
selfID := event.GetSelfID()
bot, ok := botManager.Get(selfID)
if !ok {
bots := botManager.GetAll()
if len(bots) == 0 {
logger.Error("No bot instance available")
return nil
}
bot = bots[0]
}
// 发送回复
action := &protocol.BaseAction{
Type: protocol.ActionTypeSendPrivateMessage,
Params: map[string]interface{}{
"user_id": userID,
"message": message,
},
}
_, err := bot.SendAction(ctx, action)
if err != nil {
logger.Error("Failed to send reply", zap.Error(err))
return err
}
logger.Info("Echo reply sent", zap.Any("user_id", userID))
return nil
})
// 可以继续注册更多处理函数
// engine.OnGroupMessage().Handle(...)
// engine.OnCommand("help").Handle(...)
}

View File

@@ -68,6 +68,10 @@ type EventHandler interface {
Priority() int Priority() int
// Match 判断是否匹配事件 // Match 判断是否匹配事件
Match(event Event) bool Match(event Event) bool
// Name 获取处理器名称
Name() string
// Description 获取处理器描述
Description() string
} }
// Middleware 中间件接口 // Middleware 中间件接口

View File

@@ -157,32 +157,122 @@ func (wsc *WebSocketConnection) readLoop(eventBus *engine.EventBus) {
func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) { func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.EventBus) {
wsc.Logger.Debug("Received message", zap.ByteString("data", data)) wsc.Logger.Debug("Received message", zap.ByteString("data", data))
// 解析JSON消息为BaseEvent // 解析为 map 以支持灵活的字段类型(如 self_id 可能是数字或字符串)
var event protocol.BaseEvent // 使用 sonic.Config 配置更宽松的解析,允许数字和字符串之间的转换
if err := sonic.Unmarshal(data, &event); err != nil { cfg := sonic.Config{
UseInt64: true, // 使用 int64 而不是 float64 来解析数字
NoValidateJSONSkip: true, // 跳过类型验证,允许更灵活的类型转换
}.Froze()
var rawMap map[string]interface{}
if err := cfg.Unmarshal(data, &rawMap); err != nil {
wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data)) wsc.Logger.Error("Failed to parse message", zap.Error(err), zap.ByteString("data", data))
return return
} }
// 检查是否是 API 响应(有 echo 字段且没有 post_type
// 如果是响应,不在这里处理,让 adapter 的 handleWebSocketMessages 处理
if echo, hasEcho := rawMap["echo"].(string); hasEcho && echo != "" {
if _, hasPostType := rawMap["post_type"]; !hasPostType {
// 这是 API 响应,不在这里处理
// 正向 WebSocket 时adapter 的 handleWebSocketMessages 会处理
// 反向 WebSocket 时,响应应该通过 adapter 处理
wsc.Logger.Debug("Skipping API response in handleMessage, will be handled by adapter",
zap.String("echo", echo))
return
}
}
// 构建 BaseEvent
event := &protocol.BaseEvent{
Data: make(map[string]interface{}),
}
// 处理 self_id可能是数字或字符串
if selfIDVal, ok := rawMap["self_id"]; ok {
switch v := selfIDVal.(type) {
case string:
event.SelfID = v
case float64:
event.SelfID = fmt.Sprintf("%.0f", v)
case int64:
event.SelfID = fmt.Sprintf("%d", v)
case int:
event.SelfID = fmt.Sprintf("%d", v)
default:
event.SelfID = fmt.Sprintf("%v", v)
}
}
// 如果没有SelfID使用连接的BotID
if event.SelfID == "" {
event.SelfID = wsc.BotID
}
// 处理时间戳
if timeVal, ok := rawMap["time"]; ok {
switch v := timeVal.(type) {
case float64:
event.Timestamp = int64(v)
case int64:
event.Timestamp = v
case int:
event.Timestamp = int64(v)
}
}
if event.Timestamp == 0 {
event.Timestamp = time.Now().Unix()
}
// 处理类型字段
if typeVal, ok := rawMap["post_type"]; ok {
if typeStr, ok := typeVal.(string); ok {
// OneBot11 格式post_type -> EventType 映射
switch typeStr {
case "message":
event.Type = protocol.EventTypeMessage
case "notice":
event.Type = protocol.EventTypeNotice
case "request":
event.Type = protocol.EventTypeRequest
case "meta_event":
event.Type = protocol.EventTypeMeta
case "message_sent":
// 忽略机器人自己发送的消息
wsc.Logger.Debug("Ignoring message_sent event")
return
default:
event.Type = protocol.EventType(typeStr)
}
}
} else if typeVal, ok := rawMap["type"]; ok {
if typeStr, ok := typeVal.(string); ok {
event.Type = protocol.EventType(typeStr)
}
}
// 验证必需字段 // 验证必需字段
if event.Type == "" { if event.Type == "" {
wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data)) wsc.Logger.Warn("Event type is empty", zap.ByteString("data", data))
return return
} }
// 如果没有时间戳,使用当前时间 // 处理 detail_type
if event.Timestamp == 0 { if detailTypeVal, ok := rawMap["message_type"]; ok {
event.Timestamp = time.Now().Unix() if detailTypeStr, ok := detailTypeVal.(string); ok {
event.DetailType = detailTypeStr
}
} else if detailTypeVal, ok := rawMap["detail_type"]; ok {
if detailTypeStr, ok := detailTypeVal.(string); ok {
event.DetailType = detailTypeStr
}
} }
// 如果没有SelfID使用连接的BotID // 将所有其他字段放入 Data
if event.SelfID == "" { for k, v := range rawMap {
event.SelfID = wsc.BotID if k != "self_id" && k != "time" && k != "post_type" && k != "type" && k != "message_type" && k != "detail_type" {
} event.Data[k] = v
}
// 确保Data字段不为nil
if event.Data == nil {
event.Data = make(map[string]interface{})
} }
wsc.Logger.Info("Event received", wsc.Logger.Info("Event received",
@@ -191,7 +281,7 @@ func (wsc *WebSocketConnection) handleMessage(data []byte, eventBus *engine.Even
zap.String("self_id", event.SelfID)) zap.String("self_id", event.SelfID))
// 发布到事件总线 // 发布到事件总线
eventBus.Publish(&event) eventBus.Publish(event)
} }
// SendMessage 发送消息 // SendMessage 发送消息
@@ -339,6 +429,7 @@ type DialConfig struct {
BotID string BotID string
MaxReconnect int MaxReconnect int
HeartbeatTick time.Duration HeartbeatTick time.Duration
AutoReadLoop bool // 是否自动启动 readLoopadapter 自己处理消息时设为 false
} }
// Dial 建立WebSocket客户端连接正向连接 // Dial 建立WebSocket客户端连接正向连接
@@ -348,6 +439,7 @@ func (wsm *WebSocketManager) Dial(addr string, botID string) (*WebSocketConnecti
BotID: botID, BotID: botID,
MaxReconnect: 5, MaxReconnect: 5,
HeartbeatTick: 30 * time.Second, HeartbeatTick: 30 * time.Second,
AutoReadLoop: false, // adapter 自己处理消息,不自动启动 readLoop
}) })
} }
@@ -383,8 +475,10 @@ func (wsm *WebSocketManager) DialWithConfig(config DialConfig) (*WebSocketConnec
zap.String("addr", config.URL), zap.String("addr", config.URL),
zap.String("remote_addr", wsConn.RemoteAddr)) zap.String("remote_addr", wsConn.RemoteAddr))
// 启动读取循环和心跳 // 启动读取循环和心跳(如果启用)
go wsConn.readLoop(wsm.eventBus) if config.AutoReadLoop {
go wsConn.readLoop(wsm.eventBus)
}
go wsConn.heartbeatLoop() go wsConn.heartbeatLoop()
// 如果是正向连接,启动重连监控 // 如果是正向连接,启动重连监控