Skip to content

Commit 1761b10

Browse files
author
Tsotne Tabidze
authored
Add RedshiftDataSource (#1669)
* Add RedshiftDataSource Signed-off-by: Tsotne Tabidze <[email protected]> * Call parent __init__ first Signed-off-by: Tsotne Tabidze <[email protected]>
1 parent 3cb303f commit 1761b10

File tree

15 files changed

+414
-33
lines changed

15 files changed

+414
-33
lines changed

protos/feast/core/DataSource.proto

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ message DataSource {
3333
BATCH_BIGQUERY = 2;
3434
STREAM_KAFKA = 3;
3535
STREAM_KINESIS = 4;
36+
BATCH_REDSHIFT = 5;
3637
}
3738
SourceType type = 1;
3839

@@ -100,11 +101,22 @@ message DataSource {
100101
StreamFormat record_format = 3;
101102
}
102103

104+
// Defines options for DataSource that sources features from a Redshift Query
105+
message RedshiftOptions {
106+
// Redshift table name
107+
string table = 1;
108+
109+
// SQL query that returns a table containing feature data. Must contain an event_timestamp column, and respective
110+
// entity columns
111+
string query = 2;
112+
}
113+
103114
// DataSource options.
104115
oneof options {
105116
FileOptions file_options = 11;
106117
BigQueryOptions bigquery_options = 12;
107118
KafkaOptions kafka_options = 13;
108119
KinesisOptions kinesis_options = 14;
120+
RedshiftOptions redshift_options = 15;
109121
}
110122
}

sdk/python/feast/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
FileSource,
99
KafkaSource,
1010
KinesisSource,
11+
RedshiftSource,
1112
SourceType,
1213
)
1314
from .entity import Entity
@@ -37,6 +38,7 @@
3738
"FileSource",
3839
"KafkaSource",
3940
"KinesisSource",
41+
"RedshiftSource",
4042
"Feature",
4143
"FeatureStore",
4244
"FeatureTable",

sdk/python/feast/data_source.py

Lines changed: 260 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
from typing import Callable, Dict, Iterable, Optional, Tuple
1818

1919
from pyarrow.parquet import ParquetFile
20+
from tenacity import retry, retry_unless_exception_type, wait_exponential
2021

2122
from feast import type_map
2223
from feast.data_format import FileFormat, StreamFormat
23-
from feast.errors import DataSourceNotFoundException
24+
from feast.errors import (
25+
DataSourceNotFoundException,
26+
RedshiftCredentialsError,
27+
RedshiftQueryError,
28+
)
2429
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
30+
from feast.repo_config import RepoConfig
2531
from feast.value_type import ValueType
2632

2733

