This document is relevant for: Inf2
, Trn1
, Trn2
JAX Neuron plugin Setup#
The JAX Neuron plugin is a set of modularized JAX plugin packages integrating AWS Trainium and Inferentia machine learning accelerators into JAX as pluggable devices. It includes the following Python packages, all hosted on the AWS Neuron pip repository.
-
libneuronxla
: A package containing Neuron’s integration into JAX’s runtime PJRT, built using the PJRT C-API plugin mechanism. Installing this package enables using Trainium and Inferentia natively as JAX devices. jax-neuronx
: A package containing Neuron-specific JAX features, such as the Neuron NKI JAX interface. It also serves as a meta-package for providing a tested combination of thejax-neuronx
,jax
,jaxlib
,libneuronxla
, andneuronx-cc
packages. Making proper use of the features provided injax-neuronx
will unleash the full potential of Trainium and Inferentia.
Note
If you are facing a connectivity issue during the model loading process on a Trn1 instance with Ubuntu, that could probably be because of Ubuntu limitations with multiple interfaces. To solve this problem, please follow the steps mentioned here.
Users are highly encouraged to use DLAMI to launch the instances, since DLAMIs come with the required fix.
Launch the Instance
To launch an instance, follow the instructions at launch an Amazon EC2 Instance. Make sure to select the correct instance type on the EC2 console.
For more information about instance sizes and pricing, see Amazon EC2 Trn1 Instances and Amazon EC2 Inf2 Instances
Select Ubuntu Server 22 AMI.
When launching a Trn1, adjust your primary EBS volume size to a minimum of 512GB.
After launching the instance, follow the instructions in Connect to your instance to connect to the instance.
Install Drivers and Tools
Ubuntu
# Configure Linux for Neuron repository updates
. /etc/os-release
sudo tee /etc/apt/sources.list.d/neuron.list > /dev/null <<EOF
deb https://apt.repos.neuron.amazonaws.com ${VERSION_CODENAME} main
EOF
wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | sudo apt-key add -
# Update OS packages
sudo apt-get update -y
# Install OS headers
sudo apt-get install linux-headers-$(uname -r) -y
# Install git
sudo apt-get install git -y
# install Neuron Driver
sudo apt-get install aws-neuronx-dkms=2.* -y
# Install Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
sudo apt-get install aws-neuronx-runtime-lib=2.* -y
# Install Neuron Tools
sudo apt-get install aws-neuronx-tools=2.* -y
# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH
Amazon Linux 2023
# Configure Linux for Neuron repository updates
sudo tee /etc/yum.repos.d/neuron.repo > /dev/null <<EOF
[neuron]
name=Neuron YUM Repository
baseurl=https://yum.repos.neuron.amazonaws.com
enabled=1
metadata_expire=0
EOF
sudo rpm --import https://yum.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB
# Update OS packages
sudo yum update -y
# Install OS headers
sudo yum install kernel-devel-$(uname -r) kernel-headers-$(uname -r) -y
# Install git
sudo yum install git -y
# install Neuron Driver
sudo yum install aws-neuronx-dkms-2.* -y
# Install Neuron Runtime
sudo yum install aws-neuronx-collectives-2.* -y
sudo yum install aws-neuronx-runtime-lib-2.* -y
# Install Neuron Tools
sudo yum install aws-neuronx-tools-2.* -y
# Add PATH
export PATH=/opt/aws/neuron/bin:$PATH
Install the JAX Neuron Plugin
We provide two methods for installing the JAX Neuron plugin. The first is to install
the jax-neuronx
meta-package from the AWS Neuron pip repository. This method provides
a production-ready JAX environment where jax-neuronx
’s major dependencies, namely
jax
, jaxlib
, libneuronxla
, and neuronx-cc
, have undergone thorough testing
by the AWS Neuron team and will have their versions pinned during installation.
python3 -m pip install jax-neuronx[stable] --extra-index-url=https://pip.repos.neuron.amazonaws.com
The second is to install packages jax
, jaxlib
, libneuronxla
,
and neuronx-cc
separately, with jax-neuronx
being an optional addition.
Because libneuronxla
supports a broad range of jaxlib
versions through
the PJRT C-API mechanism, this method provides flexibility when choosing
jax
and jaxlib
versions, enabling JAX users to bring the JAX Neuron plugin
into their own JAX environments.
python3 -m pip install jax==0.4.31 jaxlib==0.4.31 jax-neuronx libneuronxla neuronx-cc==2.* --extra-index-url=https://pip.repos.neuron.amazonaws.com
We can now run some simple JAX programs on the Trainium or Inferentia accelerators.
~$ python3 -c 'import jax; print(jax.numpy.multiply(1, 1))'
Platform 'neuron' is experimental and not all JAX functionality may be correctly supported!
.
Compiler status PASS
1
Compatibility between packages jaxlib
and libneuronxla
can be
determined from PJRT C-API
version.
For more information, see PJRT integration
guide.
To determine compatible JAX versions, you can use the
libneuronxla.supported_clients
API for querying known supported
client packages and their versions.
Help on function supported_clients in module libneuronxla.version:
supported_clients()
Return a description of supported client (jaxlib, torch-xla, etc.) versions,
as a list of strings formatted as `"<package> <version> (PJRT C-API <c-api version>)"`.
For example,
>>> import libneuronxla
>>> libneuronxla.supported_clients()
['jaxlib 0.4.31 (PJRT C-API 0.54)', 'torch_xla 2.2.0 (PJRT C-API 0.35)', 'torch_xla 2.3.0 (PJRT C-API 0.46)']
Note that the list of supported client packages and versions covers
known versions only and may be incomplete. More versions could be
supported, including Google’s future jaxlib
releases, assuming the
PJRT C-API stays compatible with the current release of
libneuronxla
. As a result, we avoid specifying any dependency
relationship between libneuronxla
and jaxlib
. This provides more
freedom when coordinating jax
and libneuronxla
installations.
This document is relevant for: Inf2
, Trn1
, Trn2