feat: add rate limiting and improve event handling
- Introduced rate limiting configuration in config.toml with options for enabling, requests per second (RPS), and burst capacity. - Enhanced event handling in the OneBot11 adapter to ignore messages sent by the bot itself. - Updated the dispatcher to register rate limit middleware based on configuration settings. - Refactored WebSocket message handling to support flexible JSON parsing and improved event type detection. - Removed deprecated echo plugin and associated tests to streamline the codebase.
This commit is contained in:
@@ -206,6 +206,11 @@ func (a *Adapter) ParseMessage(raw []byte) (protocol.Event, error) {
|
||||
return nil, fmt.Errorf("failed to unmarshal raw event: %w", err)
|
||||
}
|
||||
|
||||
// 忽略机器人自己发送的消息
|
||||
if rawEvent.PostType == "message_sent" {
|
||||
return nil, fmt.Errorf("ignoring message_sent event")
|
||||
}
|
||||
|
||||
return a.convertToEvent(&rawEvent)
|
||||
}
|
||||
|
||||
@@ -419,17 +424,24 @@ func (a *Adapter) handleWebSocketMessages() {
|
||||
zap.Int("size", len(message)),
|
||||
zap.String("preview", string(message[:min(len(message), 200)])))
|
||||
|
||||
// 尝试解析为响应
|
||||
var resp OB11Response
|
||||
if err := sonic.Unmarshal(message, &resp); err == nil {
|
||||
// 如果有echo字段,说明是API响应
|
||||
if resp.Echo != "" {
|
||||
a.logger.Debug("Received API response",
|
||||
zap.String("echo", resp.Echo),
|
||||
zap.String("status", resp.Status),
|
||||
zap.Int("retcode", resp.RetCode))
|
||||
a.wsWaiter.Notify(&resp)
|
||||
continue
|
||||
// 尝试解析为响应(先检查是否有 echo 字段)
|
||||
var tempMap map[string]interface{}
|
||||
if err := sonic.Unmarshal(message, &tempMap); err == nil {
|
||||
if echo, ok := tempMap["echo"].(string); ok && echo != "" {
|
||||
// 有 echo 字段,说明是API响应
|
||||
var resp OB11Response
|
||||
if err := sonic.Unmarshal(message, &resp); err == nil {
|
||||
a.logger.Debug("Received API response",
|
||||
zap.String("echo", resp.Echo),
|
||||
zap.String("status", resp.Status),
|
||||
zap.Int("retcode", resp.RetCode))
|
||||
a.wsWaiter.Notify(&resp)
|
||||
continue
|
||||
} else {
|
||||
a.logger.Warn("Failed to parse API response",
|
||||
zap.Error(err),
|
||||
zap.String("echo", echo))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,9 +453,14 @@ func (a *Adapter) handleWebSocketMessages() {
|
||||
a.logger.Info("Parsing OneBot event...")
|
||||
event, err := a.ParseMessage(message)
|
||||
if err != nil {
|
||||
a.logger.Error("Failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.ByteString("raw_message", message))
|
||||
// 如果是忽略的事件(如 message_sent),只记录 debug 日志
|
||||
if err.Error() == "ignoring message_sent event" {
|
||||
a.logger.Debug("Ignoring message_sent event")
|
||||
} else {
|
||||
a.logger.Error("Failed to parse event",
|
||||
zap.Error(err),
|
||||
zap.ByteString("raw_message", message))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ type Config struct {
|
||||
Log LogConfig `toml:"log"`
|
||||
Protocol ProtocolConfig `toml:"protocol"`
|
||||
Bots []BotConfig `toml:"bots"`
|
||||
Engine EngineConfig `toml:"engine"`
|
||||
}
|
||||
|
||||
// ServerConfig 服务器配置
|
||||
@@ -39,6 +40,18 @@ type ProtocolConfig struct {
|
||||
Options map[string]string `toml:"options"`
|
||||
}
|
||||
|
||||
// EngineConfig 引擎配置
|
||||
type EngineConfig struct {
|
||||
RateLimit RateLimitConfig `toml:"rate_limit"`
|
||||
}
|
||||
|
||||
// RateLimitConfig 限流配置
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `toml:"enabled"` // 是否启用限流
|
||||
RPS int `toml:"rps"` // 每秒请求数
|
||||
Burst int `toml:"burst"` // 突发容量
|
||||
}
|
||||
|
||||
// BotConfig Bot 配置
|
||||
type BotConfig struct {
|
||||
ID string `toml:"id"`
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package di
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cellbot/internal/adapter/milky"
|
||||
"cellbot/internal/adapter/onebot11"
|
||||
"cellbot/internal/config"
|
||||
"cellbot/internal/engine"
|
||||
"cellbot/internal/plugins/echo"
|
||||
_ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数
|
||||
"cellbot/internal/protocol"
|
||||
"cellbot/pkg/net"
|
||||
"context"
|
||||
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/zap"
|
||||
@@ -41,8 +42,27 @@ func ProvideEventBus(logger *zap.Logger) *engine.EventBus {
|
||||
return engine.NewEventBus(logger, 10000)
|
||||
}
|
||||
|
||||
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger) *engine.Dispatcher {
|
||||
return engine.NewDispatcher(eventBus, logger)
|
||||
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config) *engine.Dispatcher {
|
||||
dispatcher := engine.NewDispatcher(eventBus, logger)
|
||||
|
||||
// 注册限流中间件
|
||||
if cfg.Engine.RateLimit.Enabled {
|
||||
rateLimitMiddleware := engine.NewRateLimitMiddleware(
|
||||
logger,
|
||||
cfg.Engine.RateLimit.RPS,
|
||||
cfg.Engine.RateLimit.Burst,
|
||||
)
|
||||
dispatcher.RegisterMiddleware(rateLimitMiddleware)
|
||||
logger.Info("Rate limit middleware registered",
|
||||
zap.Int("rps", cfg.Engine.RateLimit.RPS),
|
||||
zap.Int("burst", cfg.Engine.RateLimit.Burst))
|
||||
}
|
||||
|
||||
return dispatcher
|
||||
}
|
||||
|
||||
func ProvidePluginRegistry(dispatcher *engine.Dispatcher, logger *zap.Logger) *engine.PluginRegistry {
|
||||
return engine.NewPluginRegistry(dispatcher, logger)
|
||||
}
|
||||
|
||||
func ProvideBotManager(logger *zap.Logger) *protocol.BotManager {
|
||||
@@ -129,10 +149,26 @@ func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.
|
||||
return nil
|
||||
}
|
||||
|
||||
func ProvideEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager, dispatcher *engine.Dispatcher) {
|
||||
echoPlugin := echo.NewEchoPlugin(logger, botManager)
|
||||
dispatcher.RegisterHandler(echoPlugin)
|
||||
logger.Info("Echo plugin registered")
|
||||
// LoadPlugins 加载所有插件
|
||||
func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *engine.PluginRegistry) {
|
||||
// 从全局注册表加载所有插件(通过 RegisterPlugin 注册的)
|
||||
plugins := engine.LoadAllPlugins(botManager, logger)
|
||||
for _, plugin := range plugins {
|
||||
registry.Register(plugin)
|
||||
}
|
||||
|
||||
// 加载所有通过 OnXXX().Handle() 注册的处理器(注入依赖)
|
||||
handlers := engine.LoadAllHandlers(botManager, logger)
|
||||
for _, handler := range handlers {
|
||||
registry.Register(handler)
|
||||
}
|
||||
|
||||
totalCount := len(plugins) + len(handlers)
|
||||
logger.Info("All plugins loaded",
|
||||
zap.Int("plugin_count", len(plugins)),
|
||||
zap.Int("handler_count", len(handlers)),
|
||||
zap.Int("total_count", totalCount),
|
||||
zap.Strings("plugins", engine.GetRegisteredPlugins()))
|
||||
}
|
||||
|
||||
var Providers = fx.Options(
|
||||
@@ -142,11 +178,12 @@ var Providers = fx.Options(
|
||||
ProvideLogger,
|
||||
ProvideEventBus,
|
||||
ProvideDispatcher,
|
||||
ProvidePluginRegistry,
|
||||
ProvideBotManager,
|
||||
ProvideWebSocketManager,
|
||||
ProvideServer,
|
||||
),
|
||||
fx.Invoke(ProvideMilkyBots),
|
||||
fx.Invoke(ProvideOneBot11Bots),
|
||||
fx.Invoke(ProvideEchoPlugin),
|
||||
fx.Invoke(LoadPlugins),
|
||||
)
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cellbot/internal/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestEventBus_PublishSubscribe(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
eventBus := NewEventBus(logger, 100)
|
||||
eventBus.Start()
|
||||
defer eventBus.Stop()
|
||||
|
||||
// 创建测试事件
|
||||
event := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "private",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// 订阅事件
|
||||
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
|
||||
|
||||
// 发布事件
|
||||
eventBus.Publish(event)
|
||||
|
||||
// 接收事件
|
||||
select {
|
||||
case receivedEvent := <-eventChan:
|
||||
if receivedEvent.GetType() != protocol.EventTypeMessage {
|
||||
t.Errorf("Expected event type '%s', got '%s'", protocol.EventTypeMessage, receivedEvent.GetType())
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timeout waiting for event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_Filter(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
eventBus := NewEventBus(logger, 100)
|
||||
eventBus.Start()
|
||||
defer eventBus.Stop()
|
||||
|
||||
// 创建测试事件
|
||||
event1 := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "private",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
event2 := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "group",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// 订阅并过滤:只接收private消息
|
||||
filter := func(e protocol.Event) bool {
|
||||
return e.GetDetailType() == "private"
|
||||
}
|
||||
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, filter)
|
||||
|
||||
// 发布两个事件
|
||||
eventBus.Publish(event1)
|
||||
eventBus.Publish(event2)
|
||||
|
||||
// 应该只收到private消息
|
||||
select {
|
||||
case receivedEvent := <-eventChan:
|
||||
if receivedEvent.GetDetailType() != "private" {
|
||||
t.Errorf("Expected detail type 'private', got '%s'", receivedEvent.GetDetailType())
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Timeout waiting for event")
|
||||
}
|
||||
|
||||
// 不应该再收到第二个事件
|
||||
select {
|
||||
case <-eventChan:
|
||||
t.Error("Should not receive group message")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
// 正确,不应该收到
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_Concurrent(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
eventBus := NewEventBus(logger, 10000)
|
||||
eventBus.Start()
|
||||
defer eventBus.Stop()
|
||||
|
||||
numSubscribers := 10
|
||||
numPublishers := 10
|
||||
numEvents := 100
|
||||
|
||||
// 创建多个订阅者
|
||||
subscribers := make([]chan protocol.Event, numSubscribers)
|
||||
for i := 0; i < numSubscribers; i++ {
|
||||
subscribers[i] = eventBus.Subscribe(protocol.EventTypeMessage, nil)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numPublishers)
|
||||
|
||||
// 多个发布者并发发布事件
|
||||
for i := 0; i < numPublishers; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numEvents; j++ {
|
||||
event := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "private",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
eventBus.Publish(event)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 验证每个订阅者都收到了事件
|
||||
for _, ch := range subscribers {
|
||||
count := 0
|
||||
for count < numPublishers*numEvents {
|
||||
select {
|
||||
case <-ch:
|
||||
count++
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Timeout waiting for events, received %d, expected %d", count, numPublishers*numEvents)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventBus_Benchmark(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
eventBus := NewEventBus(logger, 100000)
|
||||
eventBus.Start()
|
||||
defer eventBus.Stop()
|
||||
|
||||
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
event := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "private",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
eventBus.Publish(event)
|
||||
}
|
||||
})
|
||||
|
||||
// 消耗channel避免阻塞
|
||||
go func() {
|
||||
for range eventChan {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func TestDispatcher_HandlerPriority(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
eventBus := NewEventBus(logger, 100)
|
||||
dispatcher := NewDispatcher(eventBus, logger)
|
||||
|
||||
// 创建测试处理器
|
||||
handlers := make([]*TestHandler, 3)
|
||||
priorities := []int{3, 1, 2}
|
||||
|
||||
for i, priority := range priorities {
|
||||
handlers[i] = &TestHandler{
|
||||
priority: priority,
|
||||
matched: false,
|
||||
executed: false,
|
||||
}
|
||||
dispatcher.RegisterHandler(handlers[i])
|
||||
}
|
||||
|
||||
// 验证处理器按优先级排序
|
||||
if dispatcher.GetHandlerCount() != 3 {
|
||||
t.Errorf("Expected 3 handlers, got %d", dispatcher.GetHandlerCount())
|
||||
}
|
||||
|
||||
event := &protocol.BaseEvent{
|
||||
Type: protocol.EventTypeMessage,
|
||||
DetailType: "private",
|
||||
Timestamp: time.Now().Unix(),
|
||||
SelfID: "test_bot",
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// 分发事件
|
||||
dispatcher.handleEvent(ctx, event)
|
||||
|
||||
// 验证处理器是否按优先级执行
|
||||
executedOrder := make([]int, 0)
|
||||
for _, handler := range handlers {
|
||||
if handler.executed {
|
||||
executedOrder = append(executedOrder, handler.priority)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否按升序执行
|
||||
for i := 1; i < len(executedOrder); i++ {
|
||||
if executedOrder[i] < executedOrder[i-1] {
|
||||
t.Errorf("Handlers not executed in priority order: %v", executedOrder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandler 测试处理器
|
||||
type TestHandler struct {
|
||||
priority int
|
||||
matched bool
|
||||
executed bool
|
||||
}
|
||||
|
||||
func (h *TestHandler) Handle(ctx context.Context, event protocol.Event) error {
|
||||
h.executed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TestHandler) Priority() int {
|
||||
return h.priority
|
||||
}
|
||||
|
||||
func (h *TestHandler) Match(event protocol.Event) bool {
|
||||
h.matched = true
|
||||
return true
|
||||
}
|
||||
421
internal/engine/plugin.go
Normal file
421
internal/engine/plugin.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"cellbot/internal/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PluginFactory 插件工厂函数类型
|
||||
// 用于在运行时创建插件实例,支持依赖注入
|
||||
type PluginFactory func(botManager *protocol.BotManager, logger *zap.Logger) protocol.EventHandler
|
||||
|
||||
// globalPluginRegistry 全局插件注册表
|
||||
var (
|
||||
globalPluginRegistry = make(map[string]PluginFactory)
|
||||
globalPluginMu sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterPlugin 注册插件(供插件在 init 函数中调用)
|
||||
func RegisterPlugin(name string, factory PluginFactory) {
|
||||
globalPluginMu.Lock()
|
||||
defer globalPluginMu.Unlock()
|
||||
|
||||
if _, exists := globalPluginRegistry[name]; exists {
|
||||
panic("plugin already registered: " + name)
|
||||
}
|
||||
|
||||
globalPluginRegistry[name] = factory
|
||||
}
|
||||
|
||||
// GetRegisteredPlugins 获取所有已注册的插件名称
|
||||
func GetRegisteredPlugins() []string {
|
||||
globalPluginMu.RLock()
|
||||
defer globalPluginMu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(globalPluginRegistry))
|
||||
for name := range globalPluginRegistry {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// LoadPlugin 加载指定名称的插件
|
||||
func LoadPlugin(name string, botManager *protocol.BotManager, logger *zap.Logger) (protocol.EventHandler, error) {
|
||||
globalPluginMu.RLock()
|
||||
factory, exists := globalPluginRegistry[name]
|
||||
globalPluginMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, nil // 插件不存在,返回 nil 而不是错误
|
||||
}
|
||||
|
||||
return factory(botManager, logger), nil
|
||||
}
|
||||
|
||||
// LoadAllPlugins 加载所有已注册的插件
|
||||
func LoadAllPlugins(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
|
||||
globalPluginMu.RLock()
|
||||
defer globalPluginMu.RUnlock()
|
||||
|
||||
plugins := make([]protocol.EventHandler, 0, len(globalPluginRegistry))
|
||||
for name, factory := range globalPluginRegistry {
|
||||
plugin := factory(botManager, logger.Named("plugin."+name))
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
return plugins
|
||||
}
|
||||
|
||||
// PluginRegistry 插件注册表
|
||||
type PluginRegistry struct {
|
||||
plugins []protocol.EventHandler
|
||||
dispatcher *Dispatcher
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPluginRegistry 创建插件注册表
|
||||
func NewPluginRegistry(dispatcher *Dispatcher, logger *zap.Logger) *PluginRegistry {
|
||||
return &PluginRegistry{
|
||||
plugins: make([]protocol.EventHandler, 0),
|
||||
dispatcher: dispatcher,
|
||||
logger: logger.Named("plugin-registry"),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册插件
|
||||
func (r *PluginRegistry) Register(plugin protocol.EventHandler) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.plugins = append(r.plugins, plugin)
|
||||
r.dispatcher.RegisterHandler(plugin)
|
||||
r.logger.Info("Plugin registered",
|
||||
zap.String("name", plugin.Name()),
|
||||
zap.String("description", plugin.Description()))
|
||||
}
|
||||
|
||||
// GetPlugins 获取所有插件
|
||||
func (r *PluginRegistry) GetPlugins() []protocol.EventHandler {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.plugins
|
||||
}
|
||||
|
||||
// PluginBuilder 插件构建器
|
||||
type PluginBuilder struct {
|
||||
name string
|
||||
description string
|
||||
priority int
|
||||
matchFunc func(protocol.Event) bool
|
||||
handleFunc func(context.Context, protocol.Event) error
|
||||
}
|
||||
|
||||
// NewPlugin 创建插件构建器
|
||||
func NewPlugin(name string) *PluginBuilder {
|
||||
return &PluginBuilder{
|
||||
name: name,
|
||||
priority: 100, // 默认优先级
|
||||
}
|
||||
}
|
||||
|
||||
// Description 设置插件描述
|
||||
func (b *PluginBuilder) Description(desc string) *PluginBuilder {
|
||||
b.description = desc
|
||||
return b
|
||||
}
|
||||
|
||||
// Priority 设置优先级
|
||||
func (b *PluginBuilder) Priority(priority int) *PluginBuilder {
|
||||
b.priority = priority
|
||||
return b
|
||||
}
|
||||
|
||||
// Match 设置匹配函数
|
||||
func (b *PluginBuilder) Match(fn func(protocol.Event) bool) *PluginBuilder {
|
||||
b.matchFunc = fn
|
||||
return b
|
||||
}
|
||||
|
||||
// Handle 设置处理函数
|
||||
func (b *PluginBuilder) Handle(fn func(context.Context, protocol.Event) error) *PluginBuilder {
|
||||
b.handleFunc = fn
|
||||
return b
|
||||
}
|
||||
|
||||
// Build 构建插件
|
||||
func (b *PluginBuilder) Build() protocol.EventHandler {
|
||||
return &simplePlugin{
|
||||
name: b.name,
|
||||
description: b.description,
|
||||
priority: b.priority,
|
||||
matchFunc: b.matchFunc,
|
||||
handleFunc: b.handleFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// simplePlugin 简单插件实现
|
||||
type simplePlugin struct {
|
||||
name string
|
||||
description string
|
||||
priority int
|
||||
matchFunc func(protocol.Event) bool
|
||||
handleFunc func(context.Context, protocol.Event) error
|
||||
}
|
||||
|
||||
func (p *simplePlugin) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *simplePlugin) Description() string {
|
||||
return p.description
|
||||
}
|
||||
|
||||
func (p *simplePlugin) Priority() int {
|
||||
return p.priority
|
||||
}
|
||||
|
||||
func (p *simplePlugin) Match(event protocol.Event) bool {
|
||||
if p.matchFunc == nil {
|
||||
return true
|
||||
}
|
||||
return p.matchFunc(event)
|
||||
}
|
||||
|
||||
func (p *simplePlugin) Handle(ctx context.Context, event protocol.Event) error {
|
||||
if p.handleFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return p.handleFunc(ctx, event)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ZeroBot 风格的全局注册 API
|
||||
// ============================================================================
|
||||
|
||||
// globalHandlerRegistry 全局处理器注册表
|
||||
var (
|
||||
globalHandlerRegistry = make([]*HandlerBuilder, 0)
|
||||
globalHandlerMu sync.RWMutex
|
||||
)
|
||||
|
||||
// HandlerFunc 处理函数类型(支持依赖注入)
|
||||
type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error
|
||||
|
||||
// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API)
|
||||
type HandlerBuilder struct {
|
||||
matchFunc func(protocol.Event) bool
|
||||
priority int
|
||||
handleFunc HandlerFunc
|
||||
}
|
||||
|
||||
// OnPrivateMessage 匹配私聊消息
|
||||
func OnPrivateMessage() *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private"
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnGroupMessage 匹配群消息
|
||||
func OnGroupMessage() *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "group"
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnMessage 匹配所有消息
|
||||
func OnMessage() *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeMessage
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnNotice 匹配通知事件
|
||||
func OnNotice() *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeNotice
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest 匹配请求事件
|
||||
func OnRequest() *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeRequest
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// On 自定义匹配器
|
||||
func On(matchFunc func(protocol.Event) bool) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: matchFunc,
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnCommand 匹配命令(以指定前缀开头的消息)
|
||||
func OnCommand(prefix string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
return false
|
||||
}
|
||||
data := event.GetData()
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// 检查是否以命令前缀开头
|
||||
if len(rawMessage) > 0 && len(prefix) > 0 {
|
||||
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
|
||||
}
|
||||
return false
|
||||
},
|
||||
priority: 50, // 命令通常优先级较高
|
||||
}
|
||||
}
|
||||
|
||||
// OnPrefix 匹配以指定前缀开头的消息
|
||||
func OnPrefix(prefix string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
return false
|
||||
}
|
||||
data := event.GetData()
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnSuffix 匹配以指定后缀结尾的消息
|
||||
func OnSuffix(suffix string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
return false
|
||||
}
|
||||
data := event.GetData()
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return len(rawMessage) >= len(suffix) && rawMessage[len(rawMessage)-len(suffix):] == suffix
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnKeyword 匹配包含指定关键词的消息
|
||||
func OnKeyword(keyword string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
return false
|
||||
}
|
||||
data := event.GetData()
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// 简单的字符串包含检查
|
||||
return len(rawMessage) >= len(keyword) &&
|
||||
(len(rawMessage) == len(keyword) ||
|
||||
(rawMessage[:len(keyword)] == keyword ||
|
||||
rawMessage[len(rawMessage)-len(keyword):] == keyword ||
|
||||
contains(rawMessage, keyword)))
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子串(简单实现)
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Priority 设置优先级
|
||||
func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder {
|
||||
b.priority = priority
|
||||
return b
|
||||
}
|
||||
|
||||
// Handle 注册处理函数(在 init 中调用)
|
||||
func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
|
||||
globalHandlerMu.Lock()
|
||||
defer globalHandlerMu.Unlock()
|
||||
|
||||
b.handleFunc = handleFunc
|
||||
globalHandlerRegistry = append(globalHandlerRegistry, b)
|
||||
}
|
||||
|
||||
// generateHandlerName 生成处理器名称
|
||||
var handlerCounter int64
|
||||
|
||||
func generateHandlerName() string {
|
||||
counter := atomic.AddInt64(&handlerCounter, 1)
|
||||
return fmt.Sprintf("handler_%d", counter)
|
||||
}
|
||||
|
||||
// LoadAllHandlers 加载所有已注册的处理器(注入依赖)
|
||||
func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler {
|
||||
globalHandlerMu.RLock()
|
||||
defer globalHandlerMu.RUnlock()
|
||||
|
||||
handlers := make([]protocol.EventHandler, 0, len(globalHandlerRegistry))
|
||||
for i, builder := range globalHandlerRegistry {
|
||||
if builder.handleFunc == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 生成唯一的插件名称
|
||||
pluginName := fmt.Sprintf("handler_%d", i+1)
|
||||
|
||||
// 创建包装的处理函数,注入依赖
|
||||
handleFunc := func(ctx context.Context, event protocol.Event) error {
|
||||
return builder.handleFunc(ctx, event, botManager, logger)
|
||||
}
|
||||
|
||||
handler := &simplePlugin{
|
||||
name: pluginName,
|
||||
description: "Handler registered via OnXXX().Handle()",
|
||||
priority: builder.priority,
|
||||
matchFunc: builder.matchFunc,
|
||||
handleFunc: handleFunc,
|
||||
}
|
||||
|
||||
handlers = append(handlers, handler)
|
||||
}
|
||||
|
||||
return handlers
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cellbot/internal/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// EchoPlugin 回声插件
|
||||
type EchoPlugin struct {
|
||||
logger *zap.Logger
|
||||
botManager *protocol.BotManager
|
||||
}
|
||||
|
||||
// NewEchoPlugin 创建回声插件
|
||||
func NewEchoPlugin(logger *zap.Logger, botManager *protocol.BotManager) *EchoPlugin {
|
||||
return &EchoPlugin{
|
||||
logger: logger.Named("echo-plugin"),
|
||||
botManager: botManager,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 处理事件
|
||||
func (p *EchoPlugin) Handle(ctx context.Context, event protocol.Event) error {
|
||||
// 获取事件数据
|
||||
data := event.GetData()
|
||||
|
||||
// 获取消息内容
|
||||
message, ok := data["message"]
|
||||
if !ok {
|
||||
p.logger.Debug("No message field in event")
|
||||
return nil
|
||||
}
|
||||
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
p.logger.Debug("No raw_message field in event")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
userID, ok := data["user_id"]
|
||||
if !ok {
|
||||
p.logger.Debug("No user_id field in event")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.logger.Info("Received private message",
|
||||
zap.Any("user_id", userID),
|
||||
zap.String("message", rawMessage))
|
||||
|
||||
// 获取 self_id 来确定是哪个 bot
|
||||
selfID := event.GetSelfID()
|
||||
|
||||
// 获取对应的 bot 实例
|
||||
bot, ok := p.botManager.Get(selfID)
|
||||
if !ok {
|
||||
// 如果通过 selfID 找不到,尝试获取第一个 bot
|
||||
bots := p.botManager.GetAll()
|
||||
if len(bots) == 0 {
|
||||
p.logger.Error("No bot instance available")
|
||||
return nil
|
||||
}
|
||||
bot = bots[0]
|
||||
p.logger.Debug("Using first available bot",
|
||||
zap.String("bot_id", bot.GetID()))
|
||||
}
|
||||
|
||||
// 构建回复动作
|
||||
action := &protocol.BaseAction{
|
||||
Type: protocol.ActionTypeSendPrivateMessage,
|
||||
Params: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"message": message, // 原封不动返回消息
|
||||
},
|
||||
}
|
||||
|
||||
p.logger.Info("Sending echo reply",
|
||||
zap.Any("user_id", userID),
|
||||
zap.String("reply", rawMessage))
|
||||
|
||||
// 发送消息
|
||||
result, err := bot.SendAction(ctx, action)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to send echo reply",
|
||||
zap.Error(err),
|
||||
zap.Any("user_id", userID))
|
||||
return err
|
||||
}
|
||||
|
||||
p.logger.Info("Echo reply sent successfully",
|
||||
zap.Any("user_id", userID),
|
||||
zap.Any("result", result))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Priority 获取处理器优先级
|
||||
func (p *EchoPlugin) Priority() int {
|
||||
return 100 // 中等优先级
|
||||
}
|
||||
|
||||
// Match 判断是否匹配事件
|
||||
func (p *EchoPlugin) Match(event protocol.Event) bool {
|
||||
// 只处理私聊消息
|
||||
eventType := event.GetType()
|
||||
detailType := event.GetDetailType()
|
||||
|
||||
p.logger.Debug("Echo plugin matching event",
|
||||
zap.String("event_type", string(eventType)),
|
||||
zap.String("detail_type", detailType))
|
||||
|
||||
if eventType != protocol.EventTypeMessage {
|
||||
p.logger.Debug("Event type mismatch", zap.String("expected", string(protocol.EventTypeMessage)))
|
||||
return false
|
||||
}
|
||||
|
||||
if detailType != "private" {
|
||||
p.logger.Debug("Detail type mismatch", zap.String("expected", "private"), zap.String("got", detailType))
|
||||
return false
|
||||
}
|
||||
|
||||
p.logger.Info("Echo plugin matched event!")
|
||||
return true
|
||||
}
|
||||
|
||||
// Name 获取插件名称
|
||||
func (p *EchoPlugin) Name() string {
|
||||
return "Echo"
|
||||
}
|
||||
|
||||
// Description 获取插件描述
|
||||
func (p *EchoPlugin) Description() string {
|
||||
return "回声插件:将私聊消息原封不动返回"
|
||||
}
|
||||
64
internal/plugins/echo/echo_new.go
Normal file
64
internal/plugins/echo/echo_new.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package echo
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cellbot/internal/engine"
|
||||
"cellbot/internal/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 在 init 函数中注册多个处理函数(类似 ZeroBot 风格)
|
||||
|
||||
// 处理私聊消息
|
||||
engine.OnPrivateMessage().
|
||||
Handle(func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error {
|
||||
// 获取消息内容
|
||||
data := event.GetData()
|
||||
message, ok := data["message"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, ok := data["user_id"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 获取 bot 实例
|
||||
selfID := event.GetSelfID()
|
||||
bot, ok := botManager.Get(selfID)
|
||||
if !ok {
|
||||
bots := botManager.GetAll()
|
||||
if len(bots) == 0 {
|
||||
logger.Error("No bot instance available")
|
||||
return nil
|
||||
}
|
||||
bot = bots[0]
|
||||
}
|
||||
|
||||
// 发送回复
|
||||
action := &protocol.BaseAction{
|
||||
Type: protocol.ActionTypeSendPrivateMessage,
|
||||
Params: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := bot.SendAction(ctx, action)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send reply", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("Echo reply sent", zap.Any("user_id", userID))
|
||||
return nil
|
||||
})
|
||||
|
||||
// 可以继续注册更多处理函数
|
||||
// engine.OnGroupMessage().Handle(...)
|
||||
// engine.OnCommand("help").Handle(...)
|
||||
}
|
||||
@@ -68,6 +68,10 @@ type EventHandler interface {
|
||||
Priority() int
|
||||
// Match 判断是否匹配事件
|
||||
Match(event Event) bool
|
||||
// Name 获取处理器名称
|
||||
Name() string
|
||||
// Description 获取处理器描述
|
||||
Description() string
|
||||
}
|
||||
|
||||
// Middleware 中间件接口
|
||||
|
||||
Reference in New Issue
Block a user