Add precise handler for $ARGUMENTS
Add a precise propagation handler for the $ARGUMENTS primitive,
replacing the conservative fallback.
Workflow
1. Research
Do all research in this step using extended planning mode.
Before writing code:
- •Read the JAX docs for the primitive: fetch
https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html - •Read
src/asdex/_interpret/CLAUDE.mdfor conventions (docstring style, semantic line breaks, handler structure) - •Read
src/asdex/_interpret/_commons.pyto understand available utilities - •Read an existing handler with similar structure (e.g.
_pad.py,_transpose.py,_reduction.py) as a reference - •Read the existing test in
tests/_interpret/test_internals.pyfor the primitive (search for$ARGUMENTS) - •Read
src/asdex/_interpret/__init__.pyto see the current dispatch and fallback setup
Understand the primitive's semantics: how do input and output element indices map to each other? What is the Jacobian structure (permutation, selection, block-diagonal, etc.)?
2. Implement handler
- •Create
src/asdex/_interpret/_$ARGUMENTS.pywithprop_$ARGUMENTS(eqn, deps). - •Follow the handler docstring style from
_interpret/CLAUDE.md.
3. Wire up dispatch
In src/asdex/_interpret/__init__.py:
- •Import the new handler
- •Add a
case "$ARGUMENTS":branch inprop_dispatchcalling the handler - •Remove
"$ARGUMENTS"from the conservative fallbackcasegroup
4. Update tests
In tests/_interpret/test_internals.py:
- •Update the existing test: change expected values from dense (
np.ones) to the precise pattern - •Remove the
@pytest.mark.fallbackmarker andTODOcomments
Create tests/_interpret/test_$ARGUMENTS.py with thorough tests:
- •Multiple dimensionalities (1D, 2D, 3D, 4D where applicable)
- •Edge cases (size-1 dimensions, identity/trivial parameters)
- •Real-world usage patterns (e.g.
jnpfunctions that lower to this primitive)
5. Verify
Run in order:
uv run ruff check src/asdex/_interpret/_$ARGUMENTS.py uv run pytest tests/_interpret/test_$ARGUMENTS.py -v uv run pytest tests/_interpret/test_internals.py -v uv run pytest tests/ -x
6. Adversarial tests
Reread the JAX docs for the primitive: fetch https://docs.jax.dev/en/latest/_autosummary/jax.lax.$ARGUMENTS.html
Try to break the implementation by testing inputs the handler might not expect:
- •Dimensionality: 1D, 2D, 3D, and higher — if any are missing, add them.
- •Degenerate shapes: size-0 dimensions, size-1 dimensions, scalar inputs (where the primitive supports them).
- •Boundary parameters: empty parameter lists, all-dimensions, single-dimension, negative indices (if applicable).
- •Compositions: the primitive chained with itself (e.g. double-reverse, transpose-of-transpose) or with related ops.
- •Non-contiguous patterns: inputs where dependencies are not simply
{i}per element (e.g. from a prior broadcast or reduction) to verify.copy()and set merging behave correctly.
For each new test, verify the expected output by hand or against jax.jacobian.
Update and re-verify the handler if any test reveals a bug.
7. Simplify
Review the implementation with fresh eyes and look for opportunities to reduce complexity:
- •Vectorize loops: can per-element Python loops be replaced with numpy operations?
Pattern: build a flat permutation or index array with
np.arange,np.flip,np.transpose,np.indices, ornp.ravel_multi_index, then index intoin_indicesin a single list comprehension. See_rev.py,_reshape.py,_concatenate.py, and_broadcast.pyfor examples. - •Remove unused imports: after vectorizing, utilities like
flat_to_coords,row_strides, andnumelmay no longer be needed. - •Eliminate intermediate variables: if a value is computed and used only once, inline it.
- •Simplify special cases: can a special-case branch be absorbed into the general case?
After any change, re-run verification (step 6).
8. Update docs
- •
TODO.md: check off the primitive and its test items - •
src/asdex/_interpret/CLAUDE.md: add the new module to the file listing