Skip to content

Commit 26391b0

Browse files
authored
fix: Substrait ODFVs for online (#4064)
* fix substrait odfvs for online, add tests Signed-off-by: tokoko <[email protected]> * fix formatting Signed-off-by: tokoko <[email protected]> * change odfv substrait test dates relative to start_date and end_date Signed-off-by: tokoko <[email protected]> * force tests rerun Signed-off-by: tokoko <[email protected]> --------- Signed-off-by: tokoko <[email protected]>
1 parent d82d1ec commit 26391b0

File tree

4 files changed

+43
-21
lines changed

4 files changed

+43
-21
lines changed

sdk/python/feast/feature_store.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,11 +2037,10 @@ def _augment_response_with_on_demand_transforms(
20372037

20382038
proto_values = []
20392039
for selected_feature in selected_subset:
2040-
if odfv.mode in ["python", "pandas"]:
2041-
feature_vector = transformed_features[selected_feature]
2042-
proto_values.append(
2043-
python_values_to_proto_values(feature_vector, ValueType.UNKNOWN)
2044-
)
2040+
feature_vector = transformed_features[selected_feature]
2041+
proto_values.append(
2042+
python_values_to_proto_values(feature_vector, ValueType.UNKNOWN)
2043+
)
20452044

20462045
odfv_result_names |= set(selected_subset)
20472046

sdk/python/feast/infra/offline_stores/offline_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def to_arrow(
128128
features_df = self._to_df_internal(timeout=timeout)
129129
if self.on_demand_feature_views:
130130
for odfv in self.on_demand_feature_views:
131-
if odfv.mode != "pandas":
131+
if odfv.mode not in {"pandas", "substrait"}:
132132
raise Exception(
133133
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
134134
)

sdk/python/feast/on_demand_feature_view.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,9 @@ def get_transformed_features(
465465
return self.get_transformed_features_dict(
466466
feature_dict=features,
467467
)
468-
elif self.mode == "pandas" and isinstance(features, pd.DataFrame):
468+
elif self.mode in {"pandas", "substrait"} and isinstance(
469+
features, pd.DataFrame
470+
):
469471
return self.get_transformed_features_df(
470472
df_with_features=features,
471473
full_feature_names=full_feature_names,

sdk/python/tests/unit/test_on_demand_substrait_transformation.py renamed to sdk/python/tests/unit/test_substrait_transformation.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_ibis_pandas_parity():
6060
@on_demand_feature_view(
6161
sources=[driver_stats_fv],
6262
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
63+
mode="pandas",
6364
)
6465
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
6566
df = pd.DataFrame()
@@ -84,30 +85,50 @@ def substrait_view(inputs: Table) -> Table:
8485
[driver, driver_stats_source, driver_stats_fv, substrait_view, pandas_view]
8586
)
8687

88+
store.materialize(
89+
start_date=start_date,
90+
end_date=end_date,
91+
)
92+
8793
entity_df = pd.DataFrame.from_dict(
8894
{
8995
# entity's join key -> entity values
9096
"driver_id": [1001, 1002, 1003],
9197
# "event_timestamp" (reserved key) -> timestamps
9298
"event_timestamp": [
93-
datetime(2021, 4, 12, 10, 59, 42),
94-
datetime(2021, 4, 12, 8, 12, 10),
95-
datetime(2021, 4, 12, 16, 40, 26),
99+
start_date + timedelta(days=4),
100+
start_date + timedelta(days=5),
101+
start_date + timedelta(days=6),
96102
],
97103
}
98104
)
99105

106+
requested_features = [
107+
"driver_hourly_stats:conv_rate",
108+
"driver_hourly_stats:acc_rate",
109+
"driver_hourly_stats:avg_daily_trips",
110+
"substrait_view:conv_rate_plus_acc_substrait",
111+
"pandas_view:conv_rate_plus_acc",
112+
]
113+
100114
training_df = store.get_historical_features(
101-
entity_df=entity_df,
102-
features=[
103-
"driver_hourly_stats:conv_rate",
104-
"driver_hourly_stats:acc_rate",
105-
"driver_hourly_stats:avg_daily_trips",
106-
"substrait_view:conv_rate_plus_acc_substrait",
107-
"pandas_view:conv_rate_plus_acc",
108-
],
109-
).to_df()
115+
entity_df=entity_df, features=requested_features
116+
)
117+
118+
assert training_df.to_df()["conv_rate_plus_acc"].equals(
119+
training_df.to_df()["conv_rate_plus_acc_substrait"]
120+
)
121+
122+
assert training_df.to_arrow()["conv_rate_plus_acc"].equals(
123+
training_df.to_arrow()["conv_rate_plus_acc_substrait"]
124+
)
125+
126+
online_response = store.get_online_features(
127+
features=requested_features,
128+
entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}],
129+
)
110130

111-
assert training_df["conv_rate_plus_acc"].equals(
112-
training_df["conv_rate_plus_acc_substrait"]
131+
assert (
132+
online_response.to_dict()["conv_rate_plus_acc"]
133+
== online_response.to_dict()["conv_rate_plus_acc_substrait"]
113134
)

0 commit comments

Comments
 (0)