Flax
Neural networks with JAX
Flax delivers an end-to-end and flexible user experience for researchers
who use JAX with neural networks. Flax
exposes the full power of JAX. It is made up of
loosely coupled libraries, which are showcased with end-to-end integrated
guides
and examples.
Flax is used by
hundreds of projects (and growing),
both in the open source community
(like Hugging Face)
and at Google
(like
PaLM,
Imagen,
Scenic,
and Big Vision).
Mục Lục
Features#
Safety
Flax is designed for correctness and safety. Thanks to its immutable Modules
and Functional API, Flax helps mitigate bugs that araise when handling state
in JAX.
Control
Flax grants more fine grained control and expressivity than most Neural Network
frameworks via its Variable Collections, RNG Collections and Mutability conditions.
Functional API
Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex.
Terse code
Flax’s compact
Modules enables submodules to be defined directly at their callsite, leading to code that is easier to read and avoids repetition.
Installation#
pipinstall
flax
# or to install the latest version of Flax:
pipinstall
--upgrade
git+https://github.com/google/flax.git
Flax installs the vanilla CPU version of JAX, if you need a custom version please check out JAX’s installation page.
Basic usage#
class
MLP
(
nn
.
Module
):
# create a Flax Module dataclass
out_dims
:
int
@nn
.
compact
def
__call__
(
self
,
x
):
x
=
x
.
reshape
((
x
.
shape
[
0
],
-
1
))
x
=
nn
.
Dense
(
128
)(
x
)
# create inline Flax Module submodules
x
=
nn
.
relu
(
x
)
x
=
nn
.
Dense
(
self
.
out_dims
)(
x
)
# shape inference
return
x
model
=
MLP
(
out_dims
=
10
)
# instantiate the MLP model
x
=
jnp
.
empty
((
4
,
28
,
28
,
1
))
# generate random data
variables
=
model
.
init
(
PRNGKey
(
42
),
x
)
# initialize the weights
y
=
model
.
apply
(
variables
,
x
)
# make forward pass
Learn more#
Getting started
Guides
Examples
Glossary
Developer notes
The Flax philosophy
API reference