Skip to content

Commit

Permalink
MAINT conversion old->new/new->old tags (bis) (#30327)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrin Jalali <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Loïc Estève <[email protected]>
  • Loading branch information
4 people authored Nov 23, 2024
1 parent 27a903b commit 46a7c9a
Show file tree
Hide file tree
Showing 3 changed files with 840 additions and 12 deletions.
27 changes: 27 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,33 @@ def __setstate__(self, state):
except AttributeError:
self.__dict__.update(state)

# TODO(1.7): Remove this method
def _more_tags(self):
"""This code should never be reached since our `get_tags` will fallback on
`__sklearn_tags__` implemented below. We keep it for backward compatibility.
It is tested in `test_base_estimator_more_tags` in
`sklearn/utils/testing/test_tags.py`."""
from sklearn.utils._tags import _to_old_tags, default_tags

warnings.warn(
"The `_more_tags` method is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)
return _to_old_tags(default_tags(self))

# TODO(1.7): Remove this method
def _get_tags(self):
from sklearn.utils._tags import _to_old_tags, get_tags

warnings.warn(
"The `_get_tags` method is deprecated in 1.6 and will be removed in "
"1.7. Please implement the `__sklearn_tags__` method.",
category=FutureWarning,
)

return _to_old_tags(get_tags(self))

def __sklearn_tags__(self):
return Tags(
estimator_type=None,
Expand Down
271 changes: 260 additions & 11 deletions sklearn/utils/_tags.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from itertools import chain

from .fixes import _dataclass_args

Expand Down Expand Up @@ -290,6 +292,71 @@ def default_tags(estimator) -> Tags:
)


# TODO(1.7): Remove this function
def _find_tags_provider(estimator, warn=True):
"""Find the tags provider for an estimator.
Parameters
----------
estimator : estimator object
The estimator to find the tags provider for.
warn : bool, default=True
Whether to warn if the tags provider is not found.
Returns
-------
tag_provider : str
The tags provider for the estimator. Can be one of:
- "_get_tags": to use the old tags infrastructure
- "__sklearn_tags__": to use the new tags infrastructure
"""
mro_model = type(estimator).mro()
tags_mro = OrderedDict()
for klass in mro_model:
tags_provider = []
if "_more_tags" in vars(klass):
tags_provider.append("_more_tags")
if "_get_tags" in vars(klass):
tags_provider.append("_get_tags")
if "__sklearn_tags__" in vars(klass):
tags_provider.append("__sklearn_tags__")
tags_mro[klass.__name__] = tags_provider

all_providers = set(chain.from_iterable(tags_mro.values()))
if "__sklearn_tags__" not in all_providers:
# default on the old tags infrastructure
return "_get_tags"

tag_provider = "__sklearn_tags__"
for klass in tags_mro:
has_get_or_more_tags = any(
provider in tags_mro[klass] for provider in ("_get_tags", "_more_tags")
)
has_sklearn_tags = "__sklearn_tags__" in tags_mro[klass]

if tags_mro[klass] and tag_provider == "__sklearn_tags__": # is it empty
if has_get_or_more_tags and not has_sklearn_tags:
# Case where a class does not implement __sklearn_tags__ and we fallback
# to _get_tags. We should therefore warn for implementing
# __sklearn_tags__.
tag_provider = "_get_tags"
break

if warn and tag_provider == "_get_tags":
warnings.warn(
f"The {estimator.__class__.__name__} or classes from which it inherits "
"use `_get_tags` and `_more_tags`. Please define the "
"`__sklearn_tags__` method, or inherit from `sklearn.base.BaseEstimator` "
"and/or other appropriate mixins such as `sklearn.base.TransformerMixin`, "
"`sklearn.base.ClassifierMixin`, `sklearn.base.RegressorMixin`, and "
"`sklearn.base.OutlierMixin`. From scikit-learn 1.7, not defining "
"`__sklearn_tags__` will raise an error.",
category=FutureWarning,
)
return tag_provider


def get_tags(estimator) -> Tags:
"""Get estimator tags.
Expand All @@ -316,19 +383,201 @@ def get_tags(estimator) -> Tags:
The estimator tags.
"""

if hasattr(estimator, "__sklearn_tags__"):
tag_provider = _find_tags_provider(estimator)

if tag_provider == "__sklearn_tags__":
tags = estimator.__sklearn_tags__()
else:
warnings.warn(
f"Estimator {estimator} has no __sklearn_tags__ attribute, which is "
"defined in `sklearn.base.BaseEstimator`. This will raise an error in "
"scikit-learn 1.8. Please define the __sklearn_tags__ method, or inherit "
"from `sklearn.base.BaseEstimator` and other appropriate mixins such as "
"`sklearn.base.TransformerMixin`, `sklearn.base.ClassifierMixin`, "
"`sklearn.base.RegressorMixin`, and `sklearn.base.ClusterMixin`, and "
"`sklearn.base.OutlierMixin`.",
category=FutureWarning,
# TODO(1.7): Remove this branch of the code
# Let's go through the MRO and patch each class implementing _more_tags
sklearn_tags_provider = {}
more_tags_provider = {}
class_order = []
for klass in reversed(type(estimator).mro()):
if "__sklearn_tags__" in vars(klass):
sklearn_tags_provider[klass] = klass.__sklearn_tags__(estimator) # type: ignore[attr-defined]
class_order.append(klass)
elif "_more_tags" in vars(klass):
more_tags_provider[klass] = klass._more_tags(estimator) # type: ignore[attr-defined]
class_order.append(klass)

# Find differences between consecutive in the case of __sklearn_tags__
# inheritance
sklearn_tags_diff = {}
items = list(sklearn_tags_provider.items())
for current_item, next_item in zip(items[:-1], items[1:]):
current_name, current_tags = current_item
next_name, next_tags = next_item
current_tags = _to_old_tags(current_tags)
next_tags = _to_old_tags(next_tags)

# Compare tags and store differences
diff = {}
for key in current_tags:
if current_tags[key] != next_tags[key]:
diff[key] = next_tags[key]

sklearn_tags_diff[next_name] = diff

tags = {}
for klass in class_order:
if klass in sklearn_tags_diff:
tags.update(sklearn_tags_diff[klass])
elif klass in more_tags_provider:
tags.update(more_tags_provider[klass])

tags = _to_new_tags(
{**_to_old_tags(default_tags(estimator)), **tags}, estimator
)
tags = default_tags(estimator)

return tags


# TODO(1.7): Remove this function
def _safe_tags(estimator, key=None):
warnings.warn(
"The `_safe_tags` function is deprecated in 1.6 and will be removed in "
"1.7. Use the public `get_tags` function instead and make sure to implement "
"the `__sklearn_tags__` method.",
category=FutureWarning,
)
tags = _to_old_tags(get_tags(estimator))

if key is not None:
if key not in tags:
raise ValueError(
f"The key {key} is not defined for the class "
f"{estimator.__class__.__name__}."
)
return tags[key]
return tags


# TODO(1.7): Remove this function
def _to_new_tags(old_tags, estimator=None):
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
input_tags = InputTags(
one_d_array="1darray" in old_tags["X_types"],
two_d_array="2darray" in old_tags["X_types"],
three_d_array="3darray" in old_tags["X_types"],
sparse="sparse" in old_tags["X_types"],
categorical="categorical" in old_tags["X_types"],
string="string" in old_tags["X_types"],
dict="dict" in old_tags["X_types"],
positive_only=old_tags["requires_positive_X"],
allow_nan=old_tags["allow_nan"],
pairwise=old_tags["pairwise"],
)
target_tags = TargetTags(
required=old_tags["requires_y"],
one_d_labels="1dlabels" in old_tags["X_types"],
two_d_labels="2dlabels" in old_tags["X_types"],
positive_only=old_tags["requires_positive_y"],
multi_output=old_tags["multioutput"] or old_tags["multioutput_only"],
single_output=not old_tags["multioutput_only"],
)
if estimator is not None and (
hasattr(estimator, "transform") or hasattr(estimator, "fit_transform")
):
transformer_tags = TransformerTags(
preserves_dtype=old_tags["preserves_dtype"],
)
else:
transformer_tags = None
estimator_type = getattr(estimator, "_estimator_type", None)
if estimator_type == "classifier":
classifier_tags = ClassifierTags(
poor_score=old_tags["poor_score"],
multi_class=not old_tags["binary_only"],
multi_label=old_tags["multilabel"],
)
else:
classifier_tags = None
if estimator_type == "regressor":
regressor_tags = RegressorTags(
poor_score=old_tags["poor_score"],
multi_label=old_tags["multilabel"],
)
else:
regressor_tags = None
return Tags(
estimator_type=estimator_type,
target_tags=target_tags,
transformer_tags=transformer_tags,
classifier_tags=classifier_tags,
regressor_tags=regressor_tags,
input_tags=input_tags,
array_api_support=old_tags["array_api_support"],
no_validation=old_tags["no_validation"],
non_deterministic=old_tags["non_deterministic"],
requires_fit=old_tags["requires_fit"],
_skip_test=old_tags["_skip_test"],
)


# TODO(1.7): Remove this function
def _to_old_tags(new_tags):
"""Utility function convert old tags (dictionary) to new tags (dataclass)."""
if new_tags.classifier_tags:
binary_only = not new_tags.classifier_tags.multi_class
multilabel_clf = new_tags.classifier_tags.multi_label
poor_score_clf = new_tags.classifier_tags.poor_score
else:
binary_only = False
multilabel_clf = False
poor_score_clf = False

if new_tags.regressor_tags:
multilabel_reg = new_tags.regressor_tags.multi_label
poor_score_reg = new_tags.regressor_tags.poor_score
else:
multilabel_reg = False
poor_score_reg = False

if new_tags.transformer_tags:
preserves_dtype = new_tags.transformer_tags.preserves_dtype
else:
preserves_dtype = ["float64"]

tags = {
"allow_nan": new_tags.input_tags.allow_nan,
"array_api_support": new_tags.array_api_support,
"binary_only": binary_only,
"multilabel": multilabel_clf or multilabel_reg,
"multioutput": new_tags.target_tags.multi_output,
"multioutput_only": (
not new_tags.target_tags.single_output and new_tags.target_tags.multi_output
),
"no_validation": new_tags.no_validation,
"non_deterministic": new_tags.non_deterministic,
"pairwise": new_tags.input_tags.pairwise,
"preserves_dtype": preserves_dtype,
"poor_score": poor_score_clf or poor_score_reg,
"requires_fit": new_tags.requires_fit,
"requires_positive_X": new_tags.input_tags.positive_only,
"requires_y": new_tags.target_tags.required,
"requires_positive_y": new_tags.target_tags.positive_only,
"_skip_test": new_tags._skip_test,
"stateless": new_tags.requires_fit,
}
X_types = []
if new_tags.input_tags.one_d_array:
X_types.append("1darray")
if new_tags.input_tags.two_d_array:
X_types.append("2darray")
if new_tags.input_tags.three_d_array:
X_types.append("3darray")
if new_tags.input_tags.sparse:
X_types.append("sparse")
if new_tags.input_tags.categorical:
X_types.append("categorical")
if new_tags.input_tags.string:
X_types.append("string")
if new_tags.input_tags.dict:
X_types.append("dict")
if new_tags.target_tags.one_d_labels:
X_types.append("1dlabels")
if new_tags.target_tags.two_d_labels:
X_types.append("2dlabels")
tags["X_types"] = X_types
return tags
Loading

0 comments on commit 46a7c9a

Please sign in to comment.