Skip to content

Commit

Permalink
ENH move estimator type to tags (#30122)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali authored Nov 5, 2024
1 parent e887583 commit 613cff9
Show file tree
Hide file tree
Showing 37 changed files with 489 additions and 249 deletions.
1 change: 0 additions & 1 deletion doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,6 @@ def _get_submodule(module_name, submodule_name):
"ClassifierTags",
"RegressorTags",
"TransformerTags",
"default_tags",
"get_tags",
],
},
Expand Down
33 changes: 13 additions & 20 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,26 +449,19 @@ accepts an optional ``y``.

Estimator types
---------------
Some common functionality depends on the kind of estimator passed.
For example, cross-validation in :class:`model_selection.GridSearchCV` and
:func:`model_selection.cross_val_score` defaults to being stratified when used
on a classifier, but not otherwise. Similarly, scorers for average precision
that take a continuous prediction need to call ``decision_function`` for classifiers,
but ``predict`` for regressors. This distinction between classifiers and regressors
is implemented using the ``_estimator_type`` attribute, which takes a string value.
This attribute should have the following values to work as expected:

- ``"classifier"`` for classifiers
- ``"regressor"`` for regressors
- ``"clusterer"`` for clustering methods
- ``"outlier_detector"`` for outlier detectors
- ``"DensityEstimator"`` for density estimators

Inheriting from :class:`~base.ClassifierMixin`, :class:`~base.RegressorMixin`, :class:`~base.ClusterMixin`,
:class:`~base.OutlierMixin` or :class:`~base.DensityMixin`,
will set the attribute automatically. When a meta-estimator needs to distinguish
among estimator types, instead of checking ``_estimator_type`` directly, helpers
like :func:`base.is_classifier` should be used.
Some common functionality depends on the kind of estimator passed. For example,
cross-validation in :class:`model_selection.GridSearchCV` and
:func:`model_selection.cross_val_score` defaults to being stratified when used on a
classifier, but not otherwise. Similarly, scorers for average precision that take a
continuous prediction need to call ``decision_function`` for classifiers, but
``predict`` for regressors. This distinction between classifiers and regressors is
implemented by inheriting from :class:`~base.ClassifierMixin`,
:class:`~base.RegressorMixin`, :class:`~base.ClusterMixin`, :class:`~base.OutlierMixin`
or :class:`~base.DensityMixin`, which will set the corresponding :term:`estimator tags`
correctly.

When a meta-estimator needs to distinguish among estimator types, instead of checking
the value of the tags directly, helpers like :func:`base.is_classifier` should be used.

Specific models
---------------
Expand Down
17 changes: 4 additions & 13 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,6 @@ General Concepts
the :term:`duck typing` of methods like ``predict_proba`` and through
some special attributes on estimator objects:

.. glossary::

``_estimator_type``
This string-valued attribute identifies an estimator as being a
classifier, regressor, etc. It is set by mixins such as
:class:`base.ClassifierMixin`, but needs to be more explicitly
adopted on a :term:`meta-estimator`. Its value should usually be
checked by way of a helper such as :func:`base.is_classifier`.

For more detailed info, see :ref:`estimator_tags`.

feature
Expand Down Expand Up @@ -859,8 +850,8 @@ Class APIs and Estimator Types
strategy over the binary classification problem.

Classifiers must store a :term:`classes_` attribute after fitting,
and usually inherit from :class:`base.ClassifierMixin`, which sets
their :term:`_estimator_type` attribute.
and inherit from :class:`base.ClassifierMixin`, which sets
their corresponding :term:`estimator tags` correctly.

A classifier can be distinguished from other estimators with
:func:`~base.is_classifier`.
Expand Down Expand Up @@ -1003,8 +994,8 @@ Class APIs and Estimator Types
A :term:`supervised` (or :term:`semi-supervised`) :term:`predictor`
with :term:`continuous` output values.

Regressors usually inherit from :class:`base.RegressorMixin`, which
sets their :term:`_estimator_type` attribute.
Regressors inherit from :class:`base.RegressorMixin`, which sets their
:term:`estimator tags` correctly.

A regressor can be distinguished from other estimators with
:func:`~base.is_regressor`.
Expand Down
24 changes: 12 additions & 12 deletions doc/sphinxext/allow_nan_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ def make_paragraph_for_estimator_type(estimator_type):
# sub-estimator.
est = next(_construct_instances(est_class))

if est.__sklearn_tags__().input_tags.allow_nan:
module_name = ".".join(est_class.__module__.split(".")[:2])
class_title = f"{est_class.__name__}"
class_url = f"./generated/{module_name}.{class_title}.html"
item = nodes.list_item()
para = nodes.paragraph()
para += nodes.reference(
class_title, text=class_title, internal=False, refuri=class_url
)
exists = True
item += para
lst += item
if est.__sklearn_tags__().input_tags.allow_nan:
module_name = ".".join(est_class.__module__.split(".")[:2])
class_title = f"{est_class.__name__}"
class_url = f"./generated/{module_name}.{class_title}.html"
item = nodes.list_item()
para = nodes.paragraph()
para += nodes.reference(
class_title, text=class_title, internal=False, refuri=class_url
)
exists = True
item += para
lst += item
intro += lst
return [intro] if exists else None

Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.base/30122.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- Passing a class object to:func:`~sklearn.base.is_classifier`,
:func:`~sklearn.base.is_regressor`, :func:`~sklearn.base.is_transformer`, and
:func:`~sklearn.base.is_outlier_detector` is now deprecated. Pass an instance
instead.
By `Adrin Jalali`_
6 changes: 6 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.utils/30122.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- Using `_estimator_type` to set the estimator type is deprecated. Inherit from
:class:`~sklearn.base.ClassifierMixin`, :class:`~sklearn.base.RegressorMixin`,
:class:`~sklearn.base.TransformerMixin`, or :class:`~sklearn.base.OutlierMixin`
instead. Alternatively, you can set `estimator_type` in :class:`~sklearn.utils.Tags`
in the `__sklearn_tags__` method.
By `Adrin Jalali`_
Loading

0 comments on commit 613cff9

Please sign in to comment.