Skip to content

Commit 224d0a1

Browse files
committed
fix linter
Signed-off-by: Yassin Nouh <[email protected]>
1 parent f4adef3 commit 224d0a1

File tree

1 file changed

+65
-66
lines changed
  • sdk/python/feast/infra/online_stores/milvus_online_store

1 file changed

+65
-66
lines changed

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 65 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ class MilvusOnlineStore(OnlineStore):
114114

115115
def _get_db_path(self, config: RepoConfig) -> str:
116116
assert (
117-
config.online_store.type == "milvus"
118-
or config.online_store.type.endswith("MilvusOnlineStore")
117+
config.online_store.type == "milvus"
118+
or config.online_store.type.endswith("MilvusOnlineStore")
119119
)
120120

121121
if config.repo_path and not Path(config.online_store.path).is_absolute():
@@ -140,7 +140,7 @@ def _connect(self, config: RepoConfig) -> MilvusClient:
140140
return self.client
141141

142142
def _get_or_create_collection(
143-
self, config: RepoConfig, table: FeatureView
143+
self, config: RepoConfig, table: FeatureView
144144
) -> Dict[str, Any]:
145145
self.client = self._connect(config)
146146
vector_field_dict = {k.name: k for k in table.schema if k.vector_index}
@@ -199,12 +199,12 @@ def _get_or_create_collection(
199199
index_params = self.client.prepare_index_params()
200200
for vector_field in schema.fields:
201201
if (
202-
vector_field.dtype
203-
in [
204-
DataType.FLOAT_VECTOR,
205-
DataType.BINARY_VECTOR,
206-
]
207-
and vector_field.name in vector_field_dict
202+
vector_field.dtype
203+
in [
204+
DataType.FLOAT_VECTOR,
205+
DataType.BINARY_VECTOR,
206+
]
207+
and vector_field.name in vector_field_dict
208208
):
209209
metric = vector_field_dict[
210210
vector_field.name
@@ -229,18 +229,18 @@ def _get_or_create_collection(
229229
return self._collections[collection_name]
230230

231231
def online_write_batch(
232-
self,
233-
config: RepoConfig,
234-
table: FeatureView,
235-
data: List[
236-
Tuple[
237-
EntityKeyProto,
238-
Dict[str, ValueProto],
239-
datetime,
240-
Optional[datetime],
241-
]
242-
],
243-
progress: Optional[Callable[[int], Any]],
232+
self,
233+
config: RepoConfig,
234+
table: FeatureView,
235+
data: List[
236+
Tuple[
237+
EntityKeyProto,
238+
Dict[str, ValueProto],
239+
datetime,
240+
Optional[datetime],
241+
]
242+
],
243+
progress: Optional[Callable[[int], Any]],
244244
) -> None:
245245
self.client = self._connect(config)
246246
collection = self._get_or_create_collection(config, table)
@@ -287,8 +287,8 @@ def online_write_batch(
287287
single_entity_record[field] = ""
288288
# Store only the latest event timestamp per entity
289289
if (
290-
entity_key_str not in unique_entities
291-
or unique_entities[entity_key_str]["event_ts"] < timestamp_int
290+
entity_key_str not in unique_entities
291+
or unique_entities[entity_key_str]["event_ts"] < timestamp_int
292292
):
293293
unique_entities[entity_key_str] = single_entity_record
294294

@@ -302,12 +302,12 @@ def online_write_batch(
302302
)
303303

304304
def online_read(
305-
self,
306-
config: RepoConfig,
307-
table: FeatureView,
308-
entity_keys: List[EntityKeyProto],
309-
requested_features: Optional[List[str]] = None,
310-
full_feature_names: bool = False,
305+
self,
306+
config: RepoConfig,
307+
table: FeatureView,
308+
entity_keys: List[EntityKeyProto],
309+
requested_features: Optional[List[str]] = None,
310+
full_feature_names: bool = False,
311311
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
312312
self.client = self._connect(config)
313313
collection_name = _table_id(config.project, table)
@@ -316,9 +316,9 @@ def online_read(
316316
composite_key_name = _get_composite_key_name(table)
317317

318318
output_fields = (
319-
[composite_key_name]
320-
+ (requested_features if requested_features else [])
321-
+ ["created_ts", "event_ts"]
319+
[composite_key_name]
320+
+ (requested_features if requested_features else [])
321+
+ ["created_ts", "event_ts"]
322322
)
323323
assert all(
324324
field in [f["name"] for f in collection["fields"]]
@@ -335,9 +335,9 @@ def online_read(
335335
composite_entities.append(entity_key_str)
336336

337337
query_filter_for_entities = (
338-
f"{composite_key_name} in ["
339-
+ ", ".join([f"'{e}'" for e in composite_entities])
340-
+ "]"
338+
f"{composite_key_name} in ["
339+
+ ", ".join([f"'{e}'" for e in composite_entities])
340+
+ "]"
341341
)
342342
self.client.load_collection(collection_name)
343343
results = self.client.query(
@@ -441,13 +441,13 @@ def online_read(
441441
return result_list
442442

443443
def update(
444-
self,
445-
config: RepoConfig,
446-
tables_to_delete: Sequence[FeatureView],
447-
tables_to_keep: Sequence[FeatureView],
448-
entities_to_delete: Sequence[Entity],
449-
entities_to_keep: Sequence[Entity],
450-
partial: bool,
444+
self,
445+
config: RepoConfig,
446+
tables_to_delete: Sequence[FeatureView],
447+
tables_to_keep: Sequence[FeatureView],
448+
entities_to_delete: Sequence[Entity],
449+
entities_to_keep: Sequence[Entity],
450+
partial: bool,
451451
):
452452
self.client = self._connect(config)
453453
for table in tables_to_keep:
@@ -460,15 +460,15 @@ def update(
460460
self._collections.pop(collection_name, None)
461461

462462
def plan(
463-
self, config: RepoConfig, desired_registry_proto: RegistryProto
463+
self, config: RepoConfig, desired_registry_proto: RegistryProto
464464
) -> List[InfraObject]:
465465
raise NotImplementedError
466466

467467
def teardown(
468-
self,
469-
config: RepoConfig,
470-
tables: Sequence[FeatureView],
471-
entities: Sequence[Entity],
468+
self,
469+
config: RepoConfig,
470+
tables: Sequence[FeatureView],
471+
entities: Sequence[Entity],
472472
):
473473
self.client = self._connect(config)
474474
for table in tables:
@@ -478,14 +478,14 @@ def teardown(
478478
self._collections.pop(collection_name, None)
479479

480480
def retrieve_online_documents_v2(
481-
self,
482-
config: RepoConfig,
483-
table: FeatureView,
484-
requested_features: List[str],
485-
embedding: Optional[List[float]],
486-
top_k: int,
487-
distance_metric: Optional[str] = None,
488-
query_string: Optional[str] = None,
481+
self,
482+
config: RepoConfig,
483+
table: FeatureView,
484+
requested_features: List[str],
485+
embedding: Optional[List[float]],
486+
top_k: int,
487+
distance_metric: Optional[str] = None,
488+
query_string: Optional[str] = None,
489489
) -> List[
490490
Tuple[
491491
Optional[datetime],
@@ -514,7 +514,6 @@ def retrieve_online_documents_v2(
514514
self.client = self._connect(config)
515515
collection_name = _table_id(config.project, table)
516516
collection = self._get_or_create_collection(config, table)
517-
518517
if not config.online_store.vector_enabled:
519518
raise ValueError("Vector search is not enabled in the online store config")
520519

@@ -524,9 +523,9 @@ def retrieve_online_documents_v2(
524523
composite_key_name = _get_composite_key_name(table)
525524

526525
output_fields = (
527-
[composite_key_name]
528-
+ (requested_features if requested_features else [])
529-
+ ["created_ts", "event_ts"]
526+
[composite_key_name]
527+
+ (requested_features if requested_features else [])
528+
+ ["created_ts", "event_ts"]
530529
)
531530
assert all(
532531
field in [f["name"] for f in collection["fields"]]
@@ -656,14 +655,14 @@ def retrieve_online_documents_v2(
656655
)
657656
res[ann_search_field] = serialized_embedding
658657
elif entity_name_feast_primitive_type_map.get(
659-
field, PrimitiveFeastType.INVALID
658+
field, PrimitiveFeastType.INVALID
660659
) in [
661660
PrimitiveFeastType.STRING,
662661
PrimitiveFeastType.BYTES,
663662
]:
664663
res[field] = ValueProto(string_val=str(field_value))
665664
elif entity_name_feast_primitive_type_map.get(
666-
field, PrimitiveFeastType.INVALID
665+
field, PrimitiveFeastType.INVALID
667666
) in [
668667
PrimitiveFeastType.INT64,
669668
PrimitiveFeastType.INT32,
@@ -694,9 +693,9 @@ def _get_composite_key_name(table: FeatureView) -> str:
694693

695694

696695
def _extract_proto_values_to_dict(
697-
input_dict: Dict[str, Any],
698-
vector_cols: List[str],
699-
serialize_to_string=False,
696+
input_dict: Dict[str, Any],
697+
vector_cols: List[str],
698+
serialize_to_string=False,
700699
) -> Dict[str, Any]:
701700
numeric_vector_list_types = [
702701
k
@@ -724,8 +723,8 @@ def _extract_proto_values_to_dict(
724723
vector_values = getattr(feature_values, proto_val_type).val
725724
else:
726725
if (
727-
serialize_to_string
728-
and proto_val_type not in ["string_val"] + numeric_types
726+
serialize_to_string
727+
and proto_val_type not in ["string_val"] + numeric_types
729728
):
730729
vector_values = feature_values.SerializeToString().decode()
731730
else:

0 commit comments

Comments
 (0)