AgentSkillsCN

lightbox

当您需要处理LightningReflow、LightningTune或DataPorter集成时,可使用此技能。涵盖训练命令、超参数优化模式、torch.compile配置(避免内存泄漏)、常见陷阱,以及针对基于Lightning的ML项目的调试技巧。

SKILL.md
--- frontmatter
name: lightbox
description: Use this skill when working with LightningReflow, LightningTune, or DataPorter integration. Covers training commands, HPO patterns, torch.compile configuration (memory leak avoidance), common pitfalls, and debugging tips for Lightning-based ML projects.

LightBox Integration Guide

This skill provides guidance for working with projects generated by LightBox, which integrates:

  • LightningReflow - Pause/resume training with deterministic checkpoints
  • LightningTune - Hyperparameter optimization with Optuna
  • DataPorter - Multi-dataset loading with caching (optional)

Quick Reference

Training Commands

bash
# Standard training
python scripts/train_{{cookiecutter.model_name}}.py fit --config configs/{{cookiecutter.model_name}}.yaml

# Override settings
python scripts/train_{{cookiecutter.model_name}}.py fit --config configs/{{cookiecutter.model_name}}.yaml --trainer.max_epochs 200

# Resume from pause checkpoint
python scripts/train_{{cookiecutter.model_name}}.py resume --checkpoint-path pause_checkpoints/<file>.ckpt

# Resume from regular checkpoint
python scripts/train_{{cookiecutter.model_name}}.py fit --config configs/{{cookiecutter.model_name}}.yaml --ckpt_path <path>

HPO Commands

bash
# Run HPO
python scripts/{{cookiecutter.model_name}}_hpo.py --config configs/{{cookiecutter.model_name}}.yaml --n-trials 100

# Resume HPO
python scripts/{{cookiecutter.model_name}}_hpo.py --config configs/{{cookiecutter.model_name}}.yaml --resume-from tmp/hpo_checkpoints

# Test mode (quick validation)
python scripts/{{cookiecutter.model_name}}_hpo.py --config configs/{{cookiecutter.model_name}}.yaml --n-trials 3 --trial-steps 100 --test-mode

LightningReflow Integration

LightningReflowCLI Setup

The training script uses LightningReflowCLI for pause/resume support:

python
from lightning_reflow import LightningReflowCLI

cli = LightningReflowCLI(
    BaseModel,
    BaseDataModule,
    auto_configure_optimizers=False,  # Model handles optimizer config
    seed_everything_default=42,
    subclass_mode_model=True,  # Allow nested config instantiation
    subclass_mode_data=True,
    run=True,
)

TorchCompileCallback Configuration

CRITICAL: The TorchCompileCallback mode selection is important for models with dynamic shapes (e.g., autoregressive rollouts, transformers with variable sequence lengths).

Safe Configuration (Recommended)

yaml
trainer:
  callbacks:
    - class_path: lightning_reflow.callbacks.TorchCompileCallback
      init_args:
        enabled: true
        mode: "reduce-overhead"  # Safe for dynamic shapes
        dynamic: true
        inductor_config:
          triton.cudagraph_skip_dynamic_graphs: true
        target_modules:
          - "dynamics"  # or "backbone"
          - "decoders"
        verbose: true

Known Issue: Memory Leak with max-autotune

DO NOT use mode: "max-autotune" with modules that have dynamic shapes (transformers, AR rollouts):

yaml
# DANGEROUS - causes unbounded VRAM growth!
mode: "max-autotune"
target_modules:
  - "dynamics"  # Transformer with causal mask = dynamic shapes

Why: Triton continuously searches for optimal kernels when shapes change, causing memory to grow ~100MB/epoch until OOM.

Solution: Use mode: "reduce-overhead" which provides stable memory with full compilation coverage.

When max-autotune IS Safe

  • Static shape modules only (e.g., CNN decoders with fixed output resolution)
  • Modules that never see varying input shapes

WandB Artifact Checkpoints

yaml
trainer:
  callbacks:
    - class_path: lightning_reflow.callbacks.WandbArtifactCheckpoint
      init_args:
        upload_best_model: true
        upload_last_model: true
        use_compression: true

LightningTune HPO Integration

Search Space Definition

CRITICAL: Always copy.deepcopy(config) at the start to avoid side effects between trials:

python
import copy

