@@ -114,8 +114,8 @@ class MilvusOnlineStore(OnlineStore):
114114
115115 def _get_db_path (self , config : RepoConfig ) -> str :
116116 assert (
117- config .online_store .type == "milvus"
118- or config .online_store .type .endswith ("MilvusOnlineStore" )
117+ config .online_store .type == "milvus"
118+ or config .online_store .type .endswith ("MilvusOnlineStore" )
119119 )
120120
121121 if config .repo_path and not Path (config .online_store .path ).is_absolute ():
@@ -140,7 +140,7 @@ def _connect(self, config: RepoConfig) -> MilvusClient:
140140 return self .client
141141
142142 def _get_or_create_collection (
143- self , config : RepoConfig , table : FeatureView
143+ self , config : RepoConfig , table : FeatureView
144144 ) -> Dict [str , Any ]:
145145 self .client = self ._connect (config )
146146 vector_field_dict = {k .name : k for k in table .schema if k .vector_index }
@@ -199,12 +199,12 @@ def _get_or_create_collection(
199199 index_params = self .client .prepare_index_params ()
200200 for vector_field in schema .fields :
201201 if (
202- vector_field .dtype
203- in [
204- DataType .FLOAT_VECTOR ,
205- DataType .BINARY_VECTOR ,
206- ]
207- and vector_field .name in vector_field_dict
202+ vector_field .dtype
203+ in [
204+ DataType .FLOAT_VECTOR ,
205+ DataType .BINARY_VECTOR ,
206+ ]
207+ and vector_field .name in vector_field_dict
208208 ):
209209 metric = vector_field_dict [
210210 vector_field .name
@@ -229,18 +229,18 @@ def _get_or_create_collection(
229229 return self ._collections [collection_name ]
230230
231231 def online_write_batch (
232- self ,
233- config : RepoConfig ,
234- table : FeatureView ,
235- data : List [
236- Tuple [
237- EntityKeyProto ,
238- Dict [str , ValueProto ],
239- datetime ,
240- Optional [datetime ],
241- ]
242- ],
243- progress : Optional [Callable [[int ], Any ]],
232+ self ,
233+ config : RepoConfig ,
234+ table : FeatureView ,
235+ data : List [
236+ Tuple [
237+ EntityKeyProto ,
238+ Dict [str , ValueProto ],
239+ datetime ,
240+ Optional [datetime ],
241+ ]
242+ ],
243+ progress : Optional [Callable [[int ], Any ]],
244244 ) -> None :
245245 self .client = self ._connect (config )
246246 collection = self ._get_or_create_collection (config , table )
@@ -287,8 +287,8 @@ def online_write_batch(
287287 single_entity_record [field ] = ""
288288 # Store only the latest event timestamp per entity
289289 if (
290- entity_key_str not in unique_entities
291- or unique_entities [entity_key_str ]["event_ts" ] < timestamp_int
290+ entity_key_str not in unique_entities
291+ or unique_entities [entity_key_str ]["event_ts" ] < timestamp_int
292292 ):
293293 unique_entities [entity_key_str ] = single_entity_record
294294
@@ -302,12 +302,12 @@ def online_write_batch(
302302 )
303303
304304 def online_read (
305- self ,
306- config : RepoConfig ,
307- table : FeatureView ,
308- entity_keys : List [EntityKeyProto ],
309- requested_features : Optional [List [str ]] = None ,
310- full_feature_names : bool = False ,
305+ self ,
306+ config : RepoConfig ,
307+ table : FeatureView ,
308+ entity_keys : List [EntityKeyProto ],
309+ requested_features : Optional [List [str ]] = None ,
310+ full_feature_names : bool = False ,
311311 ) -> List [Tuple [Optional [datetime ], Optional [Dict [str , ValueProto ]]]]:
312312 self .client = self ._connect (config )
313313 collection_name = _table_id (config .project , table )
@@ -316,9 +316,9 @@ def online_read(
316316 composite_key_name = _get_composite_key_name (table )
317317
318318 output_fields = (
319- [composite_key_name ]
320- + (requested_features if requested_features else [])
321- + ["created_ts" , "event_ts" ]
319+ [composite_key_name ]
320+ + (requested_features if requested_features else [])
321+ + ["created_ts" , "event_ts" ]
322322 )
323323 assert all (
324324 field in [f ["name" ] for f in collection ["fields" ]]
@@ -335,9 +335,9 @@ def online_read(
335335 composite_entities .append (entity_key_str )
336336
337337 query_filter_for_entities = (
338- f"{ composite_key_name } in ["
339- + ", " .join ([f"'{ e } '" for e in composite_entities ])
340- + "]"
338+ f"{ composite_key_name } in ["
339+ + ", " .join ([f"'{ e } '" for e in composite_entities ])
340+ + "]"
341341 )
342342 self .client .load_collection (collection_name )
343343 results = self .client .query (
@@ -441,13 +441,13 @@ def online_read(
441441 return result_list
442442
443443 def update (
444- self ,
445- config : RepoConfig ,
446- tables_to_delete : Sequence [FeatureView ],
447- tables_to_keep : Sequence [FeatureView ],
448- entities_to_delete : Sequence [Entity ],
449- entities_to_keep : Sequence [Entity ],
450- partial : bool ,
444+ self ,
445+ config : RepoConfig ,
446+ tables_to_delete : Sequence [FeatureView ],
447+ tables_to_keep : Sequence [FeatureView ],
448+ entities_to_delete : Sequence [Entity ],
449+ entities_to_keep : Sequence [Entity ],
450+ partial : bool ,
451451 ):
452452 self .client = self ._connect (config )
453453 for table in tables_to_keep :
@@ -460,15 +460,15 @@ def update(
460460 self ._collections .pop (collection_name , None )
461461
462462 def plan (
463- self , config : RepoConfig , desired_registry_proto : RegistryProto
463+ self , config : RepoConfig , desired_registry_proto : RegistryProto
464464 ) -> List [InfraObject ]:
465465 raise NotImplementedError
466466
467467 def teardown (
468- self ,
469- config : RepoConfig ,
470- tables : Sequence [FeatureView ],
471- entities : Sequence [Entity ],
468+ self ,
469+ config : RepoConfig ,
470+ tables : Sequence [FeatureView ],
471+ entities : Sequence [Entity ],
472472 ):
473473 self .client = self ._connect (config )
474474 for table in tables :
@@ -478,14 +478,14 @@ def teardown(
478478 self ._collections .pop (collection_name , None )
479479
480480 def retrieve_online_documents_v2 (
481- self ,
482- config : RepoConfig ,
483- table : FeatureView ,
484- requested_features : List [str ],
485- embedding : Optional [List [float ]],
486- top_k : int ,
487- distance_metric : Optional [str ] = None ,
488- query_string : Optional [str ] = None ,
481+ self ,
482+ config : RepoConfig ,
483+ table : FeatureView ,
484+ requested_features : List [str ],
485+ embedding : Optional [List [float ]],
486+ top_k : int ,
487+ distance_metric : Optional [str ] = None ,
488+ query_string : Optional [str ] = None ,
489489 ) -> List [
490490 Tuple [
491491 Optional [datetime ],
@@ -514,7 +514,6 @@ def retrieve_online_documents_v2(
514514 self .client = self ._connect (config )
515515 collection_name = _table_id (config .project , table )
516516 collection = self ._get_or_create_collection (config , table )
517-
518517 if not config .online_store .vector_enabled :
519518 raise ValueError ("Vector search is not enabled in the online store config" )
520519
@@ -524,9 +523,9 @@ def retrieve_online_documents_v2(
524523 composite_key_name = _get_composite_key_name (table )
525524
526525 output_fields = (
527- [composite_key_name ]
528- + (requested_features if requested_features else [])
529- + ["created_ts" , "event_ts" ]
526+ [composite_key_name ]
527+ + (requested_features if requested_features else [])
528+ + ["created_ts" , "event_ts" ]
530529 )
531530 assert all (
532531 field in [f ["name" ] for f in collection ["fields" ]]
@@ -656,14 +655,14 @@ def retrieve_online_documents_v2(
656655 )
657656 res [ann_search_field ] = serialized_embedding
658657 elif entity_name_feast_primitive_type_map .get (
659- field , PrimitiveFeastType .INVALID
658+ field , PrimitiveFeastType .INVALID
660659 ) in [
661660 PrimitiveFeastType .STRING ,
662661 PrimitiveFeastType .BYTES ,
663662 ]:
664663 res [field ] = ValueProto (string_val = str (field_value ))
665664 elif entity_name_feast_primitive_type_map .get (
666- field , PrimitiveFeastType .INVALID
665+ field , PrimitiveFeastType .INVALID
667666 ) in [
668667 PrimitiveFeastType .INT64 ,
669668 PrimitiveFeastType .INT32 ,
@@ -694,9 +693,9 @@ def _get_composite_key_name(table: FeatureView) -> str:
694693
695694
696695def _extract_proto_values_to_dict (
697- input_dict : Dict [str , Any ],
698- vector_cols : List [str ],
699- serialize_to_string = False ,
696+ input_dict : Dict [str , Any ],
697+ vector_cols : List [str ],
698+ serialize_to_string = False ,
700699) -> Dict [str , Any ]:
701700 numeric_vector_list_types = [
702701 k
@@ -724,8 +723,8 @@ def _extract_proto_values_to_dict(
724723 vector_values = getattr (feature_values , proto_val_type ).val
725724 else :
726725 if (
727- serialize_to_string
728- and proto_val_type not in ["string_val" ] + numeric_types
726+ serialize_to_string
727+ and proto_val_type not in ["string_val" ] + numeric_types
729728 ):
730729 vector_values = feature_values .SerializeToString ().decode ()
731730 else :
0 commit comments