@@ -477,6 +483,15 @@ def from_proto(data_source):
477483
date_partition_column=data_source.date_partition_column,
478484
query=data_source.bigquery_options.query,
479485
)
486+
elif data_source.redshift_options.table or data_source.redshift_options.query:
487+
data_source_obj = RedshiftSource(
488+
field_mapping=data_source.field_mapping,
489+
table=data_source.redshift_options.table,
490+
event_timestamp_column=data_source.event_timestamp_column,
491+
created_timestamp_column=data_source.created_timestamp_column,
492+
date_partition_column=data_source.date_partition_column,
493+
query=data_source.redshift_options.query,
494+
)
480495
elif (
481496
data_source.kafka_options.bootstrap_servers
482497
and data_source.kafka_options.topic
@@ -520,12 +535,27 @@ def to_proto(self) -> DataSourceProto:
520535
"""
521536
raise NotImplementedError
522537

523-
def validate(self):
538+
def validate(self, config: RepoConfig):
524539
"""
525540
Validates the underlying data source.
526541
"""
527542
raise NotImplementedError
528543

544+
@staticmethod
545+
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
546+
"""
547+
Get the callable method that returns Feast type given the raw column type
548+
"""
549+
raise NotImplementedError
550+
551+
def get_table_column_names_and_types(
552+
self, config: RepoConfig
553+
) -> Iterable[Tuple[str, str]]:
554+
"""
555+
Get the list of column names and raw column types
556+
"""
557+
raise NotImplementedError
558+
529559

530560
class FileSource(DataSource):
531561
def __init__(
@@ -622,15 +652,17 @@ def to_proto(self) -> DataSourceProto:
622652

623653
return data_source_proto
624654

625-
def validate(self):
655+
def validate(self, config: RepoConfig):
626656
# TODO: validate a FileSource
627657
pass
628658

629659
@staticmethod
630660
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
631661
return type_map.pa_to_feast_value_type
632662

633-
def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
663+
def get_table_column_names_and_types(
664+
self, config: RepoConfig
665+
) -> Iterable[Tuple[str, str]]:
634666
schema = ParquetFile(self.path).schema_arrow
635667
return zip(schema.names, map(str, schema.types))
636668

@@ -703,7 +735,7 @@ def to_proto(self) -> DataSourceProto:
703735

704736
return data_source_proto
705737

706-
def validate(self):
738+
def validate(self, config: RepoConfig):
707739
if not self.query:
708740
from google.api_core.exceptions import NotFound
709741
from google.cloud import bigquery
@@ -725,7 +757,9 @@ def get_table_query_string(self) -> str:
725757
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
726758
return type_map.bq_to_feast_value_type
727759

728-
def get_table_column_names_and_types(self) -> Iterable[Tuple[str, str]]:
760+
def get_table_column_names_and_types(
761+
self, config: RepoConfig
762+
) -> Iterable[Tuple[str, str]]:
729763
from google.cloud import bigquery
730764

731765
client = bigquery.Client()
@@ -875,3 +909,223 @@ def to_proto(self) -> DataSourceProto:
875909
data_source_proto.date_partition_column = self.date_partition_column
876910

877911
return data_source_proto
912+
913+
914+
class RedshiftOptions:
915+
"""
916+
DataSource Redshift options used to source features from Redshift query
917+
"""
918+
919+
def __init__(self, table: Optional[str], query: Optional[str]):
920+
self._table = table
921+
self._query = query
922+
923+
@property
924+
def query(self):
925+
"""
926+
Returns the Redshift SQL query referenced by this source
927+
"""
928+
return self._query
929+
930+
@query.setter
931+
def query(self, query):
932+
"""
933+
Sets the Redshift SQL query referenced by this source
934+
"""
935+
self._query = query
936+
937+
@property
938+
def table(self):
939+
"""
940+
Returns the table name of this Redshift table
941+
"""
942+
return self._table
943+
944+
@table.setter
945+
def table(self, table_name):
946+
"""
947+
Sets the table ref of this Redshift table
948+
"""
949+
self._table = table_name
950+
951+
@classmethod
952+
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
953+
"""
954+
Creates a RedshiftOptions from a protobuf representation of a Redshift option
955+
956+
Args:
957+
redshift_options_proto: A protobuf representation of a DataSource
958+
959+
Returns:
960+
Returns a RedshiftOptions object based on the redshift_options protobuf
961+
"""
962+
963+
redshift_options = cls(
964+
table=redshift_options_proto.table, query=redshift_options_proto.query,
965+
)
966+
967+
return redshift_options
968+
969+
def to_proto(self) -> DataSourceProto.RedshiftOptions:
970+
"""
971+
Converts an RedshiftOptionsProto object to its protobuf representation.
972+
973+
Returns:
974+
RedshiftOptionsProto protobuf
975+
"""
976+
977+
redshift_options_proto = DataSourceProto.RedshiftOptions(
978+
table=self.table, query=self.query,
979+
)
980+
981+
return redshift_options_proto
982+
983+
984+
class RedshiftSource(DataSource):
985+
def __init__(
986+
self,
987+
event_timestamp_column: Optional[str] = "",
988+
table: Optional[str] = None,
989+
created_timestamp_column: Optional[str] = "",
990+
field_mapping: Optional[Dict[str, str]] = None,
991+
date_partition_column: Optional[str] = "",
992+
query: Optional[str] = None,
993+
):
994+
super().__init__(
995+
event_timestamp_column,
996+
created_timestamp_column,
997+
field_mapping,
998+
date_partition_column,
999+
)
1000+
1001+
self._redshift_options = RedshiftOptions(table=table, query=query)
1002+
1003+
def __eq__(self, other):
1004+
if not isinstance(other, RedshiftSource):
1005+
raise TypeError(
1006+
"Comparisons should only involve RedshiftSource class objects."
1007+
)
1008+
1009+
return (
1010+
self.redshift_options.table == other.redshift_options.table
1011+
and self.redshift_options.query == other.redshift_options.query
1012+
and self.event_timestamp_column == other.event_timestamp_column
1013+
and self.created_timestamp_column == other.created_timestamp_column
1014+
and self.field_mapping == other.field_mapping
1015+
)
1016+
1017+
@property
1018+
def table(self):
1019+
return self._redshift_options.table
1020+
1021+
@property
1022+
def query(self):
1023+
return self._redshift_options.query
1024+
1025+
@property
1026+
def redshift_options(self):
1027+
"""
1028+
Returns the Redshift options of this data source
1029+
"""
1030+
return self._redshift_options
1031+
1032+
@redshift_options.setter
1033+
def redshift_options(self, _redshift_options):
1034+
"""
1035+
Sets the Redshift options of this data source
1036+
"""
1037+
self._redshift_options = _redshift_options
1038+
1039+
def to_proto(self) -> DataSourceProto:
1040+
data_source_proto = DataSourceProto(
1041+
type=DataSourceProto.BATCH_REDSHIFT,
1042+
field_mapping=self.field_mapping,
1043+
redshift_options=self.redshift_options.to_proto(),
1044+
)
1045+
1046+
data_source_proto.event_timestamp_column = self.event_timestamp_column
1047+
data_source_proto.created_timestamp_column = self.created_timestamp_column
1048+
data_source_proto.date_partition_column = self.date_partition_column
1049+
1050+
return data_source_proto
1051+
1052+
def validate(self, config: RepoConfig):
1053+
# As long as the query gets successfully executed, or the table exists,
1054+
# the data source is validated. We don't need the results though.
1055+
# TODO: uncomment this
1056+
# self.get_table_column_names_and_types(config)
1057+
print("Validate", self.get_table_column_names_and_types(config))
1058+
1059+
def get_table_query_string(self) -> str:
1060+
"""Returns a string that can directly be used to reference this table in SQL"""
1061+
if self.table:
1062+
return f"`{self.table}`"
1063+
else:
1064+
return f"({self.query})"
1065+
1066+
@staticmethod
1067+
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
1068+
return type_map.redshift_to_feast_value_type
1069+
1070+
def get_table_column_names_and_types(
1071+
self, config: RepoConfig
1072+
) -> Iterable[Tuple[str, str]]:
1073+
import boto3
1074+
from botocore.config import Config
1075+
from botocore.exceptions import ClientError
1076+
1077+
from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig
1078+
1079+
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)
1080+
1081+
client = boto3.client(
1082+
"redshift-data", config=Config(region_name=config.offline_store.region)
1083+
)
1084+
1085+
try:
1086+
if self.table is not None:
1087+
table = client.describe_table(
1088+
ClusterIdentifier=config.offline_store.cluster_id,
1089+
Database=config.offline_store.database,
1090+
DbUser=config.offline_store.user,
1091+
Table=self.table,
1092+
)
1093+
# The API returns valid JSON with empty column list when the table doesn't exist
1094+
if len(table["ColumnList"]) == 0:
1095+
raise DataSourceNotFoundException(self.table)
1096+
1097+
columns = table["ColumnList"]
1098+
else:
1099+
statement = client.execute_statement(
1100+
ClusterIdentifier=config.offline_store.cluster_id,
1101+
Database=config.offline_store.database,
1102+
DbUser=config.offline_store.user,
1103+
Sql=f"SELECT * FROM ({self.query}) LIMIT 1",
1104+
)
1105+
1106+
# Need to retry client.describe_statement(...) until the task is finished. We don't want to bombard
1107+
# Redshift with queries, and neither do we want to wait for a long time on the initial call.
1108+
# The solution is exponential backoff. The backoff starts with 0.1 seconds and doubles exponentially
1109+
# until reaching 30 seconds, at which point the backoff is fixed.
1110+
@retry(
1111+
wait=wait_exponential(multiplier=0.1, max=30),
1112+
retry=retry_unless_exception_type(RedshiftQueryError),
1113+
)
1114+
def wait_for_statement():
1115+
desc = client.describe_statement(Id=statement["Id"])
1116+
if desc["Status"] in ("SUBMITTED", "STARTED", "PICKED"):
1117+
raise Exception # Retry
1118+
if desc["Status"] != "FINISHED":
1119+
raise RedshiftQueryError(desc) # Don't retry. Raise exception.
1120+
1121+
wait_for_statement()
1122+
1123+
result = client.get_statement_result(Id=statement["Id"])
1124+
1125+
columns = result["ColumnMetadata"]
1126+
except ClientError as e:
1127+
if e.response["Error"]["Code"] == "ValidationException":
1128+
raise RedshiftCredentialsError() from e
1129+
raise
1130+
1131+
return [(column["name"], column["typeName"].upper()) for column in columns]

sdk/python/feast/errors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,13 @@ def __init__(self, repo_obj_type: str, specific_issue: str):
124124
f"Inference to fill in missing information for {repo_obj_type} failed. {specific_issue}. "
125125
"Try filling the information explicitly."
126126
)
127+
128+
129+
class RedshiftCredentialsError(Exception):
130+
def __init__(self):
131+
super().__init__("Redshift API failed due to incorrect credentials")
132+
133+
134+
class RedshiftQueryError(Exception):
135+
def __init__(self, details):
136+
super().__init__(f"Redshift SQL Query failed to finish. Details: {details}")

0 commit comments

Comments
 (0)