feat: add rate limiting and improve event handling

- Introduced rate limiting configuration in config.toml with options for enabling, requests per second (RPS), and burst capacity.
- Enhanced event handling in the OneBot11 adapter to ignore messages sent by the bot itself.
- Updated the dispatcher to register rate limit middleware based on configuration settings.
- Refactored WebSocket message handling to support flexible JSON parsing and improved event type detection.
- Removed deprecated echo plugin and associated tests to streamline the codebase.
This commit is contained in:
lafay
2026-01-05 01:00:38 +08:00
parent 44fe05ff62
commit d16261e6bd
11 changed files with 1001 additions and 427 deletions

View File

@@ -1,250 +0,0 @@
package engine
import (
"context"
"sync"
"testing"
"time"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func TestEventBus_PublishSubscribe(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅事件
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
// 发布事件
eventBus.Publish(event)
// 接收事件
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetType() != protocol.EventTypeMessage {
t.Errorf("Expected event type '%s', got '%s'", protocol.EventTypeMessage, receivedEvent.GetType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
}
func TestEventBus_Filter(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event1 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
event2 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "group",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅并过滤:只接收private消息
filter := func(e protocol.Event) bool {
return e.GetDetailType() == "private"
}
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, filter)
// 发布两个事件
eventBus.Publish(event1)
eventBus.Publish(event2)
// 应该只收到private消息
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetDetailType() != "private" {
t.Errorf("Expected detail type 'private', got '%s'", receivedEvent.GetDetailType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
// 不应该再收到第二个事件
select {
case <-eventChan:
t.Error("Should not receive group message")
case <-time.After(50 * time.Millisecond):
// 正确,不应该收到
}
}
func TestEventBus_Concurrent(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 10000)
eventBus.Start()
defer eventBus.Stop()
numSubscribers := 10
numPublishers := 10
numEvents := 100
// 创建多个订阅者
subscribers := make([]chan protocol.Event, numSubscribers)
for i := 0; i < numSubscribers; i++ {
subscribers[i] = eventBus.Subscribe(protocol.EventTypeMessage, nil)
}
var wg sync.WaitGroup
wg.Add(numPublishers)
// 多个发布者并发发布事件
for i := 0; i < numPublishers; i++ {
go func() {
defer wg.Done()
for j := 0; j < numEvents; j++ {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
}()
}
wg.Wait()
// 验证每个订阅者都收到了事件
for _, ch := range subscribers {
count := 0
for count < numPublishers*numEvents {
select {
case <-ch:
count++
case <-time.After(500 * time.Millisecond):
t.Errorf("Timeout waiting for events, received %d, expected %d", count, numPublishers*numEvents)
break
}
}
}
}
func TestEventBus_Benchmark(b *testing.B) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100000)
eventBus.Start()
defer eventBus.Stop()
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
})
// 消耗channel避免阻塞
go func() {
for range eventChan {
}
}()
}
func TestDispatcher_HandlerPriority(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
dispatcher := NewDispatcher(eventBus, logger)
// 创建测试处理器
handlers := make([]*TestHandler, 3)
priorities := []int{3, 1, 2}
for i, priority := range priorities {
handlers[i] = &TestHandler{
priority: priority,
matched: false,
executed: false,
}
dispatcher.RegisterHandler(handlers[i])
}
// 验证处理器按优先级排序
if dispatcher.GetHandlerCount() != 3 {
t.Errorf("Expected 3 handlers, got %d", dispatcher.GetHandlerCount())
}
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
ctx := context.Background()
// 分发事件
dispatcher.handleEvent(ctx, event)
// 验证处理器是否按优先级执行
executedOrder := make([]int, 0)
for _, handler := range handlers {
if handler.executed {
executedOrder = append(executedOrder, handler.priority)
}
}
// 检查是否按升序执行
for i := 1; i < len(executedOrder); i++ {
if executedOrder[i] < executedOrder[i-1] {
t.Errorf("Handlers not executed in priority order: %v", executedOrder)
}
}
}
// TestHandler 测试处理器
type TestHandler struct {
priority int
matched bool
executed bool
}
func (h *TestHandler) Handle(ctx context.Context, event protocol.Event) error {
h.executed = true
return nil
}
func (h *TestHandler) Priority() int {
return h.priority
}
func (h *TestHandler) Match(event protocol.Event) bool {
h.matched = true
return true
}

421
internal/engine/plugin.go Normal file
View File

