This document is relevant for: Inf2, Trn1, Trn2

JAX Neuron plugin Setup#

The JAX Neuron plugin is a set of modularized JAX plugin packages integrating AWS Trainium and Inferentia machine learning accelerators into JAX as pluggable devices. It includes the following Python packages, all hosted on the AWS Neuron pip repository.

  • libneuronxla: A package containing Neuron’s integration into JAX’s runtime PJRT, built using the PJRT C-API plugin mechanism. Installing this package enables using Trainium and Inferentia natively as JAX devices.

  • jax-neuronx: A package containing Neuron-specific JAX features, such as the Neuron NKI JAX interface. It also serves as a meta-package for providing a tested combination of the jax-neuronx, jax, jaxlib, libneuronxla, and neuronx-cc packages. Making proper use of the features provided in jax-neuronx will unleash the full potential of Trainium and Inferentia.

Note

If you are facing a connectivity issue during the model loading process on a Trn1 instance with Ubuntu, that could probably be because of Ubuntu limitations with multiple interfaces. To solve this problem, please follow the steps mentioned here.

Users are highly encouraged to use DLAMI to launch the instances, since DLAMIs come with the required fix.

Launch the Instance
Install Drivers and Tools

Ubuntu

# Configure Linux for Neuron repository updates
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -

# Update OS packages 
sudo apt-get update -y

# Install OS headers 
sudo apt-get install linux-headers-$(uname -r) -y

# Install git 
sudo apt-get install git -y

# install Neuron Driver
sudo apt-get install aws-neuronx-dkms=2.* -y

# Install Neuron Runtime 
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Install Neuron Tools 
sudo apt-get install aws-neuronx-tools=2.* -y

# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH

Amazon Linux 2023

# Configure Linux for Neuron repository updates
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB

# Update OS packages 
sudo yum update -y

# Install OS headers 
sudo yum install kernel-devel-$(uname -r) kernel-headers-$(uname -r) -y

# Install git 
sudo yum install git -y

# install Neuron Driver
sudo yum install aws-neuronx-dkms-2.* -y

# Install Neuron Runtime 
sudo yum install aws-neuronx-collectives-2.* -y
sudo yum install aws-neuronx-runtime-lib-2.* -y

# Install Neuron Tools 
sudo yum install aws-neuronx-tools-2.* -y

# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH
Install the JAX Neuron Plugin

We provide two methods for installing the JAX Neuron plugin. The first is to install the jax-neuronx meta-package from the AWS Neuron pip repository. This method provides a production-ready JAX environment where jax-neuronx’s major dependencies, namely jax, jaxlib, libneuronxla, and neuronx-cc, have undergone thorough testing by the AWS Neuron team and will have their versions pinned during installation.

python3 -m pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com

The second is to install packages jax, jaxlib, libneuronxla, and neuronx-cc separately, with jax-neuronx being an optional addition. Because libneuronxla supports a broad range of jaxlib versions through the PJRT C-API mechanism, this method provides flexibility when choosing jax and jaxlib versions, enabling JAX users to bring the JAX Neuron plugin into their own JAX environments.

python3 -m pip install jax==0.4.31 jaxlib==0.4.31 jax-neuronx libneuronxla neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com

We can now run some simple JAX programs on the Trainium or Inferentia accelerators.

~$ python3 -c 'import jax; print(jax.numpy.multiply(1, 1))'
Platform 'neuron' is experimental and not all JAX functionality may be correctly supported!
.
Compiler status PASS
1

Compatibility between packages jaxlib and libneuronxla can be determined from PJRT C-API version. For more information, see PJRT integration guide.

To determine compatible JAX versions, you can use the libneuronxla.supported_clients API for querying known supported client packages and their versions.

Help on function supported_clients in module libneuronxla.version:

supported_clients()
    Return a description of supported client (jaxlib, torch-xla, etc.) versions,
    as a list of strings formatted as `"<package> <version> (PJRT C-API <c-api version>)"`.
    For example,
    >>> import libneuronxla
    >>> libneuronxla.supported_clients()
    ['jaxlib 0.4.31 (PJRT C-API 0.54)', 'torch_xla 2.2.0 (PJRT C-API 0.35)', 'torch_xla 2.3.0 (PJRT C-API 0.46)']

Note that the list of supported client packages and versions covers known versions only and may be incomplete. More versions could be supported, including Google’s future jaxlib releases, assuming the PJRT C-API stays compatible with the current release of libneuronxla. As a result, we avoid specifying any dependency relationship between libneuronxla and jaxlib. This provides more freedom when coordinating jax and libneuronxla installations.

This document is relevant for: Inf2, Trn1, Trn2