@@ -94,6 +94,7 @@ def __init__(
9494 )()
9595
9696 self ._embedding_map : Dict [str , Tuple [AnyTensor , Optional [List [int ]]]] = {}
97+ self ._ids_to_positions : Dict [str , int ] = {}
9798
9899 def python_type_to_db_type (self , python_type : Type ) -> Any :
99100 """Map python type to database type.
@@ -163,7 +164,13 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
163164 """
164165 # implementing the public option because conversion to column dict is not needed
165166 docs = self ._validate_docs (docs )
166- self ._docs .extend (docs )
167+ ids_to_positions = self ._get_ids_to_positions ()
168+ for doc in docs :
169+ if doc .id in ids_to_positions :
170+ self ._docs [ids_to_positions [doc .id ]] = doc
171+ else :
172+ self ._docs .append (doc )
173+ self ._ids_to_positions [str (doc .id )] = len (self ._ids_to_positions )
167174
168175 # Add parent_id to all sub-index documents and store sub-index documents
169176 data_by_columns = self ._get_col_value_dict (docs )
@@ -216,6 +223,7 @@ def _del_items(self, doc_ids: Sequence[str]):
216223 indices .append (i )
217224
218225 del self ._docs [indices ]
226+ self ._update_ids_to_positions ()
219227 self ._rebuild_embedding ()
220228
221229 def _ori_items (self , doc : BaseDoc ) -> BaseDoc :
@@ -259,15 +267,18 @@ def _get_items(
259267 """
260268
261269 out_docs = []
262- for i , doc in enumerate (self ._docs ):
263- if doc .id in doc_ids :
264- if raw :
265- out_docs .append (doc )
266- else :
267- ori_doc = self ._ori_items (doc )
268- schema_cls = cast (Type [BaseDoc ], self .out_schema )
269- new_doc = schema_cls (** ori_doc .__dict__ )
270- out_docs .append (new_doc )
270+ ids_to_positions = self ._get_ids_to_positions ()
271+ for doc_id in doc_ids :
272+ if doc_id not in ids_to_positions :
273+ continue
274+ doc = self ._docs [ids_to_positions [doc_id ]]
275+ if raw :
276+ out_docs .append (doc )
277+ else :
278+ ori_doc = self ._ori_items (doc )
279+ schema_cls = cast (Type [BaseDoc ], self .out_schema )
280+ new_doc = schema_cls (** ori_doc .__dict__ )
281+ out_docs .append (new_doc )
271282
272283 return out_docs
273284
@@ -461,7 +472,7 @@ def _text_search_batched(
461472 raise NotImplementedError (f'{ type (self )} does not support text search.' )
462473
463474 def _doc_exists (self , doc_id : str ) -> bool :
464- return any ( doc . id == doc_id for doc in self ._docs )
475+ return doc_id in self ._get_ids_to_positions ( )
465476
466477 def persist (self , file : Optional [str ] = None ) -> None :
467478 """Persist InMemoryExactNNIndex into a binary file."""
@@ -500,3 +511,21 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
500511 id , fields [0 ], '__' .join (fields [1 :])
501512 )
502513 return self ._get_root_doc_id (cur_root_id , root , '' )
514+
515+ def _get_ids_to_positions (self ) -> Dict [str , int ]:
516+ """
517+ Obtains a mapping between document IDs and their respective positions
518+ within the DocList. If this mapping hasn't been initialized, it will be created.
519+
520+ :return: A dictionary mapping each document ID to its corresponding position.
521+ """
522+ if not self ._ids_to_positions :
523+ self ._update_ids_to_positions ()
524+ return self ._ids_to_positions
525+
526+ def _update_ids_to_positions (self ) -> None :
527+ """
528+ Generates or updates the mapping between document IDs and their corresponding
529+ positions within the DocList.
530+ """
531+ self ._ids_to_positions = {doc .id : pos for pos , doc in enumerate (self ._docs )}
0 commit comments