Flax#
Neural Networks for JAX
Flax provides a flexible end-to-end user experience for researchers and developers who use JAX for neural networks. Flax enables you to use the full power of JAX.
At the core of Flax is NNX - a simplified API that makes it easier to create, inspect, debug, and analyze neural networks in JAX. Flax NNX has first class support for Python reference semantics, enabling users to express their models using regular Python objects. Flax NNX is an evolution of the previous Flax Linen API, and it took years of experience to bring a simpler and more user-friendly API.
Note
Flax Linen API is not going to be deprecated in the near future as most of Flax users still rely on this API. However, new users are encouraged to use Flax NNX. Check out Why Flax NNX for a comparison between Flax NNX and Linen, and our reasoning to make the new API.
To move your Flax Linen codebase to Flax NNX, get familiarized with the API in NNX Basics and then start your move following the evolution guide.
Features#
Flax NNX supports the use of regular Python objects, providing an intuitive and predictable development experience.
Flax NNX relies on Python’s object model, which results in simplicity for the user and increases development speed.
Flax NNX allows fine-grained control of the model’s state via its Filter system.
Flax NNX makes it very easy to integrate objects with regular JAX code via the Functional API.
Basic usage#
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
@nnx.jit # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads) # in-place updates
return loss
Installation#
Install via pip:
pip install flax
Or install the latest version from the repository:
pip install git+https://github.com/google/flax.git