feat: 引入依赖注入模式
- 创建Repository接口定义(UserRepository、ProfileRepository、TextureRepository等) - 创建Repository接口实现 - 创建依赖注入容器(container.Container) - 改造Handler层使用依赖注入(AuthHandler、UserHandler、TextureHandler) - 创建新的路由注册方式(RegisterRoutesWithDI) - 提供main.go示例文件展示如何使用依赖注入 同时包含之前的安全修复: - CORS配置安全加固 - 头像URL验证安全修复 - JWT algorithm confusion漏洞修复 - Recovery中间件增强 - 敏感错误信息泄露修复 - 类型断言安全修复
This commit is contained in:
@@ -1,16 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
// 获取配置,如果配置未初始化则使用默认值
|
||||
var allowedOrigins []string
|
||||
if cfg, err := config.GetConfig(); err == nil {
|
||||
allowedOrigins = cfg.Security.AllowedOrigins
|
||||
} else {
|
||||
// 默认允许所有来源(向后兼容)
|
||||
allowedOrigins = []string{"*"}
|
||||
}
|
||||
|
||||
return gin.HandlerFunc(func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
origin := c.GetHeader("Origin")
|
||||
|
||||
// 检查是否允许该来源
|
||||
allowOrigin := "*"
|
||||
if len(allowedOrigins) > 0 && allowedOrigins[0] != "*" {
|
||||
allowOrigin = ""
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == origin || allowed == "*" {
|
||||
allowOrigin = origin
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if allowOrigin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", allowOrigin)
|
||||
// 只有在非通配符模式下才允许credentials
|
||||
if allowOrigin != "*" {
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
c.Header("Access-Control-Max-Age", "86400") // 缓存预检请求结果24小时
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
|
||||
@@ -24,10 +24,11 @@ func TestCORS_Headers(t *testing.T) {
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 验证CORS响应头
|
||||
// 注意:当 Access-Control-Allow-Origin 为 "*" 时,根据CORS规范,
|
||||
// 不应该设置 Access-Control-Allow-Credentials 为 "true"
|
||||
expectedHeaders := map[string]string{
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Credentials": "true",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS, GET, PUT, DELETE",
|
||||
}
|
||||
|
||||
for header, expectedValue := range expectedHeaders {
|
||||
@@ -37,6 +38,11 @@ func TestCORS_Headers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// 验证在通配符模式下不设置Credentials(这是正确的安全行为)
|
||||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||||
t.Errorf("通配符origin模式下不应设置 Access-Control-Allow-Credentials, got %q", credentials)
|
||||
}
|
||||
|
||||
// 验证Access-Control-Allow-Headers包含必要字段
|
||||
allowHeaders := w.Header().Get("Access-Control-Allow-Headers")
|
||||
if allowHeaders == "" {
|
||||
@@ -117,6 +123,30 @@ func TestCORS_AllowHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCORS_WithSpecificOrigin 测试配置了具体origin时的CORS行为
|
||||
func TestCORS_WithSpecificOrigin(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// 注意:此测试验证的是在配置了具体allowed origins时的行为
|
||||
// 在没有配置初始化的情况下,默认使用通配符模式
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
// 默认配置下使用通配符,所以不应该设置credentials
|
||||
if credentials := w.Header().Get("Access-Control-Allow-Credentials"); credentials != "" {
|
||||
t.Logf("当前模式下 Access-Control-Allow-Credentials = %q (通配符模式不设置)", credentials)
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:检查字符串是否包含子字符串(简单实现)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) == 0 {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
@@ -11,16 +12,26 @@ import (
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
|
||||
if err, ok := recovered.(string); ok {
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", err),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
// 将任意类型的panic转换为字符串
|
||||
var errMsg string
|
||||
switch v := recovered.(type) {
|
||||
case string:
|
||||
errMsg = v
|
||||
case error:
|
||||
errMsg = v.Error()
|
||||
default:
|
||||
errMsg = fmt.Sprintf("%v", v)
|
||||
}
|
||||
|
||||
logger.Error("服务器恐慌",
|
||||
zap.String("error", errMsg),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("ip", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
zap.String("stack", string(debug.Stack())),
|
||||
)
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "服务器内部错误",
|
||||
|
||||
Reference in New Issue
Block a user