package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" ) // TestCORS_Headers 测试CORS中间件设置的响应头 func TestCORS_Headers(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证CORS响应头 // 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范, // 不应该设置 Access-Control-Allow-Credentials 为 "true" expectedHeaders := map[string]string{ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE", } for header, expectedValue := range expectedHeaders { actualValue := w.Header().Get(header) if actualValue != expectedValue { t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue) } } // 验证在通配符模式下不设置Credentials(这是正确的安全行为) if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials) } // 验证Access-Control-Allow-Headers包含必要字段 allowHeaders := w.Header().Get("Access-Control-Allow-Headers") if allowHeaders == "" { t.Error("Access-Control-Allow-Headers 不应为空") } } // TestCORS_OPTIONS 测试OPTIONS请求处理 func TestCORS_OPTIONS(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("OPTIONS", "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // OPTIONS请求应该返回204状态码 if w.Code != http.StatusNoContent { t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent) } } // TestCORS_AllowMethods 测试允许的HTTP方法 func TestCORS_AllowMethods(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) methods := []string{"GET", "POST", "PUT", "DELETE"} for _, method := range methods { t.Run(method, func(t *testing.T) { req, _ := http.NewRequest(method, "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) // 验证允许的方法头包含该方法 allowMethods := w.Header().Get("Access-Control-Allow-Methods") if allowMethods == "" { t.Error("Access-Control-Allow-Methods 不应为空") } }) } } // TestCORS_AllowHeaders 测试允许的请求头 func TestCORS_AllowHeaders(t *testing.T) { gin.SetMode(gin.TestMode) router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) allowHeaders := w.Header().Get("Access-Control-Allow-Headers") expectedHeaders := []string{"Content-Type", "Authorization", "Accept"} for _, expectedHeader := range expectedHeaders { if !contains(allowHeaders, expectedHeader) { t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader) } } } // TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为 func TestCORS_WithSpecificOrigin(t *testing.T) { gin.SetMode(gin.TestMode) // 注意:此测试验证的是在配置了具体allowed origins时的行为 // 在没有配置初始化的情况下,默认使用通配符模式 router := gin.New() router.Use(CORS()) router.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "success"}) }) req, _ := http.NewRequest("GET", "/test", nil) req.Header.Set("Origin", "http://example.com") w := httptest.NewRecorder() router.ServeHTTP(w, req) // 默认配置下使用通配符,所以不应该设置credentials if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" { t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials) } } // 辅助函数:检查字符串是否包含子字符串(简单实现) func contains(s, substr string) bool { if len(substr) == 0 { return true } if len(s) < len(substr) { return false } for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return true } } return false }