package engine import ( "context" "sync" "time" "cellbot/internal/protocol" "go.uber.org/zap" "golang.org/x/time/rate" ) // LoggingMiddleware 日志中间件 type LoggingMiddleware struct { logger *zap.Logger } // NewLoggingMiddleware 创建日志中间件 func NewLoggingMiddleware(logger *zap.Logger) *LoggingMiddleware { return &LoggingMiddleware{ logger: logger.Named("middleware.logging"), } } // Process 处理事件 func (m *LoggingMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { start := time.Now() m.logger.Info("Event received", zap.String("type", string(event.GetType())), zap.String("detail_type", event.GetDetailType()), zap.String("self_id", event.GetSelfID())) err := next(ctx, event) m.logger.Info("Event processed", zap.String("type", string(event.GetType())), zap.Duration("duration", time.Since(start)), zap.Error(err)) return err } // RateLimitMiddleware 限流中间件 type RateLimitMiddleware struct { limiters map[string]*rate.Limiter mu sync.RWMutex logger *zap.Logger rps int // 每秒请求数 burst int // 突发容量 } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(logger *zap.Logger, rps, burst int) *RateLimitMiddleware { if rps <= 0 { rps = 100 } if burst <= 0 { burst = rps * 2 } return &RateLimitMiddleware{ limiters: make(map[string]*rate.Limiter), logger: logger.Named("middleware.ratelimit"), rps: rps, burst: burst, } } // Process 处理事件 func (m *RateLimitMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { // 根据事件类型获取限流器 key := string(event.GetType()) m.mu.RLock() limiter, exists := m.limiters[key] m.mu.RUnlock() if !exists { m.mu.Lock() limiter = rate.NewLimiter(rate.Limit(m.rps), m.burst) m.limiters[key] = limiter m.mu.Unlock() } // 等待令牌 if err := limiter.Wait(ctx); err != nil { m.logger.Warn("Rate limit exceeded", zap.String("event_type", key), zap.Error(err)) return err } return next(ctx, event) } // RetryMiddleware 重试中间件 type RetryMiddleware struct { logger *zap.Logger maxRetries int delay time.Duration } // NewRetryMiddleware 创建重试中间件 func NewRetryMiddleware(logger *zap.Logger, maxRetries int, delay time.Duration) *RetryMiddleware { if maxRetries <= 0 { maxRetries = 3 } if delay <= 0 { delay = time.Second } return &RetryMiddleware{ logger: logger.Named("middleware.retry"), maxRetries: maxRetries, delay: delay, } } // Process 处理事件 func (m *RetryMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { var err error for i := 0; i <= m.maxRetries; i++ { if i > 0 { m.logger.Info("Retrying event", zap.String("event_type", string(event.GetType())), zap.Int("attempt", i), zap.Int("max_retries", m.maxRetries)) // 指数退避 backoff := m.delay * time.Duration(1< 0 { m.logger.Info("Event succeeded after retry", zap.String("event_type", string(event.GetType())), zap.Int("attempts", i+1)) } return nil } m.logger.Warn("Event processing failed", zap.String("event_type", string(event.GetType())), zap.Int("attempt", i+1), zap.Error(err)) } m.logger.Error("Event failed after all retries", zap.String("event_type", string(event.GetType())), zap.Int("total_attempts", m.maxRetries+1), zap.Error(err)) return err } // TimeoutMiddleware 超时中间件 type TimeoutMiddleware struct { logger *zap.Logger timeout time.Duration } // NewTimeoutMiddleware 创建超时中间件 func NewTimeoutMiddleware(logger *zap.Logger, timeout time.Duration) *TimeoutMiddleware { if timeout <= 0 { timeout = 30 * time.Second } return &TimeoutMiddleware{ logger: logger.Named("middleware.timeout"), timeout: timeout, } } // Process 处理事件 func (m *TimeoutMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { ctx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() done := make(chan error, 1) go func() { done <- next(ctx, event) }() select { case err := <-done: return err case <-ctx.Done(): m.logger.Warn("Event processing timeout", zap.String("event_type", string(event.GetType())), zap.Duration("timeout", m.timeout)) return ctx.Err() } } // RecoveryMiddleware 恢复中间件(捕获panic) type RecoveryMiddleware struct { logger *zap.Logger } // NewRecoveryMiddleware 创建恢复中间件 func NewRecoveryMiddleware(logger *zap.Logger) *RecoveryMiddleware { return &RecoveryMiddleware{ logger: logger.Named("middleware.recovery"), } } // Process 处理事件 func (m *RecoveryMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) (err error) { defer func() { if r := recover(); r != nil { m.logger.Error("Recovered from panic", zap.Any("panic", r), zap.String("event_type", string(event.GetType()))) err = protocol.ErrNotImplemented // 或者自定义错误 } }() return next(ctx, event) } // MetricsMiddleware 指标中间件 type MetricsMiddleware struct { logger *zap.Logger eventCounts map[string]int64 eventTimes map[string]time.Duration mu sync.RWMutex } // NewMetricsMiddleware 创建指标中间件 func NewMetricsMiddleware(logger *zap.Logger) *MetricsMiddleware { return &MetricsMiddleware{ logger: logger.Named("middleware.metrics"), eventCounts: make(map[string]int64), eventTimes: make(map[string]time.Duration), } } // Process 处理事件 func (m *MetricsMiddleware) Process(ctx context.Context, event protocol.Event, next func(context.Context, protocol.Event) error) error { start := time.Now() err := next(ctx, event) duration := time.Since(start) eventType := string(event.GetType()) m.mu.Lock() m.eventCounts[eventType]++ m.eventTimes[eventType] += duration m.mu.Unlock() return err } // GetMetrics 获取指标 func (m *MetricsMiddleware) GetMetrics() map[string]interface{} { m.mu.RLock() defer m.mu.RUnlock() metrics := make(map[string]interface{}) for eventType, count := range m.eventCounts { avgTime := m.eventTimes[eventType] / time.Duration(count) metrics[eventType] = map[string]interface{}{ "count": count, "avg_time": avgTime.String(), } } return metrics } // LogMetrics 记录指标 func (m *MetricsMiddleware) LogMetrics() { metrics := m.GetMetrics() m.logger.Info("Event metrics", zap.Any("metrics", metrics)) }