package database import ( "context" "database/sql" "fmt" "log" "os" "sync" "time" "carrotskin/pkg/config" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // DBStats 数据库连接池统计信息 type DBStats struct { MaxOpenConns int // 最大打开连接数 OpenConns int // 当前打开的连接数 InUseConns int // 正在使用的连接数 IdleConns int // 空闲连接数 WaitCount int64 // 等待连接的总次数 WaitDuration time.Duration // 等待连接的总时间 LastPingTime time.Time // 上次探活时间 LastPingSuccess bool // 上次探活是否成功 mu sync.RWMutex // 保护 LastPingTime 和 LastPingSuccess } // DB 数据库封装,包含连接池统计 type DB struct { *gorm.DB stats *DBStats sqlDB *sql.DB healthCh chan struct{} // 健康检查信号通道 closeCh chan struct{} // 关闭信号通道 wg sync.WaitGroup } // New 创建新的PostgreSQL数据库连接 func New(cfg config.DatabaseConfig) (*DB, error) { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode, cfg.Timezone, ) // 配置慢查询监控 - 优化:从200ms调整为100ms newLogger := logger.New( log.New(os.Stdout, "\r\n", log.LstdFlags), logger.Config{ SlowThreshold: 100 * time.Millisecond, // 慢查询阈值:100ms(优化后) LogLevel: logger.Warn, // 只记录警告和错误 IgnoreRecordNotFoundError: true, // 忽略记录未找到错误 Colorful: false, // 生产环境禁用彩色 }, ) // 打开数据库连接 db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: newLogger, DisableForeignKeyConstraintWhenMigrating: true, // 禁用外键约束 PrepareStmt: true, // 启用预编译语句缓存 QueryFields: true, // 明确指定查询字段 }) if err != nil { return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err) } // 获取底层SQL数据库实例 sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("获取数据库实例失败: %w", err) } // 优化连接池配置 maxIdleConns := cfg.MaxIdleConns if maxIdleConns <= 0 { maxIdleConns = 10 } maxOpenConns := cfg.MaxOpenConns if maxOpenConns <= 0 { maxOpenConns = 100 } connMaxLifetime := cfg.ConnMaxLifetime if connMaxLifetime <= 0 { connMaxLifetime = 1 * time.Hour } connMaxIdleTime := cfg.ConnMaxIdleTime if connMaxIdleTime <= 0 { connMaxIdleTime = 10 * time.Minute } sqlDB.SetMaxIdleConns(maxIdleConns) sqlDB.SetMaxOpenConns(maxOpenConns) sqlDB.SetConnMaxLifetime(connMaxLifetime) sqlDB.SetConnMaxIdleTime(connMaxIdleTime) // 测试连接(带重试机制) if err := pingWithRetry(sqlDB, 3, 2*time.Second); err != nil { return nil, fmt.Errorf("数据库连接测试失败: %w", err) } // 创建数据库封装 database := &DB{ DB: db, sqlDB: sqlDB, stats: &DBStats{}, healthCh: make(chan struct{}, 1), closeCh: make(chan struct{}), } // 初始化统计信息 database.updateStats() // 启动定期健康检查 database.startHealthCheck(30 * time.Second) log.Println("[Database] PostgreSQL连接池初始化成功") log.Printf("[Database] 连接池配置: MaxIdleConns=%d, MaxOpenConns=%d, ConnMaxLifetime=%v, ConnMaxIdleTime=%v", maxIdleConns, maxOpenConns, connMaxLifetime, connMaxIdleTime) return database, nil } // pingWithRetry 带重试的Ping操作 func pingWithRetry(sqlDB *sql.DB, maxRetries int, retryInterval time.Duration) error { var err error for i := 0; i < maxRetries; i++ { if err = sqlDB.Ping(); err == nil { return nil } if i < maxRetries-1 { log.Printf("[Database] Ping失败,%v 后重试 (%d/%d): %v", retryInterval, i+1, maxRetries, err) time.Sleep(retryInterval) } } return err } // startHealthCheck 启动定期健康检查 func (d *DB) startHealthCheck(interval time.Duration) { d.wg.Add(1) go func() { defer d.wg.Done() ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: d.ping() case <-d.healthCh: d.ping() case <-d.closeCh: return } } }() } // ping 执行连接健康检查 func (d *DB) ping() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := d.sqlDB.PingContext(ctx) d.stats.mu.Lock() d.stats.LastPingTime = time.Now() d.stats.LastPingSuccess = err == nil d.stats.mu.Unlock() if err != nil { log.Printf("[Database] 连接健康检查失败: %v", err) } else { log.Println("[Database] 连接健康检查成功") } } // GetStats 获取连接池统计信息 func (d *DB) GetStats() DBStats { d.stats.mu.RLock() defer d.stats.mu.RUnlock() // 从底层获取实时统计 stats := d.sqlDB.Stats() d.stats.MaxOpenConns = stats.MaxOpenConnections d.stats.OpenConns = stats.OpenConnections d.stats.InUseConns = stats.InUse d.stats.IdleConns = stats.Idle d.stats.WaitCount = stats.WaitCount d.stats.WaitDuration = stats.WaitDuration return *d.stats } // updateStats 初始化统计信息 func (d *DB) updateStats() { stats := d.sqlDB.Stats() d.stats.MaxOpenConns = stats.MaxOpenConnections d.stats.OpenConns = stats.OpenConnections d.stats.InUseConns = stats.InUse d.stats.IdleConns = stats.Idle } // LogStats 记录连接池状态日志 func (d *DB) LogStats() { stats := d.GetStats() log.Printf("[Database] 连接池状态: Open=%d, Idle=%d, InUse=%d, WaitCount=%d, WaitDuration=%v, LastPing=%v (%v)", stats.OpenConns, stats.IdleConns, stats.InUseConns, stats.WaitCount, stats.WaitDuration, stats.LastPingTime.Format("2006-01-02 15:04:05"), stats.LastPingSuccess) } // Close 关闭数据库连接 func (d *DB) Close() error { close(d.closeCh) d.wg.Wait() return d.sqlDB.Close() } // WithTimeout 创建带有超时控制的上下文 func WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { return context.WithTimeout(parent, timeout) } // GetDSN 获取数据源名称 func GetDSN(cfg config.DatabaseConfig) string { return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Database, cfg.SSLMode, cfg.Timezone, ) }