feat: 初始化多机器人服务端项目框架

基于Go语言构建多机器人服务端框架,包含配置管理、事件总线、依赖注入等核心模块
添加项目基础结构、README、gitignore和初始代码实现
This commit is contained in:
2026-01-04 21:19:17 +08:00
commit ac0dfb64c9
22 changed files with 2385 additions and 0 deletions

143
internal/config/config.go Normal file
View File

@@ -0,0 +1,143 @@
package config
import (
"fmt"
"sync"
"github.com/BurntSushi/toml"
"github.com/fsnotify/fsnotify"
"go.uber.org/zap"
)
// Config 应用配置结构
type Config struct {
Server ServerConfig `toml:"server"`
Log LogConfig `toml:"log"`
Protocol ProtocolConfig `toml:"protocol"`
}
// ServerConfig 服务器配置
type ServerConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
}
// LogConfig 日志配置
type LogConfig struct {
Level string `toml:"level"`
Output string `toml:"output"`
MaxSize int `toml:"max_size"`
MaxBackups int `toml:"max_backups"`
MaxAge int `toml:"max_age"`
}
// ProtocolConfig 协议配置
type ProtocolConfig struct {
Name string `toml:"name"`
Version string `toml:"version"`
Options map[string]string `toml:"options"`
}
// ConfigManager 配置管理器
type ConfigManager struct {
configPath string
config *Config
logger *zap.Logger
mu sync.RWMutex
callbacks []func(*Config)
}
// NewConfigManager 创建配置管理器
func NewConfigManager(configPath string, logger *zap.Logger) *ConfigManager {
return &ConfigManager{
configPath: configPath,
logger: logger,
callbacks: make([]func(*Config), 0),
}
}
// Load 加载配置文件
func (cm *ConfigManager) Load() error {
cm.mu.Lock()
defer cm.mu.Unlock()
var cfg Config
if _, err := toml.DecodeFile(cm.configPath, &cfg); err != nil {
return fmt.Errorf("failed to decode config: %w", err)
}
cm.config = &cfg
cm.logger.Info("Config loaded successfully",
zap.String("path", cm.configPath),
zap.String("server", fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)),
)
// 触发回调
for _, cb := range cm.callbacks {
cb(cm.config)
}
return nil
}
// Get 获取当前配置
func (cm *ConfigManager) Get() *Config {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.config
}
// Reload 重新加载配置
func (cm *ConfigManager) Reload() error {
return cm.Load()
}
// RegisterCallback 注册配置变更回调
func (cm *ConfigManager) RegisterCallback(callback func(*Config)) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.callbacks = append(cm.callbacks, callback)
}
// Watch 监听配置文件变化
func (cm *ConfigManager) Watch() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
if err := watcher.Add(cm.configPath); err != nil {
return fmt.Errorf("failed to watch config file: %w", err)
}
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
cm.logger.Info("Config file changed, reloading...",
zap.String("file", event.Name))
if err := cm.Reload(); err != nil {
cm.logger.Error("Failed to reload config",
zap.Error(err))
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
cm.logger.Error("Watcher error", zap.Error(err))
}
}
}()
return nil
}
// Close 关闭配置管理器
func (cm *ConfigManager) Close() error {
return nil
}

View File

