Skip to content

Commit

Permalink
API get_scorer returns a copy and introduce get_scorer_names (scikit-…
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali authored Mar 19, 2022
1 parent f9321be commit 7dc97a3
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 21 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,7 @@ details.

metrics.check_scoring
metrics.get_scorer
metrics.get_scorer_names
metrics.make_scorer

Classification metrics
Expand Down
15 changes: 8 additions & 7 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ Usage examples:
>>> model = svm.SVC()
>>> cross_val_score(model, X, y, cv=5, scoring='wrong_choice')
Traceback (most recent call last):
ValueError: 'wrong_choice' is not a valid scoring value. Use sorted(sklearn.metrics.SCORERS.keys()) to get valid options.
ValueError: 'wrong_choice' is not a valid scoring value. Use
sklearn.metrics.get_scorer_names() to get valid options.

.. note::

The values listed by the ``ValueError`` exception correspond to the functions measuring
prediction accuracy described in the following sections.
The scorer objects for those functions are stored in the dictionary
``sklearn.metrics.SCORERS``.
The values listed by the ``ValueError`` exception correspond to the
functions measuring prediction accuracy described in the following
sections. You can retrieve the names of all available scorers by calling
:func:`~sklearn.metrics.get_scorer_names`.

.. currentmodule:: sklearn.metrics

Expand Down Expand Up @@ -563,8 +564,8 @@ or *informedness*.
Machine Learning for Predictive Data Analytics: Algorithms, Worked Examples,
and Case Studies <https://mitpress.mit.edu/books/fundamentals-machine-learning-predictive-data-analytics>`_,
2015.
.. [Urbanowicz2015] Urbanowicz R.J., Moore, J.H. :doi:`ExSTraCS 2.0: description
and evaluation of a scalable learning classifier
.. [Urbanowicz2015] Urbanowicz R.J., Moore, J.H. :doi:`ExSTraCS 2.0: description
and evaluation of a scalable learning classifier
system <10.1007/s12065-015-0128-8>`, Evol. Intel. (2015) 8: 89.
.. _cohen_kappa:
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ Changelog
- |Enhancement| :func:`metrics.top_k_accuracy_score` raises an improved error
message when `y_true` is binary and `y_score` is 2d. :pr:`22284` by `Thomas Fan`_.

- |API| `metrics.SCORERS` is now deprecated and will be removed in 1.3. Please
use :func:`~metrics.get_scorer_names` to retrieve the names of all available
scorers. :pr:`22866` by `Adrin Jalali`_.

- |API| :class:`metrics.DistanceMetric` has been moved from
:mod:`sklearn.neighbors` to :mod:`sklearn.metric`.
Using `neighbors.DistanceMetric` for imports is still valid for
Expand Down
3 changes: 3 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@
from ._scorer import make_scorer
from ._scorer import SCORERS
from ._scorer import get_scorer
from ._scorer import get_scorer_names


from ._plot.det_curve import plot_det_curve
from ._plot.det_curve import DetCurveDisplay
Expand Down Expand Up @@ -170,6 +172,7 @@
"roc_auc_score",
"roc_curve",
"SCORERS",
"get_scorer_names",
"silhouette_samples",
"silhouette_score",
"top_k_accuracy_score",
Expand Down
50 changes: 45 additions & 5 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from collections import Counter

import numpy as np
import copy
import warnings

from . import (
r2_score,
Expand Down Expand Up @@ -389,6 +391,8 @@ def get_scorer(scoring):
"""Get a scorer from string.
Read more in the :ref:`User Guide <scoring_parameter>`.
:func:`~sklearn.metrics.get_scorer_names` can be used to retrieve the names
of all available scorers.
Parameters
----------
Expand All @@ -399,14 +403,20 @@ def get_scorer(scoring):
-------
scorer : callable
The scorer.
Notes
-----
When passed a string, this function always returns a copy of the scorer
object. Calling `get_scorer` twice for the same scorer results in two
separate scorer objects.
"""
if isinstance(scoring, str):
try:
scorer = SCORERS[scoring]
scorer = copy.deepcopy(_SCORERS[scoring])
except KeyError:
raise ValueError(
"%r is not a valid scoring value. "
"Use sorted(sklearn.metrics.SCORERS.keys()) "
"Use sklearn.metrics.get_scorer_names() "
"to get valid options." % scoring
)
else:
Expand Down Expand Up @@ -747,7 +757,21 @@ def make_scorer(
fowlkes_mallows_scorer = make_scorer(fowlkes_mallows_score)


SCORERS = dict(
# TODO(1.3) Remove
class _DeprecatedScorers(dict):
"""A temporary class to deprecate SCORERS."""

def __getitem__(self, item):
warnings.warn(
"sklearn.metrics.SCORERS is deprecated and will be removed in v1.3. "
"Please use sklearn.metrics.get_scorer_names to get a list of available "
"scorers and sklearn.metrics.get_metric to get scorer.",
FutureWarning,
)
return super().__getitem__(item)


_SCORERS = dict(
explained_variance=explained_variance_scorer,
r2=r2_scorer,
max_error=max_error_scorer,
Expand Down Expand Up @@ -784,13 +808,29 @@ def make_scorer(
)


def get_scorer_names():
"""Get the names of all available scorers.
These names can be passed to :func:`~sklearn.metrics.get_scorer` to
retrieve the scorer object.
Returns
-------
list of str
Names of all available scorers.
"""
return sorted(_SCORERS.keys())


for name, metric in [
("precision", precision_score),
("recall", recall_score),
("f1", f1_score),
("jaccard", jaccard_score),
]:
SCORERS[name] = make_scorer(metric, average="binary")
_SCORERS[name] = make_scorer(metric, average="binary")
for average in ["macro", "micro", "samples", "weighted"]:
qualified_name = "{0}_{1}".format(name, average)
SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)
_SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)

SCORERS = _DeprecatedScorers(_SCORERS)
31 changes: 22 additions & 9 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_MultimetricScorer,
_check_multimetric_scoring,
)
from sklearn.metrics import make_scorer, get_scorer, SCORERS
from sklearn.metrics import make_scorer, get_scorer, SCORERS, get_scorer_names
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import LinearSVC
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -220,8 +220,8 @@ def __call__(self, est, X, y):

def test_all_scorers_repr():
# Test that all scorers have a working repr
for name, scorer in SCORERS.items():
repr(scorer)
for name in get_scorer_names():
repr(get_scorer(name))


def check_scoring_validator_for_single_metric_usecases(scoring_validator):
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_classification_binary_scores(scorer_name, metric):
clf = LinearSVC(random_state=0)
clf.fit(X_train, y_train)

score = SCORERS[scorer_name](clf, X_test, y_test)
score = get_scorer(scorer_name)(clf, X_test, y_test)
expected_score = metric(y_test, clf.predict(X_test))
assert_almost_equal(score, expected_score)

Expand Down Expand Up @@ -444,7 +444,7 @@ def test_classification_multiclass_scores(scorer_name, metric):

clf = DecisionTreeClassifier(random_state=0)
clf.fit(X_train, y_train)
score = SCORERS[scorer_name](clf, X_test, y_test)
score = get_scorer(scorer_name)(clf, X_test, y_test)
expected_score = metric(y_test, clf.predict(X_test))
assert score == pytest.approx(expected_score)

Expand Down Expand Up @@ -617,7 +617,8 @@ def test_classification_scorer_sample_weight():
# get sensible estimators for each metric
estimator = _make_estimators(X_train, y_train, y_ml_train)

for name, scorer in SCORERS.items():
for name in get_scorer_names():
scorer = get_scorer(name)
if name in REGRESSION_SCORERS:
# skip the regression scores
continue
Expand Down Expand Up @@ -672,7 +673,8 @@ def test_regression_scorer_sample_weight():
reg = DecisionTreeRegressor(random_state=0)
reg.fit(X_train, y_train)

for name, scorer in SCORERS.items():
for name in get_scorer_names():
scorer = get_scorer(name)
if name not in REGRESSION_SCORERS:
# skip classification scorers
continue
Expand Down Expand Up @@ -701,7 +703,7 @@ def test_regression_scorer_sample_weight():
)


@pytest.mark.parametrize("name", SCORERS)
@pytest.mark.parametrize("name", get_scorer_names())
def test_scorer_memmap_input(name):
# Non-regression test for #6147: some score functions would
# return singleton memmap when computed on memmap data instead of scalar
Expand All @@ -715,7 +717,7 @@ def test_scorer_memmap_input(name):

# UndefinedMetricWarning for P / R scores
with ignore_warnings():
scorer, estimator = SCORERS[name], ESTIMATORS[name]
scorer, estimator = get_scorer(name), ESTIMATORS[name]
if name in MULTILABEL_ONLY_SCORERS:
score = scorer(estimator, X_mm, y_ml_mm_1)
else:
Expand Down Expand Up @@ -1120,6 +1122,17 @@ def test_scorer_select_proba_error(scorer):
scorer(lr, X, y)


def test_get_scorer_return_copy():
# test that get_scorer returns a copy
assert get_scorer("roc_auc") is not get_scorer("roc_auc")


# TODO(1.3) Remove
def test_SCORERS_deprecated():
with pytest.warns(FutureWarning, match="is deprecated and will be removed in v1.3"):
SCORERS["roc_auc"]


def test_scorer_no_op_multiclass_select_proba():
# check that calling a ProbaScorer on a multiclass problem do not raise
# even if `y_true` would be binary during the scoring.
Expand Down

0 comments on commit 7dc97a3

Please sign in to comment.