Skip to content

Commit

Permalink
FEAT allow metadata to be transformed in a Pipeline (#28901)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiaming Yuan <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
  • Loading branch information
3 people authored Nov 15, 2024
1 parent 6f12d3f commit 56a4adb
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :class:`pipeline.Pipeline` can now transform metadata up to the step requiring the
metadata, which can be set using the `transform_input` parameter.
By `Adrin Jalali`_
195 changes: 186 additions & 9 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MethodMapping,
_raise_for_params,
_routing_enabled,
get_routing_for_object,
process_routing,
)
from .utils.metaestimators import _BaseComposition, available_if
Expand Down Expand Up @@ -80,6 +81,46 @@ def check(self):
return check


def _cached_transform(
sub_pipeline, *, cache, param_name, param_value, transform_params
):
"""Transform a parameter value using a sub-pipeline and cache the result.
Parameters
----------
sub_pipeline : Pipeline
The sub-pipeline to be used for transformation.
cache : dict
The cache dictionary to store the transformed values.
param_name : str
The name of the parameter to be transformed.
param_value : object
The value of the parameter to be transformed.
transform_params : dict
The metadata to be used for transformation. This passed to the
`transform` method of the sub-pipeline.
Returns
-------
transformed_value : object
The transformed value of the parameter.
"""
if param_name not in cache:
# If the parameter is a tuple, transform each element of the
# tuple. This is needed to support the pattern present in
# `lightgbm` and `xgboost` where users can pass multiple
# validation sets.
if isinstance(param_value, tuple):
cache[param_name] = tuple(
sub_pipeline.transform(element, **transform_params)
for element in param_value
)
else:
cache[param_name] = sub_pipeline.transform(param_value, **transform_params)

return cache[param_name]


class Pipeline(_BaseComposition):
"""
A sequence of data transformers with an optional final predictor.
Expand Down Expand Up @@ -119,6 +160,20 @@ class Pipeline(_BaseComposition):
must define `fit`. All non-last steps must also define `transform`. See
:ref:`Combining Estimators <combining_estimators>` for more details.
transform_input : list of str, default=None
The names of the :term:`metadata` parameters that should be transformed by the
pipeline before passing it to the step consuming it.
This enables transforming some input arguments to ``fit`` (other than ``X``)
to be transformed by the steps of the pipeline up to the step which requires
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
For instance, this can be used to pass a validation set through the pipeline.
You can only set this if metadata routing is enabled, which you
can enable using ``sklearn.set_config(enable_metadata_routing=True)``.
.. versionadded:: 1.6
memory : str or object with the joblib.Memory interface, default=None
Used to cache the fitted transformers of the pipeline. The last step
will never be cached, even if it is a transformer. By default, no
Expand Down Expand Up @@ -184,12 +239,14 @@ class Pipeline(_BaseComposition):
# BaseEstimator interface
_parameter_constraints: dict = {
"steps": [list, Hidden(tuple)],
"transform_input": [list, None],
"memory": [None, str, HasMethods(["cache"])],
"verbose": ["boolean"],
}

def __init__(self, steps, *, memory=None, verbose=False):
def __init__(self, steps, *, transform_input=None, memory=None, verbose=False):
self.steps = steps
self.transform_input = transform_input
self.memory = memory
self.verbose = verbose

Expand Down Expand Up @@ -412,9 +469,92 @@ def _check_method_params(self, method, props, **kwargs):
fit_params_steps[step]["fit_predict"][param] = pval
return fit_params_steps

def _get_metadata_for_step(self, *, step_idx, step_params, all_params):
"""Get params (metadata) for step `name`.
This transforms the metadata up to this step if required, which is
indicated by the `transform_input` parameter.
If a param in `step_params` is included in the `transform_input` list,
it will be transformed.
Parameters
----------
step_idx : int
Index of the step in the pipeline.
step_params : dict
Parameters specific to the step. These are routed parameters, e.g.
`routed_params[name]`. If a parameter name here is included in the
`pipeline.transform_input`, then it will be transformed. Note that
these parameters are *after* routing, so the aliases are already
resolved.
all_params : dict
All parameters passed by the user. Here this is used to call
`transform` on the slice of the pipeline itself.
Returns
-------
dict
Parameters to be passed to the step. The ones which should be
transformed are transformed.
"""
if (
self.transform_input is None
or not all_params
or not step_params
or step_idx == 0
):
# we only need to process step_params if transform_input is set
# and metadata is given by the user.
return step_params

sub_pipeline = self[:step_idx]
sub_metadata_routing = get_routing_for_object(sub_pipeline)
# here we get the metadata required by sub_pipeline.transform
transform_params = {
key: value
for key, value in all_params.items()
if key
in sub_metadata_routing.consumes(
method="transform", params=all_params.keys()
)
}
transformed_params = dict() # this is to be returned
transformed_cache = dict() # used to transform each param once
# `step_params` is the output of `process_routing`, so it has a dict for each
# method (e.g. fit, transform, predict), which are the args to be passed to
# those methods. We need to transform the parameters which are in the
# `transform_input`, before returning these dicts.
for method, method_params in step_params.items():
transformed_params[method] = Bunch()
for param_name, param_value in method_params.items():
# An example of `(param_name, param_value)` is
# `('sample_weight', array([0.5, 0.5, ...]))`
if param_name in self.transform_input:
# This parameter now needs to be transformed by the sub_pipeline, to
# this step. We cache these computations to avoid repeating them.
transformed_params[method][param_name] = _cached_transform(
sub_pipeline,
cache=transformed_cache,
param_name=param_name,
param_value=param_value,
transform_params=transform_params,
)
else:
transformed_params[method][param_name] = param_value
return transformed_params

# Estimator interface

def _fit(self, X, y=None, routed_params=None):
def _fit(self, X, y=None, routed_params=None, raw_params=None):
"""Fit the pipeline except the last step.
routed_params is the output of `process_routing`
raw_params is the parameters passed by the user, used when `transform_input`
is set by the user, to transform metadata using a sub-pipeline.
"""
# shallow copy of steps - this should really be steps_
self.steps = list(self.steps)
self._validate_steps()
Expand All @@ -437,14 +577,20 @@ def _fit(self, X, y=None, routed_params=None):
else:
cloned_transformer = clone(transformer)
# Fit or load from cache the current transformer
step_params = self._get_metadata_for_step(
step_idx=step_idx,
step_params=routed_params[name],
all_params=raw_params,
)

X, fitted_transformer = fit_transform_one_cached(
cloned_transformer,
X,
y,
None,
weight=None,
message_clsname="Pipeline",
message=self._log_message(step_idx),
params=routed_params[name],
params=step_params,
)
# Replace the transformer of the step with the fitted
# transformer. This is necessary when loading the transformer
Expand Down Expand Up @@ -495,11 +641,22 @@ def fit(self, X, y=None, **params):
self : object
Pipeline with fitted steps.
"""
if not _routing_enabled() and self.transform_input is not None:
raise ValueError(
"The `transform_input` parameter can only be set if metadata "
"routing is enabled. You can enable metadata routing using "
"`sklearn.set_config(enable_metadata_routing=True)`."
)

routed_params = self._check_method_params(method="fit", props=params)
Xt = self._fit(X, y, routed_params)
Xt = self._fit(X, y, routed_params, raw_params=params)
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if self._final_estimator != "passthrough":
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = self._get_metadata_for_step(
step_idx=len(self) - 1,
step_params=routed_params[self.steps[-1][0]],
all_params=params,
)
self._final_estimator.fit(Xt, y, **last_step_params["fit"])

return self
Expand Down Expand Up @@ -562,7 +719,11 @@ def fit_transform(self, X, y=None, **params):
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
if last_step == "passthrough":
return Xt
last_step_params = routed_params[self.steps[-1][0]]
last_step_params = self._get_metadata_for_step(
step_idx=len(self) - 1,
step_params=routed_params[self.steps[-1][0]],
all_params=params,
)
if hasattr(last_step, "fit_transform"):
return last_step.fit_transform(
Xt, y, **last_step_params["fit_transform"]
Expand Down Expand Up @@ -1270,7 +1431,7 @@ def _name_estimators(estimators):
return list(zip(names, estimators))


def make_pipeline(*steps, memory=None, verbose=False):
def make_pipeline(*steps, memory=None, transform_input=None, verbose=False):
"""Construct a :class:`Pipeline` from the given estimators.
This is a shorthand for the :class:`Pipeline` constructor; it does not
Expand All @@ -1292,6 +1453,17 @@ def make_pipeline(*steps, memory=None, verbose=False):
or ``steps`` to inspect estimators within the pipeline. Caching the
transformers is advantageous when fitting is time consuming.
transform_input : list of str, default=None
This enables transforming some input arguments to ``fit`` (other than ``X``)
to be transformed by the steps of the pipeline up to the step which requires
them. Requirement is defined via :ref:`metadata routing <metadata_routing>`.
This can be used to pass a validation set through the pipeline for instance.
You can only set this if metadata routing is enabled, which you
can enable using ``sklearn.set_config(enable_metadata_routing=True)``.
.. versionadded:: 1.6
verbose : bool, default=False
If True, the time elapsed while fitting each step will be printed as it
is completed.
Expand All @@ -1315,7 +1487,12 @@ def make_pipeline(*steps, memory=None, verbose=False):
Pipeline(steps=[('standardscaler', StandardScaler()),
('gaussiannb', GaussianNB())])
"""
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)
return Pipeline(
_name_estimators(steps),
transform_input=transform_input,
memory=memory,
verbose=verbose,
)


def _transform_one(transformer, X, y, weight, params=None):
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/metadata_routing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def fit(self, X, y=None, sample_weight="default", metadata="default"):
record_metadata_not_default(
self, sample_weight=sample_weight, metadata=metadata
)
self.fitted_ = True
return self

def transform(self, X, sample_weight="default", metadata="default"):
Expand Down
Loading

0 comments on commit 56a4adb

Please sign in to comment.