Skip to content

Commit 7ad70bf

Browse files
authored
feat: update for inmemory index (#1724)
Signed-off-by: jupyterjazz <[email protected]>
1 parent 410665a commit 7ad70bf

File tree

2 files changed

+95
-11
lines changed

2 files changed

+95
-11
lines changed

docarray/index/backends/in_memory.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
)()
9595

9696
self._embedding_map: Dict[str, Tuple[AnyTensor, Optional[List[int]]]] = {}
97+
self._ids_to_positions: Dict[str, int] = {}
9798

9899
def python_type_to_db_type(self, python_type: Type) -> Any:
99100
"""Map python type to database type.
@@ -163,7 +164,13 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
163164
"""
164165
# implementing the public option because conversion to column dict is not needed
165166
docs = self._validate_docs(docs)
166-
self._docs.extend(docs)
167+
ids_to_positions = self._get_ids_to_positions()
168+
for doc in docs:
169+
if doc.id in ids_to_positions:
170+
self._docs[ids_to_positions[doc.id]] = doc
171+
else:
172+
self._docs.append(doc)
173+
self._ids_to_positions[str(doc.id)] = len(self._ids_to_positions)
167174

168175
# Add parent_id to all sub-index documents and store sub-index documents
169176
data_by_columns = self._get_col_value_dict(docs)
@@ -216,6 +223,7 @@ def _del_items(self, doc_ids: Sequence[str]):
216223
indices.append(i)
217224

218225
del self._docs[indices]
226+
self._update_ids_to_positions()
219227
self._rebuild_embedding()
220228

221229
def _ori_items(self, doc: BaseDoc) -> BaseDoc:
@@ -259,15 +267,18 @@ def _get_items(
259267
"""
260268

261269
out_docs = []
262-
for i, doc in enumerate(self._docs):
263-
if doc.id in doc_ids:
264-
if raw:
265-
out_docs.append(doc)
266-
else:
267-
ori_doc = self._ori_items(doc)
268-
schema_cls = cast(Type[BaseDoc], self.out_schema)
269-
new_doc = schema_cls(**ori_doc.__dict__)
270-
out_docs.append(new_doc)
270+
ids_to_positions = self._get_ids_to_positions()
271+
for doc_id in doc_ids:
272+
if doc_id not in ids_to_positions:
273+
continue
274+
doc = self._docs[ids_to_positions[doc_id]]
275+
if raw:
276+
out_docs.append(doc)
277+
else:
278+
ori_doc = self._ori_items(doc)
279+
schema_cls = cast(Type[BaseDoc], self.out_schema)
280+
new_doc = schema_cls(**ori_doc.__dict__)
281+
out_docs.append(new_doc)
271282

272283
return out_docs
273284

@@ -461,7 +472,7 @@ def _text_search_batched(
461472
raise NotImplementedError(f'{type(self)} does not support text search.')
462473

463474
def _doc_exists(self, doc_id: str) -> bool:
464-
return any(doc.id == doc_id for doc in self._docs)
475+
return doc_id in self._get_ids_to_positions()
465476

466477
def persist(self, file: Optional[str] = None) -> None:
467478
"""Persist InMemoryExactNNIndex into a binary file."""
@@ -500,3 +511,21 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
500511
id, fields[0], '__'.join(fields[1:])
501512
)
502513
return self._get_root_doc_id(cur_root_id, root, '')
514+
515+
def _get_ids_to_positions(self) -> Dict[str, int]:
516+
"""
517+
Obtains a mapping between document IDs and their respective positions
518+
within the DocList. If this mapping hasn't been initialized, it will be created.
519+
520+
:return: A dictionary mapping each document ID to its corresponding position.
521+
"""
522+
if not self._ids_to_positions:
523+
self._update_ids_to_positions()
524+
return self._ids_to_positions
525+
526+
def _update_ids_to_positions(self) -> None:
527+
"""
528+
Generates or updates the mapping between document IDs and their corresponding
529+
positions within the DocList.
530+
"""
531+
self._ids_to_positions = {doc.id: pos for pos, doc in enumerate(self._docs)}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy as np
2+
3+
from docarray import BaseDoc, DocList
4+
from docarray.index import InMemoryExactNNIndex
5+
from docarray.typing import NdArray
6+
7+
8+
class SimpleDoc(BaseDoc):
9+
embedding: NdArray[128]
10+
text: str
11+
12+
13+
def test_update_payload():
14+
docs = DocList[SimpleDoc](
15+
[SimpleDoc(embedding=np.random.rand(128), text=f'hey {i}') for i in range(100)]
16+
)
17+
index = InMemoryExactNNIndex[SimpleDoc]()
18+
index.index(docs)
19+
20+
assert index.num_docs() == 100
21+
22+
for doc in docs:
23+
doc.text += '_changed'
24+
25+
index.index(docs)
26+
assert index.num_docs() == 100
27+
28+
res = index.find(query=docs[0], search_field='embedding', limit=100)
29+
assert len(res.documents) == 100
30+
for doc in res.documents:
31+
assert '_changed' in doc.text
32+
33+
34+
def test_update_embedding():
35+
docs = DocList[SimpleDoc](
36+
[SimpleDoc(embedding=np.random.rand(128), text=f'hey {i}') for i in range(100)]
37+
)
38+
index = InMemoryExactNNIndex[SimpleDoc]()
39+
index.index(docs)
40+
assert index.num_docs() == 100
41+
42+
new_tensor = np.random.rand(128)
43+
docs[0].embedding = new_tensor
44+
45+
index.index(docs[0])
46+
assert index.num_docs() == 100
47+
48+
res = index.find(query=docs[0], search_field='embedding', limit=100)
49+
assert len(res.documents) == 100
50+
found = False
51+
for doc in res.documents:
52+
if doc.id == docs[0].id:
53+
found = True
54+
assert (doc.embedding == new_tensor).all()
55+
assert found

0 commit comments

Comments
 (0)