def search_space(trial, config: dict) -> dict:
    """Define hyperparameter search space."""
    config = copy.deepcopy(config)  # CRITICAL: Prevent trial pollution

    # Suggest hyperparameters
    config["model"]["init_args"]["learning_rate"] = trial.suggest_float(
        "learning_rate", 1e-5, 1e-3, log=True
    )
    config["model"]["init_args"]["weight_decay"] = trial.suggest_float(
        "weight_decay", 1e-6, 1e-3, log=True
    )

    return config

HPO Config Constants

Define excluded parameters and production overrides in hpo/config.py:

python
# Parameters to exclude from CLI (set via search_space)
EXCLUDED_CLI_PARAMS = [
    "model.init_args.learning_rate",
    "model.init_args.weight_decay",
]

# Production config overrides (longer training, etc.)
PRODUCTION_CONFIG_OVERRIDES = {
    "trainer.max_epochs": 200,
    "trainer.check_val_every_n_epoch": 10,
}

DataPorter Integration

MultiDatasetModule Setup

For projects using DataPorter's MultiDatasetModule:

yaml
data:
  class_path: dataporter.MultiDatasetModule
  init_args:
    batch_size: 256
    num_workers: 4
    context_length: 8  # For temporal models
    ar_steps: 0        # Autoregressive steps (0 for perception)
    resolution:
      - 160
      - 160
    datasets:
      - tartanair
      - nyu_depth
    dataset_weights: null  # Uniform if null
    modality_dropout:
      rgb: 0.0   # Required modality
      depth: 0.1  # 10% dropout
    seed: 42
    pin_memory: true
    enable_cache: true  # Disk caching for faster loading

Caching Configuration

DataPorter uses environment variables for cache location (in order of precedence):

  1. HF_DATASETS_CACHE
  2. HF_HOME
  3. XDG_CACHE_HOME

Pre-populate cache with:

bash
python scripts/preprocess_tartanair.py

Common Pitfalls and Solutions

1. OneCycleLR Scheduler Errors

Problem: ZeroDivisionError with small max_steps values.

Cause: OneCycleLR requires total_steps >= 4 for valid phase calculations. The up-phase end is pct_start * total_steps - 1, which must be > 0.

Solution: Ensure realistic max_steps in configs and tests:

yaml
trainer:
  max_steps: 25000  # Not 2 or 3

For tests, use at least max_steps=5.

2. CUDA Graphs with Dynamic Shapes

Problem: torch.compile with CUDA graphs crashes or behaves incorrectly with dynamic shapes.

Solution: Disable CUDA graphs for dynamic modules:

yaml
dynamic: true
inductor_config:
  triton.cudagraph_skip_dynamic_graphs: true

3. Sanity Check Phase Issues

Problem: Callbacks that log to WandB fail during Lightning's sanity check.

Cause: WandB may not be fully initialized during sanity check.

Solution: Guard callbacks:

python
def on_validation_epoch_end(self, trainer, pl_module):
    if trainer.sanity_checking:
        return  # Skip during sanity check
    # ... rest of callback

4. AMP Compatibility

Problem: Some modules fail with mixed precision (16-mixed).

Common culprits:

  • Pose decoders with small output ranges
  • Custom loss functions with numerical instability

Solution: Use torch.amp.autocast context managers selectively, or disable AMP for problematic modules:

python
with torch.amp.autocast(device_type='cuda', enabled=False):
    pose_output = self.pose_decoder(features)

5. Gradient Accumulation with Validation

Problem: Validation loss spikes when using accumulate_grad_batches > 1.

Solution: Ensure validation doesn't use gradient accumulation:

yaml
trainer:
  accumulate_grad_batches: 2  # Training only

6. HPO Trial Pollution

Problem: Trials affect each other's configs.

Cause: Missing copy.deepcopy() in search_space function.

Solution: Always deepcopy at function start (see HPO section above).


Config Patterns

Nested Class Instantiation

LightningCLI uses class_path and init_args for dependency injection:

yaml
model:
  class_path: mypackage.models.MyModel
  init_args:
    encoder:
      class_path: mypackage.encoders.ConvEncoder
      init_args:
        hidden_dim: 256
    decoder:
      class_path: mypackage.decoders.ConvDecoder
      init_args:
        output_channels: 3

Callback Configuration

yaml
trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        monitor: "val/loss"
        mode: "min"
        save_top_k: 3
        save_last: true
        filename: "model-epoch{epoch:04d}-val_loss{val/loss:.4f}"

    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: "step"

Logger Configuration

yaml
trainer:
  logger:
    class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      project: "my-project"
      save_dir: "tmp/wandb_logs"
      name: "experiment-name"  # Optional run name

Visualization Callbacks

