103 lines
1.7 KiB
Go
103 lines
1.7 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"sync"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RateLimiter 限流器
|
||
|
|
type RateLimiter struct {
|
||
|
|
requests map[string][]time.Time
|
||
|
|
mu sync.Mutex
|
||
|
|
limit int
|
||
|
|
window time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewRateLimiter 创建限流器
|
||
|
|
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||
|
|
rl := &RateLimiter{
|
||
|
|
requests: make(map[string][]time.Time),
|
||
|
|
limit: limit,
|
||
|
|
window: window,
|
||
|
|
}
|
||
|
|
|
||
|
|
// 定期清理过期的记录
|
||
|
|
go func() {
|
||
|
|
for {
|
||
|
|
time.Sleep(window)
|
||
|
|
rl.cleanup()
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
return rl
|
||
|
|
}
|
||
|
|
|
||
|
|
// cleanup 清理过期的记录
|
||
|
|
func (rl *RateLimiter) cleanup() {
|
||
|
|
rl.mu.Lock()
|
||
|
|
defer rl.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
for key, times := range rl.requests {
|
||
|
|
var valid []time.Time
|
||
|
|
for _, t := range times {
|
||
|
|
if now.Sub(t) < rl.window {
|
||
|
|
valid = append(valid, t)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if len(valid) == 0 {
|
||
|
|
delete(rl.requests, key)
|
||
|
|
} else {
|
||
|
|
rl.requests[key] = valid
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// isAllowed 检查是否允许请求
|
||
|
|
func (rl *RateLimiter) isAllowed(key string) bool {
|
||
|
|
rl.mu.Lock()
|
||
|
|
defer rl.mu.Unlock()
|
||
|
|
|
||
|
|
now := time.Now()
|
||
|
|
times := rl.requests[key]
|
||
|
|
|
||
|
|
// 过滤掉过期的
|
||
|
|
var valid []time.Time
|
||
|
|
for _, t := range times {
|
||
|
|
if now.Sub(t) < rl.window {
|
||
|
|
valid = append(valid, t)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(valid) >= rl.limit {
|
||
|
|
rl.requests[key] = valid
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
rl.requests[key] = append(valid, now)
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
|
||
|
|
// RateLimit 限流中间件
|
||
|
|
func RateLimit(requestsPerMinute int) gin.HandlerFunc {
|
||
|
|
limiter := NewRateLimiter(requestsPerMinute, time.Minute)
|
||
|
|
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
ip := c.ClientIP()
|
||
|
|
|
||
|
|
if !limiter.isAllowed(ip) {
|
||
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
|
|
"code": 429,
|
||
|
|
"message": "too many requests",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|