Skip to content

LouisDesdoigts/zodiax

Repository files navigation

Zodiax

PyPI version License integration Documentation


Zodiax is a lightweight extension to the object-oriented Jax framework Equinox. Equinox allows for differentiable classes that are recognised as a valid Jax type and Zodiax adds lightweight methods to simplify interfacing with these classes! Zodiax was originially built in the development of dLux and was designed to make working with large nested classes structures simple and flexible.

Zodiax is directly integrated with both Jax and Equinox, gaining all of their core features:

Documentation: louisdesdoigts.github.io/zodiax/

Contributors: Louis Desdoigts

Requires: Python 3.8+, Jax 0.4.3+

Installation: pip install zodiax

Docs installation: pip install "zodiax[docs]"

Test installation: pip install "zodiax[tests]"


Quickstart

Create a regular class that inherits from zodiax.Base

import jax
import zodiax as zdx
import jax.numpy as np

class Linear(zdx.Base):
    m : Jax.Array
    b : Jax.Array

    def __init__(self, m, b):
        self.m = m
        self.b = b

    def model(self, x):
        return self.m * x + self.b

linear = Linear(1., 1.)

Its that simple! The linear class is now a fully differentiable object that gives us all the benefits of jax with an object-oriented interface! Lets see how we can jit-compile and take gradients of this class.

@jax.jit
@jax.grad
def loss_fn(model, xs, ys):
    return np.square(model.model(xs) - ys).sum()

xs = np.arange(5)
ys = 2*np.arange(5)
grads = loss_fn(linear, xs, ys)
print(grads)
print(grads.m, grads.b)
> Linear(m=f32[], b=f32[])
> -40.0 -10.0

The grads object is an instance of the Linear class with the gradients of the parameters with respect to the loss function!