Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _coordinate_descent.py #30416

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

prempraneethkota
Copy link

Purpose

The goal was to allow users to disable the refitting of cross-validation estimators, such as LassoCV, on the full training set after finding the best hyperparameters. This feature is particularly useful for saving computational resources, especially with large datasets.

Changes Made

1. Added refit Parameter

  • Introduced a new refit parameter in the LassoCV class constructor. This parameter enables users to control whether the model should be refitted on the entire training set.
  • Code Changes:
    class LassoCV(RegressorMixin, LinearModelCV):
        def __init__(self, ..., refit=True):
            self.refit = refit
            super().__init__(...)

2.Modified fit Method

Updated the fit method to conditionally refit the model based on the refit parameter. If refit is set to True, the model is refitted on the full training set; otherwise, the model is not refitted.
Code Changes:

def fit(self, X, y, sample_weight=None, **params):
    # Perform the initial fit using cross-validation to find the best alpha
    super().fit(X, y, sample_weight=sample_weight, **params)
    
    # Conditionally refit the model with the parameters selected
    if self.refit:
        self._fit(X, y)
    
    return self

3.Outcome

  • Flexibility: Users can now choose to skip the refitting step if it is unnecessary, providing them with more control over the model fitting process.
  • Efficiency: By disabling refitting, users can save significant computational resources, which is especially beneficial for handling large datasets.

4.Uses

from sklearn.linear_model import LassoCV

# Initialize LassoCV without refitting
lasso_cv = LassoCV(refit=False)
lasso_cv.fit(X_train, y_train)

# Retrieve the best alpha (hyperparameter)
best_alpha = lasso_cv.alpha_

The implemented feature allows users to disable the refitting of LassoCV on the full training set after finding the best hyperparameters, enhancing both flexibility and efficiency in model training.

Purpose
Allow users to disable refitting of cross-validation estimators like LassoCV on the full training set after finding the best hyperparameters, to save computational resources.

Added refit Parameter: Introduced a refit parameter to the LassoCV class to control refitting behavior.

Conditional Refitting: Modified the fit method to refit the model only if refit=True.

Outcome
This change provides flexibility and efficiency, especially for large datasets, by allowing users to skip unnecessary refitting.
Copy link

github-actions bot commented Dec 6, 2024

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


black

black detected issues. Please run black . locally and push the changes. Here you can see the detected issues. Note that running black might also fix some of the issues which might be detected by ruff. Note that the installed black version is black=24.3.0.


--- /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_coordinate_descent.py	2024-12-06 10:57:46.194073+00:00
+++ /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_coordinate_descent.py	2024-12-06 10:58:01.262375+00:00
@@ -908,11 +908,11 @@
         copy_X=True,
         tol=1e-4,
         warm_start=False,
         positive=False,
         random_state=None,
-        selection="cyclic"
+        selection="cyclic",
     ):
         self.alpha = alpha
         self.l1_ratio = l1_ratio
         self.fit_intercept = fit_intercept
         self.precompute = precompute
@@ -2082,11 +2082,11 @@
         tags = super().__sklearn_tags__()
         tags.target_tags.multi_output = False
         return tags
 
     def fit(self, X, y, sample_weight=None, **params):
-        params.pop('refit', None)
+        params.pop("refit", None)
         """Fit Lasso model with coordinate descent.
 
         Fit is on grid of alphas and best alpha estimated by cross-validation.
 
         Parameters
@@ -2127,10 +2127,11 @@
 
         if self.refit:
             self._fit(X, y)
 
         return self
+
 
 class ElasticNetCV(RegressorMixin, LinearModelCV):
     """Elastic Net model with iterative fitting along a regularization path.
 
     See glossary entry for :term:`cross-validation estimator`.
would reformat /home/runner/work/scikit-learn/scikit-learn/sklearn/linear_model/_coordinate_descent.py

Oh no! 💥 💔 💥
1 file would be reformatted, 923 files would be left unchanged.

Generated for commit: e96ee60. Link to the linter CI: here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant