Skip to content

Commit 6bf7516

Browse files
authored
refactor: Rework get_online_features helper functions (#5060)
* rework _populate_response_from_feature_data Signed-off-by: Artem Petrov <[email protected]> * rework _convert_rows_to_protobuf Signed-off-by: Artem Petrov <[email protected]> * fix typing Signed-off-by: Artem Petrov <[email protected]> --------- Signed-off-by: Artem Petrov <[email protected]>
1 parent 0fffe21 commit 6bf7516

File tree

4 files changed

+79
-77
lines changed

4 files changed

+79
-77
lines changed

sdk/python/feast/feature_store.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2018,7 +2018,7 @@ def _retrieve_from_online_store_v2(
20182018
entity_key_dict[key] = []
20192019
entity_key_dict[key].append(python_value)
20202020

2021-
table_entity_values, idxs = utils._get_unique_entities_from_values(
2021+
table_entity_values, idxs, output_len = utils._get_unique_entities_from_values(
20222022
entity_key_dict,
20232023
)
20242024

@@ -2040,6 +2040,7 @@ def _retrieve_from_online_store_v2(
20402040
full_feature_names=False,
20412041
requested_features=features_to_request,
20422042
table=table,
2043+
output_len=output_len,
20432044
)
20442045

20452046
return OnlineResponse(online_features_response)

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_online_features(
187187

188188
for table, requested_features in grouped_refs:
189189
# Get the correct set of entity values with the correct join keys.
190-
table_entity_values, idxs = utils._get_unique_entities(
190+
table_entity_values, idxs, output_len = utils._get_unique_entities(
191191
table,
192192
join_key_values,
193193
entity_name_to_join_key_map,
@@ -215,6 +215,7 @@ def get_online_features(
215215
full_feature_names,
216216
requested_features,
217217
table,
218+
output_len,
218219
)
219220

220221
if requested_on_demand_feature_views:
@@ -274,7 +275,7 @@ async def get_online_features_async(
274275

275276
async def query_table(table, requested_features):
276277
# Get the correct set of entity values with the correct join keys.
277-
table_entity_values, idxs = utils._get_unique_entities(
278+
table_entity_values, idxs, output_len = utils._get_unique_entities(
278279
table,
279280
join_key_values,
280281
entity_name_to_join_key_map,
@@ -290,7 +291,7 @@ async def query_table(table, requested_features):
290291
requested_features=requested_features,
291292
)
292293

293-
return idxs, read_rows
294+
return idxs, read_rows, output_len
294295

295296
all_responses = await asyncio.gather(
296297
*[
@@ -299,7 +300,7 @@ async def query_table(table, requested_features):
299300
]
300301
)
301302

302-
for (idxs, read_rows), (table, requested_features) in zip(
303+
for (idxs, read_rows, output_len), (table, requested_features) in zip(
303304
all_responses, grouped_refs
304305
):
305306
feature_data = utils._convert_rows_to_protobuf(
@@ -314,6 +315,7 @@ async def query_table(table, requested_features):
314315
full_feature_names,
315316
requested_features,
316317
table,
318+
output_len,
317319
)
318320

319321
if requested_on_demand_feature_views:

sdk/python/feast/utils.py

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -490,16 +490,28 @@ def _group_feature_refs(
490490
return fvs_result, odfvs_result
491491

492492

493-
def apply_list_mapping(
494-
lst: Iterable[Any], mapping_indexes: Iterable[List[int]]
495-
) -> Iterable[Any]:
496-
output_len = sum(len(item) for item in mapping_indexes)
497-
output = [None] * output_len
498-
for elem, destinations in zip(lst, mapping_indexes):
493+
def construct_response_feature_vector(
494+
values_vector: Iterable[Any],
495+
statuses_vector: Iterable[Any],
496+
timestamp_vector: Iterable[Any],
497+
mapping_indexes: Iterable[List[int]],
498+
output_len: int,
499+
) -> GetOnlineFeaturesResponse.FeatureVector:
500+
values_output: Iterable[Any] = [None] * output_len
501+
statuses_output: Iterable[Any] = [None] * output_len
502+
timestamp_output: Iterable[Any] = [None] * output_len
503+
504+
for i, destinations in enumerate(mapping_indexes):
499505
for idx in destinations:
500-
output[idx] = elem
501-
502-
return output
506+
values_output[idx] = values_vector[i] # type: ignore[index]
507+
statuses_output[idx] = statuses_vector[i] # type: ignore[index]
508+
timestamp_output[idx] = timestamp_vector[i] # type: ignore[index]
509+
510+
return GetOnlineFeaturesResponse.FeatureVector(
511+
values=values_output,
512+
statuses=statuses_output,
513+
event_timestamps=timestamp_output,
514+
)
503515

504516

505517
def _augment_response_with_on_demand_transforms(
@@ -674,7 +686,7 @@ def _get_unique_entities(
674686
table: "FeatureView",
675687
join_key_values: Dict[str, List[ValueProto]],
676688
entity_name_to_join_key_map: Dict[str, str],
677-
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
689+
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
678690
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.
679691
680692
This method allows us to query the OnlineStore for data we need only once
@@ -712,7 +724,7 @@ def _get_unique_entities(
712724

713725
# If there are no rows, return empty tuples.
714726
if not rowise:
715-
return (), ()
727+
return (), (), 0
716728

717729
# Sort rowise so that rows with the same join key values are adjacent.
718730
rowise.sort(key=lambda row: tuple(getattr(x, x.WhichOneof("val")) for x in row[1]))
@@ -725,16 +737,16 @@ def _get_unique_entities(
725737

726738
# If no groups were formed (should not happen for valid input), return empty tuples.
727739
if not groups:
728-
return (), ()
740+
return (), (), 0
729741

730742
# Unpack the unique entities and their original row indexes.
731743
unique_entities, indexes = tuple(zip(*groups))
732-
return unique_entities, indexes
744+
return unique_entities, indexes, len(rowise)
733745

734746

735747
def _get_unique_entities_from_values(
736748
table_entity_values: Dict[str, List[ValueProto]],
737-
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...]]:
749+
) -> Tuple[Tuple[Dict[str, ValueProto], ...], Tuple[List[int], ...], int]:
738750
"""Return the set of unique composite Entities for a Feature View and the indexes at which they appear.
739751
740752
This method allows us to query the OnlineStore for data we need only once
@@ -758,7 +770,7 @@ def _get_unique_entities_from_values(
758770
]
759771
)
760772
)
761-
return unique_entities, indexes
773+
return unique_entities, indexes, len(rowise)
762774

763775

764776
def _drop_unneeded_columns(
@@ -835,6 +847,7 @@ def _populate_response_from_feature_data(
835847
full_feature_names: bool,
836848
requested_features: Iterable[str],
837849
table: "FeatureView",
850+
output_len: int,
838851
):
839852
"""Populate the GetOnlineFeaturesResponse with feature data.
840853
@@ -853,33 +866,22 @@ def _populate_response_from_feature_data(
853866
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
854867
data in `feature_data`.
855868
table: The FeatureView that `feature_data` was retrieved from.
869+
output_len: The number of result rows in `online_features_response`.
856870
"""
857871
# Add the feature names to the response.
872+
table_name = table.projection.name_to_use()
858873
requested_feature_refs = [
859-
(
860-
f"{table.projection.name_to_use()}__{feature_name}"
861-
if full_feature_names
862-
else feature_name
863-
)
874+
f"{table_name}__{feature_name}" if full_feature_names else feature_name
864875
for feature_name in requested_features
865876
]
866877
online_features_response.metadata.feature_names.val.extend(requested_feature_refs)
867878

868-
timestamps, statuses, values = zip(*feature_data)
869-
870-
# Populate the result with data fetched from the OnlineStore
871-
# which is guaranteed to be aligned with `requested_features`.
872-
for (
873-
feature_idx,
874-
(timestamp_vector, statuses_vector, values_vector),
875-
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
876-
online_features_response.results.append(
877-
GetOnlineFeaturesResponse.FeatureVector(
878-
values=apply_list_mapping(values_vector, indexes),
879-
statuses=apply_list_mapping(statuses_vector, indexes),
880-
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
881-
)
879+
# Process each feature vector in a single pass
880+
for timestamp_vector, statuses_vector, values_vector in feature_data:
881+
response_vector = construct_response_feature_vector(
882+
values_vector, statuses_vector, timestamp_vector, indexes, output_len
882883
)
884+
online_features_response.results.append(response_vector)
883885

884886

885887
def _populate_response_from_feature_data_v2(
@@ -891,6 +893,7 @@ def _populate_response_from_feature_data_v2(
891893
indexes: Iterable[List[int]],
892894
online_features_response: GetOnlineFeaturesResponse,
893895
requested_features: Iterable[str],
896+
output_len: int,
894897
):
895898
"""Populate the GetOnlineFeaturesResponse with feature data.
896899
@@ -908,6 +911,7 @@ def _populate_response_from_feature_data_v2(
908911
"customer_fv__daily_transactions").
909912
requested_features: The names of the features in `feature_data`. This should be ordered in the same way as the
910913
data in `feature_data`.
914+
output_len: The number of result rows in `online_features_response`.
911915
"""
912916
# Add the feature names to the response.
913917
requested_feature_refs = [(feature_name) for feature_name in requested_features]
@@ -917,17 +921,11 @@ def _populate_response_from_feature_data_v2(
917921

918922
# Populate the result with data fetched from the OnlineStore
919923
# which is guaranteed to be aligned with `requested_features`.
920-
for (
921-
feature_idx,
922-
(timestamp_vector, statuses_vector, values_vector),
923-
) in enumerate(zip(zip(*timestamps), zip(*statuses), zip(*values))):
924-
online_features_response.results.append(
925-
GetOnlineFeaturesResponse.FeatureVector(
926-
values=apply_list_mapping(values_vector, indexes),
927-
statuses=apply_list_mapping(statuses_vector, indexes),
928-
event_timestamps=apply_list_mapping(timestamp_vector, indexes),
929-
)
924+
for timestamp_vector, statuses_vector, values_vector in feature_data:
925+
response_vector = construct_response_feature_vector(
926+
values_vector, statuses_vector, timestamp_vector, indexes, output_len
930927
)
928+
online_features_response.results.append(response_vector)
931929

932930

933931
def _convert_entity_key_to_proto_to_dict(
@@ -1246,33 +1244,32 @@ def _convert_rows_to_protobuf(
12461244
requested_features: List[str],
12471245
read_rows: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]],
12481246
) -> List[Tuple[List[Timestamp], List["FieldStatus.ValueType"], List[ValueProto]]]:
1249-
# Each row is a set of features for a given entity key.
1250-
# We only need to convert the data to Protobuf once.
1247+
# Pre-calculate the length to avoid repeated calculations
1248+
n_rows = len(read_rows)
1249+
1250+
# Create single instances of commonly used values
12511251
null_value = ValueProto()
1252-
read_row_protos = []
1253-
for read_row in read_rows:
1254-
row_ts_proto = Timestamp()
1255-
row_ts, feature_data = read_row
1256-
# TODO (Ly): reuse whatever timestamp if row_ts is None?
1257-
if row_ts is not None:
1258-
row_ts_proto.FromDatetime(row_ts)
1259-
event_timestamps = [row_ts_proto] * len(requested_features)
1260-
if feature_data is None:
1261-
statuses = [FieldStatus.NOT_FOUND] * len(requested_features)
1262-
values = [null_value] * len(requested_features)
1263-
else:
1264-
statuses = []
1265-
values = []
1266-
for feature_name in requested_features:
1267-
# Make sure order of data is the same as requested_features.
1268-
if feature_name not in feature_data:
1269-
statuses.append(FieldStatus.NOT_FOUND)
1270-
values.append(null_value)
1271-
else:
1272-
statuses.append(FieldStatus.PRESENT)
1273-
values.append(feature_data[feature_name])
1274-
read_row_protos.append((event_timestamps, statuses, values))
1275-
return read_row_protos
1252+
null_status = FieldStatus.NOT_FOUND
1253+
null_timestamp = Timestamp()
1254+
present_status = FieldStatus.PRESENT
1255+
1256+
requested_features_vectors = []
1257+
for feature_name in requested_features:
1258+
ts_vector = [null_timestamp] * n_rows
1259+
status_vector = [null_status] * n_rows
1260+
value_vector = [null_value] * n_rows
1261+
for idx, read_row in enumerate(read_rows):
1262+
row_ts_proto = Timestamp()
1263+
row_ts, feature_data = read_row
1264+
# TODO (Ly): reuse whatever timestamp if row_ts is None?
1265+
if row_ts is not None:
1266+
row_ts_proto.FromDatetime(row_ts)
1267+
ts_vector[idx] = row_ts_proto
1268+
if (feature_data is not None) and (feature_name in feature_data):
1269+
status_vector[idx] = present_status
1270+
value_vector[idx] = feature_data[feature_name]
1271+
requested_features_vectors.append((ts_vector, status_vector, value_vector))
1272+
return requested_features_vectors
12761273

12771274

12781275
def has_all_tags(

sdk/python/tests/unit/test_unit_feature_store.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_get_unique_entities_success():
3838
projection=MockFeatureViewProjection(join_key_map={}),
3939
)
4040

41-
unique_entities, indexes = utils._get_unique_entities(
41+
unique_entities, indexes, output_len = utils._get_unique_entities(
4242
table=fv,
4343
join_key_values=entity_values,
4444
entity_name_to_join_key_map=entity_name_to_join_key_map,
@@ -51,6 +51,7 @@ def test_get_unique_entities_success():
5151

5252
assert unique_entities == expected_entities
5353
assert indexes == expected_indexes
54+
assert output_len == 3
5455

5556

5657
def test_get_unique_entities_missing_join_key_success():
@@ -74,7 +75,7 @@ def test_get_unique_entities_missing_join_key_success():
7475
projection=MockFeatureViewProjection(join_key_map={}),
7576
)
7677

77-
unique_entities, indexes = utils._get_unique_entities(
78+
unique_entities, indexes, output_len = utils._get_unique_entities(
7879
table=fv,
7980
join_key_values=entity_values,
8081
entity_name_to_join_key_map=entity_name_to_join_key_map,
@@ -87,6 +88,7 @@ def test_get_unique_entities_missing_join_key_success():
8788

8889
assert unique_entities == expected_entities
8990
assert indexes == expected_indexes
91+
assert output_len == 3
9092
# We're not say anything about the entity_1 missing from the unique_entities list
9193
assert "entity_1" not in [entity.keys() for entity in unique_entities]
9294

0 commit comments

Comments
 (0)