Skip to content

Commit

Permalink
[azure parallel] use device=None rather than device="cpu" to make arr…
Browse files Browse the repository at this point in the history
…ay-api-strict happy
  • Loading branch information
lesteve committed Dec 11, 2024
1 parent aee11aa commit d90dbb9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _single_array_device(array):
# to do np.asarray so that the resulting array will be on the CPU.
or not get_config()["array_api_dispatch"]
):
return "cpu"
return None
else:
return array.device

Expand Down
9 changes: 5 additions & 4 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,11 @@ def check_array_api_input(
try:
np.asarray(X_xp)
np.asarray(y_xp)
# TODO For some reason there are a few errors with array-api-strict.
# Probably not worth investigating for now, since using
# array-api-strict with array API disabled does not seem a very
# relevant use case.
# TODO There are a few errors in SearchCV with array-api-strict because
# we end up doing X[train_indices] where X is a array-api-strict array
# and indices a numpy array. Probably not worth investigating for now,
# since using array-api-strict with array API disabled does not seem a
# very relevant.
numpy_asarray_works = xp.__name__ != "array_api_strict"

except TypeError:
Expand Down
20 changes: 10 additions & 10 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,21 @@ def __init__(self, device_name):
# early for different devices would prevent the np.asarray conversion to
# happen. For example, `r2_score(np.ones(5), torch.ones(5))` should work
# fine with array API disabled.
assert device(Array("cpu"), Array("mygpu")) == "cpu"
assert device(Array("cpu"), Array("mygpu")) is None

# Test raise if on different devices and array API dispatch is enabled
err_msg = "Input arrays use different devices: cpu, mygpu"
with config_context(array_api_dispatch=True):
with pytest.raises(ValueError, match=err_msg):
device(Array("cpu"), Array("mygpu"))

# Test expected value is returned otherwise
array1 = Array("device")
array2 = Array("device")
# Test expected value is returned otherwise
array1 = Array("device")
array2 = Array("device")

assert array1.device == device(array1)
assert array1.device == device(array1, array2)
assert array1.device == device(array1, array1, array2)
assert array1.device == device(array1)
assert array1.device == device(array1, array2)
assert array1.device == device(array1, array1, array2)


# TODO: add cupy to the list of libraries once the the following upstream issue
Expand Down Expand Up @@ -560,7 +560,7 @@ def test_get_namespace_and_device():
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
assert namespace is get_namespace(some_numpy_array)[0]
assert not is_array_api
assert device == "cpu"
assert device is None

# Otherwise, expose the torch namespace and device via array API compat
# wrapper.
Expand Down Expand Up @@ -628,8 +628,8 @@ def test_sparse_device(csr_container, dispatch):
try:
with config_context(array_api_dispatch=dispatch):
assert device(a, b) is None
assert device(a, numpy.array([1])) == "cpu"
assert device(a, numpy.array([1])) is None
assert get_namespace_and_device(a, b)[2] is None
assert get_namespace_and_device(a, numpy.array([1]))[2] == "cpu"
assert get_namespace_and_device(a, numpy.array([1]))[2] is None
except ImportError:
raise SkipTest("array_api_compat is not installed")

0 comments on commit d90dbb9

Please sign in to comment.