Files
backend/internal/middleware/cors_test.go

165 lines
4.5 KiB
Go
Raw Normal View History

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestCORS_Headers 测试CORS中间件设置的响应头
func TestCORS_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证CORS响应头
// 注意:当 Access-Control-Allow-Origin 为 "*" 时根据CORS规范
// 不应该设置 Access-Control-Allow-Credentials 为 "true"
expectedHeaders := map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
}
for header, expectedValue := range expectedHeaders {
actualValue := w.Header().Get(header)
if actualValue != expectedValue {
t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue)
}
}
// 验证在通配符模式下不设置Credentials这是正确的安全行为
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials)
}
// 验证Access-Control-Allow-Headers包含必要字段
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
if allowHeaders == "" {
t.Error("Access-Control-Allow-Headers 不应为空")
}
}
// TestCORS_OPTIONS 测试OPTIONS请求处理
func TestCORS_OPTIONS(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("OPTIONS", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// OPTIONS请求应该返回204状态码
if w.Code != http.StatusNoContent {
t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent)
}
}
// TestCORS_AllowMethods 测试允许的HTTP方法
func TestCORS_AllowMethods(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
methods := []string{"GET", "POST", "PUT", "DELETE"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req, _ := http.NewRequest(method, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证允许的方法头包含该方法
allowMethods := w.Header().Get("Access-Control-Allow-Methods")
if allowMethods == "" {
t.Error("Access-Control-Allow-Methods 不应为空")
}
})
}
}
// TestCORS_AllowHeaders 测试允许的请求头
func TestCORS_AllowHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
expectedHeaders := []string{"Content-Type", "Authorization", "Accept"}
for _, expectedHeader := range expectedHeaders {
if !contains(allowHeaders, expectedHeader) {
t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader)
}
}
}
// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为
func TestCORS_WithSpecificOrigin(t *testing.T) {
gin.SetMode(gin.TestMode)
// 注意此测试验证的是在配置了具体allowed origins时的行为
// 在没有配置初始化的情况下,默认使用通配符模式
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 默认配置下使用通配符所以不应该设置credentials
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials)
}
}
// 辅助函数:检查字符串是否包含子字符串(简单实现)
func contains(s, substr string) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}