Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Fix device detection when array API dispatch is disabled #30454

3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/30454.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Fix regression when scikit-learn metric called on PyTorch CPU tensors would
raise an error (with array API dispatch disabled which is the default).
By :user:`Loïc Estève <lesteve>`
34 changes: 34 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,40 @@
if isinstance(multioutput, np.ndarray):
metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device)

# When array API dispatch is disabled, and np.asarray works (for example PyTorch
# with CPU device), calling the metric function with such numpy compatible inputs
# should work (albeit by implicitly converting to numpy arrays instead of
# dispatching to the array library).
try:
np.asarray(a_xp)
np.asarray(b_xp)
numpy_as_array_works = True
except TypeError:

Check warning on line 1828 in sklearn/metrics/tests/test_common.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/tests/test_common.py#L1828

Added line #L1828 was not covered by tests
# PyTorch with CUDA device and CuPy raise TypeError consistently.
# Exception type may need to be updated in the future for other
# libraries.
numpy_as_array_works = False

Check warning on line 1832 in sklearn/metrics/tests/test_common.py

View check run for this annotation

Codecov / codecov/patch

sklearn/metrics/tests/test_common.py#L1832

Added line #L1832 was not covered by tests

if numpy_as_array_works:
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
assert_allclose(
metric_xp,
metric_np,
atol=_atol_for_type(dtype_name),
)
metric_xp_mixed_1 = metric(a_np, b_xp, **metric_kwargs)
assert_allclose(
metric_xp_mixed_1,
metric_np,
atol=_atol_for_type(dtype_name),
)
metric_xp_mixed_2 = metric(a_xp, b_np, **metric_kwargs)
assert_allclose(
metric_xp_mixed_2,
metric_np,
atol=_atol_for_type(dtype_name),
)

with config_context(array_api_dispatch=True):
metric_xp = metric(a_xp, b_xp, **metric_kwargs)

Expand Down
13 changes: 10 additions & 3 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,17 @@ def _check_array_api_dispatch(array_api_dispatch):

def _single_array_device(array):
"""Hardware device where the array data resides on."""
if isinstance(array, (numpy.ndarray, numpy.generic)) or not hasattr(
array, "device"
if (
isinstance(array, (numpy.ndarray, numpy.generic))
or not hasattr(array, "device")
# When array API dispatch is disabled, we expect the scikit-learn code
# to use np.asarray so that the resulting NumPy array will implicitly use the
# CPU. In this case, scikit-learn should stay as device neutral as possible,
# hence the use of `device=None` which is accepted by all libraries, before
# and after the expected conversion to NumPy via np.asarray.
or not get_config()["array_api_dispatch"]
):
return "cpu"
return None
else:
return array.device

Expand Down
34 changes: 34 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,40 @@
"transform",
)

try:
np.asarray(X_xp)
np.asarray(y_xp)
# TODO There are a few errors in SearchCV with array-api-strict because
# we end up doing X[train_indices] where X is an array-api-strict array
# and train_indices is a numpy array. array-api-strict insists
# train_indices should be an array-api-strict array. On the other hand,
# all the array API libraries (PyTorch, jax, CuPy) accept indexing with a
# numpy array. This is probably not worth doing anything about for
# now since array-api-strict seems a bit too strict ...
numpy_asarray_works = xp.__name__ != "array_api_strict"

except TypeError:

Check warning on line 1128 in sklearn/utils/estimator_checks.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/estimator_checks.py#L1128

Added line #L1128 was not covered by tests
# PyTorch with CUDA device and CuPy raise TypeError consistently.
# Exception type may need to be updated in the future for other
# libraries.
numpy_asarray_works = False

Check warning on line 1132 in sklearn/utils/estimator_checks.py

View check run for this annotation

Codecov / codecov/patch

sklearn/utils/estimator_checks.py#L1132

Added line #L1132 was not covered by tests

if numpy_asarray_works:
# In this case, array_api_dispatch is disabled and we rely on np.asarray
# being called to convert the non-NumPy inputs to NumPy arrays when needed.
est_fitted_with_as_array = clone(est).fit(X_xp, y_xp)
# We only do a smoke test for now, in order to avoid complicating the
# test function even further.
for method_name in methods:
method = getattr(est_fitted_with_as_array, method_name, None)
if method is None:
continue

if method_name == "score":
method(X_xp, y_xp)
else:
method(X_xp)

for method_name in methods:
method = getattr(est, method_name, None)
if method is None:
Expand Down
33 changes: 21 additions & 12 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def test_device_none_if_no_input():
assert device(None, "name") is None


@skip_if_array_api_compat_not_configured
def test_device_inspection():
class Device:
def __init__(self, name):
Expand All @@ -273,18 +274,26 @@ def __init__(self, device_name):
with pytest.raises(TypeError):
hash(Array("device").device)

# Test raise if on different devices
# If array API dispatch is disabled the device should be ignored. Erroring
# early for different devices would prevent the np.asarray conversion to
# happen. For example, `r2_score(np.ones(5), torch.ones(5))` should work
# fine with array API disabled.
assert device(Array("cpu"), Array("mygpu")) is None

# Test that ValueError is raised if on different devices and array API dispatch is
# enabled.
err_msg = "Input arrays use different devices: cpu, mygpu"
with pytest.raises(ValueError, match=err_msg):
device(Array("cpu"), Array("mygpu"))
with config_context(array_api_dispatch=True):
with pytest.raises(ValueError, match=err_msg):
device(Array("cpu"), Array("mygpu"))

# Test expected value is returned otherwise
array1 = Array("device")
array2 = Array("device")
# Test expected value is returned otherwise
array1 = Array("device")
array2 = Array("device")

assert array1.device == device(array1)
assert array1.device == device(array1, array2)
assert array1.device == device(array1, array1, array2)
assert array1.device == device(array1)
assert array1.device == device(array1, array2)
assert array1.device == device(array1, array1, array2)


# TODO: add cupy to the list of libraries once the the following upstream issue
Expand Down Expand Up @@ -553,7 +562,7 @@ def test_get_namespace_and_device():
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
assert namespace is get_namespace(some_numpy_array)[0]
assert not is_array_api
assert device.type == "cpu"
assert device is None

# Otherwise, expose the torch namespace and device via array API compat
# wrapper.
Expand Down Expand Up @@ -621,8 +630,8 @@ def test_sparse_device(csr_container, dispatch):
try:
with config_context(array_api_dispatch=dispatch):
assert device(a, b) is None
assert device(a, numpy.array([1])) == "cpu"
assert device(a, numpy.array([1])) is None
assert get_namespace_and_device(a, b)[2] is None
assert get_namespace_and_device(a, numpy.array([1]))[2] == "cpu"
assert get_namespace_and_device(a, numpy.array([1]))[2] is None
except ImportError:
raise SkipTest("array_api_compat is not installed")
Loading