Flax

Neural networks with JAX

Flax delivers an end-to-end and flexible user experience for researchers
who use JAX with neural networks
. Flax
exposes the full power of JAX. It is made up of
loosely coupled libraries, which are showcased with end-to-end integrated
guides
and examples.

Flax is used by
hundreds of projects (and growing),
both in the open source community
(like Hugging Face)
and at Google
(like
PaLM,
Imagen,
Scenic,
and Big Vision).

Features#

Safety

Flax is designed for correctness and safety. Thanks to its immutable Modules
and Functional API, Flax helps mitigate bugs that araise when handling state
in JAX.

Control

Flax grants more fine grained control and expressivity than most Neural Network
frameworks via its Variable Collections, RNG Collections and Mutability conditions.

Functional API

Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex.

Terse code

Flax’s compact Modules enables submodules to be defined directly at their callsite, leading to code that is easier to read and avoids repetition.

Installation#

pip

install

flax

# or to install the latest version of Flax:

pip

install

--upgrade

git+https://github.com/google/flax.git

Flax installs the vanilla CPU version of JAX, if you need a custom version please check out JAX’s installation page.

Basic usage#

class

MLP

(

nn

.

Module

):

# create a Flax Module dataclass

out_dims

:

int

@nn

.

compact

def

__call__

(

self

,

x

):

x

=

x

.

reshape

((

x

.

shape

[

0

],

-

1

))

x

=

nn

.

Dense

(

128

)(

x

)

# create inline Flax Module submodules

x

=

nn

.

relu

(

x

)

x

=

nn

.

Dense

(

self

.

out_dims

)(

x

)

# shape inference

return

x

model

=

MLP

(

out_dims

=

10

)

# instantiate the MLP model

x

=

jnp

.

empty

((

4

,

28

,

28

,

1

))

# generate random data

variables

=

model

.

init

(

PRNGKey

(

42

),

x

)

# initialize the weights

y

=

model

.

apply

(

variables

,

x

)

# make forward pass

Learn more#

Getting started

Guides

Examples

Glossary

Developer notes

The Flax philosophy

API reference