Skip to content

Commit b5ef384

Browse files
authored
feat: Add online_read_async for dynamodb (#4244)
1 parent 15524ce commit b5ef384

File tree

10 files changed

+790
-156
lines changed

10 files changed

+790
-156
lines changed

sdk/python/feast/infra/online_stores/dynamodb.py

Lines changed: 146 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333

3434
try:
3535
import boto3
36+
from aiobotocore import session
37+
from boto3.dynamodb.types import TypeDeserializer
3638
from botocore.config import Config
3739
from botocore.exceptions import ClientError
3840
except ImportError as e:
@@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):
8082

8183
_dynamodb_client = None
8284
_dynamodb_resource = None
85+
_aioboto_session = None
8386

8487
def update(
8588
self,
@@ -223,69 +226,103 @@ def online_read(
223226
"""
224227
online_config = config.online_store
225228
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
229+
226230
dynamodb_resource = self._get_dynamodb_resource(
227231
online_config.region, online_config.endpoint_url
228232
)
229233
table_instance = dynamodb_resource.Table(
230234
_get_table_name(online_config, config, table)
231235
)
232236

233-
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
234-
entity_ids = [
235-
compute_entity_id(
236-
entity_key,
237-
entity_key_serialization_version=config.entity_key_serialization_version,
238-
)
239-
for entity_key in entity_keys
240-
]
241237
batch_size = online_config.batch_size
238+
entity_ids = self._to_entity_ids(config, entity_keys)
242239
entity_ids_iter = iter(entity_ids)
240+
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
241+
243242
while True:
244243
batch = list(itertools.islice(entity_ids_iter, batch_size))
245-
batch_result: List[
246-
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
247-
] = []
244+
248245
# No more items to insert
249246
if len(batch) == 0:
250247
break
251-
batch_entity_ids = {
252-
table_instance.name: {
253-
"Keys": [{"entity_id": entity_id} for entity_id in batch],
254-
"ConsistentRead": online_config.consistent_reads,
255-
}
256-
}
248+
batch_entity_ids = self._to_resource_batch_get_payload(
249+
online_config, table_instance.name, batch
250+
)
257251
response = dynamodb_resource.batch_get_item(
258252
RequestItems=batch_entity_ids,
259253
)
260-
response = response.get("Responses")
261-
table_responses = response.get(table_instance.name)
262-
if table_responses:
263-
table_responses = self._sort_dynamodb_response(
264-
table_responses, entity_ids
265-
)
266-
entity_idx = 0
267-
for tbl_res in table_responses:
268-
entity_id = tbl_res["entity_id"]
269-
while entity_id != batch[entity_idx]:
270-
batch_result.append((None, None))
271-
entity_idx += 1
272-
res = {}
273-
for feature_name, value_bin in tbl_res["values"].items():
274-
val = ValueProto()
275-
val.ParseFromString(value_bin.value)
276-
res[feature_name] = val
277-
batch_result.append(
278-
(datetime.fromisoformat(tbl_res["event_ts"]), res)
279-
)
280-
entity_idx += 1
281-
282-
# Not all entities in a batch may have responses
283-
# Pad with remaining values in batch that were not found
284-
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
285-
batch_result.extend(batch_size_nones)
254+
batch_result = self._process_batch_get_response(
255+
table_instance.name, response, entity_ids, batch
256+
)
286257
result.extend(batch_result)
287258
return result
288259

260+
async def online_read_async(
261+
self,
262+
config: RepoConfig,
263+
table: FeatureView,
264+
entity_keys: List[EntityKeyProto],
265+
requested_features: Optional[List[str]] = None,
266+
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
267+
"""
268+
Reads features values for the given entity keys asynchronously.
269+
270+
Args:
271+
config: The config for the current feature store.
272+
table: The feature view whose feature values should be read.
273+
entity_keys: The list of entity keys for which feature values should be read.
274+
requested_features: The list of features that should be read.
275+
276+
Returns:
277+
A list of the same length as entity_keys. Each item in the list is a tuple where the first
278+
item is the event timestamp for the row, and the second item is a dict mapping feature names
279+
to values, which are returned in proto format.
280+
"""
281+
online_config = config.online_store
282+
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
283+
284+
batch_size = online_config.batch_size
285+
entity_ids = self._to_entity_ids(config, entity_keys)
286+
entity_ids_iter = iter(entity_ids)
287+
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
288+
table_name = _get_table_name(online_config, config, table)
289+
290+
deserialize = TypeDeserializer().deserialize
291+
292+
def to_tbl_resp(raw_client_response):
293+
return {
294+
"entity_id": deserialize(raw_client_response["entity_id"]),
295+
"event_ts": deserialize(raw_client_response["event_ts"]),
296+
"values": deserialize(raw_client_response["values"]),
297+
}
298+
299+
async with self._get_aiodynamodb_client(online_config.region) as client:
300+
while True:
301+
batch = list(itertools.islice(entity_ids_iter, batch_size))
302+
303+
# No more items to insert
304+
if len(batch) == 0:
305+
break
306+
batch_entity_ids = self._to_client_batch_get_payload(
307+
online_config, table_name, batch
308+
)
309+
response = await client.batch_get_item(
310+
RequestItems=batch_entity_ids,
311+
)
312+
batch_result = self._process_batch_get_response(
313+
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp
314+
)
315+
result.extend(batch_result)
316+
return result
317+
318+
def _get_aioboto_session(self):
319+
if self._aioboto_session is None:
320+
self._aioboto_session = session.get_session()
321+
return self._aioboto_session
322+
323+
def _get_aiodynamodb_client(self, region: str):
324+
return self._get_aioboto_session().create_client("dynamodb", region_name=region)
325+
289326
def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
290327
if self._dynamodb_client is None:
291328
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
@@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
298335
)
299336
return self._dynamodb_resource
300337

301-
def _sort_dynamodb_response(self, responses: list, order: list) -> Any:
338+
def _sort_dynamodb_response(
339+
self,
340+
responses: list,
341+
order: list,
342+
to_tbl_response: Callable = lambda raw_dict: raw_dict,
343+
) -> Any:
302344
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
303345
# Assign an index to order
304346
order_with_index = {value: idx for idx, value in enumerate(order)}
305347
# Sort table responses by index
306348
table_responses_ordered: Any = [
307-
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
349+
(order_with_index[tbl_res["entity_id"]], tbl_res)
350+
for tbl_res in map(to_tbl_response, responses)
308351
]
309352
table_responses_ordered = sorted(
310353
table_responses_ordered, key=lambda tup: tup[0]
@@ -341,6 +384,64 @@ def _write_batch_non_duplicates(
341384
if progress:
342385
progress(1)
343386

387+
def _process_batch_get_response(
388+
self, table_name, response, entity_ids, batch, **sort_kwargs
389+
):
390+
response = response.get("Responses")
391+
table_responses = response.get(table_name)
392+
393+
batch_result = []
394+
if table_responses:
395+
table_responses = self._sort_dynamodb_response(
396+
table_responses, entity_ids, **sort_kwargs
397+
)
398+
entity_idx = 0
399+
for tbl_res in table_responses:
400+
entity_id = tbl_res["entity_id"]
401+
while entity_id != batch[entity_idx]:
402+
batch_result.append((None, None))
403+
entity_idx += 1
404+
res = {}
405+
for feature_name, value_bin in tbl_res["values"].items():
406+
val = ValueProto()
407+
val.ParseFromString(value_bin.value)
408+
res[feature_name] = val
409+
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
410+
entity_idx += 1
411+
# Not all entities in a batch may have responses
412+
# Pad with remaining values in batch that were not found
413+
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
414+
batch_result.extend(batch_size_nones)
415+
return batch_result
416+
417+
@staticmethod
418+
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
419+
return [
420+
compute_entity_id(
421+
entity_key,
422+
entity_key_serialization_version=config.entity_key_serialization_version,
423+
)
424+
for entity_key in entity_keys
425+
]
426+
427+
@staticmethod
428+
def _to_resource_batch_get_payload(online_config, table_name, batch):
429+
return {
430+
table_name: {
431+
"Keys": [{"entity_id": entity_id} for entity_id in batch],
432+
"ConsistentRead": online_config.consistent_reads,
433+
}
434+
}
435+
436+
@staticmethod
437+
def _to_client_batch_get_payload(online_config, table_name, batch):
438+
return {
439+
table_name: {
440+
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch],
441+
"ConsistentRead": online_config.consistent_reads,
442+
}
443+
}
444+
344445

345446
def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
346447
return boto3.client(

0 commit comments

Comments
 (0)