package database import ( "context" "time" "gorm.io/gorm" ) // QueryConfig 查询配置 type QueryConfig struct { Timeout time.Duration // 查询超时时间 Select []string // 只查询指定字段 Preload []string // 预加载关联 } // WithContext 为查询添加 context 超时控制 func WithContext(ctx context.Context, db *gorm.DB, timeout time.Duration) *gorm.DB { if timeout > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, timeout) // 注意:这里不能 defer cancel(),因为查询可能在函数返回后才执行 // cancel 会在查询完成后自动调用 _ = cancel } return db.WithContext(ctx) } // SelectOptimized 只查询需要的字段,减少数据传输 func SelectOptimized(db *gorm.DB, fields []string) *gorm.DB { if len(fields) > 0 { return db.Select(fields) } return db } // PreloadOptimized 预加载关联,避免 N+1 查询 func PreloadOptimized(db *gorm.DB, preloads []string) *gorm.DB { for _, preload := range preloads { db = db.Preload(preload) } return db } // FindOne 优化的单条查询 func FindOne[T any](ctx context.Context, db *gorm.DB, cfg QueryConfig, condition interface{}, args ...interface{}) (*T, error) { var result T query := WithContext(ctx, db, cfg.Timeout) query = SelectOptimized(query, cfg.Select) query = PreloadOptimized(query, cfg.Preload) err := query.Where(condition, args...).First(&result).Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, nil } return nil, err } return &result, nil } // FindMany 优化的多条查询 func FindMany[T any](ctx context.Context, db *gorm.DB, cfg QueryConfig, condition interface{}, args ...interface{}) ([]T, error) { var results []T query := WithContext(ctx, db, cfg.Timeout) query = SelectOptimized(query, cfg.Select) query = PreloadOptimized(query, cfg.Preload) err := query.Where(condition, args...).Find(&results).Error if err != nil { return nil, err } return results, nil } // BatchFind 批量查询优化,使用 IN 查询 func BatchFind[T any](ctx context.Context, db *gorm.DB, fieldName string, ids []interface{}) ([]T, error) { if len(ids) == 0 { return []T{}, nil } var results []T query := WithContext(ctx, db, 5*time.Second) // 分批查询,每次最多1000条,避免 IN 子句过长 batchSize := 1000 for i := 0; i < len(ids); i += batchSize { end := i + batchSize if end > len(ids) { end = len(ids) } var batch []T if err := query.Where(fieldName+" IN ?", ids[i:end]).Find(&batch).Error; err != nil { return nil, err } results = append(results, batch...) } return results, nil } // CountWithTimeout 带超时的计数查询 func CountWithTimeout(ctx context.Context, db *gorm.DB, model interface{}, timeout time.Duration) (int64, error) { var count int64 query := WithContext(ctx, db, timeout) err := query.Model(model).Count(&count).Error return count, err } // ExistsOptimized 优化的存在性检查 func ExistsOptimized(ctx context.Context, db *gorm.DB, model interface{}, condition interface{}, args ...interface{}) (bool, error) { var count int64 query := WithContext(ctx, db, 3*time.Second) // 使用 SELECT 1 优化,不需要查询所有字段 err := query.Model(model).Select("1").Where(condition, args...).Limit(1).Count(&count).Error if err != nil { return false, err } return count > 0, nil } // UpdateOptimized 优化的更新操作 func UpdateOptimized(ctx context.Context, db *gorm.DB, model interface{}, updates map[string]interface{}) error { query := WithContext(ctx, db, 3*time.Second) return query.Model(model).Updates(updates).Error } // BulkInsert 批量插入优化 func BulkInsert[T any](ctx context.Context, db *gorm.DB, records []T, batchSize int) error { if len(records) == 0 { return nil } query := WithContext(ctx, db, 10*time.Second) // 使用 CreateInBatches 分批插入 if batchSize <= 0 { batchSize = 100 } return query.CreateInBatches(records, batchSize).Error } // TransactionWithTimeout 带超时的事务 func TransactionWithTimeout(ctx context.Context, db *gorm.DB, timeout time.Duration, fn func(*gorm.DB) error) error { query := WithContext(ctx, db, timeout) return query.Transaction(fn) }