Skip to content

Commit 8f25887

Browse files
feat: i/o for DocVec (#1562)
Signed-off-by: Johannes Messner <[email protected]>
1 parent bcb60ca commit 8f25887

File tree

13 files changed

+571
-72
lines changed

13 files changed

+571
-72
lines changed

docarray/array/doc_list/doc_list.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T:
306306
"""
307307
return super().from_protobuf(pb_msg)
308308

309+
@classmethod
310+
def _get_proto_class(cls: Type[T]):
311+
from docarray.proto import DocListProto
312+
313+
return DocListProto
314+
309315
@overload
310316
def __getitem__(self, item: SupportsIndex) -> T_doc:
311317
...

docarray/array/doc_list/io.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
if TYPE_CHECKING:
4141
import pandas as pd
4242

43-
from docarray import DocList
4443
from docarray.proto import DocListProto
4544

4645
T = TypeVar('T', bound='IOMixinArray')
@@ -332,11 +331,11 @@ def to_json(self) -> bytes:
332331

333332
@classmethod
334333
def from_csv(
335-
cls,
334+
cls: Type['T'],
336335
file_path: str,
337336
encoding: str = 'utf-8',
338337
dialect: Union[str, csv.Dialect] = 'excel',
339-
) -> 'DocList':
338+
) -> 'T':
340339
"""
341340
Load a DocList from a csv file following the schema defined in the
342341
[`.doc_type`][docarray.DocList] attribute.
@@ -358,10 +357,10 @@ def from_csv(
358357
359358
:return: `DocList` object
360359
"""
361-
if cls.doc_type == AnyDoc:
360+
if cls.doc_type == AnyDoc or cls.doc_type == BaseDoc:
362361
raise TypeError(
363362
'There is no document schema defined. '
364-
'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
363+
f'Please specify the {cls}\'s Document type using `{cls}[MyDoc]`.'
365364
)
366365

367366
if file_path.startswith('http'):
@@ -376,14 +375,15 @@ def from_csv(
376375

377376
@classmethod
378377
def _from_csv_file(
379-
cls, file: Union[StringIO, TextIOWrapper], dialect: Union[str, csv.Dialect]
380-
) -> 'DocList':
381-
from docarray import DocList
378+
cls: Type['T'],
379+
file: Union[StringIO, TextIOWrapper],
380+
dialect: Union[str, csv.Dialect],
381+
) -> 'T':
382382

383383
rows = csv.DictReader(file, dialect=dialect)
384384

385385
doc_type = cls.doc_type
386-
docs = DocList.__class_getitem__(doc_type)()
386+
docs = []
387387

388388
field_names: List[str] = (
389389
[] if rows.fieldnames is None else [str(f) for f in rows.fieldnames]
@@ -405,7 +405,7 @@ def _from_csv_file(
405405
doc_dict: Dict[Any, Any] = _access_path_dict_to_nested_dict(access_path2val)
406406
docs.append(doc_type.parse_obj(doc_dict))
407407

408-
return docs
408+
return cls(docs)
409409

410410
def to_csv(
411411
self, file_path: str, dialect: Union[str, csv.Dialect] = 'excel'
@@ -426,11 +426,11 @@ def to_csv(
426426
`'unix'` (for csv file generated on UNIX systems).
427427
428428
"""
429-
if self.doc_type == AnyDoc:
429+
if self.doc_type == AnyDoc or self.doc_type == BaseDoc:
430430
raise TypeError(
431-
'DocList must be homogeneous to be converted to a csv.'
431+
f'{type(self)} must be homogeneous to be converted to a csv.'
432432
'There is no document schema defined. '
433-
'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
433+
f'Please specify the {type(self)}\'s Document type using `{type(self)}[MyDoc]`.'
434434
)
435435
fields = self.doc_type._get_access_paths()
436436

@@ -443,7 +443,7 @@ def to_csv(
443443
writer.writerow(doc_dict)
444444

445445
@classmethod
446-
def from_dataframe(cls, df: 'pd.DataFrame') -> 'DocList':
446+
def from_dataframe(cls: Type['T'], df: 'pd.DataFrame') -> 'T':
447447
"""
448448
Load a `DocList` from a `pandas.DataFrame` following the schema
449449
defined in the [`.doc_type`][docarray.DocList] attribute.
@@ -486,10 +486,10 @@ class Person(BaseDoc):
486486
"""
487487
from docarray import DocList
488488

489-
if cls.doc_type == AnyDoc:
489+
if cls.doc_type == AnyDoc or cls.doc_type == BaseDoc:
490490
raise TypeError(
491491
'There is no document schema defined. '
492-
'Please specify the DocList\'s Document type using `DocList[MyDoc]`.'
492+
f'Please specify the {cls}\'s Document type using `{cls}[MyDoc]`.'
493493
)
494494

495495
doc_type = cls.doc_type
@@ -515,6 +515,8 @@ class Person(BaseDoc):
515515
doc_dict = _access_path_dict_to_nested_dict(access_path2val)
516516
docs.append(doc_type.parse_obj(doc_dict))
517517

518+
if not isinstance(docs, cls):
519+
return cls(docs)
518520
return docs
519521

520522
def to_dataframe(self) -> 'pd.DataFrame':
@@ -563,6 +565,11 @@ def _stream_header(self) -> bytes:
563565
num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False)
564566
return version_byte + num_docs_as_bytes
565567

568+
@classmethod
569+
@abstractmethod
570+
def _get_proto_class(cls: Type[T]):
571+
...
572+
566573
@classmethod
567574
def _load_binary_all(
568575
cls: Type[T],
@@ -593,12 +600,10 @@ def _load_binary_all(
593600
compress = None
594601

595602
if protocol is not None and protocol == 'protobuf-array':
596-
from docarray.proto import DocListProto
597-
598-
dap = DocListProto()
599-
dap.ParseFromString(d)
603+
proto = cls._get_proto_class()()
604+
proto.ParseFromString(d)
600605

601-
return cls.from_protobuf(dap)
606+
return cls.from_protobuf(proto)
602607
elif protocol is not None and protocol == 'pickle-array':
603608
return pickle.loads(d)
604609

docarray/array/doc_vec/column_storage.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ItemsView,
77
Iterable,
88
MutableMapping,
9+
NamedTuple,
910
Optional,
1011
Type,
1112
TypeVar,
@@ -26,6 +27,13 @@
2627
T = TypeVar('T', bound='ColumnStorage')
2728

2829

30+
class ColumnsJsonCompatible(NamedTuple):
31+
tensor_columns: Dict[str, Any]
32+
doc_columns: Dict[str, Any]
33+
docs_vec_columns: Dict[str, Any]
34+
any_columns: Dict[str, Any]
35+
36+
2937
class ColumnStorage:
3038
"""
3139
ColumnStorage is a container to store the columns of the
@@ -91,6 +99,25 @@ def __getitem__(self: T, item: IndexIterType) -> T:
9199
self.tensor_type,
92100
)
93101

102+
def columns_json_compatible(self) -> ColumnsJsonCompatible:
103+
tens_cols = {
104+
key: value._docarray_to_json_compatible() if value is not None else value
105+
for key, value in self.tensor_columns.items()
106+
}
107+
doc_cols = {
108+
key: value._docarray_to_json_compatible() if value is not None else value
109+
for key, value in self.doc_columns.items()
110+
}
111+
doc_vec_cols = {
112+
key: [vec._docarray_to_json_compatible() for vec in value]
113+
if value is not None
114+
else value
115+
for key, value in self.docs_vec_columns.items()
116+
}
117+
return ColumnsJsonCompatible(
118+
tens_cols, doc_cols, doc_vec_cols, self.any_columns
119+
)
120+
94121
def __eq__(self, other: Any) -> bool:
95122
if not isinstance(other, ColumnStorage):
96123
return False
@@ -146,6 +173,11 @@ def __getitem__(self, name: str) -> Any:
146173
return None
147174
return col[self.index]
148175

176+
def __reduce__(self):
177+
# implementing __reduce__ to solve a pickle issue when subclassing dict
178+
# see here: https://stackoverflow.com/questions/21144845/how-can-i-unpickle-a-subclass-of-dict-that-validates-with-setitem-in-pytho
179+
return (ColumnStorageView, (self.index, self.storage))
180+
149181
def __setitem__(self, name, value) -> None:
150182
if self.storage.columns[name] is None:
151183
raise ValueError(

0 commit comments

Comments
 (0)