package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "go.uber.org/zap/zaptest" ) // TestRecovery_PanicHandling 测试恢复中间件处理panic func TestRecovery_PanicHandling(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) // 创建一个会panic的路由 router.GET("/panic", func(c *gin.Context) { panic("test panic") }) req, _ := http.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() // 应该不会导致测试panic,而是返回500错误 router.ServeHTTP(w, req) // 验证返回500状态码 if w.Code != http.StatusInternalServerError { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError) } } // TestRecovery_StringPanic 测试字符串类型的panic func TestRecovery_StringPanic(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) router.GET("/panic", func(c *gin.Context) { panic("string panic message") }) req, _ := http.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证返回500状态码 if w.Code != http.StatusInternalServerError { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError) } } // TestRecovery_ErrorPanic 测试error类型的panic func TestRecovery_ErrorPanic(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) router.GET("/panic", func(c *gin.Context) { panic(http.ErrBodyReadAfterClose) }) req, _ := http.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() // 应该不会导致测试panic router.ServeHTTP(w, req) // 验证返回500状态码 if w.Code != http.StatusInternalServerError { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError) } } // TestRecovery_NilPanic 测试nil panic func TestRecovery_NilPanic(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) router.GET("/panic", func(c *gin.Context) { // 直接panic模拟nil pointer错误,避免linter警告 panic("runtime error: invalid memory address or nil pointer dereference") }) req, _ := http.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证返回500状态码 if w.Code != http.StatusInternalServerError { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError) } } // TestRecovery_ResponseFormat 测试恢复后的响应格式 func TestRecovery_ResponseFormat(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) router.GET("/panic", func(c *gin.Context) { panic("test panic") }) req, _ := http.NewRequest("GET", "/panic", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证响应体包含错误信息 body := w.Body.String() if body == "" { t.Error("响应体不应为空") } } // TestRecovery_NormalRequest 测试正常请求不受影响 func TestRecovery_NormalRequest(t *testing.T) { gin.SetMode(gin.TestMode) logger := zaptest.NewLogger(t) router := gin.New() router.Use(Recovery(logger)) router.GET("/normal", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/normal", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 正常请求应该不受影响 if w.Code != http.StatusOK { t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK) } }