Mechanistic Interpretability
Overview
Mechanistic interpretability (mech interp) is the science of reverse-engineering neural networks to understand the algorithms they learn. The core question: "What computation is this model performing, and how?"
Key concepts:
- •Residual stream: The main highway through the model; each layer reads from and writes to it
- •Features: Directions in activation space representing interpretable concepts
- •Circuits: Subgraphs implementing specific behaviors
- •Superposition: Models represent more features than dimensions using non-orthogonal directions
Why it matters:
- •Understand model capabilities and limitations
- •Debug unexpected behaviors
- •Verify safety properties
- •Build interpretable AI systems
Core Workflow
Phase 1: Environment Setup
- •
Detect compute resources
pythonimport torch device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") - •
Install libraries based on model size
Model Size Primary Tool Install ≤2B params TransformerLens pip install transformer-lens2B-13B params nnsight pip install nnsightSAE training SAELens pip install sae-lens[train]Both ecosystems nnterp pip install nnterp - •
Load model
pythonfrom transformer_lens import HookedTransformer model = HookedTransformer.from_pretrained( "gpt2-small", device=device, )
See references/tools.md for detailed tool setup.
Phase 2: Experiment Design
Before writing code, clarify:
- •
Research question: What specifically are you trying to understand?
- •"What does attention head L5H3 do?"
- •"How does the model represent 'is_capital_of' relationships?"
- •"Which components contribute to this prediction?"
- •
Hypothesis: What do you expect to find?
- •Phrase as testable predictions
- •Include what would falsify the hypothesis
- •
Technique selection (see table below)
- •
Validation plan: How will you verify findings?
- •Causal interventions
- •Held-out examples
- •Alternative explanations to rule out
Phase 3: Implementation
Select technique based on your goal:
| Goal | Technique | Tool/Method | Reference |
|---|---|---|---|
| What tokens does model predict at each layer? | Logit Lens | resid @ W_U | techniques.md |
| Which component affects this output? | Activation Patching | run_with_hooks | techniques.md |
| How much does each head contribute to logit? | Direct Logit Attribution | Decompose residual | techniques.md |
| What information does this head move? | OV Circuit Analysis | W_V @ W_O | techniques.md |
| What attends to what? | QK Circuit Analysis | W_Q @ W_K.T | techniques.md |
| Is information X represented here? | Probing | Train classifier | techniques.md |
| Find interpretable features | SAE | Train/load SAE | sae-guide.md |
| Which feature represents concept Y? | Feature Search | Max activating examples | sae-guide.md |
Phase 4: Analysis
- •
Run experiments
- •Cache activations:
logits, cache = model.run_with_cache(tokens) - •Always use
torch.no_grad()for inference - •Save intermediate results
- •Cache activations:
- •
Visualize results
- •Attention heatmaps
- •Patching effect matrices
- •Feature activation distributions
- •
Iterate
- •Refine hypothesis based on findings
- •Test edge cases
- •Look for counterexamples
Phase 5: Validation
Before claiming findings, verify:
- • Causal evidence: Ablating/patching changes behavior as predicted
- • Held-out data: Results replicate on unseen examples
- • Multiple seeds: Not an artifact of specific randomness
- • Alternative explanations: Ruled out simpler stories
- • Effect size: Practically meaningful, not just statistically significant
See references/pitfalls.md for common mistakes.
Technique Quick Reference
Logit Lens
Project intermediate representations through unembedding to see evolving predictions.
for layer in range(model.cfg.n_layers):
resid = cache["resid_post", layer]
resid_normed = model.ln_final(resid)
logits = resid_normed @ model.W_U
top_token = logits[0, -1].argmax()
print(f"Layer {layer}: {model.to_str_tokens(top_token)}")
Activation Patching
Measure causal effect by swapping activations between runs.
def patch_hook(activation, hook):
activation[:, pos, :] = clean_cache[hook.name][:, pos, :]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(hook_point, patch_hook)]
)
Direct Logit Attribution
Decompose final logits into per-component contributions.
target_dir = model.W_U[:, target_token_idx]
for layer in range(model.cfg.n_layers):
attn_contribution = cache["attn_out", layer][0, -1] @ target_dir
mlp_contribution = cache["mlp_out", layer][0, -1] @ target_dir
print(f"L{layer} attn: {attn_contribution:.3f}, mlp: {mlp_contribution:.3f}")
SAE Feature Analysis
Find interpretable features in activations.
from sae_lens import SAE
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
feature_acts = sae.encode(cache["resid_pre", 8])
top_features = feature_acts[0, -1].topk(10)
Model Size Guidance
| Model | Library | Memory (FP16) | Notes |
|---|---|---|---|
| GPT-2-small | TransformerLens | ~0.25GB | Best for learning |
| GPT-2-medium/large | TransformerLens | ~0.7-1.5GB | Good balance |
| GPT-2-xl | TransformerLens | ~3GB | Needs decent GPU |
| Pythia-70M to 410M | TransformerLens | ~0.15-0.8GB | Checkpoints available |
| Pythia-1B to 2.8B | TransformerLens | ~2-5.5GB | Pushes memory |
| Pythia-6.9B+ | nnsight | ~14GB+ | Use nnsight for efficiency |
| Llama-2-7B, Mistral-7B | nnsight | ~14GB | Needs 24GB+ GPU |
| Llama-2-13B+ | nnsight | ~26GB+ | Need A100/multi-GPU |
See references/compute-awareness.md for memory estimation.
When to Ask the User
Ask before proceeding when:
- •
Research question unclear
"What specific behavior or component are you trying to understand?"
- •
Compute constraints unknown
"What GPU do you have available? This model needs ~XGB VRAM."
- •
Multiple valid approaches
"We could use activation patching (causal) or probing (correlational). Which do you prefer?"
- •
Unexpected results
"The results don't match expectations. Should we investigate further or try a different approach?"
- •
Scaling decisions
"Initial results look promising on GPT-2-small. Want to scale up to a larger model?"
Common Tasks
"Set up a mech interp project"
- •Create project structure (see repo-maintenance.md)
- •Install dependencies based on target model
- •Set up CLAUDE.md with project-specific instructions
- •Configure experiment tracking (wandb or simple JSON logging)
"What does this attention head do?"
- •Visualize attention patterns across diverse inputs
- •Analyze QK circuit (what attends to what)
- •Analyze OV circuit (what information moves)
- •Test with activation patching (is it necessary?)
- •Check for known patterns (induction, copying, etc.)
"Find the circuit for behavior X"
- •Design clean/corrupted input pairs
- •Patch residual stream: layer × position heatmap
- •Narrow to specific layers
- •Patch individual heads
- •Validate with ablation
- •Analyze winning components
"Train an SAE"
- •Choose layer and hook point
- •Estimate memory requirements
- •Set hyperparameters (expansion factor, L1 coefficient)
- •Run training with monitoring (L0, reconstruction loss, dead features)
- •Evaluate quality before analysis
See references/sae-guide.md for detailed guidance.
"Interpret SAE features"
- •Load pretrained SAE or train your own
- •Find max activating examples for features of interest
- •Look for patterns in activating contexts
- •Test hypothesis with feature steering/ablation
- •Validate causal role
Quality Checklist
Before concluding analysis:
- • Research question clearly stated
- • Appropriate technique selected
- • Code runs without errors
- • Results visualized
- • Causal validation performed
- • Edge cases tested
- • Alternative explanations considered
- • Results documented with reproducibility info
Reference Files
| File | Contents |
|---|---|
| tools.md | TransformerLens, nnsight, SAELens setup |
| techniques.md | Patching, logit lens, circuits, probing |
| sae-guide.md | SAE training and analysis |
| visualization.md | Plotting patterns and dashboards |
| pitfalls.md | Common mistakes and validation |
| repo-maintenance.md | Project structure templates |
| vocabulary.md | Glossary of terms |
| compute-awareness.md | GPU/memory guidance |