chore: 初始化仓库,排除二进制文件和覆盖率文件
This commit is contained in:
113
pkg/database/manager.go
Normal file
113
pkg/database/manager.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"carrotskin/internal/model"
|
||||
"carrotskin/pkg/config"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
// dbInstance 全局数据库实例
|
||||
dbInstance *gorm.DB
|
||||
// once 确保只初始化一次
|
||||
once sync.Once
|
||||
// initError 初始化错误
|
||||
initError error
|
||||
)
|
||||
|
||||
// Init 初始化数据库连接(线程安全,只会执行一次)
|
||||
func Init(cfg config.DatabaseConfig, logger *zap.Logger) error {
|
||||
once.Do(func() {
|
||||
dbInstance, initError = New(cfg)
|
||||
if initError != nil {
|
||||
logger.Error("数据库初始化失败", zap.Error(initError))
|
||||
return
|
||||
}
|
||||
logger.Info("数据库连接成功")
|
||||
})
|
||||
return initError
|
||||
}
|
||||
|
||||
// GetDB 获取数据库实例(线程安全)
|
||||
func GetDB() (*gorm.DB, error) {
|
||||
if dbInstance == nil {
|
||||
return nil, fmt.Errorf("数据库未初始化,请先调用 database.Init()")
|
||||
}
|
||||
return dbInstance, nil
|
||||
}
|
||||
|
||||
// MustGetDB 获取数据库实例,如果未初始化则panic
|
||||
func MustGetDB() *gorm.DB {
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// AutoMigrate 自动迁移数据库表结构
|
||||
func AutoMigrate(logger *zap.Logger) error {
|
||||
db, err := GetDB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("开始执行数据库迁移...")
|
||||
|
||||
// 迁移所有表 - 注意顺序:先创建被引用的表,再创建引用表
|
||||
err = db.AutoMigrate(
|
||||
// 用户相关表(先创建,因为其他表可能引用它)
|
||||
&model.User{},
|
||||
&model.UserPointLog{},
|
||||
&model.UserLoginLog{},
|
||||
|
||||
// 档案相关表
|
||||
&model.Profile{},
|
||||
|
||||
// 材质相关表
|
||||
&model.Texture{},
|
||||
&model.UserTextureFavorite{},
|
||||
&model.TextureDownloadLog{},
|
||||
|
||||
// 认证相关表
|
||||
&model.Token{},
|
||||
|
||||
// Yggdrasil相关表(在User之后创建,因为它引用User)
|
||||
&model.Yggdrasil{},
|
||||
|
||||
// 系统配置表
|
||||
&model.SystemConfig{},
|
||||
|
||||
// 审计日志表
|
||||
&model.AuditLog{},
|
||||
|
||||
// Casbin权限规则表
|
||||
&model.CasbinRule{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("数据库迁移失败", zap.Error(err))
|
||||
return fmt.Errorf("数据库迁移失败: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("数据库迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func Close() error {
|
||||
if dbInstance == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := dbInstance.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sqlDB.Close()
|
||||
}
|
||||
85
pkg/database/manager_test.go
Normal file
85
pkg/database/manager_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"carrotskin/pkg/config"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap/zaptest"
|
||||
)
|
||||
|
||||
// TestGetDB_NotInitialized 测试未初始化时获取数据库实例
|
||||
func TestGetDB_NotInitialized(t *testing.T) {
|
||||
_, err := GetDB()
|
||||
if err == nil {
|
||||
t.Error("未初始化时应该返回错误")
|
||||
}
|
||||
|
||||
expectedError := "数据库未初始化,请先调用 database.Init()"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("错误消息 = %q, want %q", err.Error(), expectedError)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMustGetDB_Panic 测试MustGetDB在未初始化时panic
|
||||
func TestMustGetDB_Panic(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("MustGetDB 应该在未初始化时panic")
|
||||
}
|
||||
}()
|
||||
|
||||
_ = MustGetDB()
|
||||
}
|
||||
|
||||
// TestInit_Database 测试数据库初始化逻辑
|
||||
func TestInit_Database(t *testing.T) {
|
||||
cfg := config.DatabaseConfig{
|
||||
Driver: "postgres",
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "password",
|
||||
Database: "testdb",
|
||||
SSLMode: "disable",
|
||||
Timezone: "Asia/Shanghai",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
ConnMaxLifetime: 0,
|
||||
}
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 验证Init函数存在且可调用
|
||||
// 注意:实际连接可能失败,这是可以接受的
|
||||
err := Init(cfg, logger)
|
||||
if err != nil {
|
||||
t.Logf("Init() 返回错误(可能正常,如果数据库未运行): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoMigrate_ErrorHandling 测试AutoMigrate的错误处理逻辑
|
||||
func TestAutoMigrate_ErrorHandling(t *testing.T) {
|
||||
logger := zaptest.NewLogger(t)
|
||||
|
||||
// 测试未初始化时的错误处理
|
||||
err := AutoMigrate(logger)
|
||||
if err == nil {
|
||||
// 如果数据库已初始化,这是正常的
|
||||
t.Log("AutoMigrate() 成功(数据库可能已初始化)")
|
||||
} else {
|
||||
// 如果数据库未初始化,应该返回错误
|
||||
if err.Error() == "" {
|
||||
t.Error("AutoMigrate() 应该返回有意义的错误消息")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestClose_NotInitialized 测试未初始化时关闭数据库
|
||||
func TestClose_NotInitialized(t *testing.T) {
|
||||
// 未初始化时关闭应该不返回错误
|
||||
err := Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close() 在未初始化时应该返回nil,实际返回: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
73
pkg/database/postgres.go
Normal file
73
pkg/database/postgres.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"carrotskin/pkg/config"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// New 创建新的PostgreSQL数据库连接
|
||||
func New(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.Username,
|
||||
cfg.Password,
|
||||
cfg.Database,
|
||||
cfg.SSLMode,
|
||||
cfg.Timezone,
|
||||
)
|
||||
|
||||
// 配置GORM日志级别
|
||||
var gormLogLevel logger.LogLevel
|
||||
switch {
|
||||
case cfg.Driver == "postgres":
|
||||
gormLogLevel = logger.Info
|
||||
default:
|
||||
gormLogLevel = logger.Silent
|
||||
}
|
||||
|
||||
// 打开数据库连接
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(gormLogLevel),
|
||||
DisableForeignKeyConstraintWhenMigrating: true, // 禁用自动创建外键约束,避免循环依赖问题
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层SQL数据库实例
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// GetDSN 获取数据源名称
|
||||
func GetDSN(cfg config.DatabaseConfig) string {
|
||||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
cfg.Host,
|
||||
cfg.Port,
|
||||
cfg.Username,
|
||||
cfg.Password,
|
||||
cfg.Database,
|
||||
cfg.SSLMode,
|
||||
cfg.Timezone,
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user