3333
3434try :
3535 import boto3
36+ from aiobotocore import session
37+ from boto3 .dynamodb .types import TypeDeserializer
3638 from botocore .config import Config
3739 from botocore .exceptions import ClientError
3840except ImportError as e :
@@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):
8082
8183 _dynamodb_client = None
8284 _dynamodb_resource = None
85+ _aioboto_session = None
8386
8487 def update (
8588 self ,
@@ -223,69 +226,103 @@ def online_read(
223226 """
224227 online_config = config .online_store
225228 assert isinstance (online_config , DynamoDBOnlineStoreConfig )
229+
226230 dynamodb_resource = self ._get_dynamodb_resource (
227231 online_config .region , online_config .endpoint_url
228232 )
229233 table_instance = dynamodb_resource .Table (
230234 _get_table_name (online_config , config , table )
231235 )
232236
233- result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
234- entity_ids = [
235- compute_entity_id (
236- entity_key ,
237- entity_key_serialization_version = config .entity_key_serialization_version ,
238- )
239- for entity_key in entity_keys
240- ]
241237 batch_size = online_config .batch_size
238+ entity_ids = self ._to_entity_ids (config , entity_keys )
242239 entity_ids_iter = iter (entity_ids )
240+ result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
241+
243242 while True :
244243 batch = list (itertools .islice (entity_ids_iter , batch_size ))
245- batch_result : List [
246- Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]
247- ] = []
244+
248245 # No more items to insert
249246 if len (batch ) == 0 :
250247 break
251- batch_entity_ids = {
252- table_instance .name : {
253- "Keys" : [{"entity_id" : entity_id } for entity_id in batch ],
254- "ConsistentRead" : online_config .consistent_reads ,
255- }
256- }
248+ batch_entity_ids = self ._to_resource_batch_get_payload (
249+ online_config , table_instance .name , batch
250+ )
257251 response = dynamodb_resource .batch_get_item (
258252 RequestItems = batch_entity_ids ,
259253 )
260- response = response .get ("Responses" )
261- table_responses = response .get (table_instance .name )
262- if table_responses :
263- table_responses = self ._sort_dynamodb_response (
264- table_responses , entity_ids
265- )
266- entity_idx = 0
267- for tbl_res in table_responses :
268- entity_id = tbl_res ["entity_id" ]
269- while entity_id != batch [entity_idx ]:
270- batch_result .append ((None , None ))
271- entity_idx += 1
272- res = {}
273- for feature_name , value_bin in tbl_res ["values" ].items ():
274- val = ValueProto ()
275- val .ParseFromString (value_bin .value )
276- res [feature_name ] = val
277- batch_result .append (
278- (datetime .fromisoformat (tbl_res ["event_ts" ]), res )
279- )
280- entity_idx += 1
281-
282- # Not all entities in a batch may have responses
283- # Pad with remaining values in batch that were not found
284- batch_size_nones = ((None , None ),) * (len (batch ) - len (batch_result ))
285- batch_result .extend (batch_size_nones )
254+ batch_result = self ._process_batch_get_response (
255+ table_instance .name , response , entity_ids , batch
256+ )
286257 result .extend (batch_result )
287258 return result
288259
260+ async def online_read_async (
261+ self ,
262+ config : RepoConfig ,
263+ table : FeatureView ,
264+ entity_keys : List [EntityKeyProto ],
265+ requested_features : Optional [List [str ]] = None ,
266+ ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
267+ """
268+ Reads features values for the given entity keys asynchronously.
269+
270+ Args:
271+ config: The config for the current feature store.
272+ table: The feature view whose feature values should be read.
273+ entity_keys: The list of entity keys for which feature values should be read.
274+ requested_features: The list of features that should be read.
275+
276+ Returns:
277+ A list of the same length as entity_keys. Each item in the list is a tuple where the first
278+ item is the event timestamp for the row, and the second item is a dict mapping feature names
279+ to values, which are returned in proto format.
280+ """
281+ online_config = config .online_store
282+ assert isinstance (online_config , DynamoDBOnlineStoreConfig )
283+
284+ batch_size = online_config .batch_size
285+ entity_ids = self ._to_entity_ids (config , entity_keys )
286+ entity_ids_iter = iter (entity_ids )
287+ result : List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]] = []
288+ table_name = _get_table_name (online_config , config , table )
289+
290+ deserialize = TypeDeserializer ().deserialize
291+
292+ def to_tbl_resp (raw_client_response ):
293+ return {
294+ "entity_id" : deserialize (raw_client_response ["entity_id" ]),
295+ "event_ts" : deserialize (raw_client_response ["event_ts" ]),
296+ "values" : deserialize (raw_client_response ["values" ]),
297+ }
298+
299+ async with self ._get_aiodynamodb_client (online_config .region ) as client :
300+ while True :
301+ batch = list (itertools .islice (entity_ids_iter , batch_size ))
302+
303+ # No more items to insert
304+ if len (batch ) == 0 :
305+ break
306+ batch_entity_ids = self ._to_client_batch_get_payload (
307+ online_config , table_name , batch
308+ )
309+ response = await client .batch_get_item (
310+ RequestItems = batch_entity_ids ,
311+ )
312+ batch_result = self ._process_batch_get_response (
313+ table_name , response , entity_ids , batch , to_tbl_response = to_tbl_resp
314+ )
315+ result .extend (batch_result )
316+ return result
317+
318+ def _get_aioboto_session (self ):
319+ if self ._aioboto_session is None :
320+ self ._aioboto_session = session .get_session ()
321+ return self ._aioboto_session
322+
323+ def _get_aiodynamodb_client (self , region : str ):
324+ return self ._get_aioboto_session ().create_client ("dynamodb" , region_name = region )
325+
289326 def _get_dynamodb_client (self , region : str , endpoint_url : Optional [str ] = None ):
290327 if self ._dynamodb_client is None :
291328 self ._dynamodb_client = _initialize_dynamodb_client (region , endpoint_url )
@@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
298335 )
299336 return self ._dynamodb_resource
300337
301- def _sort_dynamodb_response (self , responses : list , order : list ) -> Any :
338+ def _sort_dynamodb_response (
339+ self ,
340+ responses : list ,
341+ order : list ,
342+ to_tbl_response : Callable = lambda raw_dict : raw_dict ,
343+ ) -> Any :
302344 """DynamoDB Batch Get Item doesn't return items in a particular order."""
303345 # Assign an index to order
304346 order_with_index = {value : idx for idx , value in enumerate (order )}
305347 # Sort table responses by index
306348 table_responses_ordered : Any = [
307- (order_with_index [tbl_res ["entity_id" ]], tbl_res ) for tbl_res in responses
349+ (order_with_index [tbl_res ["entity_id" ]], tbl_res )
350+ for tbl_res in map (to_tbl_response , responses )
308351 ]
309352 table_responses_ordered = sorted (
310353 table_responses_ordered , key = lambda tup : tup [0 ]
@@ -341,6 +384,64 @@ def _write_batch_non_duplicates(
341384 if progress :
342385 progress (1 )
343386
387+ def _process_batch_get_response (
388+ self , table_name , response , entity_ids , batch , ** sort_kwargs
389+ ):
390+ response = response .get ("Responses" )
391+ table_responses = response .get (table_name )
392+
393+ batch_result = []
394+ if table_responses :
395+ table_responses = self ._sort_dynamodb_response (
396+ table_responses , entity_ids , ** sort_kwargs
397+ )
398+ entity_idx = 0
399+ for tbl_res in table_responses :
400+ entity_id = tbl_res ["entity_id" ]
401+ while entity_id != batch [entity_idx ]:
402+ batch_result .append ((None , None ))
403+ entity_idx += 1
404+ res = {}
405+ for feature_name , value_bin in tbl_res ["values" ].items ():
406+ val = ValueProto ()
407+ val .ParseFromString (value_bin .value )
408+ res [feature_name ] = val
409+ batch_result .append ((datetime .fromisoformat (tbl_res ["event_ts" ]), res ))
410+ entity_idx += 1
411+ # Not all entities in a batch may have responses
412+ # Pad with remaining values in batch that were not found
413+ batch_size_nones = ((None , None ),) * (len (batch ) - len (batch_result ))
414+ batch_result .extend (batch_size_nones )
415+ return batch_result
416+
417+ @staticmethod
418+ def _to_entity_ids (config : RepoConfig , entity_keys : List [EntityKeyProto ]):
419+ return [
420+ compute_entity_id (
421+ entity_key ,
422+ entity_key_serialization_version = config .entity_key_serialization_version ,
423+ )
424+ for entity_key in entity_keys
425+ ]
426+
427+ @staticmethod
428+ def _to_resource_batch_get_payload (online_config , table_name , batch ):
429+ return {
430+ table_name : {
431+ "Keys" : [{"entity_id" : entity_id } for entity_id in batch ],
432+ "ConsistentRead" : online_config .consistent_reads ,
433+ }
434+ }
435+
436+ @staticmethod
437+ def _to_client_batch_get_payload (online_config , table_name , batch ):
438+ return {
439+ table_name : {
440+ "Keys" : [{"entity_id" : {"S" : entity_id }} for entity_id in batch ],
441+ "ConsistentRead" : online_config .consistent_reads ,
442+ }
443+ }
444+
344445
345446def _initialize_dynamodb_client (region : str , endpoint_url : Optional [str ] = None ):
346447 return boto3 .client (
0 commit comments