Files
backend/pkg/database/postgres.go
lan a111872b32
Some checks failed
Build / build (push) Successful in 2m23s
Build / build-docker (push) Failing after 1m37s
feat(auth): upgrade casbin to v3 and enhance connection pool configurations
- Upgrade casbin from v2 to v3 across go.mod and pkg/auth/casbin.go
- Add slide captcha verification to registration flow (CheckVerified, ConsumeVerified)
- Add DB wrapper with connection pool statistics and health checks
- Add Redis connection pool optimizations with stats and health monitoring
- Add new config options: ConnMaxLifetime, HealthCheckInterval, EnableRetryOnError
- Optimize slow query threshold from 200ms to 100ms
- Add ping with retry mechanism for database and Redis connections
2026-02-25 19:00:50 +08:00

247 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"sync"
"time"
"carrotskin/pkg/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DBStats 数据库连接池统计信息
type DBStats struct {
MaxOpenConns int // 最大打开连接数
OpenConns int // 当前打开的连接数
InUseConns int // 正在使用的连接数
IdleConns int // 空闲连接数
WaitCount int64 // 等待连接的总次数
WaitDuration time.Duration // 等待连接的总时间
LastPingTime time.Time // 上次探活时间
LastPingSuccess bool // 上次探活是否成功
mu sync.RWMutex // 保护 LastPingTime 和 LastPingSuccess
}
// DB 数据库封装,包含连接池统计
type DB struct {
*gorm.DB
stats *DBStats
sqlDB *sql.DB
healthCh chan struct{} // 健康检查信号通道
closeCh chan struct{} // 关闭信号通道
wg sync.WaitGroup
}
// New 创建新的PostgreSQL数据库连接
func New(cfg config.DatabaseConfig) (*DB, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
cfg.Timezone,
)
// 配置慢查询监控 - 优化从200ms调整为100ms
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: 100 * time.Millisecond, // 慢查询阈值100ms优化后
LogLevel: logger.Warn, // 只记录警告和错误
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
Colorful: false, // 生产环境禁用彩色
},
)
// 打开数据库连接
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: newLogger,
DisableForeignKeyConstraintWhenMigrating: true, // 禁用外键约束
PrepareStmt: true, // 启用预编译语句缓存
QueryFields: true, // 明确指定查询字段
})
if err != nil {
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
}
// 获取底层SQL数据库实例
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 优化连接池配置
maxIdleConns := cfg.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = 10
}
maxOpenConns := cfg.MaxOpenConns
if maxOpenConns <= 0 {
maxOpenConns = 100
}
connMaxLifetime := cfg.ConnMaxLifetime
if connMaxLifetime <= 0 {
connMaxLifetime = 1 * time.Hour
}
connMaxIdleTime := cfg.ConnMaxIdleTime
if connMaxIdleTime <= 0 {
connMaxIdleTime = 10 * time.Minute
}
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
sqlDB.SetConnMaxIdleTime(connMaxIdleTime)
// 测试连接(带重试机制)
if err := pingWithRetry(sqlDB, 3, 2*time.Second); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
// 创建数据库封装
database := &DB{
DB: db,
sqlDB: sqlDB,
stats: &DBStats{},
healthCh: make(chan struct{}, 1),
closeCh: make(chan struct{}),
}
// 初始化统计信息
database.updateStats()
// 启动定期健康检查
database.startHealthCheck(30 * time.Second)
log.Println("[Database] PostgreSQL连接池初始化成功")
log.Printf("[Database] 连接池配置: MaxIdleConns=%d, MaxOpenConns=%d, ConnMaxLifetime=%v, ConnMaxIdleTime=%v",
maxIdleConns, maxOpenConns, connMaxLifetime, connMaxIdleTime)
return database, nil
}
// pingWithRetry 带重试的Ping操作
func pingWithRetry(sqlDB *sql.DB, maxRetries int, retryInterval time.Duration) error {
var err error
for i := 0; i < maxRetries; i++ {
if err = sqlDB.Ping(); err == nil {
return nil
}
if i < maxRetries-1 {
log.Printf("[Database] Ping失败%v 后重试 (%d/%d): %v", retryInterval, i+1, maxRetries, err)
time.Sleep(retryInterval)
}
}
return err
}
// startHealthCheck 启动定期健康检查
func (d *DB) startHealthCheck(interval time.Duration) {
d.wg.Add(1)
go func() {
defer d.wg.Done()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
d.ping()
case <-d.healthCh:
d.ping()
case <-d.closeCh:
return
}
}
}()
}
// ping 执行连接健康检查
func (d *DB) ping() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := d.sqlDB.PingContext(ctx)
d.stats.mu.Lock()
d.stats.LastPingTime = time.Now()
d.stats.LastPingSuccess = err == nil
d.stats.mu.Unlock()
if err != nil {
log.Printf("[Database] 连接健康检查失败: %v", err)
} else {
log.Println("[Database] 连接健康检查成功")
}
}
// GetStats 获取连接池统计信息
func (d *DB) GetStats() DBStats {
d.stats.mu.RLock()
defer d.stats.mu.RUnlock()
// 从底层获取实时统计
stats := d.sqlDB.Stats()
d.stats.MaxOpenConns = stats.MaxOpenConnections
d.stats.OpenConns = stats.OpenConnections
d.stats.InUseConns = stats.InUse
d.stats.IdleConns = stats.Idle
d.stats.WaitCount = stats.WaitCount
d.stats.WaitDuration = stats.WaitDuration
return *d.stats
}
// updateStats 初始化统计信息
func (d *DB) updateStats() {
stats := d.sqlDB.Stats()
d.stats.MaxOpenConns = stats.MaxOpenConnections
d.stats.OpenConns = stats.OpenConnections
d.stats.InUseConns = stats.InUse
d.stats.IdleConns = stats.Idle
}
// LogStats 记录连接池状态日志
func (d *DB) LogStats() {
stats := d.GetStats()
log.Printf("[Database] 连接池状态: Open=%d, Idle=%d, InUse=%d, WaitCount=%d, WaitDuration=%v, LastPing=%v (%v)",
stats.OpenConns, stats.IdleConns, stats.InUseConns, stats.WaitCount, stats.WaitDuration,
stats.LastPingTime.Format("2006-01-02 15:04:05"), stats.LastPingSuccess)
}
// Close 关闭数据库连接
func (d *DB) Close() error {
close(d.closeCh)
d.wg.Wait()
return d.sqlDB.Close()
}
// WithTimeout 创建带有超时控制的上下文
func WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, timeout)
}
// GetDSN 获取数据源名称
func GetDSN(cfg config.DatabaseConfig) string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
cfg.Host,
cfg.Port,
cfg.Username,
cfg.Password,
cfg.Database,
cfg.SSLMode,
cfg.Timezone,
)
}