Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Array API support for confusion_matrix #30440

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

StefanieSenger
Copy link
Contributor

@StefanieSenger StefanieSenger commented Dec 9, 2024

Reference Issues/PRs

towards #26024

What does this implement/fix? Explain your changes.

This PR aims to add Array API support to confusion_matrix(). I have run the CUDA tests on Colab and they too, pass.

@OmarManzoor @ogrisel @lesteve: do you want to have a look?

Copy link

github-actions bot commented Dec 9, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 49f75b7. Link to the linter CI: here

return np.zeros((n_labels, n_labels), dtype=int)
elif len(np.intersect1d(y_true, labels)) == 0:
return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_)
elif not xp.isin(labels, y_true).any():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is currently only tested for ndarrays.

xp.isin() is not existing in array_api_strict and I am trying to find an alternative way that also works in the strict definition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use sklearn.utils._array_api._isin? Looking at isin status in array API I found data-apis/array-api#854 that mentions scikit-learn has an implementation of it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, I can! ✨
Thanks for the suggestion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the fact that it was not failing with xp.isin makes me think that some additional tests would be needed to test this part of the code with array API (non numpy) inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a test.

Comment on lines 377 to 384
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
y_pred = xp.asarray(
[label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_
)
y_true = xp.asarray(
[label_to_ind.get(x, n_labels + 1) for x in y_true], device=device_
)
Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block within the if need_index_conversion condition is only tested for ndarrays, because of the way our tests are written. It should work for the other array libraries that we currently integrate, but I feel we should in fact test this part of the code?

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails in label_to_ind = {y: x for x, y in enumerate(labels)} for array_api_strict, because these array elements are not hashable.

I will try to fix this (possibly by re-factoring, don't spoiler) and add a test.

Comment on lines +408 to +409
for true, pred, weight in zip(y_true, y_pred, sample_weight):
cm[true, pred] += weight
Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this performant enough? I think it could be, because we are mostly dealing with small matrices at this point. But finding another way might be better. I am not sure how to do this with the tools available in array_api_strict though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the feeling that this loop will kill any computational benefit of array API support. We might as well ensure that y_true and y_pred are numpy arrays using _convert_to_numpy and rely on the coo_matrix trick instead. This would keep the code simpler.

That being said, I think convenience array API support for classification metrics that rely on confusion matrix internally is useful as discussed in #30439 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it seems like python loops do not go well with GPUs. However there doesn't seem to be an alternative with the array api because it doesn't support any sort of advanced indexing.

So either we might have to use the loop if we insist on following the array api or we could simply use the original code by utilizing the _convert_to_numpy as @ogrisel suggested.

else:
cm = xp.zeros((n_labels, n_labels), dtype=dtype, device=device_)
for true, pred, weight in zip(y_true, y_pred, sample_weight):
cm[true, pred] += weight

with np.errstate(all="ignore"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I did a quick research if any of the other array libraries (other than numpy) would raise these warnings for divisions by zero as well. Result: it doesn't seem they do.

But just to be sure, is there any need to handle these warning for any other array library?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any standardization of warnings and exceptions in the array API standard at this point unfortunately:

https://data-apis.org/array-api/latest/design_topics/exceptions.html

@ogrisel ogrisel added the CUDA CI label Dec 9, 2024
@github-actions github-actions bot removed the CUDA CI label Dec 9, 2024
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @StefanieSenger

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing, @OmarManzoor.
I have implemented your suggestions. Would you mind having another look?

(Currently still working on it: I found some problem. So no rush.)

sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
sklearn/metrics/_classification.py Outdated Show resolved Hide resolved
def _nan_to_num(array, xp=None):
"""Substitutes NaN values with 0 and inf values with the maximum or minimum
numbers available for the dtype respectively; like np.nan_to_num."""
if xp is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need this. get_namespace handles the fact that we already have an xp defined so it returns that.

Suggested change
if xp is None:

Comment on lines 1109 to 1110
if xp is None:
xp, _ = get_namespace(array, xp=xp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if xp is None:
xp, _ = get_namespace(array, xp=xp)
xp, _ = get_namespace(array, xp=xp)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants