Files

103 lines
1.7 KiB
Go
Raw Permalink Normal View History

package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// RateLimiter 限流器
type RateLimiter struct {
requests map[string][]time.Time
mu sync.Mutex
limit int
window time.Duration
}
// NewRateLimiter 创建限流器
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
requests: make(map[string][]time.Time),
limit: limit,
window: window,
}
// 定期清理过期的记录
go func() {
for {
time.Sleep(window)
rl.cleanup()
}
}()
return rl
}
// cleanup 清理过期的记录
func (rl *RateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, times := range rl.requests {
var valid []time.Time
for _, t := range times {
if now.Sub(t) < rl.window {
valid = append(valid, t)
}
}
if len(valid) == 0 {
delete(rl.requests, key)
} else {
rl.requests[key] = valid
}
}
}
// isAllowed 检查是否允许请求
func (rl *RateLimiter) isAllowed(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
times := rl.requests[key]
// 过滤掉过期的
var valid []time.Time
for _, t := range times {
if now.Sub(t) < rl.window {
valid = append(valid, t)
}
}
if len(valid) >= rl.limit {
rl.requests[key] = valid
return false
}
rl.requests[key] = append(valid, now)
return true
}
// RateLimit 限流中间件
func RateLimit(requestsPerMinute int) gin.HandlerFunc {
limiter := NewRateLimiter(requestsPerMinute, time.Minute)
return func(c *gin.Context) {
ip := c.ClientIP()
if !limiter.isAllowed(ip) {
c.JSON(http.StatusTooManyRequests, gin.H{
"code": 429,
"message": "too many requests",
})
c.Abort()
return
}
c.Next()
}
}