JAX Scientific Computing
Core Rules
- •Pure functions - No side effects
- •JIT outer functions -
@jax.jiton hot paths - •vmap not loops -
jax.vmap(fn)instead of list comprehensions - •Split RNG keys - Never reuse keys
Patterns
python
# RNG: always split key, k1, k2 = jax.random.split(key, 3) # Batching: vmap not loops batched = jax.vmap(fn)(inputs) # Loops: use scan _, results = jax.lax.scan(step_fn, init, xs)
Gotchas
- •Arrays are immutable
- •No Python control flow in JIT - use
jax.lax.cond,jax.lax.scan - •Check NaNs:
jnp.isnan(x).any()