package middleware import ( "net/http" "sync" "time" "github.com/gin-gonic/gin" ) // RateLimiter 限流器 type RateLimiter struct { requests map[string][]time.Time mu sync.Mutex limit int window time.Duration } // NewRateLimiter 创建限流器 func NewRateLimiter(limit int, window time.Duration) *RateLimiter { rl := &RateLimiter{ requests: make(map[string][]time.Time), limit: limit, window: window, } // 定期清理过期的记录 go func() { for { time.Sleep(window) rl.cleanup() } }() return rl } // cleanup 清理过期的记录 func (rl *RateLimiter) cleanup() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() for key, times := range rl.requests { var valid []time.Time for _, t := range times { if now.Sub(t) < rl.window { valid = append(valid, t) } } if len(valid) == 0 { delete(rl.requests, key) } else { rl.requests[key] = valid } } } // isAllowed 检查是否允许请求 func (rl *RateLimiter) isAllowed(key string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() times := rl.requests[key] // 过滤掉过期的 var valid []time.Time for _, t := range times { if now.Sub(t) < rl.window { valid = append(valid, t) } } if len(valid) >= rl.limit { rl.requests[key] = valid return false } rl.requests[key] = append(valid, now) return true } // RateLimit 限流中间件 func RateLimit(requestsPerMinute int) gin.HandlerFunc { limiter := NewRateLimiter(requestsPerMinute, time.Minute) return func(c *gin.Context) { ip := c.ClientIP() if !limiter.isAllowed(ip) { c.JSON(http.StatusTooManyRequests, gin.H{ "code": 429, "message": "too many requests", }) c.Abort() return } c.Next() } }