- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等) - 创建Repository接口实现 - 创建依赖注入容器(container.Container) - 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler) - 创建新的路由注册方式(RegisterRoutesWithDI) - 提供main.go示例文件展示如何使用依赖注入 同时包含之前的安全修复: - CORS配置安全加固 - 头像URL验证安全修复 - JWT algorithm confusion漏洞修复 - Recovery中间件增强 - 敏感错误信息泄露修复 - 类型断言安全修复
165 lines
4.5 KiB
Go
165 lines
4.5 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"testing"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// TestCORS_Headers 测试CORS中间件设置的响应头
|
||
func TestCORS_Headers(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
router := gin.New()
|
||
router.Use(CORS())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||
})
|
||
|
||
req, _ := http.NewRequest("GET", "/test", nil)
|
||
w := httptest.NewRecorder()
|
||
|
||
router.ServeHTTP(w, req)
|
||
|
||
// 验证CORS响应头
|
||
// 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范,
|
||
// 不应该设置 Access-Control-Allow-Credentials 为 "true"
|
||
expectedHeaders := map[string]string{
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||
}
|
||
|
||
for header, expectedValue := range expectedHeaders {
|
||
actualValue := w.Header().Get(header)
|
||
if actualValue != expectedValue {
|
||
t.Errorf("Header %s = %q, want %q", header, actualValue, expectedValue)
|
||
}
|
||
}
|
||
|
||
// 验证在通配符模式下不设置Credentials(这是正确的安全行为)
|
||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||
t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials)
|
||
}
|
||
|
||
// 验证Access-Control-Allow-Headers包含必要字段
|
||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||
if allowHeaders == "" {
|
||
t.Error("Access-Control-Allow-Headers 不应为空")
|
||
}
|
||
}
|
||
|
||
// TestCORS_OPTIONS 测试OPTIONS请求处理
|
||
func TestCORS_OPTIONS(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
router := gin.New()
|
||
router.Use(CORS())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||
})
|
||
|
||
req, _ := http.NewRequest("OPTIONS", "/test", nil)
|
||
w := httptest.NewRecorder()
|
||
|
||
router.ServeHTTP(w, req)
|
||
|
||
// OPTIONS请求应该返回204状态码
|
||
if w.Code != http.StatusNoContent {
|
||
t.Errorf("OPTIONS请求状态码 = %d, want %d", w.Code, http.StatusNoContent)
|
||
}
|
||
}
|
||
|
||
// TestCORS_AllowMethods 测试允许的HTTP方法
|
||
func TestCORS_AllowMethods(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
router := gin.New()
|
||
router.Use(CORS())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||
})
|
||
|
||
methods := []string{"GET", "POST", "PUT", "DELETE"}
|
||
for _, method := range methods {
|
||
t.Run(method, func(t *testing.T) {
|
||
req, _ := http.NewRequest(method, "/test", nil)
|
||
w := httptest.NewRecorder()
|
||
|
||
router.ServeHTTP(w, req)
|
||
|
||
// 验证允许的方法头包含该方法
|
||
allowMethods := w.Header().Get("Access-Control-Allow-Methods")
|
||
if allowMethods == "" {
|
||
t.Error("Access-Control-Allow-Methods 不应为空")
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// TestCORS_AllowHeaders 测试允许的请求头
|
||
func TestCORS_AllowHeaders(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
router := gin.New()
|
||
router.Use(CORS())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||
})
|
||
|
||
req, _ := http.NewRequest("GET", "/test", nil)
|
||
w := httptest.NewRecorder()
|
||
|
||
router.ServeHTTP(w, req)
|
||
|
||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||
expectedHeaders := []string{"Content-Type", "Authorization", "Accept"}
|
||
|
||
for _, expectedHeader := range expectedHeaders {
|
||
if !contains(allowHeaders, expectedHeader) {
|
||
t.Errorf("Access-Control-Allow-Headers 应包含 %s", expectedHeader)
|
||
}
|
||
}
|
||
}
|
||
|
||
// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为
|
||
func TestCORS_WithSpecificOrigin(t *testing.T) {
|
||
gin.SetMode(gin.TestMode)
|
||
|
||
// 注意:此测试验证的是在配置了具体allowed origins时的行为
|
||
// 在没有配置初始化的情况下,默认使用通配符模式
|
||
router := gin.New()
|
||
router.Use(CORS())
|
||
router.GET("/test", func(c *gin.Context) {
|
||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||
})
|
||
|
||
req, _ := http.NewRequest("GET", "/test", nil)
|
||
req.Header.Set("Origin", "http://example.com")
|
||
w := httptest.NewRecorder()
|
||
|
||
router.ServeHTTP(w, req)
|
||
|
||
// 默认配置下使用通配符,所以不应该设置credentials
|
||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||
t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials)
|
||
}
|
||
}
|
||
|
||
// 辅助函数:检查字符串是否包含子字符串(简单实现)
|
||
func contains(s, substr string) bool {
|
||
if len(substr) == 0 {
|
||
return true
|
||
}
|
||
if len(s) < len(substr) {
|
||
return false
|
||
}
|
||
for i := 0; i <= len(s)-len(substr); i++ {
|
||
if s[i:i+len(substr)] == substr {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|