Files
backend/internal/task/runner.go

169 lines
3.5 KiB
Go
Raw Normal View History

package task
import (
"context"
"math/rand"
"runtime/debug"
"sync"
"time"
"go.uber.org/zap"
)
// Task 定义可调度任务
type Task interface {
Name() string
Interval() time.Duration
Run(ctx context.Context) error
}
// Runner 简单的周期任务调度器
type Runner struct {
tasks []Task
logger *zap.Logger
wg sync.WaitGroup
startImmediately bool
jitterPercent float64
}
// NewRunner 创建任务调度器
func NewRunner(logger *zap.Logger, tasks ...Task) *Runner {
return NewRunnerWithOptions(logger, tasks)
}
// RunnerOption 运行器配置项
type RunnerOption func(r *Runner)
// WithStartImmediately 是否启动后立即执行一次(默认 true
func WithStartImmediately(start bool) RunnerOption {
return func(r *Runner) {
r.startImmediately = start
}
}
// WithJitter 为执行间隔增加 0~percent 之间的随机抖动percent=0 关闭默认0
// 可降低多个任务同时触发的概率
func WithJitter(percent float64) RunnerOption {
return func(r *Runner) {
if percent < 0 {
percent = 0
}
r.jitterPercent = percent
}
}
// NewRunnerWithOptions 支持可选配置的创建函数
func NewRunnerWithOptions(logger *zap.Logger, tasks []Task, opts ...RunnerOption) *Runner {
r := &Runner{
tasks: tasks,
logger: logger,
startImmediately: true,
jitterPercent: 0,
}
for _, opt := range opts {
opt(r)
}
return r
}
// Start 启动所有任务(异步)
func (r *Runner) Start(ctx context.Context) {
for _, t := range r.tasks {
task := t
r.wg.Add(1)
go func() {
defer r.wg.Done()
defer r.recoverPanic(task)
interval := r.normalizeInterval(task.Interval())
// 可选:立即执行一次
if r.startImmediately {
r.runOnce(ctx, task)
}
// 周期执行
for {
wait := r.applyJitter(interval)
if !r.wait(ctx, wait) {
return
}
// 每轮读取最新的 interval允许任务动态调整间隔
interval = r.normalizeInterval(task.Interval())
select {
case <-ctx.Done():
return
default:
r.runOnce(ctx, task)
}
}
}()
}
}
// Wait 等待所有任务退出
func (r *Runner) Wait() {
r.wg.Wait()
}
func (r *Runner) runOnce(ctx context.Context, task Task) {
if err := task.Run(ctx); err != nil && r.logger != nil {
r.logger.Warn("任务执行失败", zap.String("task", task.Name()), zap.Error(err))
}
}
// normalizeInterval 确保间隔为正值
func (r *Runner) normalizeInterval(d time.Duration) time.Duration {
if d <= 0 {
return time.Minute
}
return d
}
// applyJitter 在基础间隔上添加最多 jitterPercent 的随机抖动
func (r *Runner) applyJitter(base time.Duration) time.Duration {
if r.jitterPercent <= 0 {
return base
}
maxJitter := time.Duration(float64(base) * r.jitterPercent)
if maxJitter <= 0 {
return base
}
return base + time.Duration(rand.Int63n(int64(maxJitter)))
}
// wait 封装带 context 的 sleep
func (r *Runner) wait(ctx context.Context, d time.Duration) bool {
if d <= 0 {
select {
case <-ctx.Done():
return false
default:
return true
}
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
// recoverPanic 防止任务 panic 导致 goroutine 退出
func (r *Runner) recoverPanic(task Task) {
if rec := recover(); rec != nil && r.logger != nil {
r.logger.Error("任务发生panic",
zap.String("task", task.Name()),
zap.Any("panic", rec),
zap.ByteString("stack", debug.Stack()),
)
}
}