186 lines
4.1 KiB
Go
186 lines
4.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap/zaptest"
|
|
)
|
|
|
|
// TestLogger_Middleware 测试日志中间件基本功能
|
|
func TestLogger_Middleware(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
router := gin.New()
|
|
router.Use(Logger(logger))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
start := time.Now()
|
|
router.ServeHTTP(w, req)
|
|
duration := time.Since(start)
|
|
|
|
// 验证请求成功处理
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
|
|
// 验证处理时间合理(应该很短)
|
|
if duration > 1*time.Second {
|
|
t.Errorf("处理时间过长: %v", duration)
|
|
}
|
|
}
|
|
|
|
// TestLogger_RequestInfo 测试日志中间件记录的请求信息
|
|
func TestLogger_RequestInfo(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
router := gin.New()
|
|
router.Use(Logger(logger))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
method string
|
|
path string
|
|
}{
|
|
{
|
|
name: "GET请求",
|
|
method: "GET",
|
|
path: "/test",
|
|
},
|
|
{
|
|
name: "POST请求",
|
|
method: "POST",
|
|
path: "/test",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest(tt.method, tt.path, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
// 验证请求被正确处理
|
|
if w.Code != http.StatusOK && w.Code != http.StatusNotFound {
|
|
t.Errorf("状态码 = %d", w.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestLogger_QueryParams 测试带查询参数的请求
|
|
func TestLogger_QueryParams(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
router := gin.New()
|
|
router.Use(Logger(logger))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/test?page=1&size=20", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
// 验证请求成功处理
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
// TestLogger_StatusCodes 测试不同状态码的日志记录
|
|
func TestLogger_StatusCodes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
router := gin.New()
|
|
router.Use(Logger(logger))
|
|
|
|
router.GET("/success", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
router.GET("/notfound", func(c *gin.Context) {
|
|
c.JSON(http.StatusNotFound, gin.H{"message": "not found"})
|
|
})
|
|
router.GET("/error", func(c *gin.Context) {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"message": "error"})
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
wantStatus int
|
|
}{
|
|
{
|
|
name: "成功请求",
|
|
path: "/success",
|
|
wantStatus: http.StatusOK,
|
|
},
|
|
{
|
|
name: "404请求",
|
|
path: "/notfound",
|
|
wantStatus: http.StatusNotFound,
|
|
},
|
|
{
|
|
name: "500请求",
|
|
path: "/error",
|
|
wantStatus: http.StatusInternalServerError,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
req, _ := http.NewRequest("GET", tt.path, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(w, req)
|
|
|
|
if w.Code != tt.wantStatus {
|
|
t.Errorf("状态码 = %d, want %d", w.Code, tt.wantStatus)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestLogger_Latency 测试延迟计算
|
|
func TestLogger_Latency(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
router := gin.New()
|
|
router.Use(Logger(logger))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
// 模拟一些处理时间
|
|
time.Sleep(10 * time.Millisecond)
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|
})
|
|
|
|
req, _ := http.NewRequest("GET", "/test", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
start := time.Now()
|
|
router.ServeHTTP(w, req)
|
|
duration := time.Since(start)
|
|
|
|
// 验证延迟计算合理(应该包含处理时间)
|
|
if duration < 10*time.Millisecond {
|
|
t.Errorf("延迟计算可能不正确: %v", duration)
|
|
}
|
|
}
|