package middleware import ( "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "go.uber.org/zap/zaptest" ) // TestLogger_Middleware 测试日志中间件基本功能 func TestLogger_Middleware(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Logger(logger)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() start := time.Now() router.ServeHTTP(w, req) duration := time.Since(start) // 验证请求成功处理 if w.Code != http.StatusOK { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK) } // 验证处理时间合理(应该很短) if duration > 1*time.Second { t.Errorf("处理时间过长: %v", duration) } } // TestLogger_RequestInfo 测试日志中间件记录的请求信息 func TestLogger_RequestInfo(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Logger(logger)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) tests := []struct { name string method string path string }{ { name: "GET请求", method: "GET", path: "/test", }, { name: "POST请求", method: "POST", path: "/test", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest(tt.method, tt.path, nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证请求被正确处理 if w.Code != http.StatusOK && w.Code != http.StatusNotFound { t.Errorf("状态码 = %d", w.Code) } }) } } // TestLogger_QueryParams 测试带查询参数的请求 func TestLogger_QueryParams(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Logger(logger)) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test?page=1&size=20", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证请求成功处理 if w.Code != http.StatusOK { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK) } } // TestLogger_StatusCodes 测试不同状态码的日志记录 func TestLogger_StatusCodes(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Logger(logger)) router.GET("/success", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) router.GET("/notfound", func(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"message": "not found"}) }) router.GET("/error", func(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"message": "error"}) }) tests := []struct { name string path string wantStatus int }{ { name: "成功请求", path: "/success", wantStatus: http.StatusOK, }, { name: "404请求", path: "/notfound", wantStatus: http.StatusNotFound, }, { name: "500请求", path: "/error", wantStatus: http.StatusInternalServerError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", tt.path, nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Code != tt.wantStatus { t.Errorf("状态码 = %d, want %d", w.Code, tt.wantStatus) } }) } } // TestLogger_Latency 测试延迟计算 func TestLogger_Latency(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Logger(logger)) router.GET("/test", func(c *gin.Context) { // 模拟一些处理时间 time.Sleep(10 * time.Millisecond) c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() start := time.Now() router.ServeHTTP(w, req) duration := time.Since(start) // 验证延迟计算合理(应该包含处理时间) if duration < 10*time.Millisecond { t.Errorf("延迟计算可能不正确: %v", duration) } }