Why Log Training vs Validation Inputs?

  • Validation inputs are deterministic - same samples every time, not useful for repeated logging
  • Training inputs show distribution - different samples each time due to shuffling/augmentation
  • Log both to understand what the model sees

Validation-Based Logging Schedule

Use a validation counter instead of epoch-based logging when using step-based validation:

python
class BaseVisualizerCallback(L.Callback):
    def __init__(self, log_every_n_validations: int = 5):
        self.log_every_n_validations = log_every_n_validations
        self._validation_count = 0
        self._last_train_batch = None
        self._last_batch = None

    def _should_log_this_validation(self, trainer) -> bool:
        return self._validation_count % self.log_every_n_validations == 0

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Capture first training batch of each epoch
        if batch_idx == 0:
            self._last_train_batch = self._detach_batch(batch)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        # Capture first validation batch
        if batch_idx == 0:
            self._last_batch = self._detach_batch(batch)

    def on_validation_epoch_end(self, trainer, pl_module):
        # Skip during sanity check - wandb may not be initialized
        if trainer.sanity_checking:
            self._last_batch = None
            return

        should_log = self._should_log_this_validation(trainer)
        self._validation_count += 1

        if not should_log or self._last_batch is None:
            self._last_batch = None
            self._last_train_batch = None
            return

        try:
            # Log training inputs (shows distribution)
            if self._last_train_batch is not None:
                self._log_training_inputs(self._last_train_batch)
            # Log validation inputs + reconstructions
            self._log_reconstruction_comparison(pl_module, self._last_batch)
        finally:
            self._last_batch = None
            self._last_train_batch = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

Image Grid Utility

python
def create_image_grid(images: Tensor, nrow: int = 4, padding: int = 2) -> Tensor:
    """Create grid from batch of images [B, C, H, W] -> [C, H', W']."""
    B, C, H, W = images.shape
    ncol = (B + nrow - 1) // nrow

    grid_h = ncol * H + (ncol + 1) * padding
    grid_w = nrow * W + (nrow + 1) * padding
    grid = torch.zeros(C, grid_h, grid_w)

    for idx in range(B):
        row, col = idx // nrow, idx % nrow
        y = padding + row * (H + padding)
        x = padding + col * (W + padding)
        grid[:, y:y+H, x:x+W] = images[idx].clamp(0, 1)

    return grid

Media Logging Keys Convention

code
media/train_inputs/rgb_samples     # Training data (changes each log)
media/train_inputs/depth_samples   # Training depth maps
media/val_inputs/rgb_samples       # Validation data (deterministic)
media/val_reconstruction/rgb_video # GT vs prediction video
media/val_reconstruction/rgb_grid  # GT vs prediction grid comparison

Testing Patterns

Mock Trial for HPO Tests

python
from tests.helpers.hpo_utils import MockTrial

def test_search_space():
    config = load_config("configs/model.yaml")
    trial = MockTrial({
        "learning_rate": 1e-4,
        "weight_decay": 1e-5,
    })

    modified = search_space(trial, config)
    assert modified["model"]["init_args"]["learning_rate"] == 1e-4

GPU Test Markers

python
import pytest

@pytest.mark.gpu
def test_training_on_gpu():
    ...

@pytest.mark.slow
def test_full_training_run():
    ...

Run specific markers:

bash
pytest -m gpu      # GPU tests only
pytest -m "not slow"  # Skip slow tests

Debugging Tips

VRAM Monitoring

bash
# Watch GPU memory every 2 seconds
watch -n 2 nvidia-smi

# Log to CSV for analysis
nvidia-smi --query-gpu=timestamp,utilization.gpu,memory.used,memory.total --format=csv -l 5 > vram_log.csv

Gradient Debugging

Log per-module gradient norms:

python
def on_before_optimizer_step(self, optimizer):
    for name, module in self.named_modules():
        grad_norm = self._compute_grad_norm(module)
        if grad_norm is not None:
            self.log(f"grad_norm/{name}", grad_norm)

Prediction Statistics

Log prediction statistics to detect mode collapse:

python
def _log_prediction_diagnostics(self, predictions, prefix="train"):
    for name, pred in predictions.items():
        self.log(f"{prefix}/{name}_pred_mean", pred.mean())
        self.log(f"{prefix}/{name}_pred_std", pred.std())
        # Batch variance - low = mode collapse
        batch_var = pred.mean(dim=list(range(1, pred.ndim))).var()
        self.log(f"{prefix}/{name}_pred_batch_var", batch_var)