Run Cloud TPU applications in a Docker container
Docker containers make configuring applications easier by combining your code and all needed dependencies in one distributable package. You can run Docker containers within TPU VMs to simplify configuring and sharing your Cloud TPU applications. This document describes how to set up a Docker container for each ML framework supported by Cloud TPU.
Train a TensorFlow model in a Docker container
TPU device
-
Create a file named
Dockerfile
in your current directory and paste the following textFROM python:3.8 RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so RUN git clone https://github.com/tensorflow/models.git WORKDIR ./models RUN pip install -r official/requirements.txt ENV PYTHONPATH=/models
Create Cloud Storage bucket
gcloud storage buckets create gs://your-bucket-name --location=europe-west4
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-vm-tf-2.18.0-pjrt
Copy the Dockerfile to your TPU VM
gcloud compute tpus tpu-vm scp ./Dockerfile your-tpu-name:
SSH into the TPU VM
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a
Build the Docker image
sudo docker build -t your-image-name .
Start the Docker container
sudo docker run -ti --rm --net=host --name your-container-name --privileged your-image-name bash
Set environment variables
export STORAGE_BUCKET=gs://your-bucket-name export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export MODEL_DIR=${STORAGE_BUCKET}/resnet-2x
Train ResNet
python3 official/vision/train.py \ --tpu=local \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*,trainer.train_steps=100"
When the training script completes, make sure you clean up the resources.
- Type
exit
to exit from the Docker container - Type
exit
to exit from the TPU VM - Delete the TPU VM
$ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
TPU Pod
Create a file named
Dockerfile
in your current directory and paste the following textFROM python:3.8 RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.12.0/tensorflow-2.12.0-cp38-cp38-linux_x86_64.whl RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.6.0/libtpu.so -o /lib/libtpu.so RUN git clone https://github.com/tensorflow/models.git WORKDIR ./models RUN pip install -r official/requirements.txt ENV PYTHONPATH=/models
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v3-32 \ --version=tpu-vm-tf-2.18.0-pod-pjrt
Copy the Dockerfile to your TPU VM
gcloud compute tpus tpu-vm scp ./Dockerfile your-tpu-name:
SSH into the TPU VM
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a
Build the Docker image
sudo docker build -t your-image-name .
Start a Docker container
sudo docker run -ti --rm --net=host --name your-container-name --privileged your-image-name bash
Train ResNet
python3 official/vision/train.py \ --tpu=local \ --experiment=resnet_imagenet \ --mode=train_and_eval \ --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \ --model_dir=${MODEL_DIR} \ --params_override="task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
When the training script completes, make sure you clean up the resources.
- Type
exit
to exit from the Docker container - Type
exit
to exit from the TPU VM - Delete the TPU VM
$ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
Train a PyTorch model in a Docker container
TPU device
Create Cloud TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
SSH into TPU VM
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=europe-west4-a
Start a container in the TPU VM using the nightly PyTorch/XLA image.
sudo docker run -ti --rm --name your-container-name --privileged gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm bash
Configure TPU runtime
There are two PyTorch/XLA runtime options: PJRT and XRT. We recommend you use PJRT unless you have a reason to use XRT. To learn more about the different runtime configurations, see you have a reason to use XRT. To learn more about the different runtime configurations, see the PJRT runtime documentation.
PJRT
export PJRT_DEVICE=TPU
XRT
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
Clone the PyTorch XLA repo
git clone --recursive https://github.com/pytorch/xla.git
Train ResNet50
python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1
When the training script completes, make sure you clean up the resources.
- Type
exit
to exit from the Docker container - Type
exit
to exit from the TPU VM - Delete the TPU VM
$ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
TPU Pod
When you run PyTorch code on a TPU Pod, you must run your code on all TPU
workers at the same time. One way to do this is to use the
gcloud compute tpus tpu-vm ssh
command with the --worker=all
and
--command
flags. The following procedure shows you how create a Docker
image to make setting up each TPU worker easier.
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base
Add the current user to the docker group
gcloud compute tpus tpu-vm ssh your-tpu-name \ --zone=us-central2-b \ --worker=all \ --command="sudo usermod -a -G docker $USER"
Run the training script in a container on all TPU workers.
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=us-central2-b \ --command="docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU gcr.io/tpu-pytorch/xla:r2.0_3.8_tpuvm python /pytorch/xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"
Docker command flags:
--rm
remove the container after its process terminates.--privileged
exposes the TPU device to the container.--net=host
binds all of the container's ports to the TPU VM to allow communication between the hosts in the Pod.-e
set environment variables.
When the training script completes, make sure you clean up the resources.
Delete the TPU VM using the following command:
$ gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b
Train a JAX model in a Docker container
TPU Device
Create the TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type=v2-8 \ --version=tpu-ubuntu2204-base
SSH into TPU VM
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=europe-west4-a
Start Docker daemon in TPU VM
sudo systemctl start docker
Start Docker container
sudo docker run -ti --rm --name your-container-name --privileged --network=host python:3.8 bash
Install JAX
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Install FLAX
pip install --upgrade clu git clone https://github.com/google/flax.git pip install --user -e flax
Run the FLAX MNIST training script
cd flax/examples/mnist python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
When the training script completes, make sure you clean up the resources.
- Type
exit
to exit from the Docker container - Type
exit
to exit from the TPU VM Delete the TPU VM
$ gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a
TPU Pod
When you run JAX code on a TPU Pod, you must run your JAX code on all TPU
workers at the same time. One way to do this is to use the gcloud compute tpus tpu-vm ssh
command with the --worker=all
and --command
flags. The following
procedure shows you how create a Docker image to make setting up each TPU
worker easier.
Create a file named
Dockerfile
in your current directory and paste the following textFROM python:3.8 RUN pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html RUN pip install --upgrade clu RUN git clone https://github.com/google/flax.git RUN pip install --user -e flax WORKDIR ./flax/examples/mnist
Build the Docker image
docker build -t your-image-name .
Add a tag to your Docker image before pushing it to the Artifact Registry. For more information on working with Artifact Registry, see Work with container images.
docker tag your-image-name europe-west-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
Push your Docker image to the Artifact Registry
docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --zone=europe-west4-a \ --accelerator-type==v2-8 \ --version=tpu-ubuntu2204-base
Pull the Docker image from the Artifact Registry on all TPU workers.
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="sudo usermod -a -G docker ${USER}"
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
Run the container on all TPU workers.
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ zone=europe-west4-a \ --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image:your-tag bash"
Run the training script on all TPU workers:
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5"
When the training script completes, make sure you clean up the resources.
Shut down the container on all workers:
gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \ --zone=europe-west4-a \ --command="docker kill your-container-name"
Delete the TPU VM using the following command:
$ gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=europe-west4-a
What's next
- Cloud TPU Tutorials
- Manage TPUs
- Cloud TPU System Architecture
- Run TensorFlow code on TPU Pod slices
- Run JAX code on TPU Pod slices