Files
backend/pkg/database/postgres.go

247 lines
6.4 KiB
Go
Raw Normal View History

package database
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"sync"
"time"
"carrotskin/pkg/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DBStats 数据库连接池统计信息
type DBStats struct {
MaxOpenConns int // 最大打开连接数
OpenConns int // 当前打开的连接数
InUseConns int // 正在使用的连接数
IdleConns int // 空闲连接数
WaitCount int64 // 等待连接的总次数
WaitDuration time.Duration // 等待连接的总时间
LastPingTime time.Time // 上次探活时间
LastPingSuccess bool // 上次探活是否成功
mu sync.RWMutex // 保护 LastPingTime 和 LastPingSuccess
}
// DB 数据库封装,包含连接池统计
type DB struct {
*gorm.DB
stats *DBStats
sqlDB *sql.DB
healthCh chan struct{} // 健康检查信号通道
closeCh chan struct{} // 关闭信号通道
wg sync.WaitGroup
}
// New 创建新的PostgreSQL数据库连接
func New(cfg config.DatabaseConfig) (*DB, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
cfg.Timezone,
)
// 配置慢查询监控 - 优化从200ms调整为100ms
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: 100 * time.Millisecond, // 慢查询阈值100ms优化后
LogLevel: logger.Warn, // 只记录警告和错误
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
Colorful: false, // 生产环境禁用彩色
},
)
// 打开数据库连接
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
DisableForeignKeyConstraintWhenMigrating: true, // 禁用外键约束
PrepareStmt: true, // 启用预编译语句缓存
QueryFields: true, // 明确指定查询字段
})
if err != nil {
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
}
// 获取底层SQL数据库实例
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 优化连接池配置
maxIdleConns := cfg.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = 10
}
maxOpenConns := cfg.MaxOpenConns
if maxOpenConns <= 0 {
maxOpenConns = 100
}
connMaxLifetime := cfg.ConnMaxLifetime
if connMaxLifetime <= 0 {
connMaxLifetime = 1 * time.Hour
}
2025-12-03 10:58:39 +08:00
connMaxIdleTime := cfg.ConnMaxIdleTime
if connMaxIdleTime <= 0 {
connMaxIdleTime = 10 * time.Minute
}
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
2025-12-03 10:58:39 +08:00
sqlDB.SetConnMaxIdleTime(connMaxIdleTime)
// 测试连接(带重试机制)
if err := pingWithRetry(sqlDB, 3, 2*time.Second); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
// 创建数据库封装
database := &DB{
DB: db,
sqlDB: sqlDB,
stats: &DBStats{},
healthCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
}
// 初始化统计信息
database.updateStats()
// 启动定期健康检查
database.startHealthCheck(30 * time.Second)
log.Println("[Database] PostgreSQL连接池初始化成功")
log.Printf("[Database] 连接池配置: MaxIdleConns=%d, MaxOpenConns=%d, ConnMaxLifetime=%v, ConnMaxIdleTime=%v",
maxIdleConns, maxOpenConns, connMaxLifetime, connMaxIdleTime)
return database, nil
}
// pingWithRetry 带重试的Ping操作
func pingWithRetry(sqlDB *sql.DB, maxRetries int, retryInterval time.Duration) error {
var err error
for i := 0; i < maxRetries; i++ {
if err = sqlDB.Ping(); err == nil {
return nil
}
if i < maxRetries-1 {
log.Printf("[Database] Ping失败%v 后重试 (%d/%d): %v", retryInterval, i+1, maxRetries, err)
time.Sleep(retryInterval)
}
}
return err
}
// startHealthCheck 启动定期健康检查
func (d *DB) startHealthCheck(interval time.Duration) {
d.wg.Add(1)
go func() {
defer d.wg.Done()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
d.ping()
case <-d.healthCh:
d.ping()
case <-d.closeCh:
return
}
}
}()
}
// ping 执行连接健康检查
func (d *DB) ping() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.sqlDB.PingContext(ctx)
d.stats.mu.Lock()
d.stats.LastPingTime = time.Now()
d.stats.LastPingSuccess = err == nil
d.stats.mu.Unlock()
if err != nil {
log.Printf("[Database] 连接健康检查失败: %v", err)
} else {
log.Println("[Database] 连接健康检查成功")
}
}
// GetStats 获取连接池统计信息
func (d *DB) GetStats() DBStats {
d.stats.mu.RLock()
defer d.stats.mu.RUnlock()
// 从底层获取实时统计
stats := d.sqlDB.Stats()
d.stats.MaxOpenConns = stats.MaxOpenConnections
d.stats.OpenConns = stats.OpenConnections
d.stats.InUseConns = stats.InUse
d.stats.IdleConns = stats.Idle
d.stats.WaitCount = stats.WaitCount
d.stats.WaitDuration = stats.WaitDuration
return *d.stats
}
// updateStats 初始化统计信息
func (d *DB) updateStats() {
stats := d.sqlDB.Stats()
d.stats.MaxOpenConns = stats.MaxOpenConnections
d.stats.OpenConns = stats.OpenConnections
d.stats.InUseConns = stats.InUse
d.stats.IdleConns = stats.Idle
}
// LogStats 记录连接池状态日志
func (d *DB) LogStats() {
stats := d.GetStats()
log.Printf("[Database] 连接池状态: Open=%d, Idle=%d, InUse=%d, WaitCount=%d, WaitDuration=%v, LastPing=%v (%v)",
stats.OpenConns, stats.IdleConns, stats.InUseConns, stats.WaitCount, stats.WaitDuration,
stats.LastPingTime.Format("2006-01-02 15:04:05"), stats.LastPingSuccess)
}
// Close 关闭数据库连接
func (d *DB) Close() error {
close(d.closeCh)
d.wg.Wait()
return d.sqlDB.Close()
}
// WithTimeout 创建带有超时控制的上下文
func WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, timeout)
}
// GetDSN 获取数据源名称
func GetDSN(cfg config.DatabaseConfig) string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
cfg.Timezone,
)
}