Files
backend/internal/middleware/cors_test.go
lan f7589ebbb8 feat: 引入依赖注入模式
- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等)
- 创建Repository接口实现
- 创建依赖注入容器(container.Container)
- 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler)
- 创建新的路由注册方式(RegisterRoutesWithDI)
- 提供main.go示例文件展示如何使用依赖注入

同时包含之前的安全修复:
- CORS配置安全加固
- 头像URL验证安全修复
- JWT algorithm confusion漏洞修复
- Recovery中间件增强
- 敏感错误信息泄露修复
- 类型断言安全修复
2025-12-02 17:40:39 +08:00

165 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
// TestCORS_Headers 测试CORS中间件设置的响应头
func TestCORS_Headers(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证CORS响应头
// 注意:当 Access-Control-Allow-Origin 为 "*" 时根据CORS规范
// 不应该设置 Access-Control-Allow-Credentials 为 "true"
expectedHeaders := map[string]string{
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
}
for header, expectedValue := range expectedHeaders {
actualValue := w.Header().Get(header)
if actualValue != expectedValue {
t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue)
}
}
// 验证在通配符模式下不设置Credentials这是正确的安全行为
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials)
}
// 验证Access-Control-Allow-Headers包含必要字段
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
if allowHeaders == "" {
t.Error("Access-Control-Allow-Headers 不应为空")
}
}
// TestCORS_OPTIONS 测试OPTIONS请求处理
func TestCORS_OPTIONS(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("OPTIONS", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// OPTIONS请求应该返回204状态码
if w.Code != http.StatusNoContent {
t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent)
}
}
// TestCORS_AllowMethods 测试允许的HTTP方法
func TestCORS_AllowMethods(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
methods := []string{"GET", "POST", "PUT", "DELETE"}
for _, method := range methods {
t.Run(method, func(t *testing.T) {
req, _ := http.NewRequest(method, "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 验证允许的方法头包含该方法
allowMethods := w.Header().Get("Access-Control-Allow-Methods")
if allowMethods == "" {
t.Error("Access-Control-Allow-Methods 不应为空")
}
})
}
}
// TestCORS_AllowHeaders 测试允许的请求头
func TestCORS_AllowHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
expectedHeaders := []string{"Content-Type", "Authorization", "Accept"}
for _, expectedHeader := range expectedHeaders {
if !contains(allowHeaders, expectedHeader) {
t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader)
}
}
}
// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为
func TestCORS_WithSpecificOrigin(t *testing.T) {
gin.SetMode(gin.TestMode)
// 注意此测试验证的是在配置了具体allowed origins时的行为
// 在没有配置初始化的情况下,默认使用通配符模式
router := gin.New()
router.Use(CORS())
router.GET("/test", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "success"})
})
req, _ := http.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://example.com")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
// 默认配置下使用通配符所以不应该设置credentials
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials)
}
}
// 辅助函数:检查字符串是否包含子字符串(简单实现)
func contains(s, substr string) bool {
if len(substr) == 0 {
return true
}
if len(s) < len(substr) {
return false
}
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}