Skip to content

Commit df993df

Browse files
committed
add spark transformation
Signed-off-by: HaoXuAI <[email protected]>
1 parent 0083303 commit df993df

File tree

12 files changed

+260
-11
lines changed

12 files changed

+260
-11
lines changed

sdk/python/feast/batch_feature_view.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from datetime import datetime, timedelta
44
from types import FunctionType
5-
from typing import Dict, List, Optional, Tuple, Union
5+
from typing import Dict, List, Optional, Tuple, Union, Callable, Any
66

77
import dill
88

@@ -61,7 +61,7 @@ class BatchFeatureView(FeatureView):
6161
owner: str
6262
timestamp_field: str
6363
materialization_intervals: List[Tuple[datetime, datetime]]
64-
udf: Optional[FunctionType]
64+
udf: Optional[Callable[[Any], Any]]
6565
udf_string: Optional[str]
6666
feature_transformation: Transformation
6767

@@ -78,7 +78,7 @@ def __init__(
7878
description: str = "",
7979
owner: str = "",
8080
schema: Optional[List[Field]] = None,
81-
udf: Optional[FunctionType] = None,
81+
udf: Optional[Callable[[Any], Any]],
8282
udf_string: Optional[str] = "",
8383
feature_transformation: Optional[Transformation] = None,
8484
):

sdk/python/feast/infra/compute_engines/__init__.py

Whitespace-only changes.

sdk/python/feast/infra/compute_engines/spark/__init__.py

Whitespace-only changes.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Optional, Dict
2+
3+
from feast.repo_config import FeastConfigBaseModel
4+
from pydantic import StrictStr
5+
6+
7+
class SparkComputeConfig(FeastConfigBaseModel):
8+
type: StrictStr = "spark"
9+
""" Spark Compute type selector"""
10+
11+
spark_conf: Optional[Dict[str, str]] = None
12+
""" Configuration overlay for the spark session """
13+
# sparksession is not serializable and we dont want to pass it around as an argument
14+
15+
staging_location: Optional[StrictStr] = None
16+
""" Remote path for batch materialization jobs"""
17+
18+
region: Optional[StrictStr] = None
19+
""" AWS Region if applicable for s3-based staging locations"""
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Optional, Dict
2+
3+
from pyspark import SparkConf
4+
from pyspark.sql import SparkSession
5+
6+
7+
def get_or_create_new_spark_session(
8+
spark_config: Optional[Dict[str, str]] = None
9+
) -> SparkSession:
10+
spark_session = SparkSession.getActiveSession()
11+
if not spark_session:
12+
spark_builder = SparkSession.builder
13+
if spark_config:
14+
spark_builder = spark_builder.config(
15+
conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()])
16+
)
17+
18+
spark_session = spark_builder.getOrCreate()
19+
return spark_session

sdk/python/feast/stream_feature_view.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ def get_feature_transformation(self) -> Optional[Transformation]:
151151
if self.mode in (
152152
TransformationMode.PANDAS,
153153
TransformationMode.PYTHON,
154-
TransformationMode.SPARK,
155-
) or self.mode in ("pandas", "python", "spark"):
154+
TransformationMode.SPARK_SQL,
155+
TransformationMode.SPARK
156+
) or self.mode in ("pandas", "python", "spark_sql", "spark"):
156157
return Transformation(
157158
mode=self.mode, udf=self.udf, udf_string=self.udf_string or ""
158159
)

sdk/python/feast/transformation/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
description: str = "",
8282
owner: str = "",
8383
):
84-
self.mode = mode if isinstance(mode, str) else mode.value
84+
self.mode = mode
8585
self.udf = udf
8686
self.udf_string = udf_string
8787
self.name = name
@@ -99,7 +99,7 @@ def to_proto(self) -> Union[UserDefinedFunctionProto, SubstraitTransformationPro
9999
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "Transformation":
100100
return Transformation(mode=self.mode, udf=self.udf, udf_string=self.udf_string)
101101

102-
def transform(self, inputs: Any) -> Any:
102+
def transform(self, *inputs: Any) -> Any:
103103
raise NotImplementedError
104104

105105
def transform_arrow(self, *args, **kwargs) -> Any:

sdk/python/feast/transformation/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"pandas": "feast.transformation.pandas_transformation.PandasTransformation",
66
"substrait": "feast.transformation.substrait_transformation.SubstraitTransformation",
77
"sql": "feast.transformation.sql_transformation.SQLTransformation",
8+
"spark_sql": "feast.transformation.spark_transformation.SparkTransformation",
89
"spark": "feast.transformation.spark_transformation.SparkTransformation",
910
}
1011

