Files
backend/internal/middleware/recovery_test.go

154 lines
3.6 KiB
Go
Raw Normal View History

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"go.uber.org/zap/zaptest"
)
// TestRecovery_PanicHandling 测试恢复中间件处理panic
func TestRecovery_PanicHandling(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
// 创建一个会panic的路由
router.GET("/panic", func(c *gin.Context) {
panic("test panic")
})
req, _ := http.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder()
// 应该不会导致测试panic而是返回500错误
router.ServeHTTP(w, req)
// 验证返回500状态码
if w.Code != http.StatusInternalServerError {
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
// TestRecovery_StringPanic 测试字符串类型的panic
func TestRecovery_StringPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
router.GET("/panic", func(c *gin.Context) {
panic("string panic message")
})
req, _ := http.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证返回500状态码
if w.Code != http.StatusInternalServerError {
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
// TestRecovery_ErrorPanic 测试error类型的panic
func TestRecovery_ErrorPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
router.GET("/panic", func(c *gin.Context) {
panic(http.ErrBodyReadAfterClose)
})
req, _ := http.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder()
// 应该不会导致测试panic
router.ServeHTTP(w, req)
// 验证返回500状态码
if w.Code != http.StatusInternalServerError {
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
// TestRecovery_NilPanic 测试nil panic
func TestRecovery_NilPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
router.GET("/panic", func(c *gin.Context) {
// 直接panic模拟nil pointer错误避免linter警告
panic("runtime error: invalid memory address or nil pointer dereference")
})
req, _ := http.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证返回500状态码
if w.Code != http.StatusInternalServerError {
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
// TestRecovery_ResponseFormat 测试恢复后的响应格式
func TestRecovery_ResponseFormat(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
router.GET("/panic", func(c *gin.Context) {
panic("test panic")
})
req, _ := http.NewRequest("GET", "/panic", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证响应体包含错误信息
body := w.Body.String()
if body == "" {
t.Error("响应体不应为空")
}
}
// TestRecovery_NormalRequest 测试正常请求不受影响
func TestRecovery_NormalRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
logger := zaptest.NewLogger(t)
router := gin.New()
router.Use(Recovery(logger))
router.GET("/normal", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/normal", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 正常请求应该不受影响
if w.Code != http.StatusOK {
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
}
}