4040if TYPE_CHECKING :
4141 import pandas as pd
4242
43- from docarray import DocList
4443 from docarray .proto import DocListProto
4544
4645T = TypeVar ('T' , bound = 'IOMixinArray' )
@@ -332,11 +331,11 @@ def to_json(self) -> bytes:
332331
333332 @classmethod
334333 def from_csv (
335- cls ,
334+ cls : Type [ 'T' ] ,
336335 file_path : str ,
337336 encoding : str = 'utf-8' ,
338337 dialect : Union [str , csv .Dialect ] = 'excel' ,
339- ) -> 'DocList ' :
338+ ) -> 'T ' :
340339 """
341340 Load a DocList from a csv file following the schema defined in the
342341 [`.doc_type`][docarray.DocList] attribute.
@@ -358,10 +357,10 @@ def from_csv(
358357
359358 :return: `DocList` object
360359 """
361- if cls .doc_type == AnyDoc :
360+ if cls .doc_type == AnyDoc or cls . doc_type == BaseDoc :
362361 raise TypeError (
363362 'There is no document schema defined. '
364- 'Please specify the DocList \' s Document type using `DocList [MyDoc]`.'
363+ f 'Please specify the { cls } \' s Document type using `{ cls } [MyDoc]`.'
365364 )
366365
367366 if file_path .startswith ('http' ):
@@ -376,14 +375,15 @@ def from_csv(
376375
377376 @classmethod
378377 def _from_csv_file (
379- cls , file : Union [StringIO , TextIOWrapper ], dialect : Union [str , csv .Dialect ]
380- ) -> 'DocList' :
381- from docarray import DocList
378+ cls : Type ['T' ],
379+ file : Union [StringIO , TextIOWrapper ],
380+ dialect : Union [str , csv .Dialect ],
381+ ) -> 'T' :
382382
383383 rows = csv .DictReader (file , dialect = dialect )
384384
385385 doc_type = cls .doc_type
386- docs = DocList . __class_getitem__ ( doc_type )()
386+ docs = []
387387
388388 field_names : List [str ] = (
389389 [] if rows .fieldnames is None else [str (f ) for f in rows .fieldnames ]
@@ -405,7 +405,7 @@ def _from_csv_file(
405405 doc_dict : Dict [Any , Any ] = _access_path_dict_to_nested_dict (access_path2val )
406406 docs .append (doc_type .parse_obj (doc_dict ))
407407
408- return docs
408+ return cls ( docs )
409409
410410 def to_csv (
411411 self , file_path : str , dialect : Union [str , csv .Dialect ] = 'excel'
@@ -426,11 +426,11 @@ def to_csv(
426426 `'unix'` (for csv file generated on UNIX systems).
427427
428428 """
429- if self .doc_type == AnyDoc :
429+ if self .doc_type == AnyDoc or self . doc_type == BaseDoc :
430430 raise TypeError (
431- 'DocList must be homogeneous to be converted to a csv.'
431+ f' { type ( self ) } must be homogeneous to be converted to a csv.'
432432 'There is no document schema defined. '
433- 'Please specify the DocList \' s Document type using `DocList [MyDoc]`.'
433+ f 'Please specify the { type ( self ) } \' s Document type using `{ type ( self ) } [MyDoc]`.'
434434 )
435435 fields = self .doc_type ._get_access_paths ()
436436
@@ -443,7 +443,7 @@ def to_csv(
443443 writer .writerow (doc_dict )
444444
445445 @classmethod
446- def from_dataframe (cls , df : 'pd.DataFrame' ) -> 'DocList ' :
446+ def from_dataframe (cls : Type [ 'T' ] , df : 'pd.DataFrame' ) -> 'T ' :
447447 """
448448 Load a `DocList` from a `pandas.DataFrame` following the schema
449449 defined in the [`.doc_type`][docarray.DocList] attribute.
@@ -486,10 +486,10 @@ class Person(BaseDoc):
486486 """
487487 from docarray import DocList
488488
489- if cls .doc_type == AnyDoc :
489+ if cls .doc_type == AnyDoc or cls . doc_type == BaseDoc :
490490 raise TypeError (
491491 'There is no document schema defined. '
492- 'Please specify the DocList \' s Document type using `DocList [MyDoc]`.'
492+ f 'Please specify the { cls } \' s Document type using `{ cls } [MyDoc]`.'
493493 )
494494
495495 doc_type = cls .doc_type
@@ -515,6 +515,8 @@ class Person(BaseDoc):
515515 doc_dict = _access_path_dict_to_nested_dict (access_path2val )
516516 docs .append (doc_type .parse_obj (doc_dict ))
517517
518+ if not isinstance (docs , cls ):
519+ return cls (docs )
518520 return docs
519521
520522 def to_dataframe (self ) -> 'pd.DataFrame' :
@@ -563,6 +565,11 @@ def _stream_header(self) -> bytes:
563565 num_docs_as_bytes = len (self ).to_bytes (8 , 'big' , signed = False )
564566 return version_byte + num_docs_as_bytes
565567
568+ @classmethod
569+ @abstractmethod
570+ def _get_proto_class (cls : Type [T ]):
571+ ...
572+
566573 @classmethod
567574 def _load_binary_all (
568575 cls : Type [T ],
@@ -593,12 +600,10 @@ def _load_binary_all(
593600 compress = None
594601
595602 if protocol is not None and protocol == 'protobuf-array' :
596- from docarray .proto import DocListProto
597-
598- dap = DocListProto ()
599- dap .ParseFromString (d )
603+ proto = cls ._get_proto_class ()()
604+ proto .ParseFromString (d )
600605
601- return cls .from_protobuf (dap )
606+ return cls .from_protobuf (proto )
602607 elif protocol is not None and protocol == 'pickle-array' :
603608 return pickle .loads (d )
604609
0 commit comments