package di import ( "context" "fmt" "cellbot/internal/adapter/milky" "cellbot/internal/adapter/onebot11" "cellbot/internal/config" "cellbot/internal/database" "cellbot/internal/engine" _ "cellbot/internal/plugins/echo" // 导入插件以触发 init 函数 "cellbot/internal/plugins/mcstatus" _ "cellbot/internal/plugins/mcstatus" // 导入插件以触发 init 函数 _ "cellbot/internal/plugins/welcome" // 导入插件以触发 init 函数 "cellbot/internal/protocol" "cellbot/pkg/net" "go.uber.org/fx" "go.uber.org/zap" ) func ProvideLogger(cfg *config.Config) (*zap.Logger, error) { return config.InitLogger(&cfg.Log) } func ProvideConfig() (*config.Config, error) { configManager := config.NewConfigManager("configs/config.toml", zap.NewNop()) if err := configManager.Load(); err != nil { return nil, err } return configManager.Get(), nil } func ProvideConfigManager(logger *zap.Logger) (*config.ConfigManager, error) { configManager := config.NewConfigManager("configs/config.toml", logger) if err := configManager.Load(); err != nil { return nil, err } if err := configManager.Watch(); err != nil { logger.Warn("Failed to watch config file", zap.Error(err)) } return configManager, nil } func ProvideEventBus(logger *zap.Logger) *engine.EventBus { return engine.NewEventBus(logger, 10000) } func ProvideScheduler(logger *zap.Logger) *engine.Scheduler { return engine.NewScheduler(logger) } func ProvideDispatcher(eventBus *engine.EventBus, logger *zap.Logger, cfg *config.Config, scheduler *engine.Scheduler) *engine.Dispatcher { dispatcher := engine.NewDispatcherWithScheduler(eventBus, logger, scheduler) // 注册限流中间件 if cfg.Engine.RateLimit.Enabled { rateLimitMiddleware := engine.NewRateLimitMiddleware( logger, cfg.Engine.RateLimit.RPS, cfg.Engine.RateLimit.Burst, ) dispatcher.RegisterMiddleware(rateLimitMiddleware) logger.Info("Rate limit middleware registered", zap.Int("rps", cfg.Engine.RateLimit.RPS), zap.Int("burst", cfg.Engine.RateLimit.Burst)) } return dispatcher } func ProvidePluginRegistry(dispatcher *engine.Dispatcher, logger *zap.Logger) *engine.PluginRegistry { return engine.NewPluginRegistry(dispatcher, logger) } func ProvideBotManager(logger *zap.Logger) *protocol.BotManager { return protocol.NewBotManager(logger) } func ProvideWebSocketManager(logger *zap.Logger, eventBus *engine.EventBus) *net.WebSocketManager { return net.NewWebSocketManager(logger, eventBus) } func ProvideServer(cfg *config.Config, logger *zap.Logger, botManager *protocol.BotManager, eventBus *engine.EventBus) *net.Server { return net.NewServer(cfg.Server.Host, cfg.Server.Port, logger, botManager, eventBus) } // ProvideDatabase 提供数据库服务 func ProvideDatabase(logger *zap.Logger) database.Database { return database.NewSQLiteDatabase(logger, "data/cellbot.db") } // InitMCStatusDatabase 初始化 MC 状态插件的数据库 func InitMCStatusDatabase(dbService database.Database, logger *zap.Logger, botManager *protocol.BotManager) error { // 为每个 bot 初始化数据库表 bots := botManager.GetAll() for _, bot := range bots { botID := bot.GetID() db, err := dbService.GetDB(botID) if err != nil { logger.Error("Failed to get database for bot", zap.String("bot_id", botID), zap.Error(err)) continue } // 创建表(使用原始 SQL 避免循环依赖) // 注意:虽然使用 fmt.Sprintf,但 tableName 已经通过 sanitizeTableName 清理过,相对安全 tableName := database.GetTableName(botID, "mc_server_binds") // 使用参数化查询更安全,但 SQLite 的 CREATE TABLE 不支持参数化表名 // 所以这里使用清理过的表名是合理的 if err := db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id TEXT PRIMARY KEY, server_ip TEXT NOT NULL ) `, tableName)).Error; err != nil { logger.Error("Failed to create table", zap.String("bot_id", botID), zap.String("table", tableName), zap.Error(err)) } else { logger.Info("Database table initialized", zap.String("bot_id", botID), zap.String("table", tableName)) } } // 初始化插件数据库 mcstatus.InitDatabase(dbService) return nil } func ProvideMilkyBots(cfg *config.Config, logger *zap.Logger, eventBus *engine.EventBus, wsManager *net.WebSocketManager, botManager *protocol.BotManager, lc fx.Lifecycle) error { for _, botCfg := range cfg.Bots { if botCfg.Protocol == "milky" && botCfg.Enabled { logger.Info("Creating Milky bot", zap.String("bot_id", botCfg.ID)) milkyCfg := &milky.Config{ ProtocolURL: botCfg.Milky.ProtocolURL, AccessToken: botCfg.Milky.AccessToken, EventMode: botCfg.Milky.EventMode, WebhookListenAddr: botCfg.Milky.WebhookListenAddr, Timeout: botCfg.Milky.Timeout, RetryCount: botCfg.Milky.RetryCount, } bot := milky.NewBot(botCfg.ID, milkyCfg, eventBus, wsManager, logger) botManager.Add(bot) lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { logger.Info("Starting Milky bot", zap.String("bot_id", botCfg.ID)) // 在后台启动连接,失败时只记录错误,不终止应用 go func() { if err := bot.Connect(context.Background()); err != nil { logger.Error("Failed to connect Milky bot, will retry in background", zap.String("bot_id", botCfg.ID), zap.Error(err)) // 可以在这里实现重试逻辑 } else { logger.Info("Milky bot connected successfully", zap.String("bot_id", botCfg.ID)) } }() return nil }, OnStop: func(ctx context.Context) error { logger.Info("Stopping Milky bot", zap.String("bot_id", botCfg.ID)) return bot.Disconnect(ctx) }, }) } } return nil } func ProvideOneBot11Bots(cfg *config.Config, logger *zap.Logger, wsManager *net.WebSocketManager, eventBus *engine.EventBus, botManager *protocol.BotManager, lc fx.Lifecycle) error { for _, botCfg := range cfg.Bots { if botCfg.Protocol == "onebot11" && botCfg.Enabled { logger.Info("Creating OneBot11 bot", zap.String("bot_id", botCfg.ID)) ob11Cfg := &onebot11.Config{ ConnectionType: botCfg.OneBot11.ConnectionType, Host: botCfg.OneBot11.Host, Port: botCfg.OneBot11.Port, AccessToken: botCfg.OneBot11.AccessToken, WSUrl: botCfg.OneBot11.WSUrl, WSReverseUrl: botCfg.OneBot11.WSReverseUrl, Heartbeat: botCfg.OneBot11.Heartbeat, ReconnectInterval: botCfg.OneBot11.ReconnectInterval, HTTPUrl: botCfg.OneBot11.HTTPUrl, HTTPPostUrl: botCfg.OneBot11.HTTPPostUrl, Secret: botCfg.OneBot11.Secret, Timeout: botCfg.OneBot11.Timeout, SelfID: botCfg.OneBot11.SelfID, Nickname: botCfg.OneBot11.Nickname, } bot := onebot11.NewBot(botCfg.ID, ob11Cfg, logger, wsManager, eventBus) botManager.Add(bot) lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { logger.Info("Starting OneBot11 bot", zap.String("bot_id", botCfg.ID)) // 在后台启动连接,失败时只记录错误,不终止应用 go func() { if err := bot.Connect(context.Background()); err != nil { logger.Error("Failed to connect OneBot11 bot, will retry in background", zap.String("bot_id", botCfg.ID), zap.Error(err)) // 可以在这里实现重试逻辑 } else { logger.Info("OneBot11 bot connected successfully", zap.String("bot_id", botCfg.ID)) } }() return nil }, OnStop: func(ctx context.Context) error { logger.Info("Stopping OneBot11 bot", zap.String("bot_id", botCfg.ID)) return bot.Disconnect(ctx) }, }) } } return nil } // LoadPlugins 加载所有插件 func LoadPlugins(logger *zap.Logger, botManager *protocol.BotManager, registry *engine.PluginRegistry) { // 从全局注册表加载所有插件(通过 RegisterPlugin 注册的) plugins := engine.LoadAllPlugins(botManager, logger) for _, plugin := range plugins { registry.Register(plugin) } // 加载所有通过 OnXXX().Handle() 注册的处理器(注入依赖) handlers := engine.LoadAllHandlers(botManager, logger) for _, handler := range handlers { registry.Register(handler) } totalCount := len(plugins) + len(handlers) logger.Info("All plugins loaded", zap.Int("plugin_count", len(plugins)), zap.Int("handler_count", len(handlers)), zap.Int("total_count", totalCount), zap.Strings("plugins", engine.GetRegisteredPlugins())) } // LoadScheduledJobs 加载所有定时任务(由依赖注入系统调用) func LoadScheduledJobs(scheduler *engine.Scheduler, logger *zap.Logger) error { return engine.LoadAllJobs(scheduler, logger) } var Providers = fx.Options( fx.Provide( ProvideConfig, ProvideConfigManager, ProvideLogger, ProvideEventBus, ProvideScheduler, ProvideDispatcher, ProvidePluginRegistry, ProvideBotManager, ProvideWebSocketManager, ProvideDatabase, ProvideServer, ), fx.Invoke(ProvideMilkyBots), fx.Invoke(ProvideOneBot11Bots), fx.Invoke(LoadPlugins), fx.Invoke(LoadScheduledJobs), fx.Invoke(InitMCStatusDatabase), )