forked from feast-dev/feast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfeature_server.py
More file actions
171 lines (144 loc) · 5.98 KB
/
feature_server.py
File metadata and controls
171 lines (144 loc) · 5.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import json
import traceback
import warnings
import gunicorn.app.base
import pandas as pd
from fastapi import FastAPI, HTTPException, Request, Response, status
from fastapi.logger import logger
from fastapi.params import Depends
from google.protobuf.json_format import MessageToDict, Parse
from pydantic import BaseModel
import feast
from feast import proto_json
from feast.data_source import PushMode
from feast.errors import PushSourceNotFoundException
from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest
# TODO: deprecate this in favor of push features
class WriteToFeatureStoreRequest(BaseModel):
feature_view_name: str
df: dict
allow_registry_cache: bool = True
class PushFeaturesRequest(BaseModel):
push_source_name: str
df: dict
allow_registry_cache: bool = True
to: str = "online"
def get_app(store: "feast.FeatureStore"):
proto_json.patch()
app = FastAPI()
async def get_body(request: Request):
return await request.body()
@app.post("/get-online-features")
def get_online_features(body=Depends(get_body)):
try:
# Validate and parse the request data into GetOnlineFeaturesRequest Protobuf object
request_proto = GetOnlineFeaturesRequest()
Parse(body, request_proto)
# Initialize parameters for FeatureStore.get_online_features(...) call
if request_proto.HasField("feature_service"):
features = store.get_feature_service(
request_proto.feature_service, allow_cache=True
)
else:
features = list(request_proto.features.val)
full_feature_names = request_proto.full_feature_names
batch_sizes = [len(v.val) for v in request_proto.entities.values()]
num_entities = batch_sizes[0]
if any(batch_size != num_entities for batch_size in batch_sizes):
raise HTTPException(status_code=500, detail="Uneven number of columns")
response_proto = store._get_online_features(
features=features,
entity_values=request_proto.entities,
full_feature_names=full_feature_names,
native_entity_values=False,
).proto
# Convert the Protobuf object to JSON and return it
return MessageToDict( # type: ignore
response_proto, preserving_proto_field_name=True, float_precision=18
)
except Exception as e:
# Print the original exception on the server side
logger.exception(traceback.format_exc())
# Raise HTTPException to return the error message to the client
raise HTTPException(status_code=500, detail=str(e))
@app.post("/push")
def push(body=Depends(get_body)):
try:
request = PushFeaturesRequest(**json.loads(body))
df = pd.DataFrame(request.df)
if request.to == "offline":
to = PushMode.OFFLINE
elif request.to == "online":
to = PushMode.ONLINE
elif request.to == "online_and_offline":
to = PushMode.ONLINE_AND_OFFLINE
else:
raise ValueError(
f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']."
)
store.push(
push_source_name=request.push_source_name,
df=df,
allow_registry_cache=request.allow_registry_cache,
to=to,
)
except PushSourceNotFoundException as e:
# Print the original exception on the server side
logger.exception(traceback.format_exc())
# Raise HTTPException to return the error message to the client
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
# Print the original exception on the server side
logger.exception(traceback.format_exc())
# Raise HTTPException to return the error message to the client
raise HTTPException(status_code=500, detail=str(e))
@app.post("/write-to-online-store")
def write_to_online_store(body=Depends(get_body)):
warnings.warn(
"write_to_online_store is deprecated. Please consider using /push instead",
RuntimeWarning,
)
try:
request = WriteToFeatureStoreRequest(**json.loads(body))
df = pd.DataFrame(request.df)
store.write_to_online_store(
feature_view_name=request.feature_view_name,
df=df,
allow_registry_cache=request.allow_registry_cache,
)
except Exception as e:
# Print the original exception on the server side
logger.exception(traceback.format_exc())
# Raise HTTPException to return the error message to the client
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health():
return Response(status_code=status.HTTP_200_OK)
return app
class FeastServeApplication(gunicorn.app.base.BaseApplication):
def __init__(self, store: "feast.FeatureStore", **options):
self._app = get_app(store=store)
self._options = options
super().__init__()
def load_config(self):
for key, value in self._options.items():
if key.lower() in self.cfg.settings and value is not None:
self.cfg.set(key.lower(), value)
self.cfg.set("worker_class", "uvicorn.workers.UvicornWorker")
def load(self):
return self._app
def start_server(
store: "feast.FeatureStore",
host: str,
port: int,
no_access_log: bool,
workers: int,
keep_alive_timeout: int,
):
FeastServeApplication(
store=store,
bind=f"{host}:{port}",
accesslog=None if no_access_log else "-",
workers=workers,
keepalive=keep_alive_timeout,
).run()