Skip to content

FIX Fix device detection when array API dispatch is disabled #362

FIX Fix device detection when array API dispatch is disabled

FIX Fix device detection when array API dispatch is disabled #362

Workflow file for this run

name: CUDA GPU
# Only run this workflow when a Pull Request is labeled with the
# 'CUDA CI' label.
on:
pull_request:
types:
- labeled
jobs:
build_wheel:
if: contains(github.event.pull_request.labels.*.name, 'CUDA CI')
runs-on: "ubuntu-latest"
name: Build wheel for Pull Request
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: cp312-manylinux_x86_64
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014
CIBW_BUILD_VERBOSITY: 1
CIBW_ARCHS: x86_64
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels
path: ./wheelhouse/*.whl
tests:
if: contains(github.event.pull_request.labels.*.name, 'CUDA CI')
needs: [build_wheel]
runs-on:
group: cuda-gpu-runner-group
# Set this high enough so that the tests can comforatble run. We set a
# timeout to make abusing this workflow less attractive.
timeout-minutes: 20
name: Run Array API unit tests
steps:
- uses: actions/download-artifact@v4
with:
pattern: cibw-wheels
path: ~/dist
- uses: actions/setup-python@v5
with:
# XXX: The 3.12.4 release of Python on GitHub Actions is corrupted:
# https://github.com/actions/setup-python/issues/886
python-version: '3.12.3'
- name: Checkout main repository
uses: actions/checkout@v4
- name: Cache conda environment
id: cache-conda
uses: actions/cache@v4
with:
path: ~/conda
key: ${{ runner.os }}-build-${{ hashFiles('build_tools/github/create_gpu_environment.sh') }}-${{ hashFiles('build_tools/github/pylatest_conda_forge_cuda_array-api_linux-64_conda.lock') }}
- name: Install miniforge
if: ${{ steps.cache-conda.outputs.cache-hit != 'true' }}
run: bash build_tools/github/create_gpu_environment.sh
- name: Install scikit-learn
run: |
source "${HOME}/conda/etc/profile.d/conda.sh"
conda activate sklearn
pip install ~/dist/cibw-wheels/$(ls ~/dist/cibw-wheels)
- name: Run array API tests
run: |
source "${HOME}/conda/etc/profile.d/conda.sh"
conda activate sklearn
python -c "import sklearn; sklearn.show_versions()"
SCIPY_ARRAY_API=1 pytest --pyargs sklearn -k 'array_api'
# Run in /home/runner to not load sklearn from the checkout repo
working-directory: /home/runner