Skip to content

Commit 4ed107c

Browse files
authored
fix: Implement apply_materialization and infra methods in sql registry (feast-dev#2775)
Signed-off-by: Achal Shah <[email protected]>
1 parent 846ff4a commit 4ed107c

1 file changed

Lines changed: 68 additions & 31 deletions

File tree

  • sdk/python/feast/infra/registry_stores

sdk/python/feast/infra/registry_stores/sql.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import datetime
22
from pathlib import Path
33
from threading import Lock
4-
from typing import Any, List, Optional
4+
from typing import Any, List, Optional, Union
55

66
from sqlalchemy import ( # type: ignore
77
BigInteger,
@@ -39,6 +39,7 @@
3939
FeatureService as FeatureServiceProto,
4040
)
4141
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
42+
from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto
4243
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
4344
OnDemandFeatureView as OnDemandFeatureViewProto,
4445
)
@@ -138,6 +139,14 @@
138139
Column("validation_reference_proto", LargeBinary, nullable=False),
139140
)
140141

142+
managed_infra = Table(
143+
"managed_infra",
144+
metadata,
145+
Column("infra_name", String(50), primary_key=True),
146+
Column("last_updated_timestamp", BigInteger, nullable=False),
147+
Column("infra_proto", LargeBinary, nullable=False),
148+
)
149+
141150

142151
class SqlRegistry(BaseRegistry):
143152
def __init__(
@@ -168,6 +177,7 @@ def teardown(self):
168177
conn.execute(stmt)
169178

170179
def refresh(self):
180+
# This method is a no-op since we're always reading the latest values from the db.
171181
pass
172182

173183
def get_stream_feature_view(
@@ -353,16 +363,7 @@ def apply_data_source(
353363
def apply_feature_view(
354364
self, feature_view: BaseFeatureView, project: str, commit: bool = True
355365
):
356-
if isinstance(feature_view, StreamFeatureView):
357-
fv_table = stream_feature_views
358-
elif isinstance(feature_view, FeatureView):
359-
fv_table = feature_views
360-
elif isinstance(feature_view, OnDemandFeatureView):
361-
fv_table = on_demand_feature_views
362-
elif isinstance(feature_view, RequestFeatureView):
363-
fv_table = request_feature_views
364-
else:
365-
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
366+
fv_table = self._infer_fv_table(feature_view)
366367

367368
return self._apply_object(
368369
fv_table, "feature_view_name", feature_view, "feature_view_proto"
@@ -457,7 +458,25 @@ def apply_materialization(
457458
end_date: datetime,
458459
commit: bool = True,
459460
):
460-
pass
461+
table = self._infer_fv_table(feature_view)
462+
python_class, proto_class = self._infer_fv_classes(feature_view)
463+
464+
if python_class in {RequestFeatureView, OnDemandFeatureView}:
465+
raise ValueError(
466+
f"Cannot apply materialization for feature {feature_view.name} of type {python_class}"
467+
)
468+
fv: Union[FeatureView, StreamFeatureView] = self._get_object(
469+
table,
470+
feature_view.name,
471+
project,
472+
proto_class,
473+
python_class,
474+
"feature_view_name",
475+
"feature_view_proto",
476+
FeatureViewNotFoundException,
477+
)
478+
fv.materialization_intervals.append((start_date, end_date))
479+
self._apply_object(table, "feature_view_name", fv, "feature_view_proto")
461480

462481
def delete_validation_reference(self, name: str, project: str, commit: bool = True):
463482
self._delete_object(
@@ -469,27 +488,29 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr
469488
)
470489

471490
def update_infra(self, infra: Infra, project: str, commit: bool = True):
472-
pass
491+
self._apply_object(
492+
managed_infra, "infra_name", infra, "infra_proto", name="infra_obj"
493+
)
473494

474495
def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
475-
return Infra()
496+
return self._get_object(
497+
managed_infra,
498+
"infra_obj",
499+
project,
500+
InfraProto,
501+
Infra,
502+
"infra_name",
503+
"infra_proto",
504+
None,
505+
)
476506

477507
def apply_user_metadata(
478508
self,
479509
project: str,
480510
feature_view: BaseFeatureView,
481511
metadata_bytes: Optional[bytes],
482512
):
483-
if isinstance(feature_view, StreamFeatureView):
484-
table = stream_feature_views
485-
elif isinstance(feature_view, FeatureView):
486-
table = feature_views
487-
elif isinstance(feature_view, OnDemandFeatureView):
488-
table = on_demand_feature_views
489-
elif isinstance(feature_view, RequestFeatureView):
490-
table = request_feature_views
491-
else:
492-
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
513+
table = self._infer_fv_table(feature_view)
493514

494515
name = feature_view.name
495516
with self.engine.connect() as conn:
@@ -511,9 +532,7 @@ def apply_user_metadata(
511532
else:
512533
raise FeatureViewNotFoundException(feature_view.name, project=project)
513534

514-
def get_user_metadata(
515-
self, project: str, feature_view: BaseFeatureView
516-
) -> Optional[bytes]:
535+
def _infer_fv_table(self, feature_view):
517536
if isinstance(feature_view, StreamFeatureView):
518537
table = stream_feature_views
519538
elif isinstance(feature_view, FeatureView):
@@ -524,6 +543,25 @@ def get_user_metadata(
524543
table = request_feature_views
525544
else:
526545
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
546+
return table
547+
548+
def _infer_fv_classes(self, feature_view):
549+
if isinstance(feature_view, StreamFeatureView):
550+
python_class, proto_class = StreamFeatureView, StreamFeatureViewProto
551+
elif isinstance(feature_view, FeatureView):
552+
python_class, proto_class = FeatureView, FeatureViewProto
553+
elif isinstance(feature_view, OnDemandFeatureView):
554+
python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto
555+
elif isinstance(feature_view, RequestFeatureView):
556+
python_class, proto_class = RequestFeatureView, RequestFeatureViewProto
557+
else:
558+
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
559+
return python_class, proto_class
560+
561+
def get_user_metadata(
562+
self, project: str, feature_view: BaseFeatureView
563+
) -> Optional[bytes]:
564+
table = self._infer_fv_table(feature_view)
527565

528566
name = feature_view.name
529567
with self.engine.connect() as conn:
@@ -556,12 +594,11 @@ def proto(self) -> RegistryProto:
556594
return r
557595

558596
def commit(self):
597+
# This method is a no-op since we're always writing values eagerly to the db.
559598
pass
560599

561-
def _apply_object(
562-
self, table, id_field_name, obj, proto_field_name,
563-
):
564-
name = obj.name
600+
def _apply_object(self, table, id_field_name, obj, proto_field_name, name=None):
601+
name = name or obj.name
565602
with self.engine.connect() as conn:
566603
stmt = select(table).where(getattr(table.c, id_field_name) == name)
567604
row = conn.execute(stmt).first()

0 commit comments

Comments
 (0)