154 lines
3.6 KiB
Go
154 lines
3.6 KiB
Go
|
|
package middleware
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"net/http"
|
|||
|
|
"net/http/httptest"
|
|||
|
|
"testing"
|
|||
|
|
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"go.uber.org/zap/zaptest"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// TestRecovery_PanicHandling 测试恢复中间件处理panic
|
|||
|
|
func TestRecovery_PanicHandling(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
// 创建一个会panic的路由
|
|||
|
|
router.GET("/panic", func(c *gin.Context) {
|
|||
|
|
panic("test panic")
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
// 应该不会导致测试panic,而是返回500错误
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 验证返回500状态码
|
|||
|
|
if w.Code != http.StatusInternalServerError {
|
|||
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestRecovery_StringPanic 测试字符串类型的panic
|
|||
|
|
func TestRecovery_StringPanic(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
router.GET("/panic", func(c *gin.Context) {
|
|||
|
|
panic("string panic message")
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 验证返回500状态码
|
|||
|
|
if w.Code != http.StatusInternalServerError {
|
|||
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestRecovery_ErrorPanic 测试error类型的panic
|
|||
|
|
func TestRecovery_ErrorPanic(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
router.GET("/panic", func(c *gin.Context) {
|
|||
|
|
panic(http.ErrBodyReadAfterClose)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
// 应该不会导致测试panic
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 验证返回500状态码
|
|||
|
|
if w.Code != http.StatusInternalServerError {
|
|||
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestRecovery_NilPanic 测试nil panic
|
|||
|
|
func TestRecovery_NilPanic(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
router.GET("/panic", func(c *gin.Context) {
|
|||
|
|
// 直接panic模拟nil pointer错误,避免linter警告
|
|||
|
|
panic("runtime error: invalid memory address or nil pointer dereference")
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 验证返回500状态码
|
|||
|
|
if w.Code != http.StatusInternalServerError {
|
|||
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusInternalServerError)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestRecovery_ResponseFormat 测试恢复后的响应格式
|
|||
|
|
func TestRecovery_ResponseFormat(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
router.GET("/panic", func(c *gin.Context) {
|
|||
|
|
panic("test panic")
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/panic", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 验证响应体包含错误信息
|
|||
|
|
body := w.Body.String()
|
|||
|
|
if body == "" {
|
|||
|
|
t.Error("响应体不应为空")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TestRecovery_NormalRequest 测试正常请求不受影响
|
|||
|
|
func TestRecovery_NormalRequest(t *testing.T) {
|
|||
|
|
gin.SetMode(gin.TestMode)
|
|||
|
|
|
|||
|
|
logger := zaptest.NewLogger(t)
|
|||
|
|
router := gin.New()
|
|||
|
|
router.Use(Recovery(logger))
|
|||
|
|
|
|||
|
|
router.GET("/normal", func(c *gin.Context) {
|
|||
|
|
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
req, _ := http.NewRequest("GET", "/normal", nil)
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
|
|||
|
|
router.ServeHTTP(w, req)
|
|||
|
|
|
|||
|
|
// 正常请求应该不受影响
|
|||
|
|
if w.Code != http.StatusOK {
|
|||
|
|
t.Errorf("状态码 = %d, want %d", w.Code, http.StatusOK)
|
|||
|
|
}
|
|||
|
|
}
|