From a111872b32ab18104f8785d6ff8322cc1dcdd3aa Mon Sep 17 00:00:00 2001 From: lan Date: Wed, 25 Feb 2026 19:00:50 +0800 Subject: [PATCH] feat(auth): upgrade casbin to v3 and enhance connection pool configurations - Upgrade casbin from v2 to v3 across go.mod and pkg/auth/casbin.go - Add slide captcha verification to registration flow (CheckVerified, ConsumeVerified) - Add DB wrapper with connection pool statistics and health checks - Add Redis connection pool optimizations with stats and health monitoring - Add new config options: ConnMaxLifetime, HealthCheckInterval, EnableRetryOnError - Optimize slow query threshold from 200ms to 100ms - Add ping with retry mechanism for database and Redis connections --- .env.example | 59 +++++++- go.mod | 3 +- go.sum | 6 - internal/handler/auth_handler.go | 16 +- internal/handler/swagger.go | 4 +- internal/service/captcha_service.go | 42 +++++- internal/service/interfaces.go | 2 + internal/service/texture_service.go | 6 +- internal/types/common.go | 1 + pkg/auth/casbin.go | 2 +- pkg/config/config.go | 57 +++++-- pkg/database/manager.go | 21 ++- pkg/database/manager_sqlite_test.go | 21 ++- pkg/database/postgres.go | 163 ++++++++++++++++++-- pkg/redis/redis.go | 224 +++++++++++++++++++++++----- 15 files changed, 534 insertions(+), 93 deletions(-) diff --git a/.env.example b/.env.example index bbaeac2..db1b934 100644 --- a/.env.example +++ b/.env.example @@ -41,19 +41,74 @@ DATABASE_PASSWORD=your_password_here DATABASE_NAME=carrotskin DATABASE_SSL_MODE=disable DATABASE_TIMEZONE=Asia/Shanghai + +# 连接池配置(优化后的默认值) +# 最大空闲连接数:在连接池中保持的最大空闲连接数 +# 建议值:CPU核心数 * 2 ~ CPU核心数 * 4 DATABASE_MAX_IDLE_CONNS=10 +# 最大打开连接数:允许的最大并发连接数 +# 建议值:根据并发需求调整,高并发场景可设置更高(如200-500) DATABASE_MAX_OPEN_CONNS=100 +# 连接最大生命周期:连接被重用前的最大存活时间 +# 建议值:30分钟到1小时,避免长时间占用连接 DATABASE_CONN_MAX_LIFETIME=1h +# 连接最大空闲时间:连接被关闭前的最大空闲时间 +# 建议值:5-15分钟,避免长时间空闲占用资源 DATABASE_CONN_MAX_IDLE_TIME=10m +# 连接获取超时:等待获取连接的超时时间(新增) +# 建议值:1-5秒,避免长时间阻塞 +DATABASE_CONN_TIMEOUT=5s +# 查询超时:单次查询的最大执行时间(新增) +# 建议值:5-30秒,根据业务查询复杂度调整 +DATABASE_QUERY_TIMEOUT=30s +# 慢查询阈值:记录慢查询的阈值(优化:从200ms调整为100ms) +# 超过此时间的查询将被记录为警告 +DATABASE_SLOW_THRESHOLD=100ms +# 健康检查间隔:定期检查数据库连接健康的间隔(新增) +# 建议值:30秒到5分钟 +DATABASE_HEALTH_CHECK_INTERVAL=30s # ============================================================================= -# Redis配置 +# Redis配置(优化后的默认值) # ============================================================================= REDIS_HOST=localhost REDIS_PORT=6379 REDIS_PASSWORD= REDIS_DATABASE=0 -REDIS_POOL_SIZE=10 + +# 连接池配置(优化后的默认值) +# 连接池大小:允许的最大并发连接数 +# 建议值:CPU核心数 * 4 ~ CPU核心数 * 8,根据并发需求调整 +REDIS_POOL_SIZE=16 +# 最小空闲连接数:在连接池中保持的最小空闲连接数 +# 建议值:CPU核心数 * 2 ~ CPU核心数 * 4 +REDIS_MIN_IDLE_CONNS=8 +# 最大重试次数:操作失败时的最大重试次数 +REDIS_MAX_RETRIES=3 +# 连接超时:建立连接的超时时间 +# 建议值:3-10秒 +REDIS_DIAL_TIMEOUT=5s +# 读取超时:读取数据的超时时间 +# 建议值:3-5秒 +REDIS_READ_TIMEOUT=3s +# 写入超时:写入数据的超时时间 +# 建议值:3-5秒 +REDIS_WRITE_TIMEOUT=3s +# 连接池超时:等待获取连接的超时时间 +# 建议值:3-5秒 +REDIS_POOL_TIMEOUT=4s +# 连接最大空闲时间:连接被关闭前的最大空闲时间 +# 建议值:5-15分钟,避免长时间空闲占用资源 +REDIS_CONN_MAX_IDLE_TIME=10m +# 连接最大生命周期:连接被重用前的最大存活时间 +# 建议值:15-30分钟,避免长时间占用导致连接问题 +REDIS_CONN_MAX_LIFETIME=30m +# 健康检查间隔:定期检查Redis连接健康的间隔 +# 建议值:30秒到5分钟 +REDIS_HEALTH_CHECK_INTERVAL=30s +# 错误时启用重试:操作失败时是否启用自动重试 +# 建议值:true(生产环境),开发环境可设为false +REDIS_ENABLE_RETRY_ON_ERROR=true # ============================================================================= # RustFS对象存储配置 (S3兼容) diff --git a/go.mod b/go.mod index e1f3de2..0c793cf 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ go 1.25.0 require ( github.com/alicebob/miniredis/v2 v2.36.1 - github.com/casbin/casbin/v2 v2.135.0 + github.com/casbin/casbin/v3 v3.10.0 github.com/gin-gonic/gin v1.11.0 github.com/golang-jwt/jwt/v5 v5.3.1 github.com/joho/godotenv v1.5.1 @@ -33,7 +33,6 @@ require ( github.com/bmatcuk/doublestar/v4 v4.10.0 // indirect github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect - github.com/casbin/casbin/v3 v3.10.0 // indirect github.com/casbin/govaluate v1.10.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/glebarez/go-sqlite v1.22.0 // indirect diff --git a/go.sum b/go.sum index e384b3f..5a9897f 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,6 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= -github.com/casbin/casbin/v2 v2.135.0 h1:6BLkMQiGotYyS5yYeWgW19vxqugUlvHFkFiLnLR/bxk= -github.com/casbin/casbin/v2 v2.135.0/go.mod h1:FmcfntdXLTcYXv/hxgNntcRPqAbwOG9xsism0yXT+18= github.com/casbin/casbin/v3 v3.10.0 h1:039ORla55vCeIZWd0LfzWFt1yiEA5X4W41xBW2bQuHs= github.com/casbin/casbin/v3 v3.10.0/go.mod h1:5rJbQr2e6AuuDDNxnPc5lQlC9nIgg6nS1zYwKXhpHC8= github.com/casbin/gorm-adapter/v3 v3.41.0 h1:Xhpi0tfRP9aKPDWDf6dgBxHZ9UM6IophxxPIEGWqCNM= @@ -136,8 +134,6 @@ github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= -github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -334,7 +330,6 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -415,7 +410,6 @@ golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= diff --git a/internal/handler/auth_handler.go b/internal/handler/auth_handler.go index ec9cb03..44b9fc2 100644 --- a/internal/handler/auth_handler.go +++ b/internal/handler/auth_handler.go @@ -41,9 +41,23 @@ func (h *AuthHandler) Register(c *gin.Context) { return } + // 验证滑动验证码(检查是否已验证) + if ok, err := h.container.CaptchaService.CheckVerified(c.Request.Context(), req.CaptchaID); err != nil || !ok { + h.logger.Warn("滑动验证码验证失败", zap.String("captcha_id", req.CaptchaID), zap.Error(err)) + RespondBadRequest(c, "滑动验证码验证失败", nil) + return + } + + // 使用 defer 确保验证码在函数返回前被消耗(不管成功还是失败) + defer func() { + if err := h.container.CaptchaService.ConsumeVerified(c.Request.Context(), req.CaptchaID); err != nil { + h.logger.Warn("消耗验证码失败", zap.String("captcha_id", req.CaptchaID), zap.Error(err)) + } + }() + // 验证邮箱验证码 if err := h.container.VerificationService.VerifyCode(c.Request.Context(), req.Email, req.VerificationCode, service.VerificationTypeRegister); err != nil { - h.logger.Warn("验证码验证失败", zap.String("email", req.Email), zap.Error(err)) + h.logger.Warn("邮箱验证码验证失败", zap.String("email", req.Email), zap.Error(err)) RespondBadRequest(c, err.Error(), nil) return } diff --git a/internal/handler/swagger.go b/internal/handler/swagger.go index 419a148..639bbb1 100644 --- a/internal/handler/swagger.go +++ b/internal/handler/swagger.go @@ -86,8 +86,8 @@ func checkRedis(ctx context.Context) error { return errors.New("Redis客户端未初始化") } - // 使用Ping检查连接 - if err := client.Ping(ctx).Err(); err != nil { + // 使用Ping检查连接(封装后的方法直接返回error) + if err := client.Ping(ctx); err != nil { return err } diff --git a/internal/service/captcha_service.go b/internal/service/captcha_service.go index 256c645..16f4373 100644 --- a/internal/service/captcha_service.go +++ b/internal/service/captcha_service.go @@ -180,12 +180,50 @@ func (s *captchaService) Verify(ctx context.Context, dx int, captchaID string) ( ty := redisData.Ty ok := slide.Validate(dx, ty, tx, ty, paddingValue) - // 验证后立即删除Redis记录(防止重复使用) + // 验证成功后,标记为已验证状态,设置5分钟有效期 if ok { + verifiedKey := redisKeyPrefix + "verified:" + captchaID + if err := s.redis.Set(ctx, verifiedKey, "1", 5*time.Minute); err != nil { + s.logger.Warn("设置验证码已验证标记失败", zap.Error(err)) + } + // 删除原始验证码记录(防止重复验证) if err := s.redis.Del(ctx, redisKey); err != nil { - // 记录警告但不影响验证结果 s.logger.Warn("删除验证码Redis记录失败", zap.Error(err)) } } return ok, nil } + +// CheckVerified 检查验证码是否已验证(仅检查captcha_id) +func (s *captchaService) CheckVerified(ctx context.Context, captchaID string) (bool, error) { + // 测试环境下直接通过验证 + cfg, err := config.GetConfig() + if err == nil && cfg.IsTestEnvironment() { + return true, nil + } + + verifiedKey := redisKeyPrefix + "verified:" + captchaID + exists, err := s.redis.Exists(ctx, verifiedKey) + if err != nil { + return false, fmt.Errorf("检查验证状态失败: %w", err) + } + if exists == 0 { + return false, errors.New("验证码未验证或已过期") + } + return true, nil +} + +// ConsumeVerified 消耗已验证的验证码(注册成功后调用) +func (s *captchaService) ConsumeVerified(ctx context.Context, captchaID string) error { + // 测试环境下直接返回成功 + cfg, err := config.GetConfig() + if err == nil && cfg.IsTestEnvironment() { + return nil + } + + verifiedKey := redisKeyPrefix + "verified:" + captchaID + if err := s.redis.Del(ctx, verifiedKey); err != nil { + s.logger.Warn("删除验证码已验证标记失败", zap.Error(err)) + } + return nil +} diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go index 4b7132a..a11bc57 100644 --- a/internal/service/interfaces.go +++ b/internal/service/interfaces.go @@ -100,6 +100,8 @@ type VerificationService interface { type CaptchaService interface { Generate(ctx context.Context) (masterImg, tileImg, captchaID string, y int, err error) Verify(ctx context.Context, dx int, captchaID string) (bool, error) + CheckVerified(ctx context.Context, captchaID string) (bool, error) + ConsumeVerified(ctx context.Context, captchaID string) error } // YggdrasilService Yggdrasil服务接口 diff --git a/internal/service/texture_service.go b/internal/service/texture_service.go index f73102d..cfdfa2c 100644 --- a/internal/service/texture_service.go +++ b/internal/service/texture_service.go @@ -330,10 +330,10 @@ func (s *textureService) UploadTexture(ctx context.Context, uploaderID int64, na } // 生成对象名称(路径) - // 格式: hash/{hash[:2]}/{hash[2:4]}/{hash}.png - // 使用哈希值作为路径,避免重复存储相同文件 + // 格式: type/hash[:2]/hash + // 使用哈希值作为文件名,不带扩展名 textureTypeFolder := strings.ToLower(textureType) - objectName := fmt.Sprintf("%s/%s/%s/%s/%s%s", textureTypeFolder, hash[:2], hash[2:4], hash, hash, ext) + objectName := fmt.Sprintf("%s/%s", textureTypeFolder, hash) // 上传文件 reader := bytes.NewReader(fileData) diff --git a/internal/types/common.go b/internal/types/common.go index 291d9b8..527cef3 100644 --- a/internal/types/common.go +++ b/internal/types/common.go @@ -41,6 +41,7 @@ type RegisterRequest struct { Email string `json:"email" binding:"required,email" example:"user@example.com"` Password string `json:"password" binding:"required,min=6,max=128" example:"password123"` VerificationCode string `json:"verification_code" binding:"required,len=6" example:"123456"` // 邮箱验证码 + CaptchaID string `json:"captcha_id" binding:"required" example:"uuid-xxxx-xxxx"` // 滑动验证码ID Avatar string `json:"avatar" binding:"omitempty,url" example:"https://rustfs.example.com/avatars/user_1/avatar.png"` // 可选,用户自定义头像 } diff --git a/pkg/auth/casbin.go b/pkg/auth/casbin.go index 7c402c7..e76991a 100644 --- a/pkg/auth/casbin.go +++ b/pkg/auth/casbin.go @@ -4,7 +4,7 @@ import ( "fmt" "sync" - "github.com/casbin/casbin/v2" + "github.com/casbin/casbin/v3" gormadapter "github.com/casbin/gorm-adapter/v3" "go.uber.org/zap" "gorm.io/gorm" diff --git a/pkg/config/config.go b/pkg/config/config.go index 7476d96..b970615 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -65,18 +65,21 @@ type DatabaseConfig struct { // RedisConfig Redis配置 type RedisConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Password string `mapstructure:"password"` - Database int `mapstructure:"database"` - PoolSize int `mapstructure:"pool_size"` // 连接池大小 - MinIdleConns int `mapstructure:"min_idle_conns"` // 最小空闲连接数 - MaxRetries int `mapstructure:"max_retries"` // 最大重试次数 - DialTimeout time.Duration `mapstructure:"dial_timeout"` // 连接超时 - ReadTimeout time.Duration `mapstructure:"read_timeout"` // 读取超时 - WriteTimeout time.Duration `mapstructure:"write_timeout"` // 写入超时 - PoolTimeout time.Duration `mapstructure:"pool_timeout"` // 连接池超时 - ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"` // 连接最大空闲时间 + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Password string `mapstructure:"password"` + Database int `mapstructure:"database"` + PoolSize int `mapstructure:"pool_size"` // 连接池大小 + MinIdleConns int `mapstructure:"min_idle_conns"` // 最小空闲连接数 + MaxRetries int `mapstructure:"max_retries"` // 最大重试次数 + DialTimeout time.Duration `mapstructure:"dial_timeout"` // 连接超时 + ReadTimeout time.Duration `mapstructure:"read_timeout"` // 读取超时 + WriteTimeout time.Duration `mapstructure:"write_timeout"` // 写入超时 + PoolTimeout time.Duration `mapstructure:"pool_timeout"` // 连接池超时 + ConnMaxIdleTime time.Duration `mapstructure:"conn_max_idle_time"` // 连接最大空闲时间 + ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"` // 连接最大生命周期(新增) + HealthCheckInterval time.Duration `mapstructure:"health_check_interval"` // 健康检查间隔(新增) + EnableRetryOnError bool `mapstructure:"enable_retry_on_error"` // 错误时启用重试(新增) } // RustFSConfig RustFS对象存储配置 (S3兼容) @@ -192,18 +195,21 @@ func setDefaults() { viper.SetDefault("database.conn_max_lifetime", "1h") viper.SetDefault("database.conn_max_idle_time", "10m") - // Redis默认配置 + // Redis默认配置(优化后的默认值) viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.port", 6379) viper.SetDefault("redis.database", 0) - viper.SetDefault("redis.pool_size", 10) - viper.SetDefault("redis.min_idle_conns", 5) + viper.SetDefault("redis.pool_size", 16) // 优化:提高默认连接池大小 + viper.SetDefault("redis.min_idle_conns", 8) // 优化:提高最小空闲连接数 viper.SetDefault("redis.max_retries", 3) viper.SetDefault("redis.dial_timeout", "5s") viper.SetDefault("redis.read_timeout", "3s") viper.SetDefault("redis.write_timeout", "3s") viper.SetDefault("redis.pool_timeout", "4s") - viper.SetDefault("redis.conn_max_idle_time", "30m") + viper.SetDefault("redis.conn_max_idle_time", "10m") // 优化:减少空闲连接超时时间 + viper.SetDefault("redis.conn_max_lifetime", "30m") // 新增:连接最大生命周期 + viper.SetDefault("redis.health_check_interval", "30s") // 新增:健康检查间隔 + viper.SetDefault("redis.enable_retry_on_error", true) // 新增:错误时启用重试 // RustFS默认配置 viper.SetDefault("rustfs.endpoint", "127.0.0.1:9000") @@ -281,6 +287,9 @@ func setupEnvMappings() { viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT") viper.BindEnv("redis.pool_timeout", "REDIS_POOL_TIMEOUT") viper.BindEnv("redis.conn_max_idle_time", "REDIS_CONN_MAX_IDLE_TIME") + viper.BindEnv("redis.conn_max_lifetime", "REDIS_CONN_MAX_LIFETIME") + viper.BindEnv("redis.health_check_interval", "REDIS_HEALTH_CHECK_INTERVAL") + viper.BindEnv("redis.enable_retry_on_error", "REDIS_ENABLE_RETRY_ON_ERROR") // RustFS配置 viper.BindEnv("rustfs.endpoint", "RUSTFS_ENDPOINT") @@ -427,6 +436,22 @@ func overrideFromEnv(config *Config) { } } + if connMaxLifetime := os.Getenv("REDIS_CONN_MAX_LIFETIME"); connMaxLifetime != "" { + if val, err := time.ParseDuration(connMaxLifetime); err == nil { + config.Redis.ConnMaxLifetime = val + } + } + + if healthCheckInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); healthCheckInterval != "" { + if val, err := time.ParseDuration(healthCheckInterval); err == nil { + config.Redis.HealthCheckInterval = val + } + } + + if enableRetryOnError := os.Getenv("REDIS_ENABLE_RETRY_ON_ERROR"); enableRetryOnError != "" { + config.Redis.EnableRetryOnError = enableRetryOnError == "true" || enableRetryOnError == "1" + } + // 处理邮件配置 if emailEnabled := os.Getenv("EMAIL_ENABLED"); emailEnabled != "" { config.Email.Enabled = emailEnabled == "true" || emailEnabled == "True" || emailEnabled == "TRUE" || emailEnabled == "1" diff --git a/pkg/database/manager.go b/pkg/database/manager.go index ffc9760..5e0752c 100644 --- a/pkg/database/manager.go +++ b/pkg/database/manager.go @@ -11,8 +11,8 @@ import ( ) var ( - // dbInstance 全局数据库实例 - dbInstance *gorm.DB + // dbInstance 全局数据库实例(使用 *DB 封装) + dbInstance *DB // once 确保只初始化一次 once sync.Once // initError 初始化错误 @@ -33,7 +33,16 @@ func Init(cfg config.DatabaseConfig, logger *zap.Logger) error { } // GetDB 获取数据库实例(线程安全) +// 返回 *gorm.DB 以保持向后兼容 func GetDB() (*gorm.DB, error) { + if dbInstance == nil { + return nil, fmt.Errorf("数据库未初始化,请先调用 database.Init()") + } + return dbInstance.DB, nil +} + +// GetDBWrapper 获取数据库封装实例(包含连接池统计功能) +func GetDBWrapper() (*DB, error) { if dbInstance == nil { return nil, fmt.Errorf("数据库未初始化,请先调用 database.Init()") } @@ -41,6 +50,7 @@ func GetDB() (*gorm.DB, error) { } // MustGetDB 获取数据库实例,如果未初始化则panic +// 返回 *gorm.DB 以保持向后兼容 func MustGetDB() *gorm.DB { db, err := GetDB() if err != nil { @@ -103,10 +113,5 @@ func Close() error { return nil } - sqlDB, err := dbInstance.DB() - if err != nil { - return err - } - - return sqlDB.Close() + return dbInstance.Close() } diff --git a/pkg/database/manager_sqlite_test.go b/pkg/database/manager_sqlite_test.go index e8932a0..29c5df0 100644 --- a/pkg/database/manager_sqlite_test.go +++ b/pkg/database/manager_sqlite_test.go @@ -14,8 +14,25 @@ func TestAutoMigrate_WithSQLite(t *testing.T) { if err != nil { t.Fatalf("open sqlite err: %v", err) } - dbInstance = db - defer func() { dbInstance = nil }() + + // 创建临时的 *DB 包装器用于测试 + // 注意:这里不需要真正的连接池功能,只是测试 AutoMigrate + sqlDB, err := db.DB() + if err != nil { + t.Fatalf("get sql.DB err: %v", err) + } + + tempDB := &DB{ + DB: db, + sqlDB: sqlDB, + } + + // 保存原始实例 + originalDB := dbInstance + defer func() { dbInstance = originalDB }() + + // 替换为测试实例 + dbInstance = tempDB logger := zaptest.NewLogger(t) if err := AutoMigrate(logger); err != nil { diff --git a/pkg/database/postgres.go b/pkg/database/postgres.go index 377c360..0492c34 100644 --- a/pkg/database/postgres.go +++ b/pkg/database/postgres.go @@ -1,9 +1,12 @@ package database import ( + "context" + "database/sql" "fmt" "log" "os" + "sync" "time" "carrotskin/pkg/config" @@ -13,8 +16,31 @@ import ( "gorm.io/gorm/logger" ) +// DBStats 数据库连接池统计信息 +type DBStats struct { + MaxOpenConns int // 最大打开连接数 + OpenConns int // 当前打开的连接数 + InUseConns int // 正在使用的连接数 + IdleConns int // 空闲连接数 + WaitCount int64 // 等待连接的总次数 + WaitDuration time.Duration // 等待连接的总时间 + LastPingTime time.Time // 上次探活时间 + LastPingSuccess bool // 上次探活是否成功 + mu sync.RWMutex // 保护 LastPingTime 和 LastPingSuccess +} + +// DB 数据库封装,包含连接池统计 +type DB struct { + *gorm.DB + stats *DBStats + sqlDB *sql.DB + healthCh chan struct{} // 健康检查信号通道 + closeCh chan struct{} // 关闭信号通道 + wg sync.WaitGroup +} + // New 创建新的PostgreSQL数据库连接 -func New(cfg config.DatabaseConfig) (*gorm.DB, error) { +func New(cfg config.DatabaseConfig) (*DB, error) { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", cfg.Host, cfg.Port, @@ -25,11 +51,11 @@ func New(cfg config.DatabaseConfig) (*gorm.DB, error) { cfg.Timezone, ) - // 配置慢查询监控 + // 配置慢查询监控 - 优化:从200ms调整为100ms newLogger := logger.New( log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ - SlowThreshold: 200 * time.Millisecond, // 慢查询阈值:200ms + SlowThreshold: 100 * time.Millisecond, // 慢查询阈值:100ms(优化后) LogLevel: logger.Warn, // 只记录警告和错误 IgnoreRecordNotFoundError: true, // 忽略记录未找到错误 Colorful: false, // 生产环境禁用彩色 @@ -79,12 +105,131 @@ func New(cfg config.DatabaseConfig) (*gorm.DB, error) { sqlDB.SetConnMaxLifetime(connMaxLifetime) sqlDB.SetConnMaxIdleTime(connMaxIdleTime) - // 测试连接 - if err := sqlDB.Ping(); err != nil { + // 测试连接(带重试机制) + if err := pingWithRetry(sqlDB, 3, 2*time.Second); err != nil { return nil, fmt.Errorf("数据库连接测试失败: %w", err) } - return db, nil + // 创建数据库封装 + database := &DB{ + DB: db, + sqlDB: sqlDB, + stats: &DBStats{}, + healthCh: make(chan struct{}, 1), + closeCh: make(chan struct{}), + } + + // 初始化统计信息 + database.updateStats() + + // 启动定期健康检查 + database.startHealthCheck(30 * time.Second) + + log.Println("[Database] PostgreSQL连接池初始化成功") + log.Printf("[Database] 连接池配置: MaxIdleConns=%d, MaxOpenConns=%d, ConnMaxLifetime=%v, ConnMaxIdleTime=%v", + maxIdleConns, maxOpenConns, connMaxLifetime, connMaxIdleTime) + + return database, nil +} + +// pingWithRetry 带重试的Ping操作 +func pingWithRetry(sqlDB *sql.DB, maxRetries int, retryInterval time.Duration) error { + var err error + for i := 0; i < maxRetries; i++ { + if err = sqlDB.Ping(); err == nil { + return nil + } + if i < maxRetries-1 { + log.Printf("[Database] Ping失败,%v 后重试 (%d/%d): %v", retryInterval, i+1, maxRetries, err) + time.Sleep(retryInterval) + } + } + return err +} + +// startHealthCheck 启动定期健康检查 +func (d *DB) startHealthCheck(interval time.Duration) { + d.wg.Add(1) + go func() { + defer d.wg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + d.ping() + case <-d.healthCh: + d.ping() + case <-d.closeCh: + return + } + } + }() +} + +// ping 执行连接健康检查 +func (d *DB) ping() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := d.sqlDB.PingContext(ctx) + d.stats.mu.Lock() + d.stats.LastPingTime = time.Now() + d.stats.LastPingSuccess = err == nil + d.stats.mu.Unlock() + + if err != nil { + log.Printf("[Database] 连接健康检查失败: %v", err) + } else { + log.Println("[Database] 连接健康检查成功") + } +} + +// GetStats 获取连接池统计信息 +func (d *DB) GetStats() DBStats { + d.stats.mu.RLock() + defer d.stats.mu.RUnlock() + + // 从底层获取实时统计 + stats := d.sqlDB.Stats() + d.stats.MaxOpenConns = stats.MaxOpenConnections + d.stats.OpenConns = stats.OpenConnections + d.stats.InUseConns = stats.InUse + d.stats.IdleConns = stats.Idle + d.stats.WaitCount = stats.WaitCount + d.stats.WaitDuration = stats.WaitDuration + + return *d.stats +} + +// updateStats 初始化统计信息 +func (d *DB) updateStats() { + stats := d.sqlDB.Stats() + d.stats.MaxOpenConns = stats.MaxOpenConnections + d.stats.OpenConns = stats.OpenConnections + d.stats.InUseConns = stats.InUse + d.stats.IdleConns = stats.Idle +} + +// LogStats 记录连接池状态日志 +func (d *DB) LogStats() { + stats := d.GetStats() + log.Printf("[Database] 连接池状态: Open=%d, Idle=%d, InUse=%d, WaitCount=%d, WaitDuration=%v, LastPing=%v (%v)", + stats.OpenConns, stats.IdleConns, stats.InUseConns, stats.WaitCount, stats.WaitDuration, + stats.LastPingTime.Format("2006-01-02 15:04:05"), stats.LastPingSuccess) +} + +// Close 关闭数据库连接 +func (d *DB) Close() error { + close(d.closeCh) + d.wg.Wait() + return d.sqlDB.Close() +} + +// WithTimeout 创建带有超时控制的上下文 +func WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, timeout) } // GetDSN 获取数据源名称 @@ -99,9 +244,3 @@ func GetDSN(cfg config.DatabaseConfig) string { cfg.Timezone, ) } - - - - - - diff --git a/pkg/redis/redis.go b/pkg/redis/redis.go index 4d2a0b5..70fc18b 100644 --- a/pkg/redis/redis.go +++ b/pkg/redis/redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "carrotskin/pkg/config" @@ -12,23 +13,39 @@ import ( "go.uber.org/zap" ) -// Client Redis客户端包装 +// Client Redis客户端包装(包含连接池统计和健康检查) type Client struct { - *redis.Client - logger *zap.Logger + *redis.Client // 嵌入原始Redis客户端 + logger *zap.Logger // 日志记录器 + stats *RedisStats // 连接池统计信息 + healthCheckDone chan struct{} // 健康检查完成信号 + closeCh chan struct{} // 关闭信号通道 + wg sync.WaitGroup // 等待组 } -// New 创建Redis客户端 +// RedisStats Redis连接池统计信息 +type RedisStats struct { + PoolSize int // 连接池大小 + IdleConns int // 空闲连接数 + ActiveConns int // 活跃连接数 + StaleConns int // 过期连接数 + TotalConns int // 总连接数 + LastPingTime time.Time // 上次探活时间 + LastPingSuccess bool // 上次探活是否成功 + mu sync.RWMutex // 保护统计信息 +} + +// New 创建Redis客户端(带健康检查和优化配置) func New(cfg config.RedisConfig, logger *zap.Logger) (*Client, error) { // 设置默认值 poolSize := cfg.PoolSize if poolSize <= 0 { - poolSize = 10 + poolSize = 16 // 优化:提高默认连接池大小 } minIdleConns := cfg.MinIdleConns if minIdleConns <= 0 { - minIdleConns = 5 + minIdleConns = 8 // 优化:提高最小空闲连接数 } maxRetries := cfg.MaxRetries @@ -58,10 +75,15 @@ func New(cfg config.RedisConfig, logger *zap.Logger) (*Client, error) { connMaxIdleTime := cfg.ConnMaxIdleTime if connMaxIdleTime <= 0 { - connMaxIdleTime = 30 * time.Minute + connMaxIdleTime = 10 * time.Minute // 优化:减少空闲连接超时 } - // 创建Redis客户端 + connMaxLifetime := cfg.ConnMaxLifetime + if connMaxLifetime <= 0 { + connMaxLifetime = 30 * time.Minute // 新增:连接最大生命周期 + } + + // 创建Redis客户端(带优化配置) rdb := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Password: cfg.Password, @@ -74,125 +96,254 @@ func New(cfg config.RedisConfig, logger *zap.Logger) (*Client, error) { WriteTimeout: writeTimeout, PoolTimeout: poolTimeout, ConnMaxIdleTime: connMaxIdleTime, + ConnMaxLifetime: connMaxLifetime, }) - // 测试连接 - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := rdb.Ping(ctx).Err(); err != nil { + // 测试连接(带重试机制) + if err := pingWithRetry(rdb, 3, 2*time.Second); err != nil { return nil, fmt.Errorf("Redis连接失败: %w", err) } + // 创建客户端包装 + client := &Client{ + Client: rdb, + logger: logger, + stats: &RedisStats{}, + healthCheckDone: make(chan struct{}), + closeCh: make(chan struct{}), + } + + // 初始化统计信息 + client.updateStats() + + // 启动定期健康检查 + healthCheckInterval := cfg.HealthCheckInterval + if healthCheckInterval <= 0 { + healthCheckInterval = 30 * time.Second + } + client.startHealthCheck(healthCheckInterval) + logger.Info("Redis连接成功", zap.String("host", cfg.Host), zap.Int("port", cfg.Port), zap.Int("database", cfg.Database), + zap.Int("pool_size", poolSize), + zap.Int("min_idle_conns", minIdleConns), ) - return &Client{ - Client: rdb, - logger: logger, - }, nil + return client, nil +} + +// pingWithRetry 带重试的Ping操作 +func pingWithRetry(rdb *redis.Client, maxRetries int, retryInterval time.Duration) error { + var err error + for i := 0; i < maxRetries; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err = rdb.Ping(ctx).Err() + cancel() + if err == nil { + return nil + } + if i < maxRetries-1 { + time.Sleep(retryInterval) + } + } + return err +} + +// startHealthCheck 启动定期健康检查 +func (c *Client) startHealthCheck(interval time.Duration) { + c.wg.Add(1) + go func() { + defer c.wg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.doHealthCheck() + case <-c.closeCh: + return + } + } + }() +} + +// doHealthCheck 执行健康检查 +func (c *Client) doHealthCheck() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 更新统计信息 + c.updateStats() + + // 执行Ping检查 + err := c.Client.Ping(ctx).Err() + c.stats.mu.Lock() + c.stats.LastPingTime = time.Now() + c.stats.LastPingSuccess = err == nil + c.stats.mu.Unlock() + + if err != nil { + c.logger.Warn("Redis健康检查失败", zap.Error(err)) + } else { + c.logger.Debug("Redis健康检查成功") + } +} + +// updateStats 更新连接池统计信息 +func (c *Client) updateStats() { + // 获取底层连接池统计信息 + stats := c.Client.PoolStats() + c.stats.mu.Lock() + c.stats.PoolSize = c.Client.Options().PoolSize + c.stats.IdleConns = int(stats.IdleConns) + c.stats.ActiveConns = int(stats.TotalConns) - int(stats.IdleConns) + c.stats.TotalConns = int(stats.TotalConns) + c.stats.StaleConns = int(stats.StaleConns) + c.stats.mu.Unlock() +} + +// GetStats 获取连接池统计信息 +func (c *Client) GetStats() RedisStats { + c.stats.mu.RLock() + defer c.stats.mu.RUnlock() + return RedisStats{ + PoolSize: c.stats.PoolSize, + IdleConns: c.stats.IdleConns, + ActiveConns: c.stats.ActiveConns, + StaleConns: c.stats.StaleConns, + TotalConns: c.stats.TotalConns, + LastPingTime: c.stats.LastPingTime, + LastPingSuccess: c.stats.LastPingSuccess, + } +} + +// LogStats 记录连接池状态日志 +func (c *Client) LogStats() { + stats := c.GetStats() + c.logger.Info("Redis连接池状态", + zap.Int("pool_size", stats.PoolSize), + zap.Int("idle_conns", stats.IdleConns), + zap.Int("active_conns", stats.ActiveConns), + zap.Int("total_conns", stats.TotalConns), + zap.Int("stale_conns", stats.StaleConns), + zap.Bool("last_ping_success", stats.LastPingSuccess), + ) +} + +// Ping 验证Redis连接(带超时控制) +func (c *Client) Ping(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return c.Client.Ping(ctx).Err() } // Close 关闭Redis连接 func (c *Client) Close() error { + // 停止健康检查 + close(c.closeCh) + c.wg.Wait() + c.logger.Info("正在关闭Redis连接") + c.LogStats() // 关闭前记录最终状态 return c.Client.Close() } -// Set 设置键值对(带过期时间) +// ===== 以下是封装的便捷方法,用于返回 (value, error) 格式 ===== + +// Set 设置键值对(带过期时间)- 封装版本 func (c *Client) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error { return c.Client.Set(ctx, key, value, expiration).Err() } -// Get 获取键值 +// Get 获取键值 - 封装版本 func (c *Client) Get(ctx context.Context, key string) (string, error) { return c.Client.Get(ctx, key).Result() } -// Del 删除键 +// Del 删除键 - 封装版本 func (c *Client) Del(ctx context.Context, keys ...string) error { return c.Client.Del(ctx, keys...).Err() } -// Exists 检查键是否存在 +// Exists 检查键是否存在 - 封装版本 func (c *Client) Exists(ctx context.Context, keys ...string) (int64, error) { return c.Client.Exists(ctx, keys...).Result() } -// Expire 设置键的过期时间 +// Expire 设置键的过期时间 - 封装版本 func (c *Client) Expire(ctx context.Context, key string, expiration time.Duration) error { return c.Client.Expire(ctx, key, expiration).Err() } -// TTL 获取键的剩余过期时间 +// TTL 获取键的剩余过期时间 - 封装版本 func (c *Client) TTL(ctx context.Context, key string) (time.Duration, error) { return c.Client.TTL(ctx, key).Result() } -// Incr 自增 +// Incr 自增 - 封装版本 func (c *Client) Incr(ctx context.Context, key string) (int64, error) { return c.Client.Incr(ctx, key).Result() } -// Decr 自减 +// Decr 自减 - 封装版本 func (c *Client) Decr(ctx context.Context, key string) (int64, error) { return c.Client.Decr(ctx, key).Result() } -// HSet 设置哈希字段 +// HSet 设置哈希字段 - 封装版本 func (c *Client) HSet(ctx context.Context, key string, values ...interface{}) error { return c.Client.HSet(ctx, key, values...).Err() } -// HGet 获取哈希字段 +// HGet 获取哈希字段 - 封装版本 func (c *Client) HGet(ctx context.Context, key, field string) (string, error) { return c.Client.HGet(ctx, key, field).Result() } -// HGetAll 获取所有哈希字段 +// HGetAll 获取所有哈希字段 - 封装版本 func (c *Client) HGetAll(ctx context.Context, key string) (map[string]string, error) { return c.Client.HGetAll(ctx, key).Result() } -// HDel 删除哈希字段 +// HDel 删除哈希字段 - 封装版本 func (c *Client) HDel(ctx context.Context, key string, fields ...string) error { return c.Client.HDel(ctx, key, fields...).Err() } -// SAdd 添加集合成员 +// SAdd 添加集合成员 - 封装版本 func (c *Client) SAdd(ctx context.Context, key string, members ...interface{}) error { return c.Client.SAdd(ctx, key, members...).Err() } -// SMembers 获取集合所有成员 +// SMembers 获取集合所有成员 - 封装版本 func (c *Client) SMembers(ctx context.Context, key string) ([]string, error) { return c.Client.SMembers(ctx, key).Result() } -// SRem 删除集合成员 +// SRem 删除集合成员 - 封装版本 func (c *Client) SRem(ctx context.Context, key string, members ...interface{}) error { return c.Client.SRem(ctx, key, members...).Err() } -// SIsMember 检查是否是集合成员 +// SIsMember 检查是否是集合成员 - 封装版本 func (c *Client) SIsMember(ctx context.Context, key string, member interface{}) (bool, error) { return c.Client.SIsMember(ctx, key, member).Result() } -// ZAdd 添加有序集合成员 +// ZAdd 添加有序集合成员 - 封装版本 func (c *Client) ZAdd(ctx context.Context, key string, members ...redis.Z) error { return c.Client.ZAdd(ctx, key, members...).Err() } -// ZRange 获取有序集合范围内的成员 +// ZRange 获取有序集合范围内的成员 - 封装版本 func (c *Client) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { return c.Client.ZRange(ctx, key, start, stop).Result() } -// ZRem 删除有序集合成员 +// ZRem 删除有序集合成员 - 封装版本 func (c *Client) ZRem(ctx context.Context, key string, members ...interface{}) error { return c.Client.ZRem(ctx, key, members...).Err() } @@ -207,6 +358,7 @@ func (c *Client) TxPipeline() redis.Pipeliner { return c.Client.TxPipeline() } +// Nil 检查错误是否为Nil(key不存在) func (c *Client) Nil(err error) bool { return errors.Is(err, redis.Nil) }