@@ -0,0 +1,113 @@
package config
import (
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
)
func TestConfigManager_Load(t *testing.T) {
// 创建临时配置文件
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "test_config.toml")
configContent := `
[server]
host = "127.0.0.1"
port = 8080
[log]
level = "debug"
output = "stdout"
max_size = 100
max_backups = 3
max_age = 7
[protocol]
name = "test"
version = "1.0"
[protocol.options]
key = "value"
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to create config file: %v", err)
}
logger := zap.NewNop()
cm := NewConfigManager(configPath, logger)
err = cm.Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg := cm.Get()
if cfg == nil {
t.Fatal("Config is nil")
}
if cfg.Server.Host != "127.0.0.1" {
t.Errorf("Expected host '127.0.0.1', got '%s'", cfg.Server.Host)
}
if cfg.Server.Port != 8080 {
t.Errorf("Expected port 8080, got %d", cfg.Server.Port)
}
if cfg.Log.Level != "debug" {
t.Errorf("Expected log level 'debug', got '%s'", cfg.Log.Level)
}
if cfg.Protocol.Name != "test" {
t.Errorf("Expected protocol name 'test', got '%s'", cfg.Protocol.Name)
}
}
func TestInitLogger(t *testing.T) {
tests := []struct {
name string
cfg *LogConfig
wantErr bool
}{
{
name: "stdout logger",
cfg: &LogConfig{
Level: "info",
Output: "stdout",
},
wantErr: false,
},
{
name: "stderr logger",
cfg: &LogConfig{
Level: "error",
Output: "stderr",
},
wantErr: false,
},
{
name: "file logger",
cfg: &LogConfig{
Level: "debug",
Output: filepath.Join(t.TempDir(), "test.log"),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger, err := InitLogger(tt.cfg)
if (err != nil) != tt.wantErr {
t.Errorf("InitLogger() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && logger == nil {
t.Error("Expected non-nil logger")
}
})
}
}

69
internal/config/logger.go Normal file
View File

@@ -0,0 +1,69 @@
package config
import (
"fmt"
"os"
"path/filepath"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// InitLogger 初始化日志
func InitLogger(cfg *LogConfig) (*zap.Logger, error) {
encoderConfig := zapcore.EncoderConfig{
TimeKey: "time",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.SecondsDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
var writer zapcore.WriteSyncer
switch cfg.Output {
case "stdout":
writer = zapcore.AddSync(os.Stdout)
case "stderr":
writer = zapcore.AddSync(os.Stderr)
default:
// 创建日志目录
if err := os.MkdirAll(filepath.Dir(cfg.Output), 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
file, err := os.OpenFile(cfg.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open log file: %w", err)
}
writer = zapcore.AddSync(file)
}
// 解析日志级别
var level zapcore.Level
switch cfg.Level {
case "debug":
level = zapcore.DebugLevel
case "info":
level = zapcore.InfoLevel
case "warn":
level = zapcore.WarnLevel
case "error":
level = zapcore.ErrorLevel
default:
level = zapcore.InfoLevel
}
core := zapcore.NewCore(
zapcore.NewJSONEncoder(encoderConfig),
writer,
level,
)
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return logger, nil
}

11
internal/di/app.go Normal file
View File

@@ -0,0 +1,11 @@
package di
import "go.uber.org/fx"
// NewApp 创建应用实例
func NewApp() *fx.App {
return fx.New(
Providers,
Lifecycle,
)
}

77
internal/di/lifecycle.go Normal file
View File

@@ -0,0 +1,77 @@
package di
import (
"context"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"cellbot/pkg/net"
"go.uber.org/fx"
"go.uber.org/zap"
)
// RegisterLifecycleHooks 注册应用生命周期钩子
func RegisterLifecycleHooks(
logger *zap.Logger,
eventBus *engine.EventBus,
dispatcher *engine.Dispatcher,
botManager *protocol.BotManager,
server *net.Server,
) fx.Option {
return fx.Invoke(
func(lc fx.Lifecycle) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("Starting CellBot application...")
// 启动事件总线
eventBus.Start()
// 启动分发器
dispatcher.Start(ctx)
// 启动所有机器人
if err := botManager.StartAll(ctx); err != nil {
logger.Error("Failed to start bots", zap.Error(err))
}
// 启动HTTP服务器
if err := server.Start(); err != nil {
logger.Error("Failed to start server", zap.Error(err))
return err
}
logger.Info("CellBot application started successfully")
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("Stopping CellBot application...")
// 停止HTTP服务器
if err := server.Stop(); err != nil {
logger.Error("Failed to stop server", zap.Error(err))
}
// 停止所有机器人
if err := botManager.StopAll(ctx); err != nil {
logger.Error("Failed to stop bots", zap.Error(err))
}
// 停止分发器
dispatcher.Stop()
// 停止事件总线
eventBus.Stop()
logger.Info("CellBot application stopped successfully")
return nil
},
})
},
)
}
// Lifecycle 生命周期管理选项
var Lifecycle = fx.Options(
fx.Invoke(RegisterLifecycleHooks),
)

70
internal/di/providers.go Normal file
View File

@@ -0,0 +1,70 @@
package di
import (
"cellbot/internal/config"
"cellbot/internal/engine"
"cellbot/internal/protocol"
"cellbot/pkg/net"
"go.uber.org/fx"
"go.uber.org/zap"
)
// ProvideLogger 提供日志实例
func ProvideLogger(cfg *config.Config) (*zap.Logger, error) {
return config.InitLogger(&cfg.Log)
}
// ProvideConfig 提供配置实例
func ProvideConfig() (*config.Config, error) {
configManager := config.NewConfigManager("configs/config.toml", zap.NewNop())
if err := configManager.Load(); err != nil {
return nil, err
}
return configManager.Get(), nil
}
// ProvideConfigManager 提供配置管理器
func ProvideConfigManager(logger *zap.Logger) (*config.ConfigManager, error) {
configManager := config.NewConfigManager("configs/config.toml", logger)
if err := configManager.Load(); err != nil {
return nil, err
}
// 启动配置文件监听
if err := configManager.Watch(); err != nil {
logger.Warn("Failed to watch config file", zap.Error(err))
}
return configManager, nil
}
// ProvideEventBus 提供事件总线
func ProvideEventBus(logger *zap.Logger) *engine.EventBus {
return engine.NewEventBus(logger, 10000)
}
// ProvideDispatcher 提供事件分发器
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger) *engine.Dispatcher {
return engine.NewDispatcher(eventBus, logger)
}
// ProvideBotManager 提供机器人管理器
func ProvideBotManager(logger *zap.Logger) *protocol.BotManager {
return protocol.NewBotManager(logger)
}
// ProvideServer 提供HTTP服务器
func ProvideServer(cfg *config.Config, logger *zap.Logger, botManager *protocol.BotManager, eventBus *engine.EventBus) *net.Server {
return net.NewServer(cfg.Server.Host, cfg.Server.Port, logger, botManager, eventBus)
}
// Providers 依赖注入提供者列表
var Providers = fx.Options(
fx.Provide(
ProvideConfig,
ProvideConfigManager,
ProvideLogger,
ProvideEventBus,
ProvideDispatcher,
ProvideBotManager,
ProvideServer,
),
)

View File

@@ -0,0 +1,162 @@
package engine
import (
"context"
"sort"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
// Dispatcher 事件分发器
// 管理事件处理器并按照优先级分发事件
type Dispatcher struct {
handlers []protocol.EventHandler
middlewares []protocol.Middleware
logger *zap.Logger
eventBus *EventBus
}
// NewDispatcher 创建事件分发器
func NewDispatcher(eventBus *EventBus, logger *zap.Logger) *Dispatcher {
return &Dispatcher{
handlers: make([]protocol.EventHandler, 0),
middlewares: make([]protocol.Middleware, 0),
logger: logger.Named("dispatcher"),
eventBus: eventBus,
}
}
// RegisterHandler 注册事件处理器
func (d *Dispatcher) RegisterHandler(handler protocol.EventHandler) {
d.handlers = append(d.handlers, handler)
// 按优先级排序(数值越小优先级越高)
sort.Slice(d.handlers, func(i, j int) bool {
return d.handlers[i].Priority() < d.handlers[j].Priority()
})
d.logger.Debug("Handler registered",
zap.Int("priority", handler.Priority()),
zap.Int("total_handlers", len(d.handlers)))
}
// UnregisterHandler 取消注册事件处理器
func (d *Dispatcher) UnregisterHandler(handler protocol.EventHandler) {
for i, h := range d.handlers {
if h == handler {
d.handlers = append(d.handlers[:i], d.handlers[i+1:]...)
break
}
}
d.logger.Debug("Handler unregistered",
zap.Int("total_handlers", len(d.handlers)))
}
// RegisterMiddleware 注册中间件
func (d *Dispatcher) RegisterMiddleware(middleware protocol.Middleware) {
d.middlewares = append(d.middlewares, middleware)
d.logger.Debug("Middleware registered",
zap.Int("total_middlewares", len(d.middlewares)))
}
// Start 启动分发器
func (d *Dispatcher) Start(ctx context.Context) {
// 订阅所有类型的事件
for _, eventType := range []protocol.EventType{
protocol.EventTypeMessage,
protocol.EventTypeNotice,
protocol.EventTypeRequest,
protocol.EventTypeMeta,
} {
eventChan := d.eventBus.Subscribe(eventType, nil)
go d.eventLoop(ctx, eventChan)
}
d.logger.Info("Dispatcher started")
}
// Stop 停止分发器
func (d *Dispatcher) Stop() {
d.logger.Info("Dispatcher stopped")
}
// eventLoop 事件循环
func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Event) {
for {
select {
case event, ok := <-eventChan:
if !ok {
return
}
d.handleEvent(ctx, event)
case <-ctx.Done():
return
}
}
}
// handleEvent 处理单个事件
func (d *Dispatcher) handleEvent(ctx context.Context, event protocol.Event) {
d.logger.Debug("Processing event",
zap.String("type", string(event.GetType())),
zap.String("detail_type", event.GetDetailType()))
// 通过中间件链处理事件
next := d.createHandlerChain(ctx, event)
// 执行中间件链
if len(d.middlewares) > 0 {
d.executeMiddlewares(ctx, event, func(ctx context.Context, e protocol.Event) error {
next(ctx, e)
return nil
})
} else {
next(ctx, event)
}
}
// createHandlerChain 创建处理器链
func (d *Dispatcher) createHandlerChain(ctx context.Context, event protocol.Event) func(context.Context, protocol.Event) {
return func(ctx context.Context, e protocol.Event) {
for _, handler := range d.handlers {
if handler.Match(event) {
if err := handler.Handle(ctx, e); err != nil {
d.logger.Error("Handler execution failed",
zap.Error(err),
zap.String("event_type", string(e.GetType())))
}
}
}
}
}
// executeMiddlewares 执行中间件链
func (d *Dispatcher) executeMiddlewares(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) {
// 从后向前构建中间件链
handler := next
for i := len(d.middlewares) - 1; i >= 0; i-- {
middleware := d.middlewares[i]
currentHandler := handler
handler = func(ctx context.Context, e protocol.Event) error {
if err := middleware.Process(ctx, e, currentHandler); err != nil {
d.logger.Error("Middleware execution failed",
zap.Error(err),
zap.String("event_type", string(e.GetType())))
}
return nil
}
}
// 执行中间件链
handler(ctx, event)
}
// GetHandlerCount 获取处理器数量
func (d *Dispatcher) GetHandlerCount() int {
return len(d.handlers)
}
// GetMiddlewareCount 获取中间件数量
func (d *Dispatcher) GetMiddlewareCount() int {
return len(d.middlewares)
}

182
internal/engine/eventbus.go Normal file
View File

@@ -0,0 +1,182 @@
package engine
import (
"context"
"sync"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
// Subscription 订阅信息
type Subscription struct {
ID string
Chan chan protocol.Event
Filter func(protocol.Event) bool
}
// EventBus 事件总线
// 基于channel的高性能发布订阅实现
type EventBus struct {
subscriptions map[string][]*Subscription
mu sync.RWMutex
logger *zap.Logger
eventChan chan protocol.Event
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
// NewEventBus 创建事件总线
func NewEventBus(logger *zap.Logger, bufferSize int) *EventBus {
ctx, cancel := context.WithCancel(context.Background())
return &EventBus{
subscriptions: make(map[string][]*Subscription),
logger: logger.Named("eventbus"),
eventChan: make(chan protocol.Event, bufferSize),
ctx: ctx,
cancel: cancel,
}
}
// Start 启动事件总线
func (eb *EventBus) Start() {
eb.wg.Add(1)
go eb.dispatch()
eb.logger.Info("Event bus started")
}
// Stop 停止事件总线
func (eb *EventBus) Stop() {
eb.cancel()
eb.wg.Wait()
close(eb.eventChan)
eb.logger.Info("Event bus stopped")
}
// Publish 发布事件
func (eb *EventBus) Publish(event protocol.Event) {
select {
case eb.eventChan <- event:
case <-eb.ctx.Done():
eb.logger.Warn("Event bus is shutting down, event dropped",
zap.String("type", string(event.GetType())))
}
}
// Subscribe 订阅事件
func (eb *EventBus) Subscribe(eventType protocol.EventType, filter func(protocol.Event) bool) chan protocol.Event {
eb.mu.Lock()
defer eb.mu.Unlock()
sub := &Subscription{
ID: generateSubscriptionID(),
Chan: make(chan protocol.Event, 100),
Filter: filter,
}
key := string(eventType)
eb.subscriptions[key] = append(eb.subscriptions[key], sub)
eb.logger.Debug("New subscription added",
zap.String("event_type", key),
zap.String("sub_id", sub.ID))
return sub.Chan
}
// Unsubscribe 取消订阅
func (eb *EventBus) Unsubscribe(eventType protocol.EventType, ch chan protocol.Event) {
eb.mu.Lock()
defer eb.mu.Unlock()
key := string(eventType)
subs := eb.subscriptions[key]
for i, sub := range subs {
if sub.Chan == ch {
close(sub.Chan)
eb.subscriptions[key] = append(subs[:i], subs[i+1:]...)
eb.logger.Debug("Subscription removed",
zap.String("event_type", key),
zap.String("sub_id", sub.ID))
return
}
}
}
// dispatch 分发事件到订阅者
func (eb *EventBus) dispatch() {
defer eb.wg.Done()
for {
select {
case event, ok := <-eb.eventChan:
if !ok {
return
}
eb.dispatchEvent(event)
case <-eb.ctx.Done():
return
}
}
}
// dispatchEvent 分发单个事件
func (eb *EventBus) dispatchEvent(event protocol.Event) {
eb.mu.RLock()
key := string(event.GetType())
subs := eb.subscriptions[key]
// 复制订阅者列表避免锁竞争
subsCopy := make([]*Subscription, len(subs))
copy(subsCopy, subs)
eb.mu.RUnlock()
for _, sub := range subsCopy {
if sub.Filter == nil || sub.Filter(event) {
select {
case sub.Chan <- event:
default:
// 订阅者channel已满,丢弃事件
eb.logger.Warn("Subscription channel full, event dropped",
zap.String("sub_id", sub.ID),
zap.String("event_type", key))
}
}
}
}
// GetSubscriptionCount 获取订阅者数量
func (eb *EventBus) GetSubscriptionCount(eventType protocol.EventType) int {
eb.mu.RLock()
defer eb.mu.RUnlock()
return len(eb.subscriptions[string(eventType)])
}
// Clear 清空所有订阅
func (eb *EventBus) Clear() {
eb.mu.Lock()
defer eb.mu.Unlock()
for eventType, subs := range eb.subscriptions {
for _, sub := range subs {
close(sub.Chan)
}
delete(eb.subscriptions, eventType)
}
eb.logger.Info("All subscriptions cleared")
}
// generateSubscriptionID 生成订阅ID
func generateSubscriptionID() string {
return "sub-" + randomString(8)
}
// randomString 生成随机字符串
func randomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[i%len(charset)]
}
return string(b)
}

View File

@@ -0,0 +1,250 @@
package engine
import (
"context"
"sync"
"testing"
"time"
"cellbot/internal/protocol"
"go.uber.org/zap"
)
func TestEventBus_PublishSubscribe(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅事件
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
// 发布事件
eventBus.Publish(event)
// 接收事件
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetType() != protocol.EventTypeMessage {
t.Errorf("Expected event type '%s', got '%s'", protocol.EventTypeMessage, receivedEvent.GetType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
}
func TestEventBus_Filter(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
eventBus.Start()
defer eventBus.Stop()
// 创建测试事件
event1 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
event2 := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "group",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
// 订阅并过滤:只接收private消息
filter := func(e protocol.Event) bool {
return e.GetDetailType() == "private"
}
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, filter)
// 发布两个事件
eventBus.Publish(event1)
eventBus.Publish(event2)
// 应该只收到private消息
select {
case receivedEvent := <-eventChan:
if receivedEvent.GetDetailType() != "private" {
t.Errorf("Expected detail type 'private', got '%s'", receivedEvent.GetDetailType())
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for event")
}
// 不应该再收到第二个事件
select {
case <-eventChan:
t.Error("Should not receive group message")
case <-time.After(50 * time.Millisecond):
// 正确,不应该收到
}
}
func TestEventBus_Concurrent(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 10000)
eventBus.Start()
defer eventBus.Stop()
numSubscribers := 10
numPublishers := 10
numEvents := 100
// 创建多个订阅者
subscribers := make([]chan protocol.Event, numSubscribers)
for i := 0; i < numSubscribers; i++ {
subscribers[i] = eventBus.Subscribe(protocol.EventTypeMessage, nil)
}
var wg sync.WaitGroup
wg.Add(numPublishers)
// 多个发布者并发发布事件
for i := 0; i < numPublishers; i++ {
go func() {
defer wg.Done()
for j := 0; j < numEvents; j++ {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
}()
}
wg.Wait()
// 验证每个订阅者都收到了事件
for _, ch := range subscribers {
count := 0
for count < numPublishers*numEvents {
select {
case <-ch:
count++
case <-time.After(500 * time.Millisecond):
t.Errorf("Timeout waiting for events, received %d, expected %d", count, numPublishers*numEvents)
break
}
}
}
}
func TestEventBus_Benchmark(b *testing.B) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100000)
eventBus.Start()
defer eventBus.Stop()
eventChan := eventBus.Subscribe(protocol.EventTypeMessage, nil)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
eventBus.Publish(event)
}
})
// 消耗channel避免阻塞
go func() {
for range eventChan {
}
}()
}
func TestDispatcher_HandlerPriority(t *testing.T) {
logger := zap.NewNop()
eventBus := NewEventBus(logger, 100)
dispatcher := NewDispatcher(eventBus, logger)
// 创建测试处理器
handlers := make([]*TestHandler, 3)
priorities := []int{3, 1, 2}
for i, priority := range priorities {
handlers[i] = &TestHandler{
priority: priority,
matched: false,
executed: false,
}
dispatcher.RegisterHandler(handlers[i])
}
// 验证处理器按优先级排序
if dispatcher.GetHandlerCount() != 3 {
t.Errorf("Expected 3 handlers, got %d", dispatcher.GetHandlerCount())
}
event := &protocol.BaseEvent{
Type: protocol.EventTypeMessage,
DetailType: "private",
Timestamp: time.Now().Unix(),
SelfID: "test_bot",
Data: make(map[string]interface{}),
}
ctx := context.Background()
// 分发事件
dispatcher.handleEvent(ctx, event)
// 验证处理器是否按优先级执行
executedOrder := make([]int, 0)
for _, handler := range handlers {
if handler.executed {
executedOrder = append(executedOrder, handler.priority)
}
}
// 检查是否按升序执行
for i := 1; i < len(executedOrder); i++ {
if executedOrder[i] < executedOrder[i-1] {
t.Errorf("Handlers not executed in priority order: %v", executedOrder)
}
}
}
// TestHandler 测试处理器
type TestHandler struct {
priority int
matched bool
executed bool
}
func (h *TestHandler) Handle(ctx context.Context, event protocol.Event) error {
h.executed = true
return nil
}
func (h *TestHandler) Priority() int {
return h.priority
}
func (h *TestHandler) Match(event protocol.Event) bool {
h.matched = true
return true
}

107
internal/protocol/action.go Normal file
View File

@@ -0,0 +1,107 @@
package protocol
// Action 动作接口
// 参考OneBot12协议定义统一的动作操作接口
type Action interface {
// GetType 获取动作类型
GetType() ActionType
// GetParams 获取动作参数
GetParams() map[string]interface{}
// Execute 执行动作
Execute(ctx interface{}) (map[string]interface{}, error)
}
// ActionType 动作类型
type ActionType string
const (
// 消息相关动作
ActionTypeSendPrivateMessage ActionType = "send_private_message"
ActionTypeSendGroupMessage ActionType = "send_group_message"
ActionTypeDeleteMessage ActionType = "delete_message"
// 用户相关动作
ActionTypeGetUserInfo ActionType = "get_user_info"
ActionTypeGetFriendList ActionType = "get_friend_list"
ActionTypeGetGroupInfo ActionType = "get_group_info"
ActionTypeGetGroupMemberList ActionType = "get_group_member_list"
// 群组相关动作
ActionTypeSetGroupKick ActionType = "set_group_kick"
ActionTypeSetGroupBan ActionType = "set_group_ban"
ActionTypeSetGroupAdmin ActionType = "set_group_admin"
ActionTypeSetGroupWholeBan ActionType = "set_group_whole_ban"
// 其他动作
ActionTypeGetStatus ActionType = "get_status"
ActionTypeGetVersion ActionType = "get_version"
)
// BaseAction 基础动作结构
type BaseAction struct {
Type ActionType `json:"type"`
Params map[string]interface{} `json:"params"`
}
// GetType 获取动作类型
func (a *BaseAction) GetType() ActionType {
return a.Type
}
// GetParams 获取动作参数
func (a *BaseAction) GetParams() map[string]interface{} {
return a.Params
}
// Execute 执行动作(需子类实现)
func (a *BaseAction) Execute(ctx interface{}) (map[string]interface{}, error) {
return nil, ErrNotImplemented
}
// SendPrivateMessageAction 发送私聊消息动作
type SendPrivateMessageAction struct {
BaseAction
UserID string `json:"user_id"`
Message string `json:"message"`
}
// SendGroupMessageAction 发送群聊消息动作
type SendGroupMessageAction struct {
BaseAction
GroupID string `json:"group_id"`
Message string `json:"message"`
}
// DeleteMessageAction 删除消息动作
type DeleteMessageAction struct {
BaseAction
MessageID string `json:"message_id"`
}
// GetUserInfoAction 获取用户信息动作
type GetUserInfoAction struct {
BaseAction
UserID string `json:"user_id"`
}
// GetGroupInfoAction 获取群信息动作
type GetGroupInfoAction struct {
BaseAction
GroupID string `json:"group_id"`
}
// 错误定义
var (
ErrNotImplemented = &ProtocolError{Code: 10001, Message: "action not implemented"}
ErrInvalidParams = &ProtocolError{Code: 10002, Message: "invalid parameters"}
)
// ProtocolError 协议错误
type ProtocolError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (e *ProtocolError) Error() string {
return e.Message
}

206
internal/protocol/bot.go Normal file
View File

@@ -0,0 +1,206 @@
package protocol
import (
"context"
"sync"
"go.uber.org/zap"
)
// BaseBotInstance 机器人实例基类
type BaseBotInstance struct {
id string
protocol Protocol
status BotStatus
logger *zap.Logger
mu sync.RWMutex
}
// NewBaseBotInstance 创建机器人实例基类
func NewBaseBotInstance(id string, protocol Protocol, logger *zap.Logger) *BaseBotInstance {
return &BaseBotInstance{
id: id,
protocol: protocol,
status: BotStatusStopped,
logger: logger.With(zap.String("bot_id", id)),
}
}
// GetID 获取实例ID
func (b *BaseBotInstance) GetID() string {
return b.id
}
// Name 获取协议名称
func (b *BaseBotInstance) Name() string {
return b.protocol.Name()
}
// Version 获取协议版本
func (b *BaseBotInstance) Version() string {
return b.protocol.Version()
}
// Connect 建立连接
func (b *BaseBotInstance) Connect(ctx context.Context) error {
b.mu.Lock()
b.status = BotStatusStarting
b.mu.Unlock()
if err := b.protocol.Connect(ctx); err != nil {
b.mu.Lock()
b.status = BotStatusError
b.mu.Unlock()
return err
}
b.mu.Lock()
b.status = BotStatusRunning
b.mu.Unlock()
b.logger.Info("Bot instance connected")
return nil
}
// Disconnect 断开连接
func (b *BaseBotInstance) Disconnect(ctx context.Context) error {
b.mu.Lock()
b.status = BotStatusStopping
b.mu.Unlock()
if err := b.protocol.Disconnect(ctx); err != nil {
b.mu.Lock()
b.status = BotStatusError
b.mu.Unlock()
return err
}
b.mu.Lock()
b.status = BotStatusStopped
b.mu.Unlock()
b.logger.Info("Bot instance disconnected")
return nil
}
// IsConnected 检查连接状态
func (b *BaseBotInstance) IsConnected() bool {
b.mu.RLock()
defer b.mu.RUnlock()
return b.status == BotStatusRunning
}
// SendAction 发送动作
func (b *BaseBotInstance) SendAction(ctx context.Context, action Action) (map[string]interface{}, error) {
return b.protocol.SendAction(ctx, action)
}
// HandleEvent 处理事件
func (b *BaseBotInstance) HandleEvent(ctx context.Context, event Event) error {
return b.protocol.HandleEvent(ctx, event)
}
// GetSelfID 获取机器人自身ID
func (b *BaseBotInstance) GetSelfID() string {
return b.protocol.GetSelfID()
}
// Start 启动实例
func (b *BaseBotInstance) Start(ctx context.Context) error {
return b.Connect(ctx)
}
// Stop 停止实例
func (b *BaseBotInstance) Stop(ctx context.Context) error {
return b.Disconnect(ctx)
}
// GetStatus 获取实例状态
func (b *BaseBotInstance) GetStatus() BotStatus {
b.mu.RLock()
defer b.mu.RUnlock()
return b.status
}
// BotManager 机器人管理器
type BotManager struct {
bots map[string]BotInstance
logger *zap.Logger
mu sync.RWMutex
}
// NewBotManager 创建机器人管理器
func NewBotManager(logger *zap.Logger) *BotManager {
return &BotManager{
bots: make(map[string]BotInstance),
logger: logger,
}
}
// Add 添加机器人实例
func (bm *BotManager) Add(bot BotInstance) {
bm.mu.Lock()
defer bm.mu.Unlock()
bm.bots[bot.GetID()] = bot
bm.logger.Info("Bot added", zap.String("bot_id", bot.GetID()))
}
// Remove 移除机器人实例
func (bm *BotManager) Remove(id string) {
bm.mu.Lock()
defer bm.mu.Unlock()
if bot, ok := bm.bots[id]; ok {
bot.Stop(context.Background())
delete(bm.bots, id)
bm.logger.Info("Bot removed", zap.String("bot_id", id))
}
}
// Get 获取机器人实例
func (bm *BotManager) Get(id string) (BotInstance, bool) {
bm.mu.RLock()
defer bm.mu.RUnlock()
bot, ok := bm.bots[id]
return bot, ok
}
// GetAll 获取所有机器人实例
func (bm *BotManager) GetAll() []BotInstance {
bm.mu.RLock()
defer bm.mu.RUnlock()
bots := make([]BotInstance, 0, len(bm.bots))
for _, bot := range bm.bots {
bots = append(bots, bot)
}
return bots
}
// StartAll 启动所有机器人
func (bm *BotManager) StartAll(ctx context.Context) error {
bm.mu.RLock()
defer bm.mu.RUnlock()
for _, bot := range bm.bots {
if err := bot.Start(ctx); err != nil {
bm.logger.Error("Failed to start bot",
zap.String("bot_id", bot.GetID()),
zap.Error(err))
}
}
return nil
}
// StopAll 停止所有机器人
func (bm *BotManager) StopAll(ctx context.Context) error {
bm.mu.RLock()
defer bm.mu.RUnlock()
for _, bot := range bm.bots {
if err := bot.Stop(ctx); err != nil {
bm.logger.Error("Failed to stop bot",
zap.String("bot_id", bot.GetID()),
zap.Error(err))
}
}
return nil
}

105
internal/protocol/event.go Normal file
View File

@@ -0,0 +1,105 @@
package protocol
import "time"
// EventType 事件类型
type EventType string
const (
// 事件类型常量
EventTypeMessage EventType = "message"
EventTypeNotice EventType = "notice"
EventTypeRequest EventType = "request"
EventTypeMeta EventType = "meta"
EventTypeMessageSent EventType = "message_sent"
EventTypeNoticeSent EventType = "notice_sent"
EventTypeRequestSent EventType = "request_sent"
)
// Event 通用事件接口
// 参考OneBot12协议设计提供统一的事件抽象
type Event interface {
// GetType 获取事件类型
GetType() EventType
// GetDetailType 获取详细类型
GetDetailType() string
// GetSubType 获取子类型
GetSubType() string
// GetTimestamp 获取时间戳
GetTimestamp() time.Time
// GetSelfID 获取机器人自身ID
GetSelfID() string
// GetData 获取事件数据
GetData() map[string]interface{}
}
// BaseEvent 基础事件结构
type BaseEvent struct {
Type EventType `json:"type"`
DetailType string `json:"detail_type"`
SubType string `json:"sub_type,omitempty"`
Timestamp int64 `json:"timestamp"`
SelfID string `json:"self_id"`
Data map[string]interface{} `json:"data"`
}
// GetType 获取事件类型
func (e *BaseEvent) GetType() EventType {
return e.Type
}
// GetDetailType 获取详细类型
func (e *BaseEvent) GetDetailType() string {
return e.DetailType
}
// GetSubType 获取子类型
func (e *BaseEvent) GetSubType() string {
return e.SubType
}
// GetTimestamp 获取时间戳
func (e *BaseEvent) GetTimestamp() time.Time {
return time.Unix(e.Timestamp, 0)
}
// GetSelfID 获取机器人自身ID
func (e *BaseEvent) GetSelfID() string {
return e.SelfID
}
// GetData 获取事件数据
func (e *BaseEvent) GetData() map[string]interface{} {
return e.Data
}
// MessageEvent 消息事件
type MessageEvent struct {
BaseEvent
MessageID string `json:"message_id"`
Message string `json:"message"`
AltText string `json:"alt_text,omitempty"`
}
// NoticeEvent 通知事件
type NoticeEvent struct {
BaseEvent
GroupID string `json:"group_id,omitempty"`
UserID string `json:"user_id,omitempty"`
Operator string `json:"operator,omitempty"`
}
// RequestEvent 请求事件
type RequestEvent struct {
BaseEvent
RequestID string `json:"request_id"`
UserID string `json:"user_id"`
Comment string `json:"comment"`
Flag string `json:"flag"`
}
// MetaEvent 元事件
type MetaEvent struct {
BaseEvent
Status string `json:"status"`
}

View File

@@ -0,0 +1,77 @@
package protocol
import (
"context"
)
// Protocol 通用协议接口
// 参考OneBot12协议核心设计理念定义统一的机器人协议接口
type Protocol interface {
// Name 获取协议名称
Name() string
// Version 获取协议版本
Version() string
// Connect 建立连接
Connect(ctx context.Context) error
// Disconnect 断开连接
Disconnect(ctx context.Context) error
// IsConnected 检查连接状态
IsConnected() bool
// SendAction 发送动作
SendAction(ctx context.Context, action Action) (map[string]interface{}, error)
// HandleEvent 处理事件
HandleEvent(ctx context.Context, event Event) error
// GetSelfID 获取机器人自身ID
GetSelfID() string
}
// Adapter 协议适配器接口
// 实现具体协议的接入逻辑
type Adapter interface {
Protocol
// ParseMessage 解析原始消息为Event
ParseMessage(raw []byte) (Event, error)
// SerializeAction 序列化Action为协议格式
SerializeAction(action Action) ([]byte, error)
}
// BotInstance 机器人实例接口
// 管理单个机器人实例的生命周期
type BotInstance interface {
Protocol
// GetID 获取实例ID
GetID() string
// Start 启动实例
Start(ctx context.Context) error
// Stop 停止实例
Stop(ctx context.Context) error
// GetStatus 获取实例状态
GetStatus() BotStatus
}
// BotStatus 机器人状态
type BotStatus string
const (
BotStatusStarting BotStatus = "starting"
BotStatusRunning BotStatus = "running"
BotStatusStopping BotStatus = "stopping"
BotStatusStopped BotStatus = "stopped"
BotStatusError BotStatus = "error"
)
// EventHandler 事件处理器接口
type EventHandler interface {
// Handle 处理事件
Handle(ctx context.Context, event Event) error
// Priority 获取处理器优先级(数值越小优先级越高)
Priority() int
// Match 判断是否匹配事件
Match(event Event) bool
}
// Middleware 中间件接口
type Middleware interface {
// Process 处理事件
Process(ctx context.Context, event Event, next func(ctx context.Context, event Event) error) error
}