Skip to content

Commit 3b169d4

Browse files
authored
feat: add from to dict io (#160)
1 parent 45a35a5 commit 3b169d4

File tree

6 files changed

+91
-5
lines changed

6 files changed

+91
-5
lines changed

docarray/array/mixins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .io.common import CommonIOMixin
1313
from .io.csv import CsvIOMixin
1414
from .io.dataframe import DataframeIOMixin
15+
from .io.dict import DictIOMixin
1516
from .io.from_gen import FromGeneratorMixin
1617
from .io.json import JsonIOMixin
1718
from .io.pushpull import PushPullMixin
@@ -41,6 +42,7 @@ class AllMixins(
4142
CsvIOMixin,
4243
JsonIOMixin,
4344
BinaryIOMixin,
45+
DictIOMixin,
4446
CommonIOMixin,
4547
EmbedMixin,
4648
PushPullMixin,

docarray/array/mixins/io/dict.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import TYPE_CHECKING, Type, List
2+
3+
if TYPE_CHECKING:
4+
from ....types import T
5+
6+
7+
class DictIOMixin:
8+
"""Save/load a DocumentArray into a dict of the form `{offset_0: doc_0, offset_1: doc_1, ...}`"""
9+
10+
def to_dict(self, protocol: str = 'jsonschema', **kwargs) -> List:
11+
"""Convert the object into a Python dict of the form `{offset_0: doc_0, offset_1: doc_1, ...}`
12+
13+
:param protocol: `jsonschema` or `protobuf`
14+
:return: a Python list
15+
"""
16+
return {k: d.to_dict(protocol=protocol, **kwargs) for k, d in enumerate(self)}
17+
18+
@classmethod
19+
def from_dict(cls: Type['T'], input_dict: dict, *args, **kwargs) -> 'T':
20+
"""Import a :class:`DocumentArray` from a :class:`dict` object of the form `{offset_0: doc_0, offset_1: doc_1, ...}`
21+
22+
:param input_dict: a `dict` object.
23+
:return: a :class:`DocumentArray` object
24+
"""
25+
from .... import Document, DocumentArray
26+
27+
da = cls.empty(len(input_dict), *args, **kwargs)
28+
29+
for offset, d in input_dict.items():
30+
da[offset] = Document(
31+
{k: v for k, v in d.items() if (not isinstance(v, float) or v == v)}
32+
)
33+
return da

docarray/array/mixins/io/json.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,3 @@ def to_json(self, protocol: str = 'jsonschema', **kwargs) -> str:
9999
:return: a Python list
100100
"""
101101
return json.dumps(self.to_list(protocol=protocol, **kwargs))
102-
103-
# to comply with Document interfaces but less semantically accurate
104-
to_dict = to_list
105-
from_dict = from_list

docs/fundamentals/documentarray/serialization.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Moreover, there is the ability to store/load `DocumentArray` objects to/from dis
1212
- Base64 (compressed): `.from_base64()`/`.to_base64()`
1313
- Protobuf Message: `.from_protobuf()`/`.to_protobuf()`
1414
- Python List: `.from_list()`/`.to_list()`
15+
- Python Dict: `.from_dict()`/`.to_dict()`
1516
- Pandas Dataframe: `.from_dataframe()`/`.to_dataframe()`
1617
- Cloud: `.push()`/`.pull()`
1718

@@ -324,6 +325,25 @@ da.to_list()
324325
More parameters and usages can be found in the Document-level {ref}`doc-dict`.
325326
```
326327

328+
329+
## From/to dict
330+
331+
332+
Serializing to/from Python dict is less frequently used for the same reason as `Document.to_dict()`: it is often an intermediate step of serializing to JSON. You can do:
333+
334+
```python
335+
from docarray import DocumentArray, Document
336+
337+
da = DocumentArray([Document(text='hello'), Document(text='world')])
338+
da.to_dict()
339+
```
340+
341+
```text
342+
{0: {'id': '3b31cb4c993f11ec8d12787b8ab3f5de', 'mime_type': 'text/plain', 'text': 'hello', 1: {'id': '3b31cca0993f11ec8d12787b8ab3f5de', 'mime_type': 'text/plain', 'text': 'world'}}```
343+
```
344+
345+
346+
327347
## From/to dataframe
328348

329349
```{important}

tests/unit/array/mixins/test_io.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,34 @@ def test_from_to_pd_dataframe(da_cls, config, start_storage):
172172
assert da2[1].tags == {}
173173

174174

175+
@pytest.mark.parametrize(
176+
'da_cls, config',
177+
[
178+
(DocumentArrayInMemory, lambda: None),
179+
(DocumentArraySqlite, lambda: None),
180+
(DocumentArrayPqlite, lambda: PqliteConfig(n_dim=3)),
181+
(DocumentArrayWeaviate, lambda: WeaviateConfig(n_dim=3)),
182+
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=3)),
183+
],
184+
)
185+
def test_from_to_dict(da_cls, config, start_storage):
186+
da_dict = da_cls.empty(2, config=config()).to_dict()
187+
assert len(da_cls.from_dict(da_dict, config=config())) == 2
188+
189+
# more complicated
190+
da = da_cls.empty(2, config=config())
191+
192+
da[:, 'embedding'] = [[1, 2, 3], [4, 5, 6]]
193+
da[:, 'tensor'] = [[1, 2], [2, 1]]
194+
da[0, 'tags'] = {'hello': 'world'}
195+
da_dict = da.to_dict()
196+
197+
da2 = da_cls.from_dict(da_dict, config=config())
198+
199+
assert da2[0].tags == {'hello': 'world'}
200+
assert da2[1].tags == {}
201+
202+
175203
@pytest.mark.parametrize(
176204
'da_cls, config',
177205
[

tests/unit/array/test_from_to_bytes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,21 @@ def test_from_to_protobuf(target_da):
9595

9696
@pytest.mark.parametrize('target', [DocumentArray.empty(10), random_docs(10)])
9797
@pytest.mark.parametrize('protocol', ['jsonschema', 'protobuf'])
98-
@pytest.mark.parametrize('to_fn', ['dict', 'json'])
98+
@pytest.mark.parametrize('to_fn', ['list', 'json'])
9999
def test_from_to_safe_list(target, protocol, to_fn):
100100
da_r = getattr(DocumentArray, f'from_{to_fn}')(
101101
getattr(target, f'to_{to_fn}')(protocol=protocol), protocol=protocol
102102
)
103103
assert da_r == target
104104

105105

106+
@pytest.mark.parametrize('target', [DocumentArray.empty(10), random_docs(10)])
107+
def test_from_to_safe_dict(target):
108+
target_dict = getattr(target, f'to_dict')(target)
109+
da_r = getattr(DocumentArray, f'from_dict')(target_dict)
110+
assert da_r == target
111+
112+
106113
@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
107114
@pytest.mark.parametrize('show_progress', [True, False])
108115
def test_push_pull_show_progress(show_progress, protocol):

0 commit comments

Comments
 (0)