Skip to content

Commit

Permalink
MAINT move _estimator_has function to utils (#29319)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefanie Senger <[email protected]>
Co-authored-by: adrinjalali <[email protected]>
  • Loading branch information
3 people authored Nov 5, 2024
1 parent 102663d commit 8f620fd
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 125 deletions.
21 changes: 4 additions & 17 deletions sklearn/ensemble/_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_check_method_params,
_check_sample_weight,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
has_fit_parameter,
validate_data,
Expand Down Expand Up @@ -269,22 +270,6 @@ def _parallel_predict_regression(estimators, estimators_features, X):
)


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the first fitted estimator if available, otherwise we
check the estimator attribute.
"""

def check(self):
if hasattr(self, "estimators_"):
return hasattr(self.estimators_[0], attr)
else: # self.estimator is not None
return hasattr(self.estimator, attr)

return check


class BaseBagging(BaseEnsemble, metaclass=ABCMeta):
"""Base class for Bagging meta-estimator.
Expand Down Expand Up @@ -1033,7 +1018,9 @@ def predict_log_proba(self, X):

return log_proba

@available_if(_estimator_has("decision_function"))
@available_if(
_estimator_has("decision_function", delegates=("estimators_", "estimator"))
)
def decision_function(self, X):
"""Average of the decision functions of the base classifiers.
Expand Down
44 changes: 20 additions & 24 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,13 @@
_check_feature_names_in,
_check_response_method,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
column_or_1d,
)
from ._base import _BaseHeterogeneousEnsemble, _fit_single_estimator


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted `final_estimator_` if available, otherwise we check the
unfitted `final_estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "final_estimator_"):
getattr(self.final_estimator_, attr)
else:
getattr(self.final_estimator, attr)

return True

return check


class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, metaclass=ABCMeta):
"""Base class for stacking method."""

Expand Down Expand Up @@ -364,7 +346,9 @@ def get_feature_names_out(self, input_features=None):

return np.asarray(meta_names, dtype=object)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.
Expand Down Expand Up @@ -732,7 +716,9 @@ def fit(self, X, y, *, sample_weight=None, **fit_params):
fit_params["sample_weight"] = sample_weight
return super().fit(X, y_encoded, **fit_params)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.
Expand Down Expand Up @@ -785,7 +771,11 @@ def predict(self, X, **predict_params):
y_pred = self._label_encoder.inverse_transform(y_pred)
return y_pred

@available_if(_estimator_has("predict_proba"))
@available_if(
_estimator_has(
"predict_proba", delegates=("final_estimator_", "final_estimator")
)
)
def predict_proba(self, X):
"""Predict class probabilities for `X` using the final estimator.
Expand All @@ -809,7 +799,11 @@ def predict_proba(self, X):
y_pred = np.array([preds[:, 0] for preds in y_pred]).T
return y_pred

@available_if(_estimator_has("decision_function"))
@available_if(
_estimator_has(
"decision_function", delegates=("final_estimator_", "final_estimator")
)
)
def decision_function(self, X):
"""Decision function for samples in `X` using the final estimator.
Expand Down Expand Up @@ -1125,7 +1119,9 @@ def fit_transform(self, X, y, *, sample_weight=None, **fit_params):
fit_params["sample_weight"] = sample_weight
return super().fit_transform(X, y, **fit_params)

@available_if(_estimator_has("predict"))
@available_if(
_estimator_has("predict", delegates=("final_estimator_", "final_estimator"))
)
def predict(self, X, **predict_params):
"""Predict target for X.
Expand Down
20 changes: 1 addition & 19 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..utils.metaestimators import available_if
from ..utils.validation import (
_check_feature_names,
_estimator_has,
_num_features,
check_is_fitted,
check_scalar,
Expand Down Expand Up @@ -76,25 +77,6 @@ def _calculate_threshold(estimator, importances, threshold):
return threshold


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
"""Meta-transformer for selecting features based on importance weights.
Expand Down
20 changes: 1 addition & 19 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..utils.validation import (
_check_method_params,
_deprecate_positional_args,
_estimator_has,
check_is_fitted,
validate_data,
)
Expand Down Expand Up @@ -64,25 +65,6 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, routed_params):
return rfe.step_scores_, rfe.step_n_features_


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted `estimator_` if available, otherwise we check the
unfitted `estimator`. We raise the original `AttributeError` if `attr` does
not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
"""Feature ranking with recursive feature elimination.
Expand Down
18 changes: 1 addition & 17 deletions sklearn/model_selection/_classification_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
_check_method_params,
_estimator_has,
_num_samples,
check_is_fitted,
indexable,
Expand All @@ -50,23 +51,6 @@ def _check_is_fitted(estimator):
check_is_fitted(estimator, "estimator_")


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted estimator if available, otherwise we
check the unfitted estimator.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)
return True

