Skip to content

Commit a38e0fc

Browse files
authored
feat(protobuf): add arg for compatible ndarray type (#169)
* feat(protobuf): add arg for compatible ndarray type * feat(protobuf): add arg for compatible ndarray type
1 parent 2dfa095 commit a38e0fc

File tree

6 files changed

+43
-12
lines changed

6 files changed

+43
-12
lines changed

docarray/array/mixins/io/binary.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,18 @@ def to_bytes(
311311
if not _file_ctx:
312312
return bf.getvalue()
313313

314-
def to_protobuf(self) -> 'DocumentArrayProto':
314+
def to_protobuf(self, ndarray_type: Optional[str] = None) -> 'DocumentArrayProto':
315+
"""Convert DocumentArray into a Protobuf message.
316+
317+
:param ndarray_type: can be ``list`` or ``numpy``, if set it will force all ndarray-like object from all
318+
Documents to ``List`` or ``numpy.ndarray``.
319+
:return: the protobuf message
320+
"""
315321
from ....proto.docarray_pb2 import DocumentArrayProto
316322

317323
dap = DocumentArrayProto()
318324
for d in self:
319-
dap.docs.append(d.to_protobuf())
325+
dap.docs.append(d.to_protobuf(ndarray_type))
320326
return dap
321327

322328
@classmethod

docarray/document/mixins/protobuf.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Type
1+
from typing import TYPE_CHECKING, Type, Optional
22

33
if TYPE_CHECKING:
44
from ...types import T
@@ -12,7 +12,12 @@ def from_protobuf(cls: Type['T'], pb_msg: 'DocumentProto') -> 'T':
1212

1313
return parse_proto(pb_msg)
1414

15-
def to_protobuf(self) -> 'DocumentProto':
15+
def to_protobuf(self, ndarray_type: Optional[str] = None) -> 'DocumentProto':
16+
"""Convert Document into a Protobuf message.
17+
18+
:param ndarray_type: can be ``list`` or ``numpy``, if set it will force all ndarray-like object to be ``List`` or ``numpy.ndarray``.
19+
:return: the protobuf message
20+
"""
1621
from ...proto.io import flush_proto
1722

18-
return flush_proto(self)
23+
return flush_proto(self, ndarray_type)

docarray/proto/io/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Optional
33

44
from google.protobuf.json_format import MessageToDict
55
from google.protobuf.struct_pb2 import Struct
@@ -37,13 +37,13 @@ def parse_proto(pb_msg: 'DocumentProto') -> 'Document':
3737
return Document(**fields)
3838

3939

40-
def flush_proto(doc: 'Document') -> 'DocumentProto':
40+
def flush_proto(doc: 'Document', ndarray_type: Optional[str] = None) -> 'DocumentProto':
4141
pb_msg = DocumentProto()
4242
for key in doc.non_empty_fields:
4343
try:
4444
value = getattr(doc, key)
4545
if key in ('tensor', 'embedding'):
46-
flush_ndarray(getattr(pb_msg, key), value)
46+
flush_ndarray(getattr(pb_msg, key), value, ndarray_type=ndarray_type)
4747
elif key in ('chunks', 'matches'):
4848
for d in value:
4949
d: Document

docarray/proto/io/ndarray.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Optional
22

33
import numpy as np
44

5-
from ...math.ndarray import get_array_type
5+
from ...math.ndarray import get_array_type, to_numpy_array
66

77
if TYPE_CHECKING:
88
from ...types import ArrayType
@@ -44,7 +44,14 @@ def read_ndarray(pb_msg: 'NdArrayProto') -> 'ArrayType':
4444
return _to_framework_array(x, framework)
4545

4646

47-
def flush_ndarray(pb_msg: 'NdArrayProto', value: 'ArrayType'):
47+
def flush_ndarray(
48+
pb_msg: 'NdArrayProto', value: 'ArrayType', ndarray_type: Optional[str] = None
49+
):
50+
if ndarray_type == 'list':
51+
value = to_numpy_array(value).tolist()
52+
elif ndarray_type == 'numpy':
53+
value = to_numpy_array(value)
54+
4855
framework, is_sparse = get_array_type(value)
4956

5057
if framework == 'docarray':

docs/fundamentals/document/serialization.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ mime_type: "image/jpeg"
240240

241241
One can refer to the [Protobuf specification of `Document`](../../proto/index.md) for details.
242242

243+
When `.tensor` or `.embedding` contains frameworks-specific ndarray-like object, you can use `.to_protobuf(..., ndarray_type='numpy')` or `.to_protobuf(..., ndarray_type='list')` to cast them into `list` or `numpy.ndarray` automatically. This will help to ensure the maximum compatability between different microservices.
243244

244245
## What's next?
245246

tests/unit/math/test_ndarray.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import numpy as np
12
import paddle
23
import pytest
34
import tensorflow as tf
45
import torch
56
from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix, issparse
67

78
from docarray.math.ndarray import get_array_rows
9+
from docarray.proto.docarray_pb2 import NdArrayProto
10+
from docarray.proto.io import flush_ndarray, read_ndarray
811

912

1013
@pytest.mark.parametrize(
@@ -30,7 +33,8 @@
3033
csc_matrix,
3134
],
3235
)
33-
def test_get_array_rows(data, expected_result, arraytype):
36+
@pytest.mark.parametrize('ndarray_type', ['list', 'numpy'])
37+
def test_get_array_rows(data, expected_result, arraytype, ndarray_type):
3438
data_array = arraytype(data)
3539

3640
num_rows, ndim = get_array_rows(data_array)
@@ -39,3 +43,11 @@ def test_get_array_rows(data, expected_result, arraytype):
3943
assert expected_result[0] == num_rows
4044
else:
4145
assert expected_result == (num_rows, ndim)
46+
47+
na_proto = NdArrayProto()
48+
flush_ndarray(na_proto, value=data_array, ndarray_type=ndarray_type)
49+
r_data_array = read_ndarray(na_proto)
50+
if ndarray_type == 'list':
51+
assert isinstance(r_data_array, list)
52+
elif ndarray_type == 'numpy':
53+
assert isinstance(r_data_array, np.ndarray)

0 commit comments

Comments
 (0)