package engine import ( "context" "fmt" "strings" "sync" "sync/atomic" "cellbot/internal/protocol" "go.uber.org/zap" ) // PluginFactory 插件工厂函数类型 // 用于在运行时创建插件实例,支持依赖注入 type PluginFactory func(botManager *protocol.BotManager, logger *zap.Logger) protocol.EventHandler // globalPluginRegistry 全局插件注册表 var ( globalPluginRegistry = make(map[string]PluginFactory) globalPluginMu sync.RWMutex ) // RegisterPlugin 注册插件(供插件在 init 函数中调用) func RegisterPlugin(name string, factory PluginFactory) { globalPluginMu.Lock() defer globalPluginMu.Unlock() if _, exists := globalPluginRegistry[name]; exists { panic("plugin already registered: " + name) } globalPluginRegistry[name] = factory } // GetRegisteredPlugins 获取所有已注册的插件名称 func GetRegisteredPlugins() []string { globalPluginMu.RLock() defer globalPluginMu.RUnlock() names := make([]string, 0, len(globalPluginRegistry)) for name := range globalPluginRegistry { names = append(names, name) } return names } // LoadPlugin 加载指定名称的插件 func LoadPlugin(name string, botManager *protocol.BotManager, logger *zap.Logger) (protocol.EventHandler, error) { globalPluginMu.RLock() factory, exists := globalPluginRegistry[name] globalPluginMu.RUnlock() if !exists { return nil, nil // 插件不存在,返回 nil 而不是错误 } return factory(botManager, logger), nil } // LoadAllPlugins 加载所有已注册的插件 func LoadAllPlugins(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler { globalPluginMu.RLock() defer globalPluginMu.RUnlock() plugins := make([]protocol.EventHandler, 0, len(globalPluginRegistry)) for name, factory := range globalPluginRegistry { plugin := factory(botManager, logger.Named("plugin."+name)) plugins = append(plugins, plugin) } return plugins } // PluginRegistry 插件注册表 type PluginRegistry struct { plugins []protocol.EventHandler dispatcher *Dispatcher logger *zap.Logger mu sync.RWMutex } // NewPluginRegistry 创建插件注册表 func NewPluginRegistry(dispatcher *Dispatcher, logger *zap.Logger) *PluginRegistry { return &PluginRegistry{ plugins: make([]protocol.EventHandler, 0), dispatcher: dispatcher, logger: logger.Named("plugin-registry"), } } // Register 注册插件 func (r *PluginRegistry) Register(plugin protocol.EventHandler) { r.mu.Lock() defer r.mu.Unlock() r.plugins = append(r.plugins, plugin) r.dispatcher.RegisterHandler(plugin) r.logger.Info("Plugin registered", zap.String("name", plugin.Name()), zap.String("description", plugin.Description())) } // GetPlugins 获取所有插件 func (r *PluginRegistry) GetPlugins() []protocol.EventHandler { r.mu.RLock() defer r.mu.RUnlock() return r.plugins } // PluginBuilder 插件构建器 type PluginBuilder struct { name string description string priority int matchFunc func(protocol.Event) bool handleFunc func(context.Context, protocol.Event) error } // NewPlugin 创建插件构建器 func NewPlugin(name string) *PluginBuilder { return &PluginBuilder{ name: name, priority: 100, // 默认优先级 } } // Description 设置插件描述 func (b *PluginBuilder) Description(desc string) *PluginBuilder { b.description = desc return b } // Priority 设置优先级 func (b *PluginBuilder) Priority(priority int) *PluginBuilder { b.priority = priority return b } // Match 设置匹配函数 func (b *PluginBuilder) Match(fn func(protocol.Event) bool) *PluginBuilder { b.matchFunc = fn return b } // Handle 设置处理函数 func (b *PluginBuilder) Handle(fn func(context.Context, protocol.Event) error) *PluginBuilder { b.handleFunc = fn return b } // Build 构建插件 func (b *PluginBuilder) Build() protocol.EventHandler { return &simplePlugin{ name: b.name, description: b.description, priority: b.priority, matchFunc: b.matchFunc, handleFunc: b.handleFunc, } } // simplePlugin 简单插件实现 type simplePlugin struct { name string description string priority int matchFunc func(protocol.Event) bool handleFunc func(context.Context, protocol.Event) error } func (p *simplePlugin) Name() string { return p.name } func (p *simplePlugin) Description() string { return p.description } func (p *simplePlugin) Priority() int { return p.priority } func (p *simplePlugin) Match(event protocol.Event) bool { if p.matchFunc == nil { return true } return p.matchFunc(event) } func (p *simplePlugin) Handle(ctx context.Context, event protocol.Event) error { if p.handleFunc == nil { return nil } return p.handleFunc(ctx, event) } // ============================================================================ // ZeroBot 风格的全局注册 API // ============================================================================ // globalHandlerRegistry 全局处理器注册表 var ( globalHandlerRegistry = make([]*HandlerBuilder, 0) globalHandlerMu sync.RWMutex ) // HandlerFunc 处理函数类型(支持依赖注入) type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error // HandlerMiddleware 处理器中间件函数类型 // 返回 true 表示通过中间件检查,false 表示不通过 type HandlerMiddleware func(event protocol.Event) bool // HandlerBuilder 处理器构建器(类似 ZeroBot 的 API) type HandlerBuilder struct { matchFunc func(protocol.Event) bool priority int handleFunc HandlerFunc middlewares []HandlerMiddleware } // OnPrivateMessage 匹配私聊消息 func OnPrivateMessage() *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private" }, priority: 100, } } // OnGroupMessage 匹配群消息 func OnGroupMessage() *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "group" }, priority: 100, } } // OnMessage 匹配所有消息 func OnMessage() *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { return event.GetType() == protocol.EventTypeMessage }, priority: 100, } } // OnNotice 匹配通知事件 // 用法: // // OnNotice() - 匹配所有通知事件 // OnNotice("group_increase") - 匹配群成员增加事件 // OnNotice("group_increase", "group_decrease") - 匹配群成员增加或减少事件 func OnNotice(detailTypes ...string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeNotice { return false } // 如果没有指定类型,匹配所有通知事件 if len(detailTypes) == 0 { return true } // 检查 detail_type 是否在指定列表中 eventDetailType := event.GetDetailType() for _, dt := range detailTypes { if dt == eventDetailType { return true } } return false }, priority: 100, } } // OnRequest 匹配请求事件 // 用法: // // OnRequest() - 匹配所有请求事件 // OnRequest("friend") - 匹配好友请求事件 // OnRequest("friend", "group") - 匹配好友或群请求事件 func OnRequest(detailTypes ...string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeRequest { return false } // 如果没有指定类型,匹配所有请求事件 if len(detailTypes) == 0 { return true } // 检查 detail_type 是否在指定列表中 eventDetailType := event.GetDetailType() for _, dt := range detailTypes { if dt == eventDetailType { return true } } return false }, priority: 100, } } // OnEvent 匹配指定类型的事件(可传一个或多个 EventType) // 用法: // // OnEvent() - 匹配所有事件 // OnEvent(protocol.EventTypeMessage) - 匹配消息事件 // OnEvent(protocol.EventTypeMessage, protocol.EventTypeNotice) - 匹配消息和通知事件 func OnEvent(eventTypes ...protocol.EventType) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if len(eventTypes) == 0 { return true // 不传参数时匹配所有事件 } // 检查事件类型是否在指定列表中 for _, et := range eventTypes { if event.GetType() == et { return true } } return false }, priority: 100, } } // On 自定义匹配器 func On(matchFunc func(protocol.Event) bool) *HandlerBuilder { return &HandlerBuilder{ matchFunc: matchFunc, priority: 100, } } // OnCommand 匹配命令 // 用法: // // OnCommand("/help") - 匹配 /help 命令(前缀为 /,命令为 help) // OnCommand("/", "help") - 匹配以 / 开头且命令为 help 的消息 // OnCommand("/", "help", "h") - 匹配以 / 开头且命令为 help 或 h 的消息 func OnCommand(prefix string, commands ...string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeMessage { return false } data := event.GetData() rawMessage, ok := data["raw_message"].(string) if !ok { return false } // 检查是否以前缀开头 if len(rawMessage) < len(prefix) || rawMessage[:len(prefix)] != prefix { return false } // 如果没有指定具体命令,匹配所有以该前缀开头的消息 if len(commands) == 0 { return true } // 提取命令部分(去除前缀和空格) cmdText := strings.TrimSpace(rawMessage[len(prefix):]) if cmdText == "" { return false } // 获取第一个单词作为命令 parts := strings.Fields(cmdText) if len(parts) == 0 { return false } cmd := parts[0] // 检查是否匹配指定的命令 for _, c := range commands { if cmd == c { return true } } return false }, priority: 50, // 命令通常优先级较高 } } // OnPrefix 匹配以指定前缀开头的消息 func OnPrefix(prefix string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeMessage { return false } data := event.GetData() rawMessage, ok := data["raw_message"].(string) if !ok { return false } return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix }, priority: 100, } } // OnSuffix 匹配以指定后缀结尾的消息 func OnSuffix(suffix string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeMessage { return false } data := event.GetData() rawMessage, ok := data["raw_message"].(string) if !ok { return false } return len(rawMessage) >= len(suffix) && rawMessage[len(rawMessage)-len(suffix):] == suffix }, priority: 100, } } // OnKeyword 匹配包含指定关键词的消息 func OnKeyword(keyword string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeMessage { return false } data := event.GetData() rawMessage, ok := data["raw_message"].(string) if !ok { return false } // 简单的字符串包含检查 return len(rawMessage) >= len(keyword) && (len(rawMessage) == len(keyword) || (rawMessage[:len(keyword)] == keyword || rawMessage[len(rawMessage)-len(keyword):] == keyword || contains(rawMessage, keyword))) }, priority: 100, } } // contains 检查字符串是否包含子串(简单实现) func contains(s, substr string) bool { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false } // OnFullMatch 完全匹配文本 func OnFullMatch(text string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { if event.GetType() != protocol.EventTypeMessage { return false } data := event.GetData() rawMessage, ok := data["raw_message"].(string) if !ok { return false } return rawMessage == text }, priority: 100, } } // OnDetailType 匹配指定 detail_type 的事件 func OnDetailType(detailType string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { return event.GetDetailType() == detailType }, priority: 100, } } // OnSubType 匹配指定 sub_type 的事件 func OnSubType(subType string) *HandlerBuilder { return &HandlerBuilder{ matchFunc: func(event protocol.Event) bool { return event.GetSubType() == subType }, priority: 100, } } // Priority 设置优先级 func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder { b.priority = priority return b } // Use 添加中间件(链式调用) func (b *HandlerBuilder) Use(middleware HandlerMiddleware) *HandlerBuilder { b.middlewares = append(b.middlewares, middleware) return b } // Handle 注册处理函数(在 init 中调用) func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) { globalHandlerMu.Lock() defer globalHandlerMu.Unlock() b.handleFunc = handleFunc globalHandlerRegistry = append(globalHandlerRegistry, b) } // applyMiddlewares 应用所有中间件 func (b *HandlerBuilder) applyMiddlewares(event protocol.Event) bool { for _, middleware := range b.middlewares { if !middleware(event) { return false } } return true } // generateHandlerName 生成处理器名称 var handlerCounter int64 func generateHandlerName() string { counter := atomic.AddInt64(&handlerCounter, 1) return fmt.Sprintf("handler_%d", counter) } // LoadAllHandlers 加载所有已注册的处理器(注入依赖) func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []protocol.EventHandler { globalHandlerMu.RLock() defer globalHandlerMu.RUnlock() handlers := make([]protocol.EventHandler, 0, len(globalHandlerRegistry)) for i, builder := range globalHandlerRegistry { if builder.handleFunc == nil { continue } // 生成唯一的插件名称 pluginName := fmt.Sprintf("handler_%d", i+1) // 创建包装的处理函数,注入依赖 handleFunc := func(ctx context.Context, event protocol.Event) error { return builder.handleFunc(ctx, event, botManager, logger) } // 创建包装的匹配函数,应用中间件 matchFunc := func(event protocol.Event) bool { // 先检查基础匹配 if builder.matchFunc != nil && !builder.matchFunc(event) { return false } // 再应用中间件 return builder.applyMiddlewares(event) } handler := &simplePlugin{ name: pluginName, description: "Handler registered via OnXXX().Handle()", priority: builder.priority, matchFunc: matchFunc, handleFunc: handleFunc, } handlers = append(handlers, handler) } return handlers } // ============================================================================ // 常用中间件(类似 NoneBot 风格) // ============================================================================ // OnlyToMe 只响应@机器人的消息(群聊中) func OnlyToMe() HandlerMiddleware { return func(event protocol.Event) bool { // 只对群消息生效 if event.GetType() != protocol.EventTypeMessage || event.GetDetailType() != "group" { return true // 非群消息不检查,让其他中间件处理 } data := event.GetData() selfID := event.GetSelfID() // 检查消息段中是否包含@机器人的消息 if segments, ok := data["message_segments"].([]interface{}); ok { for _, seg := range segments { if segMap, ok := seg.(map[string]interface{}); ok { segType, _ := segMap["type"].(string) if segType == "at" || segType == "mention" { segData, _ := segMap["data"].(map[string]interface{}) // 检查是否@了机器人 if userID, ok := segData["user_id"]; ok { if userIDStr := fmt.Sprintf("%v", userID); userIDStr == selfID { return true } } if qq, ok := segData["qq"]; ok { if qqStr := fmt.Sprintf("%v", qq); qqStr == selfID { return true } } } } } } // 检查 raw_message 中是否包含@机器人的信息(兼容性检查) if rawMessage, ok := data["raw_message"].(string); ok { // 简单的检查:消息是否以 @机器人 开头 // 注意:这里需要根据实际协议调整 if strings.Contains(rawMessage, fmt.Sprintf("[CQ:at,qq=%s]", selfID)) { return true } } return false } } // OnlyPrivate 只在私聊中响应 func OnlyPrivate() HandlerMiddleware { return func(event protocol.Event) bool { return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private" } } // OnlyGroup 只在群聊中响应(消息事件)或群相关事件(通知/请求事件) func OnlyGroup() HandlerMiddleware { return func(event protocol.Event) bool { // 消息事件:检查 detail_type if event.GetType() == protocol.EventTypeMessage { return event.GetDetailType() == "group" } // 通知/请求事件:检查是否有 group_id data := event.GetData() _, hasGroupID := data["group_id"] return hasGroupID } } // OnlySuperuser 只允许超级用户(需要从配置或数据中获取) func OnlySuperuser(superusers []string) HandlerMiddleware { return func(event protocol.Event) bool { data := event.GetData() userID, ok := data["user_id"] if !ok { return false } userIDStr := fmt.Sprintf("%v", userID) for _, su := range superusers { if su == userIDStr { return true } } return false } } // BlockPrivate 阻止私聊消息 func BlockPrivate() HandlerMiddleware { return func(event protocol.Event) bool { return !(event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private") } } // BlockGroup 阻止群聊消息 func BlockGroup() HandlerMiddleware { return func(event protocol.Event) bool { // 消息事件:检查 detail_type if event.GetType() == protocol.EventTypeMessage { return event.GetDetailType() != "group" } // 通知/请求事件:检查是否有 group_id data := event.GetData() _, hasGroupID := data["group_id"] return !hasGroupID } } // OnlyDetailType 只匹配指定的 detail_type func OnlyDetailType(detailType string) HandlerMiddleware { return func(event protocol.Event) bool { return event.GetDetailType() == detailType } } // OnlySubType 只匹配指定的 sub_type func OnlySubType(subType string) HandlerMiddleware { return func(event protocol.Event) bool { return event.GetSubType() == subType } }