88from logging .config import dictConfig
99from typing import Any , Dict , List , NamedTuple , Optional
1010
11- import numpy as np
12- import pandas as pd
1311from pyspark .sql import DataFrame , SparkSession , Window
14- from pyspark .sql .functions import (
15- col ,
16- expr ,
17- monotonically_increasing_id ,
18- row_number ,
19- struct ,
20- )
21- from pyspark .sql .pandas .functions import PandasUDFType , pandas_udf
22- from pyspark .sql .types import BooleanType , LongType
12+ from pyspark .sql import functions as func
13+ from pyspark .sql .functions import col , monotonically_increasing_id , row_number
14+ from pyspark .sql .types import LongType
2315
2416EVENT_TIMESTAMP_ALIAS = "event_timestamp"
17+ ENTITY_EVENT_TIMESTAMP_ALIAS = "event_timestamp_entity"
2518CREATED_TIMESTAMP_ALIAS = "created_timestamp"
2619
2720
@@ -283,14 +276,13 @@ class FeatureTable(NamedTuple):
283276 entities (List[Field]): Primary keys for the features.
284277 features (List[Field]): Feature list.
285278 max_age (int): In seconds. determines the lower bound of the timestamp of the retrieved feature.
286- If not specified, this would be unbounded
287279 project (str): Feast project name.
288280 """
289281
290282 name : str
291283 entities : List [Field ]
292284 features : List [Field ]
293- max_age : Optional [ int ] = None
285+ max_age : int
294286 project : Optional [str ] = None
295287
296288 @property
@@ -427,14 +419,10 @@ def as_of_join(
427419
428420 join_cond = (
429421 entity_with_id [entity_event_timestamp_column ]
430- >= aliased_feature_table_df [feature_event_timestamp_column_with_prefix ]
422+ == aliased_feature_table_df [
423+ f"{ feature_table .name } __{ ENTITY_EVENT_TIMESTAMP_ALIAS } "
424+ ]
431425 )
432- if feature_table .max_age :
433- join_cond = join_cond & (
434- aliased_feature_table_df [feature_event_timestamp_column_with_prefix ]
435- >= entity_with_id [entity_event_timestamp_column ]
436- - expr (f"INTERVAL { feature_table .max_age } seconds" )
437- )
438426
439427 for key in feature_table .entity_names :
440428 join_cond = join_cond & (
@@ -557,50 +545,19 @@ class SchemaError(Exception):
557545 pass
558546
559547
560- def _make_time_filter_pandas_udf (
561- spark : SparkSession ,
562- entity_pandas : pd .DataFrame ,
563- feature_table : FeatureTable ,
564- entity_event_timestamp_column : str ,
565- ):
566- entity_br = spark .sparkContext .broadcast (
567- entity_pandas .rename (
568- columns = {entity_event_timestamp_column : EVENT_TIMESTAMP_ALIAS }
569- )
570- )
571- entity_names = feature_table .entity_names
572- max_age = feature_table .max_age
573-
574- @pandas_udf (BooleanType (), PandasUDFType .SCALAR )
575- def within_time_boundaries (features : pd .DataFrame ) -> pd .Series :
576- features ["_row_id" ] = np .arange (len (features ))
577- merged = features .merge (
578- entity_br .value ,
579- how = "left" ,
580- on = entity_names ,
581- suffixes = ("_feature" , "_entity" ),
582- )
583- merged ["distance" ] = (
584- merged [f"{ EVENT_TIMESTAMP_ALIAS } _entity" ]
585- - merged [f"{ EVENT_TIMESTAMP_ALIAS } _feature" ]
586- )
587- merged ["within" ] = merged ["distance" ].dt .total_seconds ().between (0 , max_age )
588-
589- return merged .groupby (["_row_id" ]).max ()["within" ]
590-
591- return within_time_boundaries
592-
593-
594- def _filter_feature_table_by_time_range (
595- spark : SparkSession ,
548+ def filter_feature_table_by_time_range (
596549 feature_table_df : DataFrame ,
597550 feature_table : FeatureTable ,
598551 feature_event_timestamp_column : str ,
599- entity_pandas : pd . DataFrame ,
552+ entity_df : DataFrame ,
600553 entity_event_timestamp_column : str ,
601- ):
602- entity_max_timestamp = entity_pandas [entity_event_timestamp_column ].max ()
603- entity_min_timestamp = entity_pandas [entity_event_timestamp_column ].min ()
554+ ) -> DataFrame :
555+ entity_max_timestamp = entity_df .agg (
556+ {entity_event_timestamp_column : "max" }
557+ ).collect ()[0 ][0 ]
558+ entity_min_timestamp = entity_df .agg (
559+ {entity_event_timestamp_column : "min" }
560+ ).collect ()[0 ][0 ]
604561
605562 feature_table_timestamp_filter = (
606563 col (feature_event_timestamp_column ).between (
@@ -613,17 +570,32 @@ def _filter_feature_table_by_time_range(
613570
614571 time_range_filtered_df = feature_table_df .filter (feature_table_timestamp_filter )
615572
616- if feature_table .max_age :
617- within_time_boundaries_udf = _make_time_filter_pandas_udf (
618- spark , entity_pandas , feature_table , entity_event_timestamp_column
573+ time_range_filtered_df = (
574+ time_range_filtered_df .join (
575+ entity_df .withColumnRenamed (
576+ entity_event_timestamp_column , ENTITY_EVENT_TIMESTAMP_ALIAS
577+ ),
578+ on = feature_table .entity_names ,
579+ how = "inner" ,
619580 )
620-
621- time_range_filtered_df = time_range_filtered_df .withColumn (
622- "within_time_boundaries" ,
623- within_time_boundaries_udf (
624- struct (feature_table .entity_names + [feature_event_timestamp_column ])
581+ .withColumn (
582+ "distance" ,
583+ col (ENTITY_EVENT_TIMESTAMP_ALIAS ).cast ("long" )
584+ - col (EVENT_TIMESTAMP_ALIAS ).cast ("long" ),
585+ )
586+ .where ((col ("distance" ) >= 0 ) & (col ("distance" ) <= feature_table .max_age ))
587+ .withColumn (
588+ "min_distance" ,
589+ func .min ("distance" ).over (
590+ Window .partitionBy (
591+ feature_table .entity_names + [ENTITY_EVENT_TIMESTAMP_ALIAS ]
592+ )
625593 ),
626- ).filter ("within_time_boundaries = true" )
594+ )
595+ .where (col ("distance" ) == col ("min_distance" ))
596+ .select (time_range_filtered_df .columns + [ENTITY_EVENT_TIMESTAMP_ALIAS ])
597+ .localCheckpoint ()
598+ )
627599
628600 return time_range_filtered_df
629601
@@ -807,15 +779,14 @@ def retrieve_historical_features(
807779 f"{ expected_entity .name } ({ expected_entity .spark_type } ) is not present in the entity dataframe."
808780 )
809781
810- entity_pandas = entity_df .toPandas ()
782+ entity_df .cache ()
811783
812784 feature_table_dfs = [
813- _filter_feature_table_by_time_range (
814- spark ,
785+ filter_feature_table_by_time_range (
815786 feature_table_df ,
816787 feature_table ,
817788 feature_table_source .event_timestamp_column ,
818- entity_pandas ,
789+ entity_df ,
819790 entity_source .event_timestamp_column ,
820791 )
821792 for feature_table_df , feature_table , feature_table_source in zip (
@@ -873,11 +844,15 @@ def _get_args():
873844
874845
875846def _feature_table_from_dict (dct : Dict [str , Any ]) -> FeatureTable :
847+ assert (
848+ dct .get ("max_age" ) is not None and dct ["max_age" ] > 0
849+ ), "FeatureTable.maxAge must not be None and should be a positive number"
850+
876851 return FeatureTable (
877852 name = dct ["name" ],
878853 entities = [Field (** e ) for e in dct ["entities" ]],
879854 features = [Field (** f ) for f in dct ["features" ]],
880- max_age = dct . get ( "max_age" ) ,
855+ max_age = dct [ "max_age" ] ,
881856 project = dct .get ("project" ),
882857 )
883858
0 commit comments