sdk/python/feast/transformation/mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
class TransformationMode(Enum):
55
PYTHON = "python"
66
PANDAS = "pandas"
7+
SPARK_SQL = "spark_sql"
78
SPARK = "spark"
89
SQL = "sql"
910
SUBSTRAIT = "substrait"
Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,91 @@
1-
from typing import Any
1+
from typing import Any, Union, Dict, Optional, cast
2+
3+
import pandas as pd
4+
import pyspark.sql
25

36
from feast.transformation.base import Transformation
7+
from feast.transformation.mode import TransformationMode
8+
from feast.infra.compute_engines.spark.utils import get_or_create_new_spark_session
49

510

611
class SparkTransformation(Transformation):
7-
def transform(self, inputs: Any) -> Any:
8-
pass
912

10-
def infer_features(self, *args, **kwargs) -> Any:
13+
def __new__(cls,
14+
mode: Union[TransformationMode, str],
15+
udf: Any,
16+
udf_string: str,
17+
spark_config: Dict[str, Any] = {},
18+
name: Optional[str] = None,
19+
tags: Optional[Dict[str, str]] = None,
20+
description: str = "",
21+
owner: str = "",
22+
*args,
23+
**kwargs) -> "SparkTransformation":
24+
instance = super(SparkTransformation, cls).__new__(
25+
cls,
26+
mode=mode,
27+
spark_config=spark_config,
28+
udf=udf,
29+
udf_string=udf_string,
30+
name=name,
31+
tags=tags,
32+
description=description,
33+
owner=owner,
34+
)
35+
return cast(SparkTransformation, instance)
36+
37+
def __init__(self,
38+
mode: Union[TransformationMode, str],
39+
udf: Any,
40+
udf_string: str,
41+
spark_config: Dict[str, Any] = {},
42+
name: Optional[str] = None,
43+
tags: Optional[Dict[str, str]] = None,
44+
description: str = "",
45+
owner: str = "",
46+
*args,
47+
**kwargs):
48+
super().__init__(
49+
mode=mode,
50+
udf=udf,
51+
name=name,
52+
udf_string=udf_string,
53+
tags=tags,
54+
description=description,
55+
owner=owner,
56+
)
57+
self.spark_session = get_or_create_new_spark_session(spark_config)
58+
59+
def transform(self,
60+
*inputs: Union[str, pd.DataFrame],
61+
) -> pd.DataFrame:
62+
if self.mode == TransformationMode.SPARK_SQL:
63+
return self._transform_spark_sql(*inputs)
64+
else:
65+
return self._transform_spark_udf(*inputs)
66+
67+
@staticmethod
68+
def _create_temp_view_for_dataframe(df: pyspark.sql.DataFrame,
69+
name: str):
70+
df_temp_view = f"feast_transformation_temp_view_{name}"
71+
df.createOrReplaceTempView(df_temp_view)
72+
return df_temp_view
73+
74+
def _transform_spark_sql(self,
75+
*inputs: Union[pyspark.sql.DataFrame, str]
76+
) -> pd.DataFrame:
77+
inputs_str = [
78+
self._create_temp_view_for_dataframe(v, f"index_{i}")
79+
if isinstance(v, pyspark.sql.DataFrame) else v
80+
for i, v in enumerate(inputs)
81+
]
82+
return self.spark_session.sql(self.udf(*inputs_str))
83+
84+
def _transform_spark_udf(self,
85+
*inputs: Any) -> pd.DataFrame:
86+
return self.udf(*inputs)
87+
88+
def infer_features(self,
89+
*args,
90+
**kwargs) -> Any:
1191
pass

0 commit comments

Comments
 (0)