Skip to content

Commit 96141ae

Browse files
authored
Speed-up join in historical retrieval by replacing pandas with native spark (#89)
1 parent 9f2f084 commit 96141ae

3 files changed

Lines changed: 91 additions & 164 deletions

File tree

python/feast_spark/pyspark/historical_feature_retrieval_job.py

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,13 @@
88
from logging.config import dictConfig
99
from typing import Any, Dict, List, NamedTuple, Optional
1010

11-
import numpy as np
12-
import pandas as pd
1311
from pyspark.sql import DataFrame, SparkSession, Window
14-
from pyspark.sql.functions import (
15-
col,
16-
expr,
17-
monotonically_increasing_id,
18-
row_number,
19-
struct,
20-
)
21-
from pyspark.sql.pandas.functions import PandasUDFType, pandas_udf
22-
from pyspark.sql.types import BooleanType, LongType
12+
from pyspark.sql import functions as func
13+
from pyspark.sql.functions import col, monotonically_increasing_id, row_number
14+
from pyspark.sql.types import LongType
2315

2416
EVENT_TIMESTAMP_ALIAS = "event_timestamp"
17+
ENTITY_EVENT_TIMESTAMP_ALIAS = "event_timestamp_entity"
2518
CREATED_TIMESTAMP_ALIAS = "created_timestamp"
2619

2720

@@ -283,14 +276,13 @@ class FeatureTable(NamedTuple):
283276
entities (List[Field]): Primary keys for the features.
284277
features (List[Field]): Feature list.
285278
max_age (int): In seconds. determines the lower bound of the timestamp of the retrieved feature.
286-
If not specified, this would be unbounded
287279
project (str): Feast project name.
288280
"""
289281

290282
name: str
291283
entities: List[Field]
292284
features: List[Field]
293-
max_age: Optional[int] = None
285+
max_age: int
294286
project: Optional[str] = None
295287

296288
@property
@@ -427,14 +419,10 @@ def as_of_join(
427419

428420
join_cond = (
429421
entity_with_id[entity_event_timestamp_column]
430-
>= aliased_feature_table_df[feature_event_timestamp_column_with_prefix]
422+
== aliased_feature_table_df[
423+
f"{feature_table.name}__{ENTITY_EVENT_TIMESTAMP_ALIAS}"
424+
]
431425
)
432-
if feature_table.max_age:
433-
join_cond = join_cond & (
434-
aliased_feature_table_df[feature_event_timestamp_column_with_prefix]
435-
>= entity_with_id[entity_event_timestamp_column]
436-
- expr(f"INTERVAL {feature_table.max_age} seconds")
437-
)
438426

439427
for key in feature_table.entity_names:
440428
join_cond = join_cond & (
@@ -557,50 +545,19 @@ class SchemaError(Exception):
557545
pass
558546

559547

560-
def _make_time_filter_pandas_udf(
561-
spark: SparkSession,
562-
entity_pandas: pd.DataFrame,
563-
feature_table: FeatureTable,
564-
entity_event_timestamp_column: str,
565-
):
566-
entity_br = spark.sparkContext.broadcast(
567-
entity_pandas.rename(
568-
columns={entity_event_timestamp_column: EVENT_TIMESTAMP_ALIAS}
569-
)
570-
)
571-
entity_names = feature_table.entity_names
572-
max_age = feature_table.max_age
573-
574-
@pandas_udf(BooleanType(), PandasUDFType.SCALAR)
575-
def within_time_boundaries(features: pd.DataFrame) -> pd.Series:
576-
features["_row_id"] = np.arange(len(features))
577-
merged = features.merge(
578-
entity_br.value,
579-
how="left",
580-
on=entity_names,
581-
suffixes=("_feature", "_entity"),
582-
)
583-
merged["distance"] = (
584-
merged[f"{EVENT_TIMESTAMP_ALIAS}_entity"]
585-
- merged[f"{EVENT_TIMESTAMP_ALIAS}_feature"]
586-
)
587-
merged["within"] = merged["distance"].dt.total_seconds().between(0, max_age)
588-
589-
return merged.groupby(["_row_id"]).max()["within"]
590-
591-
return within_time_boundaries
592-
593-
594-
def _filter_feature_table_by_time_range(
595-
spark: SparkSession,
548+
def filter_feature_table_by_time_range(
596549
feature_table_df: DataFrame,
597550
feature_table: FeatureTable,
598551
feature_event_timestamp_column: str,
599-
entity_pandas: pd.DataFrame,
552+
entity_df: DataFrame,
600553
entity_event_timestamp_column: str,
601-
):
602-
entity_max_timestamp = entity_pandas[entity_event_timestamp_column].max()
603-
entity_min_timestamp = entity_pandas[entity_event_timestamp_column].min()
554+
) -> DataFrame:
555+
entity_max_timestamp = entity_df.agg(
556+
{entity_event_timestamp_column: "max"}
557+
).collect()[0][0]
558+
entity_min_timestamp = entity_df.agg(
559+
{entity_event_timestamp_column: "min"}
560+
).collect()[0][0]
604561

605562
feature_table_timestamp_filter = (
606563
col(feature_event_timestamp_column).between(
@@ -613,17 +570,32 @@ def _filter_feature_table_by_time_range(
613570

614571
time_range_filtered_df = feature_table_df.filter(feature_table_timestamp_filter)
615572

616-
if feature_table.max_age:
617-
within_time_boundaries_udf = _make_time_filter_pandas_udf(
618-
spark, entity_pandas, feature_table, entity_event_timestamp_column
573+
time_range_filtered_df = (
574+
time_range_filtered_df.join(
575+
entity_df.withColumnRenamed(
576+
entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS
577+
),
578+
on=feature_table.entity_names,
579+
how="inner",
619580
)
620-
621-
time_range_filtered_df = time_range_filtered_df.withColumn(
622-
"within_time_boundaries",
623-
within_time_boundaries_udf(
624-
struct(feature_table.entity_names + [feature_event_timestamp_column])
581+
.withColumn(
582+
"distance",
583+
col(ENTITY_EVENT_TIMESTAMP_ALIAS).cast("long")
584+
- col(EVENT_TIMESTAMP_ALIAS).cast("long"),
585+
)
586+
.where((col("distance") >= 0) & (col("distance") <= feature_table.max_age))
587+
.withColumn(
588+
"min_distance",
589+
func.min("distance").over(
590+
Window.partitionBy(
591+
feature_table.entity_names + [ENTITY_EVENT_TIMESTAMP_ALIAS]
592+
)
625593
),
626-
).filter("within_time_boundaries = true")
594+
)
595+
.where(col("distance") == col("min_distance"))
596+
.select(time_range_filtered_df.columns + [ENTITY_EVENT_TIMESTAMP_ALIAS])
597+
.localCheckpoint()
598+
)
627599

628600
return time_range_filtered_df
629601

@@ -807,15 +779,14 @@ def retrieve_historical_features(
807779
f"{expected_entity.name} ({expected_entity.spark_type}) is not present in the entity dataframe."
808780
)
809781

810-
entity_pandas = entity_df.toPandas()
782+
entity_df.cache()
811783

812784
feature_table_dfs = [
813-
_filter_feature_table_by_time_range(
814-
spark,
785+
filter_feature_table_by_time_range(
815786
feature_table_df,
816787
feature_table,
817788
feature_table_source.event_timestamp_column,
818-
entity_pandas,
789+
entity_df,
819790
entity_source.event_timestamp_column,
820791
)
821792
for feature_table_df, feature_table, feature_table_source in zip(
@@ -873,11 +844,15 @@ def _get_args():
873844

874845

875846
def _feature_table_from_dict(dct: Dict[str, Any]) -> FeatureTable:
847+
assert (
848+
dct.get("max_age") is not None and dct["max_age"] > 0
849+
), "FeatureTable.maxAge must not be None and should be a positive number"
850+
876851
return FeatureTable(
877852
name=dct["name"],
878853
entities=[Field(**e) for e in dct["entities"]],
879854
features=[Field(**f) for f in dct["features"]],
880-
max_age=dct.get("max_age"),
855+
max_age=dct["max_age"],
881856
project=dct.get("project"),
882857
)
883858

0 commit comments

Comments
 (0)