Skip to content

Commit

Permalink
Remove code that is not needed in tabmat v4 / glum v3 (#741)
Browse files Browse the repository at this point in the history
* Remove check_array from predict()

We don't need it here as predict calls linear_redictor, and the latter does this check. We can avoid doing it twice.

* Remove _name_categorical_variable parts

There is no need for those as Tabmat v4 handles variable names internally.

---------

Co-authored-by: Martin Stancsics <[email protected]>
  • Loading branch information
MatthiasSchmidtblaicherQC and stanmart authored Dec 12, 2023
1 parent 248c1dc commit 512740d
Showing 1 changed file with 0 additions and 36 deletions.
36 changes: 0 additions & 36 deletions src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,6 @@ def _check_offset(
return offset


def _name_categorical_variables(
categories: tuple[str], column_name: str, drop_first: bool
):
new_names = [
f"{column_name}__{category}" for category in categories[int(drop_first) :]
]
if len(new_names) == 0:
raise ValueError(
f"Categorical column: {column_name}, contains only one category. "
+ "This should be dropped from the feature matrix."
)
return new_names


def _parse_formula(
formula: FormulaSpec, include_intercept: bool = True
) -> tuple[Optional[Formula], Formula]:
Expand Down Expand Up @@ -1424,16 +1410,6 @@ def predict(
)
X = self._convert_from_pandas(X, context=captured_context)

X = check_array_tabmat_compliant(
X,
accept_sparse=["csr", "csc", "coo"],
dtype="numeric",
copy=self._should_copy_X(),
ensure_2d=True,
allow_nd=False,
drop_first=getattr(self, "drop_first", False),
)

eta = self.linear_predictor(
X, offset=offset, alpha_index=alpha_index, alpha=alpha
)
Expand Down Expand Up @@ -2718,18 +2694,6 @@ def _set_up_and_check_fit_args(
self.feature_dtypes_ = X.dtypes.to_dict()

if any(X.dtypes == "category"):
self.feature_names_ = list(
chain.from_iterable(
_name_categorical_variables(
dtype.categories,
column,
getattr(self, "drop_first", False),
)
if isinstance(dtype, pd.CategoricalDtype)
else [column]
for column, dtype in zip(X.columns, X.dtypes)
)
)

def _expand_categorical_penalties(penalty, X, drop_first):
"""
Expand Down

0 comments on commit 512740d

Please sign in to comment.