AgentSkillsCN

Training Debug

调试训练过程中出现的 OOM 内存不足、NaN 损失,以及分布式训练相关的各类问题。

SKILL.md
--- frontmatter
description: Debug training issues including OOM, NaN losses, and distributed training problems

Training Debug

Diagnose and fix common training issues in training_hub.

Out of Memory (OOM) Issues

Symptoms

code
CUDA out of memory. Tried to allocate X GiB
RuntimeError: CUDA error: out of memory

Diagnosis Steps

  1. Check GPU memory usage
bash
nvidia-smi -l 1  # Monitor GPU memory in real-time
  1. Estimate memory requirements
python
from training_hub.utils import estimate_memory

estimate_memory(
    model_path="meta-llama/Llama-3.1-8B",
    batch_size=4,
    seq_length=4096,
    precision="bf16",
)
# Output: Estimated memory: 24.5 GB per GPU

Solutions

Reduce batch size

python
sft(
    per_device_batch_size=2,  # Reduce from 4
    gradient_accumulation_steps=16,  # Increase to maintain effective batch
)

Enable gradient checkpointing

python
sft(
    gradient_checkpointing=True,  # Trades compute for memory
)

Use memory-efficient optimizers

python
sft(
    optim="adamw_8bit",  # 8-bit Adam saves ~50% optimizer memory
    # OR
    optim="paged_adamw_32bit",  # Pages optimizer states to CPU
)

Reduce sequence length

python
sft(
    max_seq_length=2048,  # Reduce from 4096
)

Use FSDP for large models

python
sft(
    fsdp=True,
    fsdp_config={"sharding_strategy": "FULL_SHARD"},
)

NaN/Inf Loss Issues

Symptoms

code
Loss: nan
Loss: inf
Training loss spiked to very large values

Diagnosis

python
# Enable anomaly detection
import torch
torch.autograd.set_detect_anomaly(True)

# Check for NaN in inputs
from training_hub.utils import check_data_quality

issues = check_data_quality("./data/train.jsonl")
print(issues)  # Reports NaN, inf, empty examples

Solutions

Fix learning rate

python
sft(
    learning_rate=1e-5,  # Lower from 2e-5
    warmup_ratio=0.1,     # Ensure warmup
)

Use loss scaling for fp16

python
sft(
    fp16=True,
    fp16_opt_level="O1",
    loss_scale=128.0,  # Or "dynamic"
)

Prefer bf16 over fp16

python
sft(
    bf16=True,  # More stable than fp16, no loss scaling needed
)

Gradient clipping

python
sft(
    max_grad_norm=1.0,  # Clip gradients
)

Check and clean data

python
# Remove problematic examples
from datasets import load_dataset

dataset = load_dataset("json", data_files="train.jsonl", split="train")
dataset = dataset.filter(lambda x: len(x["messages"]) > 0)
dataset = dataset.filter(lambda x: all(
    isinstance(m.get("content", ""), str) for m in x["messages"]
))

Distributed Training Issues

NCCL Timeout

code
RuntimeError: NCCL timeout
Watchdog caught collective operation timeout

Solutions:

bash
# Increase timeout
export NCCL_TIMEOUT=1800  # 30 minutes

# Debug NCCL
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

GPU Mismatch

code
CUDA error: invalid device ordinal
RuntimeError: Expected all tensors to be on the same device

Solutions:

bash
# Verify GPU visibility
export CUDA_VISIBLE_DEVICES=0,1,2,3  # Explicitly set GPUs

# Check GPU topology
nvidia-smi topo -m

Data Loading Bottleneck

python
# Symptoms: Low GPU utilization, CPUs maxed out

sft(
    dataloader_num_workers=8,  # Increase workers
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=4,
)

FSDP-Specific Issues

Checkpoint Loading Failures

python
# Use FSDP-compatible checkpoint saving
sft(
    fsdp=True,
    fsdp_config={
        "state_dict_type": "SHARDED_STATE_DICT",  # For large models
    },
)

Memory Fragmentation

python
sft(
    fsdp=True,
    fsdp_config={
        "limit_all_gathers": True,
        "forward_prefetch": True,
    },
)

Data Issues

Tokenization Errors

python
from training_hub.utils import validate_data

# Check if data tokenizes correctly
errors = validate_data(
    data_path="./data/train.jsonl",
    model_path="meta-llama/Llama-3.1-8B",
)

for error in errors:
    print(f"Row {error['index']}: {error['message']}")

Sequence Length Distribution

python
from training_hub.utils import analyze_data

stats = analyze_data(
    data_path="./data/train.jsonl",
    model_path="meta-llama/Llama-3.1-8B",
)

print(f"Mean length: {stats['mean_tokens']}")
print(f"Max length: {stats['max_tokens']}")
print(f"Examples > 4096: {stats['exceeds_4096']}")

Debugging Workflow

  1. Start with minimal config
python
sft(
    model_path="...",
    data_path="...",
    per_device_batch_size=1,
    max_steps=10,
    bf16=True,
)
  1. Gradually increase complexity
python
# Step 1: Increase batch size until OOM
# Step 2: Add gradient checkpointing if needed
# Step 3: Add evaluation
# Step 4: Full training run
  1. Enable logging
python
import logging
logging.basicConfig(level=logging.DEBUG)

sft(
    logging_steps=1,
    logging_first_step=True,
)

Diagnostic Commands

bash
# GPU status
nvidia-smi

# Check CUDA version
nvcc --version
python -c "import torch; print(torch.cuda.is_available())"

# Check distributed setup
python -c "import torch.distributed as dist; print(dist.is_available())"

# Memory profiling
python -c "import torch; torch.cuda.memory_summary()"

Related Skills

  • /training-configure - Configure training runs
  • /pipeline-design - Design end-to-end pipelines