Run PyTorch code on TPU Pod slices
PyTorch/XLA requires all TPU VMs to be able to access the model code and data. You can use a startup script to download the software needed to distribute the model data to all TPU VMs.
If you are connecting your TPU VMs to a Virtual Private Cloud (VPC) you must add a firewall rule in your project to allow ingress for ports 8470 - 8479. For more information about adding firewall rules, see Using firewall rules
Set up your environment
-
In the Cloud Shell, run the following command to make sure you are running the current version of
gcloud
:$ gcloud components update
If you need to install
gcloud
, use the following command:$ sudo apt install -y google-cloud-sdk
Create some environment variables:
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
Create the TPU VM
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}
Configure and run the training script
Add your SSH certificate to your project:
ssh-add ~/.ssh/google_compute_engine
Install PyTorch/XLA on all TPU VM workers
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command=" pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
Clone XLA on all TPU VM workers
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command="git clone -b r2.5 https://github.com/pytorch/xla.git"
Run the training script on all workers
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
The training takes about 5 minutes. When it completes, you should see a message similar to the following:
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
Clean up
When you are done with your TPU VM follow these steps to clean up your resources.
Disconnect from the Compute Engine instance, if you have not already done so:
(vm)$ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your Cloud TPU and Compute Engine resources.
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
Verify the resources have been deleted by running
gcloud compute tpus execution-groups list
. The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE}