Skip to content

Commit 95fe8c2

Browse files
chore: Updated tests to confirm behavior of inconsistent offline/online return values (#4615)
1 parent 9e13636 commit 95fe8c2

File tree

1 file changed

+152
-1
lines changed

1 file changed

+152
-1
lines changed

sdk/python/tests/unit/test_on_demand_python_transformation.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ def setUp(self):
425425
Field(name="avg_daily_trip_rank_names", dtype=Array(String)),
426426
],
427427
)
428+
input_request = RequestSource(
429+
name="vals_to_add",
430+
schema=[
431+
Field(name="val_to_add", dtype=Int64),
432+
Field(name="val_to_add_2", dtype=Int64),
433+
],
434+
)
428435

429436
@on_demand_feature_view(
430437
sources=[request_source, driver_stats_fv],
@@ -476,8 +483,37 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
476483
output["achieved_ranks"] = ranks
477484
return output
478485

486+
@on_demand_feature_view(
487+
sources=[
488+
driver_stats_fv,
489+
input_request,
490+
],
491+
schema=[
492+
Field(name="conv_rate_plus_val1", dtype=Float64),
493+
Field(name="conv_rate_plus_val2", dtype=Float64),
494+
],
495+
mode="pandas",
496+
)
497+
def pandas_view(features_df: pd.DataFrame) -> pd.DataFrame:
498+
df = pd.DataFrame()
499+
df["conv_rate_plus_val1"] = (
500+
features_df["conv_rate"] + features_df["val_to_add"]
501+
)
502+
df["conv_rate_plus_val2"] = (
503+
features_df["conv_rate"] + features_df["val_to_add_2"]
504+
)
505+
return df
506+
479507
self.store.apply(
480-
[driver, driver_stats_source, driver_stats_fv, python_view]
508+
[
509+
driver,
510+
driver_stats_source,
511+
driver_stats_fv,
512+
python_view,
513+
pandas_view,
514+
input_request,
515+
request_source,
516+
]
481517
)
482518
fv_applied = self.store.get_feature_view("driver_hourly_stats")
483519
assert fv_applied.entities == [driver.name]
@@ -488,6 +524,121 @@ def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
488524
feature_view_name="driver_hourly_stats", df=driver_df
489525
)
490526

527+
batch_sample = pd.DataFrame(driver_entities, columns=["driver_id"])
528+
batch_sample["val_to_add"] = 0
529+
batch_sample["val_to_add_2"] = 1
530+
batch_sample["event_timestamp"] = start_date
531+
batch_sample["created"] = start_date
532+
fv_only_cols = ["driver_id", "event_timestamp", "created"]
533+
534+
resp_base_fv = self.store.get_historical_features(
535+
entity_df=batch_sample[fv_only_cols],
536+
features=[
537+
"driver_hourly_stats:conv_rate",
538+
"driver_hourly_stats:acc_rate",
539+
"driver_hourly_stats:avg_daily_trips",
540+
],
541+
).to_df()
542+
assert resp_base_fv is not None
543+
assert sorted(resp_base_fv.columns) == [
544+
"acc_rate",
545+
"avg_daily_trips",
546+
"conv_rate",
547+
"created__",
548+
"driver_id",
549+
"event_timestamp",
550+
]
551+
resp = self.store.get_historical_features(
552+
entity_df=batch_sample,
553+
features=[
554+
"driver_hourly_stats:conv_rate",
555+
"driver_hourly_stats:acc_rate",
556+
"driver_hourly_stats:avg_daily_trips",
557+
"pandas_view:conv_rate_plus_val1",
558+
"pandas_view:conv_rate_plus_val2",
559+
],
560+
).to_df()
561+
assert resp is not None
562+
assert resp["conv_rate_plus_val1"].isnull().sum() == 0
563+
564+
# Now testing feature retrieval for driver ids not in the dataset
565+
missing_batch_sample = pd.DataFrame([1234567890], columns=["driver_id"])
566+
missing_batch_sample["val_to_add"] = 0
567+
missing_batch_sample["val_to_add_2"] = 1
568+
missing_batch_sample["event_timestamp"] = start_date
569+
missing_batch_sample["created"] = start_date
570+
resp_offline = self.store.get_historical_features(
571+
entity_df=missing_batch_sample,
572+
features=[
573+
"driver_hourly_stats:conv_rate",
574+
"driver_hourly_stats:acc_rate",
575+
"driver_hourly_stats:avg_daily_trips",
576+
"pandas_view:conv_rate_plus_val1",
577+
"pandas_view:conv_rate_plus_val2",
578+
],
579+
).to_df()
580+
assert resp_offline is not None
581+
assert resp_offline["conv_rate_plus_val1"].isnull().sum() == 1
582+
assert sorted(resp_offline.columns) == [
583+
"acc_rate",
584+
"avg_daily_trips",
585+
"conv_rate",
586+
"conv_rate_plus_val1",
587+
"conv_rate_plus_val2",
588+
"created__",
589+
"driver_id",
590+
"event_timestamp",
591+
"val_to_add",
592+
"val_to_add_2",
593+
]
594+
with pytest.raises(TypeError):
595+
_ = self.store.get_online_features(
596+
entity_rows=[
597+
{"driver_id": 1234567890, "val_to_add": 0, "val_to_add_2": 1}
598+
],
599+
features=[
600+
"driver_hourly_stats:conv_rate",
601+
"driver_hourly_stats:acc_rate",
602+
"driver_hourly_stats:avg_daily_trips",
603+
"pandas_view:conv_rate_plus_val1",
604+
"pandas_view:conv_rate_plus_val2",
605+
],
606+
)
607+
resp_online = self.store.get_online_features(
608+
entity_rows=[{"driver_id": 1001, "val_to_add": 0, "val_to_add_2": 1}],
609+
features=[
610+
"driver_hourly_stats:conv_rate",
611+
"driver_hourly_stats:acc_rate",
612+
"driver_hourly_stats:avg_daily_trips",
613+
"pandas_view:conv_rate_plus_val1",
614+
"pandas_view:conv_rate_plus_val2",
615+
],
616+
).to_df()
617+
assert resp_online is not None
618+
assert sorted(resp_online.columns) == [
619+
"acc_rate",
620+
"avg_daily_trips",
621+
"conv_rate",
622+
"conv_rate_plus_val1",
623+
"conv_rate_plus_val2",
624+
"driver_id",
625+
# It does not have the items below
626+
# "created__",
627+
# "event_timestamp",
628+
# "val_to_add",
629+
# "val_to_add_2",
630+
]
631+
# Note online and offline columns will not match because:
632+
# you want to be space efficient online when considering the impact of network latency so you want to send
633+
# and receive the minimally required set of data, which means after transformation you only need to send the
634+
# output in the response.
635+
# Offline, you will probably prioritize reproducibility and being able to iterate, which means you will want
636+
# the underlying inputs into your transformation, so the extra data is tolerable.
637+
assert sorted(resp_online.columns) != sorted(resp_offline.columns)
638+
639+
def test_setup(self):
640+
pass
641+
491642
def test_python_transformation_returning_all_data_types(self):
492643
entity_rows = [
493644
{

0 commit comments

Comments
 (0)