Skip to content

gerlero/parajax

Repository files navigation

Parajax

Automagic parallelization of calls to JAX-based functions

Documentation CI Codecov Ruff ty uv Publish PyPI PyPI - Python Version

Features

  • 🚀 Device-parallel execution: run across multiple CPUs, GPUs or TPUs automatically
  • 🧩 Fully composable with @jax.jit, @jax.vmap, and other JAX transformations
  • 🪄 Automatic handling of input shapes not divisible by the number of devices
  • 🎯 Simple interface: just decorate your function with @parallelize

Installation

pip install parajax

Example

import multiprocessing

import jax
import jax.numpy as jnp
from parajax import parallelize

jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())
# ^ Only needed on CPU: allow JAX to use all CPU cores

@parallelize
def square(xs):
    return xs**2

xs = jnp.arange(12_345)
ys = square(xs)

That's it! Invocations of square will now be automatically parallelized across all available devices.

Documentation

For more details, check out the documentation.

About

⚡ Automatic parallelization of calls to JAX-based functions

Topics

Resources

License

Stars

Watchers

Forks

Contributors 2

  •  
  •  

Languages