Skip to content

Commit

Permalink
FIX Make RFECV thread-safe when used with joblib threading backend (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lesteve authored Nov 1, 2024
1 parent 56bbb5a commit a2448b5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def fit(self, X, y, *, groups=None, **params):
func = delayed(_rfe_single_fit)

scores_features = parallel(
func(rfe, self.estimator, X, y, train, test, scorer, routed_params)
func(clone(rfe), self.estimator, X, y, train, test, scorer, routed_params)
for train, test in cv.split(X, y, **routed_params.splitter.split)
)
scores, step_n_features = zip(*scores_features)
Expand Down
19 changes: 19 additions & 0 deletions sklearn/feature_selection/tests/test_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pytest
from joblib import parallel_backend
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal

from sklearn.base import BaseEstimator, ClassifierMixin
Expand Down Expand Up @@ -703,3 +704,21 @@ def test_rfe_with_sample_weight():
rfe_sw_2.fit(X, y, sample_weight=sample_weight_2)

assert not np.array_equal(rfe_sw_2.ranking_, rfe.ranking_)


def test_rfe_with_joblib_threading_backend(global_random_seed):
X, y = make_classification(random_state=global_random_seed)

clf = LogisticRegression()
rfe = RFECV(
estimator=clf,
n_jobs=2,
)

rfe.fit(X, y)
ranking_ref = rfe.ranking_

with parallel_backend("threading"):
rfe.fit(X, y)

assert_array_equal(ranking_ref, rfe.ranking_)

0 comments on commit a2448b5

Please sign in to comment.