return check


class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
"""Base class for binary classifiers that set a non-default decision threshold.
Expand Down
18 changes: 9 additions & 9 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _check_refit(search_cv, attr):
)


def _estimator_has(attr):
def _search_estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
Calling a prediction method will only be available if `refit=True`. In
Expand Down Expand Up @@ -555,7 +555,7 @@ def score(self, X, y=None, **params):
score = score[self.refit]
return score

@available_if(_estimator_has("score_samples"))
@available_if(_search_estimator_has("score_samples"))
def score_samples(self, X):
"""Call score_samples on the estimator with the best found parameters.
Expand All @@ -578,7 +578,7 @@ def score_samples(self, X):
check_is_fitted(self)
return self.best_estimator_.score_samples(X)

@available_if(_estimator_has("predict"))
@available_if(_search_estimator_has("predict"))
def predict(self, X):
"""Call predict on the estimator with the best found parameters.
Expand All @@ -600,7 +600,7 @@ def predict(self, X):
check_is_fitted(self)
return self.best_estimator_.predict(X)

@available_if(_estimator_has("predict_proba"))
@available_if(_search_estimator_has("predict_proba"))
def predict_proba(self, X):
"""Call predict_proba on the estimator with the best found parameters.
Expand All @@ -623,7 +623,7 @@ def predict_proba(self, X):
check_is_fitted(self)
return self.best_estimator_.predict_proba(X)

@available_if(_estimator_has("predict_log_proba"))
@available_if(_search_estimator_has("predict_log_proba"))
def predict_log_proba(self, X):
"""Call predict_log_proba on the estimator with the best found parameters.
Expand All @@ -646,7 +646,7 @@ def predict_log_proba(self, X):
check_is_fitted(self)
return self.best_estimator_.predict_log_proba(X)

@available_if(_estimator_has("decision_function"))
@available_if(_search_estimator_has("decision_function"))
def decision_function(self, X):
"""Call decision_function on the estimator with the best found parameters.
Expand All @@ -669,7 +669,7 @@ def decision_function(self, X):
check_is_fitted(self)
return self.best_estimator_.decision_function(X)

@available_if(_estimator_has("transform"))
@available_if(_search_estimator_has("transform"))
def transform(self, X):
"""Call transform on the estimator with the best found parameters.
Expand All @@ -691,7 +691,7 @@ def transform(self, X):
check_is_fitted(self)
return self.best_estimator_.transform(X)

@available_if(_estimator_has("inverse_transform"))
@available_if(_search_estimator_has("inverse_transform"))
def inverse_transform(self, X=None, Xt=None):
"""Call inverse_transform on the estimator with the best found params.
Expand Down Expand Up @@ -746,7 +746,7 @@ def classes_(self):
Only available when `refit=True` and the estimator is a classifier.
"""
_estimator_has("classes_")(self)
_search_estimator_has("classes_")(self)
return self.best_estimator_.classes_

def _run_search(self, evaluate_candidates):
Expand Down
21 changes: 1 addition & 20 deletions sklearn/semi_supervised/_self_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,14 @@
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.validation import check_is_fitted, validate_data
from ..utils.validation import _estimator_has, check_is_fitted, validate_data

__all__ = ["SelfTrainingClassifier"]

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause


def _estimator_has(attr):
"""Check if we can delegate a method to the underlying estimator.
First, we check the fitted `estimator_` if available, otherwise we check
the unfitted `estimator`. We raise the original `AttributeError` if
`attr` does not exist. This function is used together with `available_if`.
"""

def check(self):
if hasattr(self, "estimator_"):
getattr(self.estimator_, attr)
else:
getattr(self.estimator, attr)

return True

return check


class SelfTrainingClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
"""Self-training classifier.
Expand Down
Loading

0 comments on commit 8f620fd

Please sign in to comment.