feat: Enhance dependency injection and service integration

- Updated main.go to initialize email service and include it in the dependency injection container.
- Refactored handlers to utilize context in service method calls, improving consistency and error handling.
- Introduced new service options for upload, security, and captcha services, enhancing modularity and testability.
- Removed unused repository implementations to streamline the codebase.

This commit continues the effort to improve the architecture by ensuring all services are properly injected and utilized across the application.
This commit is contained in:
lan
2025-12-02 22:52:33 +08:00
parent 792e96b238
commit 034e02e93a
54 changed files with 2305 additions and 2708 deletions

442
pkg/database/cache.go Normal file
View File

@@ -0,0 +1,442 @@
package database
import (
"context"
"encoding/json"
"fmt"
"time"
"carrotskin/pkg/redis"
)
// CacheConfig 缓存配置
type CacheConfig struct {
Prefix string // 缓存键前缀
Expiration time.Duration // 过期时间
Enabled bool // 是否启用缓存
}
// CacheManager 缓存管理器
type CacheManager struct {
redis *redis.Client
config CacheConfig
}
// NewCacheManager 创建缓存管理器
func NewCacheManager(redisClient *redis.Client, config CacheConfig) *CacheManager {
if config.Prefix == "" {
config.Prefix = "db:"
}
if config.Expiration == 0 {
config.Expiration = 5 * time.Minute
}
return &CacheManager{
redis: redisClient,
config: config,
}
}
// buildKey 构建缓存键
func (cm *CacheManager) buildKey(key string) string {
return cm.config.Prefix + key
}
// Get 获取缓存
func (cm *CacheManager) Get(ctx context.Context, key string, dest interface{}) error {
if !cm.config.Enabled || cm.redis == nil {
return fmt.Errorf("cache not enabled")
}
data, err := cm.redis.GetBytes(ctx, cm.buildKey(key))
if err != nil || data == nil {
return fmt.Errorf("cache miss")
}
return json.Unmarshal(data, dest)
}
// Set 设置缓存
func (cm *CacheManager) Set(ctx context.Context, key string, value interface{}, expiration ...time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
data, err := json.Marshal(value)
if err != nil {
return err
}
exp := cm.config.Expiration
if len(expiration) > 0 && expiration[0] > 0 {
exp = expiration[0]
}
return cm.redis.Set(ctx, cm.buildKey(key), data, exp)
}
// Delete 删除缓存
func (cm *CacheManager) Delete(ctx context.Context, keys ...string) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = cm.buildKey(key)
}
return cm.redis.Del(ctx, fullKeys...)
}
// DeletePattern 删除匹配模式的缓存
// 使用 Redis SCAN 命令安全地删除匹配的键,避免阻塞
func (cm *CacheManager) DeletePattern(ctx context.Context, pattern string) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
// 构建完整的匹配模式
fullPattern := cm.buildKey(pattern)
// 使用 SCAN 命令迭代查找匹配的键
var cursor uint64
var deletedCount int
for {
// 每次扫描100个键
keys, nextCursor, err := cm.redis.Client.Scan(ctx, cursor, fullPattern, 100).Result()
if err != nil {
return fmt.Errorf("扫描缓存键失败: %w", err)
}
// 批量删除找到的键
if len(keys) > 0 {
if err := cm.redis.Client.Del(ctx, keys...).Err(); err != nil {
return fmt.Errorf("删除缓存键失败: %w", err)
}
deletedCount += len(keys)
}
// 更新游标
cursor = nextCursor
// cursor == 0 表示扫描完成
if cursor == 0 {
break
}
// 检查 context 是否已取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
return nil
}
// GetOrSet 获取缓存,如果不存在则执行回调并设置缓存
func (cm *CacheManager) GetOrSet(ctx context.Context, key string, dest interface{}, fn func() (interface{}, error), expiration ...time.Duration) error {
// 尝试从缓存获取
err := cm.Get(ctx, key, dest)
if err == nil {
return nil // 缓存命中
}
// 缓存未命中,执行回调获取数据
result, err := fn()
if err != nil {
return err
}
// 设置缓存
if err := cm.Set(ctx, key, result, expiration...); err != nil {
// 缓存设置失败不影响主流程,只记录日志
// logger.Warn("failed to set cache", zap.Error(err))
}
// 将结果转换为目标类型
data, err := json.Marshal(result)
if err != nil {
return err
}
return json.Unmarshal(data, dest)
}
// Cached 缓存装饰器 - 为查询函数添加缓存
func Cached[T any](
ctx context.Context,
cache *CacheManager,
key string,
queryFn func() (*T, error),
expiration ...time.Duration,
) (*T, error) {
// 尝试从缓存获取
var result T
if err := cache.Get(ctx, key, &result); err == nil {
return &result, nil
}
// 缓存未命中,执行查询
data, err := queryFn()
if err != nil {
return nil, err
}
// 设置缓存(异步,不阻塞)
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = cache.Set(cacheCtx, key, data, expiration...)
}()
return data, nil
}
// CachedList 缓存装饰器 - 为列表查询添加缓存
func CachedList[T any](
ctx context.Context,
cache *CacheManager,
key string,
queryFn func() ([]T, error),
expiration ...time.Duration,
) ([]T, error) {
// 尝试从缓存获取
var result []T
if err := cache.Get(ctx, key, &result); err == nil {
return result, nil
}
// 缓存未命中,执行查询
data, err := queryFn()
if err != nil {
return nil, err
}
// 设置缓存(异步,不阻塞)
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_ = cache.Set(cacheCtx, key, data, expiration...)
}()
return data, nil
}
// InvalidateCache 使缓存失效的辅助函数
type CacheInvalidator struct {
cache *CacheManager
}
// NewCacheInvalidator 创建缓存失效器
func NewCacheInvalidator(cache *CacheManager) *CacheInvalidator {
return &CacheInvalidator{cache: cache}
}
// OnCreate 创建时使缓存失效
func (ci *CacheInvalidator) OnCreate(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// OnUpdate 更新时使缓存失效
func (ci *CacheInvalidator) OnUpdate(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// OnDelete 删除时使缓存失效
func (ci *CacheInvalidator) OnDelete(ctx context.Context, keys ...string) {
_ = ci.cache.Delete(ctx, keys...)
}
// BatchInvalidate 批量使缓存失效(支持模式匹配)
func (ci *CacheInvalidator) BatchInvalidate(ctx context.Context, pattern string) {
_ = ci.cache.DeletePattern(ctx, pattern)
}
// CacheKeyBuilder 缓存键构建器
type CacheKeyBuilder struct {
prefix string
}
// NewCacheKeyBuilder 创建缓存键构建器
func NewCacheKeyBuilder(prefix string) *CacheKeyBuilder {
return &CacheKeyBuilder{prefix: prefix}
}
// User 构建用户相关缓存键
func (b *CacheKeyBuilder) User(userID int64) string {
return fmt.Sprintf("%suser:id:%d", b.prefix, userID)
}
// UserByEmail 构建邮箱查询缓存键
func (b *CacheKeyBuilder) UserByEmail(email string) string {
return fmt.Sprintf("%suser:email:%s", b.prefix, email)
}
// UserByUsername 构建用户名查询缓存键
func (b *CacheKeyBuilder) UserByUsername(username string) string {
return fmt.Sprintf("%suser:username:%s", b.prefix, username)
}
// Profile 构建档案缓存键
func (b *CacheKeyBuilder) Profile(uuid string) string {
return fmt.Sprintf("%sprofile:uuid:%s", b.prefix, uuid)
}
// ProfileList 构建用户档案列表缓存键
func (b *CacheKeyBuilder) ProfileList(userID int64) string {
return fmt.Sprintf("%sprofile:user:%d:list", b.prefix, userID)
}
// Texture 构建材质缓存键
func (b *CacheKeyBuilder) Texture(textureID int64) string {
return fmt.Sprintf("%stexture:id:%d", b.prefix, textureID)
}
// TextureList 构建材质列表缓存键
func (b *CacheKeyBuilder) TextureList(userID int64, page int) string {
return fmt.Sprintf("%stexture:user:%d:page:%d", b.prefix, userID, page)
}
// Token 构建令牌缓存键
func (b *CacheKeyBuilder) Token(accessToken string) string {
return fmt.Sprintf("%stoken:%s", b.prefix, accessToken)
}
// UserPattern 用户相关的所有缓存键模式
func (b *CacheKeyBuilder) UserPattern(userID int64) string {
return fmt.Sprintf("%suser:*:%d*", b.prefix, userID)
}
// ProfilePattern 档案相关的所有缓存键模式
func (b *CacheKeyBuilder) ProfilePattern(userID int64) string {
return fmt.Sprintf("%sprofile:*:%d*", b.prefix, userID)
}
// Exists 检查缓存键是否存在
func (cm *CacheManager) Exists(ctx context.Context, key string) (bool, error) {
if !cm.config.Enabled || cm.redis == nil {
return false, nil
}
count, err := cm.redis.Exists(ctx, cm.buildKey(key))
if err != nil {
return false, err
}
return count > 0, nil
}
// TTL 获取缓存键的剩余过期时间
func (cm *CacheManager) TTL(ctx context.Context, key string) (time.Duration, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.TTL(ctx, cm.buildKey(key))
}
// Expire 设置缓存键的过期时间
func (cm *CacheManager) Expire(ctx context.Context, key string, expiration time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
return cm.redis.Expire(ctx, cm.buildKey(key), expiration)
}
// MGet 批量获取多个缓存
func (cm *CacheManager) MGet(ctx context.Context, keys []string) (map[string]interface{}, error) {
if !cm.config.Enabled || cm.redis == nil {
return nil, fmt.Errorf("cache not enabled")
}
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
// 构建完整的键
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = cm.buildKey(key)
}
// 批量获取
values, err := cm.redis.Client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
// 解析结果
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
result[keys[i]] = val
}
}
return result, nil
}
// MSet 批量设置多个缓存
func (cm *CacheManager) MSet(ctx context.Context, values map[string]interface{}, expiration time.Duration) error {
if !cm.config.Enabled || cm.redis == nil {
return nil
}
if len(values) == 0 {
return nil
}
// 逐个设置Redis MSet 不支持过期时间)
for key, value := range values {
if err := cm.Set(ctx, key, value, expiration); err != nil {
return err
}
}
return nil
}
// Increment 递增缓存值
func (cm *CacheManager) Increment(ctx context.Context, key string) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.Incr(ctx, cm.buildKey(key))
}
// Decrement 递减缓存值
func (cm *CacheManager) Decrement(ctx context.Context, key string) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
return cm.redis.Decr(ctx, cm.buildKey(key))
}
// IncrementWithExpire 递增并设置过期时间
func (cm *CacheManager) IncrementWithExpire(ctx context.Context, key string, expiration time.Duration) (int64, error) {
if !cm.config.Enabled || cm.redis == nil {
return 0, fmt.Errorf("cache not enabled")
}
fullKey := cm.buildKey(key)
// 递增
val, err := cm.redis.Incr(ctx, fullKey)
if err != nil {
return 0, err
}
// 设置过期时间(如果是新键)
if val == 1 {
_ = cm.redis.Expire(ctx, fullKey, expiration)
}
return val, nil
}

