Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add array api support for precision, recall and fbeta_score #30395

Merged
merged 6 commits into from
Jan 2, 2025

Conversation

OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

  • Adds add api support for precision, recall and fbeta_score

Any other comments?

CC: @ogrisel @adrinjalali

Copy link

github-actions bot commented Dec 3, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0b2d3e6. Link to the linter CI: here

Copy link
Member

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one minor suggestion:

@betatim
Copy link
Member

betatim commented Dec 5, 2024

LGTM, just one question about removing the assert statement.

@OmarManzoor
Copy link
Contributor Author

@ogrisel Do you think we can merge this PR?

Comment on lines 1874 to 1876
denom = beta2 * xp.asarray(
true_sum, dtype=max_float_type, device=device_
) + xp.asarray(pred_sum, dtype=max_float_type, device=device_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't true_sum and pred_sum already be on the device? I don't like that a simple multiplication and sum has become such a long piece of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array api strict is the only one which causes such issues. They require that the arrays be of the same type. Over here since beta2 is a float it doesn't accept ints and raises the following error:

TypeError: array_api_strict.float64 and array_api_strict.int64 cannot be type promoted together

Anyways I made some changes to only cast it to the correct dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I verified that the CUDA tests work fine on a kaggle kernel

@adrinjalali adrinjalali merged commit e5b9dff into scikit-learn:main Jan 2, 2025
30 checks passed
@OmarManzoor OmarManzoor deleted the array_api_prec_recall branch January 3, 2025 05:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants