Overview

JAX is NumPy + autodiff + GPU/TPU

It allows for fast scientific computing and machine learning
with the normal NumPy API
(+ additional APIs for special accelerator ops when needed)

JAX comes with powerful primitives, which you can compose arbitrarily:

  • Autodiff (jax.grad): Efficient any-order gradients w.r.t any variables

  • JIT compilation (jax.jit): Trace any function ⟶ fused accelerator ops

  • Vectorization (jax.vmap): Automatically batch code written for individual samples

  • Parallelization (jax.pmap): Automatically parallelize code across multiple accelerators (including across hosts, e.g. for TPU pods)

If you don’t know JAX but just want to learn what you need to use Flax, you can check our JAX for the impatient notebook.