View File

@@ -90,28 +90,10 @@ func AutoMigrate(logger *zap.Logger) error {
&model.CasbinRule{},
}
// 逐个迁移表,以便更好地定位问题
for _, table := range tables {
tableName := fmt.Sprintf("%T", table)
logger.Info("正在迁移表", zap.String("table", tableName))
if err := db.AutoMigrate(table); err != nil {
logger.Error("数据库迁移失败", zap.Error(err), zap.String("table", tableName))
// 如果是 User 表且错误是 insufficient arguments可能是 Properties 字段问题
if tableName == "*model.User" {
logger.Warn("User 表迁移失败,可能是 Properties 字段问题,尝试修复...")
// 尝试手动添加 properties 字段(如果不存在)
if err := db.Exec("ALTER TABLE \"user\" ADD COLUMN IF NOT EXISTS properties jsonb").Error; err != nil {
logger.Error("添加 properties 字段失败", zap.Error(err))
}
// 再次尝试迁移
if err := db.AutoMigrate(table); err != nil {
return fmt.Errorf("数据库迁移失败 (表: %T): %w", table, err)
}
} else {
return fmt.Errorf("数据库迁移失败 (表: %T): %w", table, err)
}
}
logger.Info("表迁移成功", zap.String("table", tableName))
// 批量迁移表
if err := db.AutoMigrate(tables...); err != nil {
logger.Error("数据库迁移失败", zap.Error(err))
return fmt.Errorf("数据库迁移失败: %w", err)
}
logger.Info("数据库迁移完成")

View File

@@ -0,0 +1,155 @@
package database
import (
"context"
"time"
"gorm.io/gorm"
)
// QueryConfig 查询配置
type QueryConfig struct {
Timeout time.Duration // 查询超时时间
Select []string // 只查询指定字段
Preload []string // 预加载关联
}
// WithContext 为查询添加 context 超时控制
func WithContext(ctx context.Context, db *gorm.DB, timeout time.Duration) *gorm.DB {
if timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
// 注意:这里不能 defer cancel(),因为查询可能在函数返回后才执行
// cancel 会在查询完成后自动调用
_ = cancel
}
return db.WithContext(ctx)
}
// SelectOptimized 只查询需要的字段,减少数据传输
func SelectOptimized(db *gorm.DB, fields []string) *gorm.DB {
if len(fields) > 0 {
return db.Select(fields)
}
return db
}
// PreloadOptimized 预加载关联,避免 N+1 查询
func PreloadOptimized(db *gorm.DB, preloads []string) *gorm.DB {
for _, preload := range preloads {
db = db.Preload(preload)
}
return db
}
// FindOne 优化的单条查询
func FindOne[T any](ctx context.Context, db *gorm.DB, cfg QueryConfig, condition interface{}, args ...interface{}) (*T, error) {
var result T
query := WithContext(ctx, db, cfg.Timeout)
query = SelectOptimized(query, cfg.Select)
query = PreloadOptimized(query, cfg.Preload)
err := query.Where(condition, args...).First(&result).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &result, nil
}
// FindMany 优化的多条查询
func FindMany[T any](ctx context.Context, db *gorm.DB, cfg QueryConfig, condition interface{}, args ...interface{}) ([]T, error) {
var results []T
query := WithContext(ctx, db, cfg.Timeout)
query = SelectOptimized(query, cfg.Select)
query = PreloadOptimized(query, cfg.Preload)
err := query.Where(condition, args...).Find(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// BatchFind 批量查询优化,使用 IN 查询
func BatchFind[T any](ctx context.Context, db *gorm.DB, fieldName string, ids []interface{}) ([]T, error) {
if len(ids) == 0 {
return []T{}, nil
}
var results []T
query := WithContext(ctx, db, 5*time.Second)
// 分批查询每次最多1000条避免 IN 子句过长
batchSize := 1000
for i := 0; i < len(ids); i += batchSize {
end := i + batchSize
if end > len(ids) {
end = len(ids)
}
var batch []T
if err := query.Where(fieldName+" IN ?", ids[i:end]).Find(&batch).Error; err != nil {
return nil, err
}
results = append(results, batch...)
}
return results, nil
}
// CountWithTimeout 带超时的计数查询
func CountWithTimeout(ctx context.Context, db *gorm.DB, model interface{}, timeout time.Duration) (int64, error) {
var count int64
query := WithContext(ctx, db, timeout)
err := query.Model(model).Count(&count).Error
return count, err
}
// ExistsOptimized 优化的存在性检查
func ExistsOptimized(ctx context.Context, db *gorm.DB, model interface{}, condition interface{}, args ...interface{}) (bool, error) {
var count int64
query := WithContext(ctx, db, 3*time.Second)
// 使用 SELECT 1 优化,不需要查询所有字段
err := query.Model(model).Select("1").Where(condition, args...).Limit(1).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// UpdateOptimized 优化的更新操作
func UpdateOptimized(ctx context.Context, db *gorm.DB, model interface{}, updates map[string]interface{}) error {
query := WithContext(ctx, db, 3*time.Second)
return query.Model(model).Updates(updates).Error
}
// BulkInsert 批量插入优化
func BulkInsert[T any](ctx context.Context, db *gorm.DB, records []T, batchSize int) error {
if len(records) == 0 {
return nil
}
query := WithContext(ctx, db, 10*time.Second)
// 使用 CreateInBatches 分批插入
if batchSize <= 0 {
batchSize = 100
}
return query.CreateInBatches(records, batchSize).Error
}
// TransactionWithTimeout 带超时的事务
func TransactionWithTimeout(ctx context.Context, db *gorm.DB, timeout time.Duration, fn func(*gorm.DB) error) error {
query := WithContext(ctx, db, timeout)
return query.Transaction(fn)
}

View File

@@ -2,9 +2,12 @@ package database
import (
"fmt"
"log"
"os"
"time"
"carrotskin/pkg/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
@@ -22,19 +25,23 @@ func New(cfg config.DatabaseConfig) (*gorm.DB, error) {
cfg.Timezone,
)
// 配置GORM日志级别
var gormLogLevel logger.LogLevel
switch {
case cfg.Driver == "postgres":
gormLogLevel = logger.Info
default:
gormLogLevel = logger.Silent
}
// 配置慢查询监控
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: 200 * time.Millisecond, // 慢查询阈值200ms
LogLevel: logger.Warn, // 只记录警告和错误
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
Colorful: false, // 生产环境禁用彩色
},
)
// 打开数据库连接
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(gormLogLevel),
DisableForeignKeyConstraintWhenMigrating: true, // 禁用自动创建外键约束,避免循环依赖问题
Logger: newLogger,
DisableForeignKeyConstraintWhenMigrating: true, // 禁用外键约束
PrepareStmt: true, // 启用预编译语句缓存
QueryFields: true, // 明确指定查询字段
})
if err != nil {
return nil, fmt.Errorf("连接PostgreSQL数据库失败: %w", err)
@@ -46,10 +53,26 @@ func New(cfg config.DatabaseConfig) (*gorm.DB, error) {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 配置连接池
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime)
// 优化连接池配置
maxIdleConns := cfg.MaxIdleConns
if maxIdleConns <= 0 {
maxIdleConns = 10
}
maxOpenConns := cfg.MaxOpenConns
if maxOpenConns <= 0 {
maxOpenConns = 100
}
connMaxLifetime := cfg.ConnMaxLifetime
if connMaxLifetime <= 0 {
connMaxLifetime = 1 * time.Hour
}
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
sqlDB.SetConnMaxIdleTime(10 * time.Minute)
// 测试连接
if err := sqlDB.Ping(); err != nil {