|
| 1 | +name: test-tpu |
| 2 | + |
| 3 | +on: |
| 4 | + pull_request: |
| 5 | + paths: |
| 6 | + - .github/workflows/test-tpu.yml |
| 7 | + - .github/actions/check-tpu/** |
| 8 | + workflow_dispatch: |
| 9 | + |
| 10 | +jobs: |
| 11 | + test-check-tpu-action: |
| 12 | + runs-on: linux.google.tpuv7x.1 |
| 13 | + steps: |
| 14 | + - name: Checkout PyTorch |
| 15 | + uses: actions/checkout@v4 |
| 16 | + |
| 17 | + - name: Setup Linux |
| 18 | + uses: ./.github/actions/setup-linux |
| 19 | + |
| 20 | + - name: Check TPU Availability |
| 21 | + id: check-tpu |
| 22 | + uses: ./.github/actions/check-tpu |
| 23 | + |
| 24 | + - name: Verify TPU was detected |
| 25 | + run: | |
| 26 | + echo "has_tpu output: ${{ steps.check-tpu.outputs.has_tpu }}" |
| 27 | + if [[ "${{ steps.check-tpu.outputs.has_tpu }}" != "true" ]]; then |
| 28 | + echo "ERROR: TPU should have been detected on this runner!" |
| 29 | + exit 1 |
| 30 | + fi |
| 31 | + echo "SUCCESS: TPU detected as expected" |
| 32 | +
|
| 33 | + test-tpu-docker-flags: |
| 34 | + runs-on: linux.google.tpuv7x.1 |
| 35 | + steps: |
| 36 | + - name: Checkout PyTorch |
| 37 | + uses: actions/checkout@v4 |
| 38 | + |
| 39 | + - name: Setup Linux |
| 40 | + uses: ./.github/actions/setup-linux |
| 41 | + |
| 42 | + - name: Check TPU Availability |
| 43 | + id: check-tpu |
| 44 | + uses: ./.github/actions/check-tpu |
| 45 | + |
| 46 | + - name: Setup TPU docker flags |
| 47 | + if: steps.check-tpu.outputs.has_tpu == 'true' |
| 48 | + run: echo "TPU_DOCKER_FLAGS=--privileged --network=host -e PJRT_DEVICE=TPU -e TPU_SKIP_MDS_QUERY -e TPU_TOPOLOGY -e TPU_WORKER_ID -e TPU_TOPOLOGY_WRAP -e TPU_CHIPS_PER_HOST_BOUNDS -e TPU_ACCELERATOR_TYPE -e TPU_RUNTIME_METRICS_PORTS -e TPU_TOPOLOGY_ALT -e HOST_BOUNDS -e TPU_HOST_BOUNDS -e VBAR_CONTROL_SERVICE_URL -e CHIPS_PER_HOST_BOUNDS -e TPU_WORKER_HOSTNAMES" >> "$GITHUB_ENV" |
| 49 | + |
| 50 | + - name: Pull uv docker image |
| 51 | + run: docker pull ghcr.io/astral-sh/uv:python3.12-bookworm |
| 52 | + |
| 53 | + - name: Test JAX TPU in Docker |
| 54 | + run: | |
| 55 | + set -x |
| 56 | + # shellcheck disable=SC2086 |
| 57 | + docker run --rm \ |
| 58 | + ${TPU_DOCKER_FLAGS:-} \ |
| 59 | + ghcr.io/astral-sh/uv:python3.12-bookworm \ |
| 60 | + sh -c 'uv pip install --system "jax[tpu]" && uv run python -c "import jax; print(\"JAX devices:\", jax.devices())"' |
| 61 | +
|
| 62 | + - name: Verify output |
| 63 | + run: echo "If we got here, JAX successfully detected TPU devices!" |
0 commit comments