Skip to content

Commit

Permalink
Keep decision_function minimal
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Dec 11, 2024
1 parent 84535ef commit b64baaa
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@
from ..utils._array_api import (
_asarray_with_order,
_average,
_convert_to_numpy,
get_namespace,
get_namespace_and_device,
indexing_dtype,
make_converter,
supported_float_dtypes,
)
from ..utils._seq_dataset import (
Expand Down Expand Up @@ -348,16 +346,11 @@ def decision_function(self, X):
this class would be predicted.
"""
check_is_fitted(self)
# X must be in the same namespace as self.coef_; self.intercept_ is
# either a python float or in the same namespace as self.coef_
xp, _ = get_namespace(X)
follow_X = make_converter(X)

X = validate_data(self, X, accept_sparse="csr", reset=False)
coef = follow_X(self.coef_)
if coef.ndim == 2:
coef = coef.T
scores = safe_sparse_dot(X, coef, dense_output=True) + follow_X(self.intercept_)
coef = self.coef_.T if self.coef_.ndim == 2 else self.coef_
scores = safe_sparse_dot(X, coef, dense_output=True) + self.intercept_
return (
xp.reshape(scores, (-1,))
if (scores.ndim > 1 and scores.shape[1] == 1)
Expand Down Expand Up @@ -385,15 +378,7 @@ def predict(self, X):
else:
indices = xp.argmax(scores, axis=1)

classes_xp, _ = get_namespace(self.classes_)
if classes_xp == xp:
return xp.take(self.classes_, indices, axis=0)

return np.take(
_convert_to_numpy(self.classes_, classes_xp),
_convert_to_numpy(indices, xp),
axis=0,
)
return xp.take(self.classes_, indices, axis=0)

def _predict_proba_lr(self, X):
"""Probability estimation for OvR logistic regression.
Expand Down

0 comments on commit b64baaa

Please sign in to comment.