Files
cellbot/internal/di/providers.go
lafay fb5fae1524 chore: update project structure and enhance plugin functionality
- Added new entries to .gitignore for database files.
- Updated go.mod and go.sum to include new indirect dependencies for database and ORM support.
- Refactored event handling to improve message reply functionality in the protocol.
- Enhanced the dispatcher to allow for better event processing and logging.
- Removed outdated plugin documentation and unnecessary files to streamline the codebase.
- Improved welcome message formatting and screenshot options for better user experience.
2026-01-05 05:14:31 +08:00

276 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package di
import (
"context"
"fmt"
"cellbot/internal/adapter/milky"
"cellbot/internal/adapter/onebot11"
"cellbot/internal/config"
"cellbot/internal/database"
"cellbot/internal/engine"
_ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数
"cellbot/internal/plugins/mcstatus"
_ "cellbot/internal/plugins/mcstatus" // 导入插件以触发 init 函数
_ "cellbot/internal/plugins/welcome" // 导入插件以触发 init 函数
"cellbot/internal/protocol"
"cellbot/pkg/net"
"go.uber.org/fx"
"go.uber.org/zap"
)
func ProvideLogger(cfg *config.Config) (*zap.Logger, error) {
return config.InitLogger(&cfg.Log)
}
func ProvideConfig() (*config.Config, error) {
configManager := config.NewConfigManager("configs/config.toml", zap.NewNop())
if err := configManager.Load(); err != nil {
return nil, err
}
return configManager.Get(), nil
}
func ProvideConfigManager(logger *zap.Logger) (*config.ConfigManager, error) {
configManager := config.NewConfigManager("configs/config.toml", logger)
if err := configManager.Load(); err != nil {
return nil, err
}
if err := configManager.Watch(); err != nil {
logger.Warn("Failed to watch config file", zap.Error(err))
}
return configManager, nil
}
func ProvideEventBus(logger *zap.Logger) *engine.EventBus {
return engine.NewEventBus(logger, 10000)
}
func ProvideScheduler(logger *zap.Logger) *engine.Scheduler {
return engine.NewScheduler(logger)
}
func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config, scheduler *engine.Scheduler) *engine.Dispatcher {
dispatcher := engine.NewDispatcherWithScheduler(eventBus, logger, scheduler)
// 注册限流中间件
if cfg.Engine.RateLimit.Enabled {
rateLimitMiddleware := engine.NewRateLimitMiddleware(
logger,
cfg.Engine.RateLimit.RPS,
cfg.Engine.RateLimit.Burst,
)
dispatcher.RegisterMiddleware(rateLimitMiddleware)
logger.Info("Rate limit middleware registered",
zap.Int("rps", cfg.Engine.RateLimit.RPS),
zap.Int("burst", cfg.Engine.RateLimit.Burst))
}
return dispatcher
}
func ProvidePluginRegistry(dispatcher *engine.Dispatcher, logger *zap.Logger) *engine.PluginRegistry {
return engine.NewPluginRegistry(dispatcher, logger)
}
func ProvideBotManager(logger *zap.Logger) *protocol.BotManager {
return protocol.NewBotManager(logger)
}
func ProvideWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *net.WebSocketManager {
return net.NewWebSocketManager(logger, eventBus)
}
func ProvideServer(cfg *config.Config, logger *zap.Logger, botManager *protocol.BotManager, eventBus *engine.EventBus) *net.Server {
return net.NewServer(cfg.Server.Host, cfg.Server.Port, logger, botManager, eventBus)
}
// ProvideDatabase 提供数据库服务
func ProvideDatabase(logger *zap.Logger) database.Database {
return database.NewSQLiteDatabase(logger, "data/cellbot.db")
}
// InitMCStatusDatabase 初始化 MC 状态插件的数据库
func InitMCStatusDatabase(dbService database.Database, logger *zap.Logger, botManager *protocol.BotManager) error {
// 为每个 bot 初始化数据库表
bots := botManager.GetAll()
for _, bot := range bots {
botID := bot.GetID()
db, err := dbService.GetDB(botID)
if err != nil {
logger.Error("Failed to get database for bot",
zap.String("bot_id", botID),
zap.Error(err))
continue
}
// 创建表(使用原始 SQL 避免循环依赖)
// 注意:虽然使用 fmt.Sprintf但 tableName 已经通过 sanitizeTableName 清理过,相对安全
tableName := database.GetTableName(botID, "mc_server_binds")
// 使用参数化查询更安全,但 SQLite 的 CREATE TABLE 不支持参数化表名
// 所以这里使用清理过的表名是合理的
if err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id TEXT PRIMARY KEY,
server_ip TEXT NOT NULL
)
`, tableName)).Error; err != nil {
logger.Error("Failed to create table",
zap.String("bot_id", botID),
zap.String("table", tableName),
zap.Error(err))
} else {
logger.Info("Database table initialized",
zap.String("bot_id", botID),
zap.String("table", tableName))
}
}
// 初始化插件数据库
mcstatus.InitDatabase(dbService)
return nil
}
func ProvideMilkyBots(cfg *config.Config, logger *zap.Logger, eventBus *engine.EventBus, wsManager *net.WebSocketManager, botManager *protocol.BotManager, lc fx.Lifecycle) error {
for _, botCfg := range cfg.Bots {
if botCfg.Protocol == "milky" && botCfg.Enabled {
logger.Info("Creating Milky bot", zap.String("bot_id", botCfg.ID))
milkyCfg := &milky.Config{
ProtocolURL: botCfg.Milky.ProtocolURL,
AccessToken: botCfg.Milky.AccessToken,
EventMode: botCfg.Milky.EventMode,
WebhookListenAddr: botCfg.Milky.WebhookListenAddr,
Timeout: botCfg.Milky.Timeout,
RetryCount: botCfg.Milky.RetryCount,
}
bot := milky.NewBot(botCfg.ID, milkyCfg, eventBus, wsManager, logger)
botManager.Add(bot)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("Starting Milky bot", zap.String("bot_id", botCfg.ID))
// 在后台启动连接,失败时只记录错误,不终止应用
go func() {
if err := bot.Connect(context.Background()); err != nil {
logger.Error("Failed to connect Milky bot, will retry in background",
zap.String("bot_id", botCfg.ID),
zap.Error(err))
// 可以在这里实现重试逻辑
} else {
logger.Info("Milky bot connected successfully", zap.String("bot_id", botCfg.ID))
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("Stopping Milky bot", zap.String("bot_id", botCfg.ID))
return bot.Disconnect(ctx)
},
})
}
}
return nil
}
func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.WebSocketManager, eventBus *engine.EventBus, botManager *protocol.BotManager, lc fx.Lifecycle) error {
for _, botCfg := range cfg.Bots {
if botCfg.Protocol == "onebot11" && botCfg.Enabled {
logger.Info("Creating OneBot11 bot", zap.String("bot_id", botCfg.ID))
ob11Cfg := &onebot11.Config{
ConnectionType: botCfg.OneBot11.ConnectionType,
Host: botCfg.OneBot11.Host,
Port: botCfg.OneBot11.Port,
AccessToken: botCfg.OneBot11.AccessToken,
WSUrl: botCfg.OneBot11.WSUrl,
WSReverseUrl: botCfg.OneBot11.WSReverseUrl,
Heartbeat: botCfg.OneBot11.Heartbeat,
ReconnectInterval: botCfg.OneBot11.ReconnectInterval,
HTTPUrl: botCfg.OneBot11.HTTPUrl,
HTTPPostUrl: botCfg.OneBot11.HTTPPostUrl,
Secret: botCfg.OneBot11.Secret,
Timeout: botCfg.OneBot11.Timeout,
SelfID: botCfg.OneBot11.SelfID,
Nickname: botCfg.OneBot11.Nickname,
}
bot := onebot11.NewBot(botCfg.ID, ob11Cfg, logger, wsManager, eventBus)
botManager.Add(bot)
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("Starting OneBot11 bot", zap.String("bot_id", botCfg.ID))
// 在后台启动连接,失败时只记录错误,不终止应用
go func() {
if err := bot.Connect(context.Background()); err != nil {
logger.Error("Failed to connect OneBot11 bot, will retry in background",
zap.String("bot_id", botCfg.ID),
zap.Error(err))
// 可以在这里实现重试逻辑
} else {
logger.Info("OneBot11 bot connected successfully", zap.String("bot_id", botCfg.ID))
}
}()
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("Stopping OneBot11 bot", zap.String("bot_id", botCfg.ID))
return bot.Disconnect(ctx)
},
})
}
}
return nil
}
// LoadPlugins 加载所有插件
func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *engine.PluginRegistry) {
// 从全局注册表加载所有插件(通过 RegisterPlugin 注册的)
plugins := engine.LoadAllPlugins(botManager, logger)
for _, plugin := range plugins {
registry.Register(plugin)
}
// 加载所有通过 OnXXX().Handle() 注册的处理器(注入依赖)
handlers := engine.LoadAllHandlers(botManager, logger)
for _, handler := range handlers {
registry.Register(handler)
}
totalCount := len(plugins) + len(handlers)
logger.Info("All plugins loaded",
zap.Int("plugin_count", len(plugins)),
zap.Int("handler_count", len(handlers)),
zap.Int("total_count", totalCount),
zap.Strings("plugins", engine.GetRegisteredPlugins()))
}
// LoadScheduledJobs 加载所有定时任务(由依赖注入系统调用)
func LoadScheduledJobs(scheduler *engine.Scheduler, logger *zap.Logger) error {
return engine.LoadAllJobs(scheduler, logger)
}
var Providers = fx.Options(
fx.Provide(
ProvideConfig,
ProvideConfigManager,
ProvideLogger,
ProvideEventBus,
ProvideScheduler,
ProvideDispatcher,
ProvidePluginRegistry,
ProvideBotManager,
ProvideWebSocketManager,
ProvideDatabase,
ProvideServer,
),
fx.Invoke(ProvideMilkyBots),
fx.Invoke(ProvideOneBot11Bots),
fx.Invoke(LoadPlugins),
fx.Invoke(LoadScheduledJobs),
fx.Invoke(InitMCStatusDatabase),
)