@@ -0,0 +1,421 @@
package engine
import (
"context"
"fmt"
"sync"
"sync/atomic"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
// PluginFactory 插件工厂函数类型
// 用于在运行时创建插件实例,支持依赖注入
type PluginFactory func(botManager *protocol.BotManager, logger *zap.Logger) protocol.EventHandler
// globalPluginRegistry 全局插件注册表
var (
globalPluginRegistry = make(map[string]PluginFactory)
globalPluginMu sync.RWMutex
)
// RegisterPlugin 注册插件(供插件在 init 函数中调用)
func RegisterPlugin(name string, factory PluginFactory) {
globalPluginMu.Lock()
defer globalPluginMu.Unlock()
if _, exists := globalPluginRegistry[name]; exists {
panic("plugin already registered: " + name)
}
globalPluginRegistry[name] = factory
}
// GetRegisteredPlugins 获取所有已注册的插件名称
func GetRegisteredPlugins() []string {
globalPluginMu.RLock()
defer globalPluginMu.RUnlock()
names := make([]string, 0, len(globalPluginRegistry))
for name := range globalPluginRegistry {
names = append(names, name)
}
return names
}
// LoadPlugin 加载指定名称的插件
func LoadPlugin(name string, botManager *protocol.BotManager, logger *zap.Logger) (protocol.EventHandler, error) {
globalPluginMu.RLock()
factory, exists := globalPluginRegistry[name]
globalPluginMu.RUnlock()
if !exists {
return nil, nil // 插件不存在,返回 nil 而不是错误
}
return factory(botManager, logger), nil
}
// LoadAllPlugins 加载所有已注册的插件
func LoadAllPlugins(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
globalPluginMu.RLock()
defer globalPluginMu.RUnlock()
plugins := make([]protocol.EventHandler, 0, len(globalPluginRegistry))
for name, factory := range globalPluginRegistry {
plugin := factory(botManager, logger.Named("plugin."+name))
plugins = append(plugins, plugin)
}
return plugins
}
// PluginRegistry 插件注册表
type PluginRegistry struct {
plugins []protocol.EventHandler
dispatcher *Dispatcher
logger *zap.Logger
mu sync.RWMutex
}
// NewPluginRegistry 创建插件注册表
func NewPluginRegistry(dispatcher *Dispatcher, logger *zap.Logger) *PluginRegistry {
return &PluginRegistry{
plugins: make([]protocol.EventHandler, 0),
dispatcher: dispatcher,
logger: logger.Named("plugin-registry"),
}
}
// Register 注册插件
func (r *PluginRegistry) Register(plugin protocol.EventHandler) {
r.mu.Lock()
defer r.mu.Unlock()
r.plugins = append(r.plugins, plugin)
r.dispatcher.RegisterHandler(plugin)
r.logger.Info("Plugin registered",
zap.String("name", plugin.Name()),
zap.String("description", plugin.Description()))
}
// GetPlugins 获取所有插件
func (r *PluginRegistry) GetPlugins() []protocol.EventHandler {
r.mu.RLock()
defer r.mu.RUnlock()
return r.plugins
}
// PluginBuilder 插件构建器
type PluginBuilder struct {
name string
description string
priority int
matchFunc func(protocol.Event) bool
handleFunc func(context.Context, protocol.Event) error
}
// NewPlugin 创建插件构建器
func NewPlugin(name string) *PluginBuilder {
return &PluginBuilder{
name: name,
priority: 100, // 默认优先级
}
}
// Description 设置插件描述
func (b *PluginBuilder) Description(desc string) *PluginBuilder {
b.description = desc
return b
}
// Priority 设置优先级
func (b *PluginBuilder) Priority(priority int) *PluginBuilder {
b.priority = priority
return b
}
// Match 设置匹配函数
func (b *PluginBuilder) Match(fn func(protocol.Event) bool) *PluginBuilder {
b.matchFunc = fn
return b
}
// Handle 设置处理函数
func (b *PluginBuilder) Handle(fn func(context.Context, protocol.Event) error) *PluginBuilder {
b.handleFunc = fn
return b
}
// Build 构建插件
func (b *PluginBuilder) Build() protocol.EventHandler {
return &simplePlugin{
name: b.name,
description: b.description,
priority: b.priority,
matchFunc: b.matchFunc,
handleFunc: b.handleFunc,
}
}
// simplePlugin 简单插件实现
type simplePlugin struct {
name string
description string
priority int
matchFunc func(protocol.Event) bool
handleFunc func(context.Context, protocol.Event) error
}
func (p *simplePlugin) Name() string {
return p.name
}
func (p *simplePlugin) Description() string {
return p.description
}
func (p *simplePlugin) Priority() int {
return p.priority
}
func (p *simplePlugin) Match(event protocol.Event) bool {
if p.matchFunc == nil {
return true
}
return p.matchFunc(event)
}
func (p *simplePlugin) Handle(ctx context.Context, event protocol.Event) error {
if p.handleFunc == nil {
return nil
}
return p.handleFunc(ctx, event)
}
// ============================================================================
// ZeroBot 风格的全局注册 API
// ============================================================================
// globalHandlerRegistry 全局处理器注册表
var (
globalHandlerRegistry = make([]*HandlerBuilder, 0)
globalHandlerMu sync.RWMutex
)
// HandlerFunc 处理函数类型(支持依赖注入)
type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error
// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API
type HandlerBuilder struct {
matchFunc func(protocol.Event) bool
priority int
handleFunc HandlerFunc
}
// OnPrivateMessage 匹配私聊消息
func OnPrivateMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private"
},
priority: 100,
}
}
// OnGroupMessage 匹配群消息
func OnGroupMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "group"
},
priority: 100,
}
}
// OnMessage 匹配所有消息
func OnMessage() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeMessage
},
priority: 100,
}
}
// OnNotice 匹配通知事件
func OnNotice() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeNotice
},
priority: 100,
}
}
// OnRequest 匹配请求事件
func OnRequest() *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
return event.GetType() == protocol.EventTypeRequest
},
priority: 100,
}
}
// On 自定义匹配器
func On(matchFunc func(protocol.Event) bool) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: matchFunc,
priority: 100,
}
}
// OnCommand 匹配命令(以指定前缀开头的消息)
func OnCommand(prefix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
// 检查是否以命令前缀开头
if len(rawMessage) > 0 && len(prefix) > 0 {
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
}
return false
},
priority: 50, // 命令通常优先级较高
}
}
// OnPrefix 匹配以指定前缀开头的消息
func OnPrefix(prefix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
},
priority: 100,
}
}
// OnSuffix 匹配以指定后缀结尾的消息
func OnSuffix(suffix string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
return len(rawMessage) >= len(suffix) && rawMessage[len(rawMessage)-len(suffix):] == suffix
},
priority: 100,
}
}
// OnKeyword 匹配包含指定关键词的消息
func OnKeyword(keyword string) *HandlerBuilder {
return &HandlerBuilder{
matchFunc: func(event protocol.Event) bool {
if event.GetType() != protocol.EventTypeMessage {
return false
}
data := event.GetData()
rawMessage, ok := data["raw_message"].(string)
if !ok {
return false
}
// 简单的字符串包含检查
return len(rawMessage) >= len(keyword) &&
(len(rawMessage) == len(keyword) ||
(rawMessage[:len(keyword)] == keyword ||
rawMessage[len(rawMessage)-len(keyword):] == keyword ||
contains(rawMessage, keyword)))
},
priority: 100,
}
}
// contains 检查字符串是否包含子串(简单实现)
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// Priority 设置优先级
func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder {
b.priority = priority
return b
}
// Handle 注册处理函数(在 init 中调用)
func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
globalHandlerMu.Lock()
defer globalHandlerMu.Unlock()
b.handleFunc = handleFunc
globalHandlerRegistry = append(globalHandlerRegistry, b)
}
// generateHandlerName 生成处理器名称
var handlerCounter int64
func generateHandlerName() string {
counter := atomic.AddInt64(&handlerCounter, 1)
return fmt.Sprintf("handler_%d", counter)
}
// LoadAllHandlers 加载所有已注册的处理器(注入依赖)
func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
globalHandlerMu.RLock()
defer globalHandlerMu.RUnlock()
handlers := make([]protocol.EventHandler, 0, len(globalHandlerRegistry))
for i, builder := range globalHandlerRegistry {
if builder.handleFunc == nil {
continue
}
// 生成唯一的插件名称
pluginName := fmt.Sprintf("handler_%d", i+1)
// 创建包装的处理函数,注入依赖
handleFunc := func(ctx context.Context, event protocol.Event) error {
return builder.handleFunc(ctx, event, botManager, logger)
}
handler := &simplePlugin{
name: pluginName,
description: "Handler registered via OnXXX().Handle()",
priority: builder.priority,
matchFunc: builder.matchFunc,
handleFunc: handleFunc,
}
handlers = append(handlers, handler)
}
return handlers
}