Skip to content

Commit

Permalink
FIX roc_auc_curve: Return np.nan instead of 0.0 for single class (#30103
Browse files Browse the repository at this point in the history
)

Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
janezd and glemaitre authored Oct 29, 2024
1 parent 087d8b7 commit fff920e
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions doc/whats_new/upcoming_changes/sklearn.metrics/27412.fix.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
- :func:`metrics.roc_auc_score` will now correctly return 0.0 and
- :func:`metrics.roc_auc_score` will now correctly return np.nan and
warn user if only one class is present in the labels.
By :user:`Gleb Levitski <glevv>`
By :user:`Gleb Levitski <glevv>` and :user:`Janez Demšar <janezd>`
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/30013.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :func:`metrics.roc_auc_score` will now correctly return np.nan and
warn user if only one class is present in the labels.
By :user:`Gleb Levitski <glevv>` and :user:`Janez Demšar <janezd>`
5 changes: 2 additions & 3 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,11 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
warnings.warn(
(
"Only one class is present in y_true. ROC AUC score "
"is not defined in that case. The score is set to "
"0.0."
"is not defined in that case."
),
UndefinedMetricWarning,
)
return 0.0
return np.nan

fpr, tpr, _ = roc_curve(y_true, y_score, sample_weight=sample_weight)
if max_fpr is None or max_fpr == 1:
Expand Down
5 changes: 3 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from functools import partial
from inspect import signature
from itertools import chain, permutations, product
Expand Down Expand Up @@ -843,9 +844,9 @@ def test_format_invariance_with_1d_vectors(name):
):
if "roc_auc" in name:
# for consistency between the `roc_cuve` and `roc_auc_score`
# 0.0 is returned and an `UndefinedMetricWarning` is raised
# np.nan is returned and an `UndefinedMetricWarning` is raised
with pytest.warns(UndefinedMetricWarning):
assert metric(y1_row, y2_row) == pytest.approx(0.0)
assert math.isnan(metric(y1_row, y2_row))
else:
with pytest.raises(ValueError):
metric(y1_row, y2_row)
Expand Down
7 changes: 5 additions & 2 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import re

import numpy as np
Expand Down Expand Up @@ -370,7 +371,8 @@ def test_roc_curve_toydata():
"ROC AUC score is not defined in that case."
)
with pytest.warns(UndefinedMetricWarning, match=expected_message):
roc_auc_score(y_true, y_score)
auc = roc_auc_score(y_true, y_score)
assert math.isnan(auc)

# case with no negative samples
y_true = [1, 1]
Expand All @@ -388,7 +390,8 @@ def test_roc_curve_toydata():
"ROC AUC score is not defined in that case."
)
with pytest.warns(UndefinedMetricWarning, match=expected_message):
roc_auc_score(y_true, y_score)
auc = roc_auc_score(y_true, y_score)
assert math.isnan(auc)

# Multi-label classification task
y_true = np.array([[0, 1], [0, 1]])
Expand Down

0 comments on commit fff920e

Please sign in to comment.