Skip to content

Conversation

@OmarManzoor
Copy link
Contributor

Reference Issues/PRs

Towards #26024

What does this implement/fix? Explain your changes.

  • Adds add api support for precision, recall and fbeta_score

Any other comments?

CC: @ogrisel @adrinjalali

@github-actions
Copy link

github-actions bot commented Dec 3, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0b2d3e6. Link to the linter CI: here

Copy link
Member

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just one minor suggestion:

@betatim
Copy link
Member

betatim commented Dec 5, 2024

LGTM, just one question about removing the assert statement.

@OmarManzoor
Copy link
Contributor Author

@ogrisel Do you think we can merge this PR?

Comment on lines 1874 to 1876
denom = beta2 * xp.asarray(
true_sum, dtype=max_float_type, device=device_
) + xp.asarray(pred_sum, dtype=max_float_type, device=device_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't true_sum and pred_sum already be on the device? I don't like that a simple multiplication and sum has become such a long piece of code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array api strict is the only one which causes such issues. They require that the arrays be of the same type. Over here since beta2 is a float it doesn't accept ints and raises the following error:

TypeError: array_api_strict.float64 and array_api_strict.int64 cannot be type promoted together

Anyways I made some changes to only cast it to the correct dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I verified that the CUDA tests work fine on a kaggle kernel

@adrinjalali adrinjalali merged commit e5b9dff into scikit-learn:main Jan 2, 2025
30 checks passed
@OmarManzoor OmarManzoor deleted the array_api_prec_recall branch January 3, 2025 05:57
check_array_api_multilabel_classification_metric,
],
fbeta_score: [
check_array_api_multiclass_classification_metric,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@OmarManzoor naive question - why not include check_array_api_binary_classification_metric here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it works for the binary case I think we can include it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks! I was just wondering if it was left out for a specific reason!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants