Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX methods in model_selection/_validation accept params=None with metadata routing enabled #30451

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

Fixes #30447

This fixes an issue with functions in model_selection/_validation.py where they'd raise if params=None and metadata routing is enabled.

cc @StefanieSenger @OmarManzoor @jeremiedbb

Copy link

github-actions bot commented Dec 10, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 54aa72c. Link to the linter CI: here

],
)
@config_context(enable_metadata_routing=True)
def test_cross_validate_params_none(func, extra_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a docstring for the test? It would be difficult to understand what it's testing later otherwise.

Or maybe, we can join this test with the next one (test_passed_unrequested_metadata) and modify the docstring there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you. :)

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @adrinjalali

@@ -1636,6 +1637,7 @@ def permutation_test_score(
params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")

X, y, groups = indexable(X, y, groups)
params = params or {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems redundant since we're calling _check_params_groups_deprecation just before, which already does this. I don't think params can be None here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

Copy link
Contributor

@StefanieSenger StefanieSenger Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to leave it there, because the call to _check_params_groups_deprecation will be removed later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd put it there when we remove that function. And we don't have to worry about it breaking code, cause the test will fail once we remove it.

params = {} if params is None else params
params = params or {}
Copy link
Member

@betatim betatim Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change? The original is more specific in only replacing None, and not things like {} and the various other things that are false-y.

(same for the other two instances of this in this PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's always a dictionary anyway

@@ -1172,6 +1172,7 @@ def cross_val_predict(
"""
_check_groups_routing_disabled(groups)
X, y = indexable(X, y)
params = params or {}
Copy link
Member

@betatim betatim Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to pass params to process_routing later, which motivates the "check if this is None", but can we not find a way to pass it to process_routing without having to have this line everywhere? Mostly because I think we will keep forgetting it :-/ and because it is repetitive.

The only way I can think of is to not use **params but something like kw_params=params and then checking for None in process_routing. Which isn't a super solution, but maybe someone smarter has a good idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are only needed in functions, not where we have estimators and routing from their methods. There these are passed to the method as **kwargs and therefore never None. We have very few functions where we accept a params arg and then route those. As you can see in the implementation of the routing here, it's quite convoluted and the routing machinery is not designed for these functions. So I think this single line is not adding any complexity compared to how we've implemented routing here.

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

cross_validate raises an exception when metadata routing is enabled
5 participants