feat: enhance event handling and add scheduling capabilities
- Introduced a new scheduler to manage timed tasks within the event dispatcher. - Updated the dispatcher to support the new scheduler, allowing for improved event processing. - Enhanced action serialization in the OneBot11 adapter to convert message chains to the appropriate format. - Added new dependencies for cron scheduling and other indirect packages in go.mod and go.sum. - Improved logging for event publishing and handler matching, providing better insights during execution. - Refactored plugin loading to include scheduled job management.
This commit is contained in:
@@ -30,6 +30,7 @@ type Dispatcher struct {
|
||||
middlewares []protocol.Middleware
|
||||
logger *zap.Logger
|
||||
eventBus *EventBus
|
||||
scheduler *Scheduler
|
||||
metrics DispatcherMetrics
|
||||
mu sync.RWMutex
|
||||
workerPool chan struct{} // 工作池,限制并发数
|
||||
@@ -43,6 +44,13 @@ func NewDispatcher(eventBus *EventBus, logger *zap.Logger) *Dispatcher {
|
||||
return NewDispatcherWithConfig(eventBus, logger, 100, true)
|
||||
}
|
||||
|
||||
// NewDispatcherWithScheduler 创建带调度器的事件分发器
|
||||
func NewDispatcherWithScheduler(eventBus *EventBus, logger *zap.Logger, scheduler *Scheduler) *Dispatcher {
|
||||
dispatcher := NewDispatcherWithConfig(eventBus, logger, 100, true)
|
||||
dispatcher.scheduler = scheduler
|
||||
return dispatcher
|
||||
}
|
||||
|
||||
// NewDispatcherWithConfig 使用配置创建事件分发器
|
||||
func NewDispatcherWithConfig(eventBus *EventBus, logger *zap.Logger, maxWorkers int, async bool) *Dispatcher {
|
||||
if maxWorkers <= 0 {
|
||||
@@ -114,14 +122,37 @@ func (d *Dispatcher) Start(ctx context.Context) {
|
||||
go d.eventLoop(ctx, eventChan)
|
||||
}
|
||||
|
||||
// 启动调度器
|
||||
if d.scheduler != nil {
|
||||
if err := d.scheduler.Start(); err != nil {
|
||||
d.logger.Error("Failed to start scheduler", zap.Error(err))
|
||||
} else {
|
||||
d.logger.Info("Scheduler started")
|
||||
}
|
||||
}
|
||||
|
||||
d.logger.Info("Dispatcher started")
|
||||
}
|
||||
|
||||
// Stop 停止分发器
|
||||
func (d *Dispatcher) Stop() {
|
||||
// 停止调度器
|
||||
if d.scheduler != nil {
|
||||
if err := d.scheduler.Stop(); err != nil {
|
||||
d.logger.Error("Failed to stop scheduler", zap.Error(err))
|
||||
} else {
|
||||
d.logger.Info("Scheduler stopped")
|
||||
}
|
||||
}
|
||||
|
||||
d.logger.Info("Dispatcher stopped")
|
||||
}
|
||||
|
||||
// GetScheduler 获取调度器
|
||||
func (d *Dispatcher) GetScheduler() *Scheduler {
|
||||
return d.scheduler
|
||||
}
|
||||
|
||||
// eventLoop 事件循环
|
||||
func (d *Dispatcher) eventLoop(ctx context.Context, eventChan chan protocol.Event) {
|
||||
for {
|
||||
@@ -215,13 +246,16 @@ func (d *Dispatcher) createHandlerChain(ctx context.Context, event protocol.Even
|
||||
|
||||
for i, handler := range handlers {
|
||||
matched := handler.Match(event)
|
||||
d.logger.Info("Checking handler",
|
||||
d.logger.Debug("Checking handler",
|
||||
zap.Int("handler_index", i),
|
||||
zap.String("handler_name", handler.Name()),
|
||||
zap.Int("priority", handler.Priority()),
|
||||
zap.Bool("matched", matched))
|
||||
if matched {
|
||||
d.logger.Info("Handler matched, calling Handle",
|
||||
zap.Int("handler_index", i))
|
||||
zap.Int("handler_index", i),
|
||||
zap.String("handler_name", handler.Name()),
|
||||
zap.String("handler_description", handler.Description()))
|
||||
// 使用defer捕获单个handler的panic
|
||||
func() {
|
||||
defer func() {
|
||||
|
||||
@@ -78,9 +78,10 @@ func (eb *EventBus) Stop() {
|
||||
|
||||
// Publish 发布事件
|
||||
func (eb *EventBus) Publish(event protocol.Event) {
|
||||
eb.logger.Info("Publishing event to channel",
|
||||
eb.logger.Debug("Publishing event to channel",
|
||||
zap.String("event_type", string(event.GetType())),
|
||||
zap.String("detail_type", event.GetDetailType()),
|
||||
zap.String("self_id", event.GetSelfID()),
|
||||
zap.Int("channel_len", len(eb.eventChan)),
|
||||
zap.Int("channel_cap", cap(eb.eventChan)))
|
||||
|
||||
@@ -88,8 +89,10 @@ func (eb *EventBus) Publish(event protocol.Event) {
|
||||
case eb.eventChan <- event:
|
||||
atomic.AddInt64(&eb.metrics.PublishedTotal, 1)
|
||||
atomic.StoreInt64(&eb.metrics.LastEventTime, time.Now().Unix())
|
||||
eb.logger.Info("Event successfully queued",
|
||||
zap.String("event_type", string(event.GetType())))
|
||||
eb.logger.Info("Event published successfully",
|
||||
zap.String("event_type", string(event.GetType())),
|
||||
zap.String("detail_type", event.GetDetailType()),
|
||||
zap.String("self_id", event.GetSelfID()))
|
||||
case <-eb.ctx.Done():
|
||||
atomic.AddInt64(&eb.metrics.DroppedTotal, 1)
|
||||
eb.logger.Warn("Event bus is shutting down, event dropped",
|
||||
|
||||
@@ -3,6 +3,7 @@ package engine
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@@ -208,11 +209,16 @@ var (
|
||||
// HandlerFunc 处理函数类型(支持依赖注入)
|
||||
type HandlerFunc func(ctx context.Context, event protocol.Event, botManager *protocol.BotManager, logger *zap.Logger) error
|
||||
|
||||
// HandlerMiddleware 处理器中间件函数类型
|
||||
// 返回 true 表示通过中间件检查,false 表示不通过
|
||||
type HandlerMiddleware func(event protocol.Event) bool
|
||||
|
||||
// HandlerBuilder 处理器构建器(类似 ZeroBot 的 API)
|
||||
type HandlerBuilder struct {
|
||||
matchFunc func(protocol.Event) bool
|
||||
priority int
|
||||
handleFunc HandlerFunc
|
||||
matchFunc func(protocol.Event) bool
|
||||
priority int
|
||||
handleFunc HandlerFunc
|
||||
middlewares []HandlerMiddleware
|
||||
}
|
||||
|
||||
// OnPrivateMessage 匹配私聊消息
|
||||
@@ -246,20 +252,82 @@ func OnMessage() *HandlerBuilder {
|
||||
}
|
||||
|
||||
// OnNotice 匹配通知事件
|
||||
func OnNotice() *HandlerBuilder {
|
||||
// 用法:
|
||||
//
|
||||
// OnNotice() - 匹配所有通知事件
|
||||
// OnNotice("group_increase") - 匹配群成员增加事件
|
||||
// OnNotice("group_increase", "group_decrease") - 匹配群成员增加或减少事件
|
||||
func OnNotice(detailTypes ...string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeNotice
|
||||
if event.GetType() != protocol.EventTypeNotice {
|
||||
return false
|
||||
}
|
||||
// 如果没有指定类型,匹配所有通知事件
|
||||
if len(detailTypes) == 0 {
|
||||
return true
|
||||
}
|
||||
// 检查 detail_type 是否在指定列表中
|
||||
eventDetailType := event.GetDetailType()
|
||||
for _, dt := range detailTypes {
|
||||
if dt == eventDetailType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnRequest 匹配请求事件
|
||||
func OnRequest() *HandlerBuilder {
|
||||
// 用法:
|
||||
//
|
||||
// OnRequest() - 匹配所有请求事件
|
||||
// OnRequest("friend") - 匹配好友请求事件
|
||||
// OnRequest("friend", "group") - 匹配好友或群请求事件
|
||||
func OnRequest(detailTypes ...string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeRequest
|
||||
if event.GetType() != protocol.EventTypeRequest {
|
||||
return false
|
||||
}
|
||||
// 如果没有指定类型,匹配所有请求事件
|
||||
if len(detailTypes) == 0 {
|
||||
return true
|
||||
}
|
||||
// 检查 detail_type 是否在指定列表中
|
||||
eventDetailType := event.GetDetailType()
|
||||
for _, dt := range detailTypes {
|
||||
if dt == eventDetailType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnEvent 匹配指定类型的事件(可传一个或多个 EventType)
|
||||
// 用法:
|
||||
//
|
||||
// OnEvent() - 匹配所有事件
|
||||
// OnEvent(protocol.EventTypeMessage) - 匹配消息事件
|
||||
// OnEvent(protocol.EventTypeMessage, protocol.EventTypeNotice) - 匹配消息和通知事件
|
||||
func OnEvent(eventTypes ...protocol.EventType) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if len(eventTypes) == 0 {
|
||||
return true // 不传参数时匹配所有事件
|
||||
}
|
||||
// 检查事件类型是否在指定列表中
|
||||
for _, et := range eventTypes {
|
||||
if event.GetType() == et {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
@@ -273,8 +341,13 @@ func On(matchFunc func(protocol.Event) bool) *HandlerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
// OnCommand 匹配命令(以指定前缀开头的消息)
|
||||
func OnCommand(prefix string) *HandlerBuilder {
|
||||
// OnCommand 匹配命令
|
||||
// 用法:
|
||||
//
|
||||
// OnCommand("/help") - 匹配 /help 命令(前缀为 /,命令为 help)
|
||||
// OnCommand("/", "help") - 匹配以 / 开头且命令为 help 的消息
|
||||
// OnCommand("/", "help", "h") - 匹配以 / 开头且命令为 help 或 h 的消息
|
||||
func OnCommand(prefix string, commands ...string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
@@ -285,9 +358,35 @@ func OnCommand(prefix string) *HandlerBuilder {
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// 检查是否以命令前缀开头
|
||||
if len(rawMessage) > 0 && len(prefix) > 0 {
|
||||
return len(rawMessage) >= len(prefix) && rawMessage[:len(prefix)] == prefix
|
||||
|
||||
// 检查是否以前缀开头
|
||||
if len(rawMessage) < len(prefix) || rawMessage[:len(prefix)] != prefix {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果没有指定具体命令,匹配所有以该前缀开头的消息
|
||||
if len(commands) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 提取命令部分(去除前缀和空格)
|
||||
cmdText := strings.TrimSpace(rawMessage[len(prefix):])
|
||||
if cmdText == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取第一个单词作为命令
|
||||
parts := strings.Fields(cmdText)
|
||||
if len(parts) == 0 {
|
||||
return false
|
||||
}
|
||||
cmd := parts[0]
|
||||
|
||||
// 检查是否匹配指定的命令
|
||||
for _, c := range commands {
|
||||
if cmd == c {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
@@ -364,12 +463,56 @@ func contains(s, substr string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OnFullMatch 完全匹配文本
|
||||
func OnFullMatch(text string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
if event.GetType() != protocol.EventTypeMessage {
|
||||
return false
|
||||
}
|
||||
data := event.GetData()
|
||||
rawMessage, ok := data["raw_message"].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return rawMessage == text
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnDetailType 匹配指定 detail_type 的事件
|
||||
func OnDetailType(detailType string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetDetailType() == detailType
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// OnSubType 匹配指定 sub_type 的事件
|
||||
func OnSubType(subType string) *HandlerBuilder {
|
||||
return &HandlerBuilder{
|
||||
matchFunc: func(event protocol.Event) bool {
|
||||
return event.GetSubType() == subType
|
||||
},
|
||||
priority: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 设置优先级
|
||||
func (b *HandlerBuilder) Priority(priority int) *HandlerBuilder {
|
||||
b.priority = priority
|
||||
return b
|
||||
}
|
||||
|
||||
// Use 添加中间件(链式调用)
|
||||
func (b *HandlerBuilder) Use(middleware HandlerMiddleware) *HandlerBuilder {
|
||||
b.middlewares = append(b.middlewares, middleware)
|
||||
return b
|
||||
}
|
||||
|
||||
// Handle 注册处理函数(在 init 中调用)
|
||||
func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
|
||||
globalHandlerMu.Lock()
|
||||
@@ -379,6 +522,16 @@ func (b *HandlerBuilder) Handle(handleFunc HandlerFunc) {
|
||||
globalHandlerRegistry = append(globalHandlerRegistry, b)
|
||||
}
|
||||
|
||||
// applyMiddlewares 应用所有中间件
|
||||
func (b *HandlerBuilder) applyMiddlewares(event protocol.Event) bool {
|
||||
for _, middleware := range b.middlewares {
|
||||
if !middleware(event) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// generateHandlerName 生成处理器名称
|
||||
var handlerCounter int64
|
||||
|
||||
@@ -406,11 +559,21 @@ func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []prot
|
||||
return builder.handleFunc(ctx, event, botManager, logger)
|
||||
}
|
||||
|
||||
// 创建包装的匹配函数,应用中间件
|
||||
matchFunc := func(event protocol.Event) bool {
|
||||
// 先检查基础匹配
|
||||
if builder.matchFunc != nil && !builder.matchFunc(event) {
|
||||
return false
|
||||
}
|
||||
// 再应用中间件
|
||||
return builder.applyMiddlewares(event)
|
||||
}
|
||||
|
||||
handler := &simplePlugin{
|
||||
name: pluginName,
|
||||
description: "Handler registered via OnXXX().Handle()",
|
||||
priority: builder.priority,
|
||||
matchFunc: builder.matchFunc,
|
||||
matchFunc: matchFunc,
|
||||
handleFunc: handleFunc,
|
||||
}
|
||||
|
||||
@@ -419,3 +582,129 @@ func LoadAllHandlers(botManager *protocol.BotManager, logger *zap.Logger) []prot
|
||||
|
||||
return handlers
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 常用中间件(类似 NoneBot 风格)
|
||||
// ============================================================================
|
||||
|
||||
// OnlyToMe 只响应@机器人的消息(群聊中)
|
||||
func OnlyToMe() HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
// 只对群消息生效
|
||||
if event.GetType() != protocol.EventTypeMessage || event.GetDetailType() != "group" {
|
||||
return true // 非群消息不检查,让其他中间件处理
|
||||
}
|
||||
|
||||
data := event.GetData()
|
||||
selfID := event.GetSelfID()
|
||||
|
||||
// 检查消息段中是否包含@机器人的消息
|
||||
if segments, ok := data["message_segments"].([]interface{}); ok {
|
||||
for _, seg := range segments {
|
||||
if segMap, ok := seg.(map[string]interface{}); ok {
|
||||
segType, _ := segMap["type"].(string)
|
||||
if segType == "at" || segType == "mention" {
|
||||
segData, _ := segMap["data"].(map[string]interface{})
|
||||
// 检查是否@了机器人
|
||||
if userID, ok := segData["user_id"]; ok {
|
||||
if userIDStr := fmt.Sprintf("%v", userID); userIDStr == selfID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if qq, ok := segData["qq"]; ok {
|
||||
if qqStr := fmt.Sprintf("%v", qq); qqStr == selfID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 raw_message 中是否包含@机器人的信息(兼容性检查)
|
||||
if rawMessage, ok := data["raw_message"].(string); ok {
|
||||
// 简单的检查:消息是否以 @机器人 开头
|
||||
// 注意:这里需要根据实际协议调整
|
||||
if strings.Contains(rawMessage, fmt.Sprintf("[CQ:at,qq=%s]", selfID)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyPrivate 只在私聊中响应
|
||||
func OnlyPrivate() HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
return event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private"
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyGroup 只在群聊中响应(消息事件)或群相关事件(通知/请求事件)
|
||||
func OnlyGroup() HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
// 消息事件:检查 detail_type
|
||||
if event.GetType() == protocol.EventTypeMessage {
|
||||
return event.GetDetailType() == "group"
|
||||
}
|
||||
// 通知/请求事件:检查是否有 group_id
|
||||
data := event.GetData()
|
||||
_, hasGroupID := data["group_id"]
|
||||
return hasGroupID
|
||||
}
|
||||
}
|
||||
|
||||
// OnlySuperuser 只允许超级用户(需要从配置或数据中获取)
|
||||
func OnlySuperuser(superusers []string) HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
data := event.GetData()
|
||||
userID, ok := data["user_id"]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
userIDStr := fmt.Sprintf("%v", userID)
|
||||
for _, su := range superusers {
|
||||
if su == userIDStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// BlockPrivate 阻止私聊消息
|
||||
func BlockPrivate() HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
return !(event.GetType() == protocol.EventTypeMessage && event.GetDetailType() == "private")
|
||||
}
|
||||
}
|
||||
|
||||
// BlockGroup 阻止群聊消息
|
||||
func BlockGroup() HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
// 消息事件:检查 detail_type
|
||||
if event.GetType() == protocol.EventTypeMessage {
|
||||
return event.GetDetailType() != "group"
|
||||
}
|
||||
// 通知/请求事件:检查是否有 group_id
|
||||
data := event.GetData()
|
||||
_, hasGroupID := data["group_id"]
|
||||
return !hasGroupID
|
||||
}
|
||||
}
|
||||
|
||||
// OnlyDetailType 只匹配指定的 detail_type
|
||||
func OnlyDetailType(detailType string) HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
return event.GetDetailType() == detailType
|
||||
}
|
||||
}
|
||||
|
||||
// OnlySubType 只匹配指定的 sub_type
|
||||
func OnlySubType(subType string) HandlerMiddleware {
|
||||
return func(event protocol.Event) bool {
|
||||
return event.GetSubType() == subType
|
||||
}
|
||||
}
|
||||
|
||||
631
internal/engine/scheduler.go
Normal file
631
internal/engine/scheduler.go
Normal file
@@ -0,0 +1,631 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Job 定时任务接口
|
||||
type Job interface {
|
||||
// ID 返回任务唯一标识
|
||||
ID() string
|
||||
// Start 启动任务
|
||||
Start(ctx context.Context) error
|
||||
// Stop 停止任务
|
||||
Stop() error
|
||||
// IsRunning 检查任务是否正在运行
|
||||
IsRunning() bool
|
||||
// NextRun 返回下次执行时间
|
||||
NextRun() time.Time
|
||||
}
|
||||
|
||||
// JobFunc 任务执行函数类型
|
||||
type JobFunc func(ctx context.Context) error
|
||||
|
||||
// Scheduler 定时任务调度器
|
||||
type Scheduler struct {
|
||||
jobs map[string]Job
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
running int32
|
||||
}
|
||||
|
||||
// NewScheduler 创建新的调度器
|
||||
func NewScheduler(logger *zap.Logger) *Scheduler {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Scheduler{
|
||||
jobs: make(map[string]Job),
|
||||
logger: logger.Named("scheduler"),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动调度器
|
||||
func (s *Scheduler) Start() error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return fmt.Errorf("scheduler is already running")
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 启动所有任务
|
||||
for id, job := range s.jobs {
|
||||
if err := job.Start(s.ctx); err != nil {
|
||||
s.logger.Error("Failed to start job",
|
||||
zap.String("job_id", id),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
s.logger.Info("Job started", zap.String("job_id", id))
|
||||
}
|
||||
|
||||
s.logger.Info("Scheduler started", zap.Int("job_count", len(s.jobs)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止调度器
|
||||
func (s *Scheduler) Stop() error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 1, 0) {
|
||||
return fmt.Errorf("scheduler is not running")
|
||||
}
|
||||
|
||||
s.cancel()
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// 停止所有任务
|
||||
for id, job := range s.jobs {
|
||||
if err := job.Stop(); err != nil {
|
||||
s.logger.Error("Failed to stop job",
|
||||
zap.String("job_id", id),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
s.logger.Info("Job stopped", zap.String("job_id", id))
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
s.logger.Info("Scheduler stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJob 添加任务
|
||||
func (s *Scheduler) AddJob(job Job) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
id := job.ID()
|
||||
if _, exists := s.jobs[id]; exists {
|
||||
return fmt.Errorf("job with id %s already exists", id)
|
||||
}
|
||||
|
||||
s.jobs[id] = job
|
||||
|
||||
// 如果调度器正在运行,立即启动任务
|
||||
if atomic.LoadInt32(&s.running) == 1 {
|
||||
if err := job.Start(s.ctx); err != nil {
|
||||
delete(s.jobs, id)
|
||||
return fmt.Errorf("failed to start job: %w", err)
|
||||
}
|
||||
s.logger.Info("Job added and started", zap.String("job_id", id))
|
||||
} else {
|
||||
s.logger.Info("Job added", zap.String("job_id", id))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveJob 移除任务
|
||||
func (s *Scheduler) RemoveJob(id string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
job, exists := s.jobs[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("job with id %s not found", id)
|
||||
}
|
||||
|
||||
if err := job.Stop(); err != nil {
|
||||
s.logger.Error("Failed to stop job during removal",
|
||||
zap.String("job_id", id),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
delete(s.jobs, id)
|
||||
s.logger.Info("Job removed", zap.String("job_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetJob 获取任务
|
||||
func (s *Scheduler) GetJob(id string) (Job, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
job, exists := s.jobs[id]
|
||||
return job, exists
|
||||
}
|
||||
|
||||
// GetAllJobs 获取所有任务
|
||||
func (s *Scheduler) GetAllJobs() map[string]Job {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result := make(map[string]Job, len(s.jobs))
|
||||
for id, job := range s.jobs {
|
||||
result[id] = job
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// IsRunning 检查调度器是否正在运行
|
||||
func (s *Scheduler) IsRunning() bool {
|
||||
return atomic.LoadInt32(&s.running) == 1
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Job 实现
|
||||
// ============================================================================
|
||||
|
||||
// CronJob 基于 Cron 表达式的任务
|
||||
type CronJob struct {
|
||||
id string
|
||||
spec string
|
||||
handler JobFunc
|
||||
cron *cron.Cron
|
||||
logger *zap.Logger
|
||||
running int32
|
||||
nextRun time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronJob 创建 Cron 任务
|
||||
func NewCronJob(id, spec string, handler JobFunc, logger *zap.Logger) (*CronJob, error) {
|
||||
parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
|
||||
c := cron.New(cron.WithParser(parser), cron.WithChain(cron.Recover(cron.DefaultLogger)))
|
||||
|
||||
job := &CronJob{
|
||||
id: id,
|
||||
spec: spec,
|
||||
handler: handler,
|
||||
cron: c,
|
||||
logger: logger.Named("cron-job").With(zap.String("job_id", id)),
|
||||
}
|
||||
|
||||
// 添加任务到 cron
|
||||
_, err := c.AddFunc(spec, func() {
|
||||
ctx := context.Background()
|
||||
if err := handler(ctx); err != nil {
|
||||
job.logger.Error("Cron job execution failed", zap.Error(err))
|
||||
}
|
||||
// 更新下次执行时间
|
||||
entries := c.Entries()
|
||||
job.mu.Lock()
|
||||
if len(entries) > 0 {
|
||||
// 找到最近的执行时间
|
||||
next := entries[0].Next
|
||||
for _, entry := range entries {
|
||||
if entry.Next.Before(next) {
|
||||
next = entry.Next
|
||||
}
|
||||
}
|
||||
job.nextRun = next
|
||||
}
|
||||
job.mu.Unlock()
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid cron spec: %w", err)
|
||||
}
|
||||
|
||||
// 计算初始下次执行时间(需要先启动 cron 才能计算)
|
||||
// 这里先设置为零值,在 Start 时再计算
|
||||
job.mu.Lock()
|
||||
job.nextRun = time.Time{}
|
||||
job.mu.Unlock()
|
||||
|
||||
return job, nil
|
||||
}
|
||||
|
||||
func (j *CronJob) ID() string {
|
||||
return j.id
|
||||
}
|
||||
|
||||
func (j *CronJob) Start(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
|
||||
return fmt.Errorf("job is already running")
|
||||
}
|
||||
|
||||
j.cron.Start()
|
||||
j.logger.Info("Cron job started", zap.String("spec", j.spec))
|
||||
|
||||
// 更新下次执行时间
|
||||
entries := j.cron.Entries()
|
||||
if len(entries) > 0 {
|
||||
j.mu.Lock()
|
||||
// 找到最近的执行时间
|
||||
next := entries[0].Next
|
||||
for _, entry := range entries {
|
||||
if entry.Next.Before(next) {
|
||||
next = entry.Next
|
||||
}
|
||||
}
|
||||
j.nextRun = next
|
||||
j.mu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *CronJob) Stop() error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
|
||||
return fmt.Errorf("job is not running")
|
||||
}
|
||||
|
||||
ctx := j.cron.Stop()
|
||||
<-ctx.Done()
|
||||
j.logger.Info("Cron job stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *CronJob) IsRunning() bool {
|
||||
return atomic.LoadInt32(&j.running) == 1
|
||||
}
|
||||
|
||||
func (j *CronJob) NextRun() time.Time {
|
||||
j.mu.RLock()
|
||||
defer j.mu.RUnlock()
|
||||
return j.nextRun
|
||||
}
|
||||
|
||||
// IntervalJob 固定间隔的任务
|
||||
type IntervalJob struct {
|
||||
id string
|
||||
interval time.Duration
|
||||
handler JobFunc
|
||||
logger *zap.Logger
|
||||
running int32
|
||||
nextRun time.Time
|
||||
mu sync.RWMutex
|
||||
ticker *time.Ticker
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewIntervalJob 创建固定间隔任务
|
||||
func NewIntervalJob(id string, interval time.Duration, handler JobFunc, logger *zap.Logger) *IntervalJob {
|
||||
return &IntervalJob{
|
||||
id: id,
|
||||
interval: interval,
|
||||
handler: handler,
|
||||
logger: logger.Named("interval-job").With(zap.String("job_id", id)),
|
||||
}
|
||||
}
|
||||
|
||||
func (j *IntervalJob) ID() string {
|
||||
return j.id
|
||||
}
|
||||
|
||||
func (j *IntervalJob) Start(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
|
||||
return fmt.Errorf("job is already running")
|
||||
}
|
||||
|
||||
j.ctx, j.cancel = context.WithCancel(ctx)
|
||||
j.ticker = time.NewTicker(j.interval)
|
||||
|
||||
j.mu.Lock()
|
||||
j.nextRun = time.Now().Add(j.interval)
|
||||
j.mu.Unlock()
|
||||
|
||||
j.wg.Add(1)
|
||||
go j.run()
|
||||
|
||||
j.logger.Info("Interval job started", zap.Duration("interval", j.interval))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *IntervalJob) run() {
|
||||
defer j.wg.Done()
|
||||
|
||||
// 立即执行一次(可选,根据需求调整)
|
||||
// if err := j.handler(j.ctx); err != nil {
|
||||
// j.logger.Error("Interval job execution failed", zap.Error(err))
|
||||
// }
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-j.ticker.C:
|
||||
j.mu.Lock()
|
||||
j.nextRun = time.Now().Add(j.interval)
|
||||
j.mu.Unlock()
|
||||
|
||||
if err := j.handler(j.ctx); err != nil {
|
||||
j.logger.Error("Interval job execution failed", zap.Error(err))
|
||||
}
|
||||
case <-j.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (j *IntervalJob) Stop() error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
|
||||
return fmt.Errorf("job is not running")
|
||||
}
|
||||
|
||||
if j.cancel != nil {
|
||||
j.cancel()
|
||||
}
|
||||
if j.ticker != nil {
|
||||
j.ticker.Stop()
|
||||
}
|
||||
j.wg.Wait()
|
||||
|
||||
j.logger.Info("Interval job stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *IntervalJob) IsRunning() bool {
|
||||
return atomic.LoadInt32(&j.running) == 1
|
||||
}
|
||||
|
||||
func (j *IntervalJob) NextRun() time.Time {
|
||||
j.mu.RLock()
|
||||
defer j.mu.RUnlock()
|
||||
return j.nextRun
|
||||
}
|
||||
|
||||
// OnceJob 单次延迟执行的任务
|
||||
type OnceJob struct {
|
||||
id string
|
||||
delay time.Duration
|
||||
handler JobFunc
|
||||
logger *zap.Logger
|
||||
running int32
|
||||
nextRun time.Time
|
||||
mu sync.RWMutex
|
||||
timer *time.Timer
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewOnceJob 创建单次延迟执行任务
|
||||
func NewOnceJob(id string, delay time.Duration, handler JobFunc, logger *zap.Logger) *OnceJob {
|
||||
return &OnceJob{
|
||||
id: id,
|
||||
delay: delay,
|
||||
handler: handler,
|
||||
logger: logger.Named("once-job").With(zap.String("job_id", id)),
|
||||
}
|
||||
}
|
||||
|
||||
func (j *OnceJob) ID() string {
|
||||
return j.id
|
||||
}
|
||||
|
||||
func (j *OnceJob) Start(ctx context.Context) error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 0, 1) {
|
||||
return fmt.Errorf("job is already running")
|
||||
}
|
||||
|
||||
j.ctx, j.cancel = context.WithCancel(ctx)
|
||||
j.timer = time.NewTimer(j.delay)
|
||||
|
||||
j.mu.Lock()
|
||||
j.nextRun = time.Now().Add(j.delay)
|
||||
j.mu.Unlock()
|
||||
|
||||
j.wg.Add(1)
|
||||
go j.run()
|
||||
|
||||
j.logger.Info("Once job started", zap.Duration("delay", j.delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *OnceJob) run() {
|
||||
defer j.wg.Done()
|
||||
|
||||
select {
|
||||
case <-j.timer.C:
|
||||
if err := j.handler(j.ctx); err != nil {
|
||||
j.logger.Error("Once job execution failed", zap.Error(err))
|
||||
}
|
||||
atomic.StoreInt32(&j.running, 0)
|
||||
case <-j.ctx.Done():
|
||||
if !j.timer.Stop() {
|
||||
<-j.timer.C
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (j *OnceJob) Stop() error {
|
||||
if !atomic.CompareAndSwapInt32(&j.running, 1, 0) {
|
||||
return fmt.Errorf("job is not running")
|
||||
}
|
||||
|
||||
if j.cancel != nil {
|
||||
j.cancel()
|
||||
}
|
||||
if j.timer != nil {
|
||||
if !j.timer.Stop() {
|
||||
<-j.timer.C
|
||||
}
|
||||
}
|
||||
j.wg.Wait()
|
||||
|
||||
j.logger.Info("Once job stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j *OnceJob) IsRunning() bool {
|
||||
return atomic.LoadInt32(&j.running) == 1
|
||||
}
|
||||
|
||||
func (j *OnceJob) NextRun() time.Time {
|
||||
j.mu.RLock()
|
||||
defer j.mu.RUnlock()
|
||||
return j.nextRun
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 全局调度器 API(链式风格,延迟注册)
|
||||
// ============================================================================
|
||||
|
||||
var (
|
||||
globalJobRegistry = make([]JobBuilder, 0)
|
||||
globalJobMu sync.RWMutex
|
||||
jobCounter int64
|
||||
)
|
||||
|
||||
// JobBuilder 任务构建器接口(延迟注册)
|
||||
type JobBuilder interface {
|
||||
// Build 构建任务实例(由依赖注入系统调用)
|
||||
Build(logger *zap.Logger) (Job, error)
|
||||
}
|
||||
|
||||
// generateJobID 生成任务 ID
|
||||
func generateJobID(prefix string) string {
|
||||
counter := atomic.AddInt64(&jobCounter, 1)
|
||||
return fmt.Sprintf("%s_%d", prefix, counter)
|
||||
}
|
||||
|
||||
// CronJobBuilder Cron 任务构建器
|
||||
type CronJobBuilder struct {
|
||||
id string
|
||||
spec string
|
||||
handler JobFunc
|
||||
}
|
||||
|
||||
// Cron 创建 Cron 任务构建器(在 init 函数中调用)
|
||||
func Cron(spec string) *CronJobBuilder {
|
||||
return &CronJobBuilder{
|
||||
id: generateJobID("cron"),
|
||||
spec: spec,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 设置处理函数并注册到全局注册表(延迟注册)
|
||||
func (b *CronJobBuilder) Handle(handler JobFunc) {
|
||||
b.handler = handler
|
||||
if b.handler == nil {
|
||||
panic("scheduler: handler cannot be nil")
|
||||
}
|
||||
|
||||
globalJobMu.Lock()
|
||||
defer globalJobMu.Unlock()
|
||||
globalJobRegistry = append(globalJobRegistry, b)
|
||||
}
|
||||
|
||||
// Build 构建 Cron 任务
|
||||
func (b *CronJobBuilder) Build(logger *zap.Logger) (Job, error) {
|
||||
return NewCronJob(b.id, b.spec, b.handler, logger)
|
||||
}
|
||||
|
||||
// IntervalJobBuilder 固定间隔任务构建器
|
||||
type IntervalJobBuilder struct {
|
||||
id string
|
||||
interval time.Duration
|
||||
handler JobFunc
|
||||
}
|
||||
|
||||
// Interval 创建固定间隔任务构建器(在 init 函数中调用)
|
||||
func Interval(interval time.Duration) *IntervalJobBuilder {
|
||||
return &IntervalJobBuilder{
|
||||
id: generateJobID("interval"),
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 设置处理函数并注册到全局注册表(延迟注册)
|
||||
func (b *IntervalJobBuilder) Handle(handler JobFunc) {
|
||||
b.handler = handler
|
||||
if b.handler == nil {
|
||||
panic("scheduler: handler cannot be nil")
|
||||
}
|
||||
|
||||
globalJobMu.Lock()
|
||||
defer globalJobMu.Unlock()
|
||||
globalJobRegistry = append(globalJobRegistry, b)
|
||||
}
|
||||
|
||||
// Build 构建固定间隔任务
|
||||
func (b *IntervalJobBuilder) Build(logger *zap.Logger) (Job, error) {
|
||||
return NewIntervalJob(b.id, b.interval, b.handler, logger), nil
|
||||
}
|
||||
|
||||
// OnceJobBuilder 单次延迟任务构建器
|
||||
type OnceJobBuilder struct {
|
||||
id string
|
||||
delay time.Duration
|
||||
handler JobFunc
|
||||
}
|
||||
|
||||
// Once 创建单次延迟任务构建器(在 init 函数中调用)
|
||||
func Once(delay time.Duration) *OnceJobBuilder {
|
||||
return &OnceJobBuilder{
|
||||
id: generateJobID("once"),
|
||||
delay: delay,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 设置处理函数并注册到全局注册表(延迟注册)
|
||||
func (b *OnceJobBuilder) Handle(handler JobFunc) {
|
||||
b.handler = handler
|
||||
if b.handler == nil {
|
||||
panic("scheduler: handler cannot be nil")
|
||||
}
|
||||
|
||||
globalJobMu.Lock()
|
||||
defer globalJobMu.Unlock()
|
||||
globalJobRegistry = append(globalJobRegistry, b)
|
||||
}
|
||||
|
||||
// Build 构建单次延迟任务
|
||||
func (b *OnceJobBuilder) Build(logger *zap.Logger) (Job, error) {
|
||||
return NewOnceJob(b.id, b.delay, b.handler, logger), nil
|
||||
}
|
||||
|
||||
// LoadAllJobs 加载所有已注册的任务(由依赖注入系统调用)
|
||||
func LoadAllJobs(scheduler *Scheduler, logger *zap.Logger) error {
|
||||
globalJobMu.RLock()
|
||||
defer globalJobMu.RUnlock()
|
||||
|
||||
for i, builder := range globalJobRegistry {
|
||||
job, err := builder.Build(logger)
|
||||
if err != nil {
|
||||
logger.Error("Failed to build job",
|
||||
zap.Int("index", i),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
if err := scheduler.AddJob(job); err != nil {
|
||||
logger.Error("Failed to add job to scheduler",
|
||||
zap.String("job_id", job.ID()),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debug("Job loaded",
|
||||
zap.String("job_id", job.ID()))
|
||||
}
|
||||
|
||||
logger.Info("All scheduled jobs loaded",
|
||||
zap.Int("job_count", len(globalJobRegistry)))
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user