1414from docarray .typing import ID
1515
1616if TYPE_CHECKING :
17- # import colorsys
18- # from typing import Any, Optional, TypeVar
19- # import numpy as np
20- # from rich.color import Color
2117 from rich .console import Console , ConsoleOptions , RenderResult
2218 from rich .measure import Measurement
2319
24- # from rich.segment import Segment
25- # from rich.style import Style
26- # from rich.tree import Tree
27- #
28- # import docarray
29- # from docarray.math.helper import minmax_normalize
30- # from docarray.typing import ID
31-
3220
3321class BaseDocument (BaseModel , ProtoMixin , AbstractDocument , BaseNode ):
3422 """
@@ -54,49 +42,62 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']:
5442 """
5543 return cls .__fields__ [field ].outer_type_
5644
57- def _ipython_display_ (self ):
58- """Displays the object in IPython as a side effect"""
59- self .summary ()
60-
6145 def summary (self ) -> None :
6246 """Print non-empty fields and nested structure of this Document object."""
63- from rich import print
47+ import rich
6448
6549 t = _plot_recursion (node = self )
66- print (t )
50+ rich . print (t )
6751
68- def schema_summary (self ) -> None :
69- from rich import print
70- from rich .panel import Panel
52+ @classmethod
53+ def schema_summary (cls ) -> None :
54+ """Print a summary of the Documents schema."""
55+ import rich
7156
72- panel = Panel (
73- self .get_schema (), title = 'Document Schema' , expand = False , padding = (1 , 3 )
57+ panel = rich . panel . Panel (
58+ cls .get_schema (), title = 'Document Schema' , expand = False , padding = (1 , 3 )
7459 )
75- print (panel )
60+ rich .print (panel )
61+
62+ def _ipython_display_ (self ):
63+ """Displays the object in IPython as a side effect"""
64+ self .summary ()
7665
7766 @classmethod
78- def get_schema (cls , doc_name : str = None ) -> Tree :
67+ def get_schema (cls , doc_name : Optional [ str ] = None ) -> Tree :
7968 import re
8069
8170 from rich .tree import Tree
8271
83- n = cls .__name__
72+ import docarray
73+
74+ name = cls .__name__
75+ tree = Tree (name ) if doc_name is None else Tree (f'{ doc_name } : { name } ' )
76+
77+ for k , v in cls .__annotations__ .items ():
78+
79+ field_type = cls ._get_field_type (k )
8480
85- tree = Tree (n ) if doc_name is None else Tree (f'{ doc_name } : { n } ' )
86- annotations = cls .__annotations__
87- for k , v in annotations .items ():
88- x = cls ._get_field_type (k )
8981 t = str (v ).replace ('[' , '\[' )
9082 t = re .sub ('[a-zA-Z_]*[.]' , '' , t )
9183
92- if str (v ).startswith ('typing.Union' ):
84+ if str (v ).startswith ('typing.Union' ) or str (v ).startswith (
85+ 'typing.Optional'
86+ ):
9387 sub_tree = Tree (f'{ k } : { t } ' )
9488 for arg in v .__args__ :
9589 if issubclass (arg , BaseDocument ):
9690 sub_tree .add (arg .get_schema ())
91+ elif issubclass (arg , docarray .DocumentArray ):
92+ sub_tree .add (arg .document_type .get_schema ())
93+ tree .add (sub_tree )
94+ elif issubclass (field_type , BaseDocument ):
95+ tree .add (field_type .get_schema (doc_name = k ))
96+ elif issubclass (field_type , docarray .DocumentArray ):
97+ name = v .__name__ .replace ('[' , '\[' )
98+ sub_tree = Tree (f'{ k } : { name } ' )
99+ sub_tree .add (field_type .document_type .get_schema ())
97100 tree .add (sub_tree )
98- elif issubclass (x , BaseDocument ):
99- tree .add (x .get_schema (doc_name = k ))
100101 else :
101102 tree .add (f'{ k } : { t } ' )
102103 return tree
@@ -106,92 +107,81 @@ def __rich_console__(self, console, options):
106107 id_abbrv = getattr (self , 'id' )[:7 ]
107108 yield f":page_facing_up: [b]{ kls } " f"[/b]: [cyan]{ id_abbrv } ...[cyan]"
108109
109- from collections .abc import Iterable
110-
111110 import torch
112111 from rich import box , text
113112 from rich .table import Table
114113
115- my_table = Table (
116- 'Attribute' , 'Value' , width = 80 , box = box .ROUNDED , highlight = True
117- )
114+ import docarray
118115
119- for k , v in self .__dict__ .items ():
120- col_1 , col_2 = '' , ''
116+ table = Table ('Attribute' , 'Value' , width = 80 , box = box .ROUNDED , highlight = True )
121117
122- if isinstance (v , ID ) or k .startswith ('_' ) or v is None :
118+ for k , v in self .__dict__ .items ():
119+ col_1 = f'{ k } : { v .__class__ .__name__ } '
120+ if (
121+ isinstance (v , ID | docarray .DocumentArray | docarray .BaseDocument )
122+ or k .startswith ('_' )
123+ or v is None
124+ ):
123125 continue
124126 elif isinstance (v , str ):
125- col_1 = f'{ k } : { v .__class__ .__name__ } '
126127 col_2 = str (v )[:50 ]
127128 if len (v ) > 50 :
128129 col_2 += f' ... (length: { len (v )} )'
129- elif isinstance (v , np .ndarray ) or isinstance (v , torch .Tensor ):
130- col_1 = f'{ k } : { v .__class__ .__name__ } '
131-
130+ table .add_row (col_1 , text .Text (col_2 ))
131+ elif isinstance (v , np .ndarray | torch .Tensor ):
132132 if isinstance (v , torch .Tensor ):
133133 v = v .detach ().cpu ().numpy ()
134- if v .squeeze ().ndim == 1 and len (v ) < 1000 :
135- col_2 = ColorBoxArray (v .squeeze ())
134+ if v .squeeze ().ndim == 1 and len (v ) < 50 :
135+ table . add_row ( col_1 , ColorBoxArray (v .squeeze () ))
136136 else :
137- col_2 = f'{ type (v )} of shape { v .shape } , dtype: { v .dtype } '
138-
139- elif isinstance (v , tuple ) or isinstance (v , list ):
140- col_1 = f'{ k } : { v .__class__ .__name__ } '
137+ table .add_row (
138+ col_1 ,
139+ text .Text (f'{ type (v )} of shape { v .shape } , dtype: { v .dtype } ' ),
140+ )
141+ elif isinstance (v , tuple | list ):
142+ col_2 = ''
141143 for i , x in enumerate (v ):
142144 if len (col_2 ) + len (str (x )) < 50 :
143145 col_2 = str (v [:i ])
144146 else :
145147 col_2 = f'{ col_2 [:- 1 ]} , ...] (length: { len (v )} )'
146148 break
147- elif not isinstance (v , Iterable ):
148- col_1 = f'{ k } : { v .__class__ .__name__ } '
149- col_2 = str (v )
150- else :
151- continue
149+ table .add_row (col_1 , text .Text (col_2 ))
152150
153- if not isinstance (col_2 , ColorBoxArray ):
154- col_2 = text .Text (col_2 )
155- my_table .add_row (col_1 , col_2 )
156-
157- if my_table .rows :
158- yield my_table
151+ if table .rows :
152+ yield table
159153
160154
161155def _plot_recursion (node : Any , tree : Optional [Tree ] = None ) -> Tree :
162156 import docarray
163157
164158 tree = Tree (node ) if tree is None else tree .add (node )
165159
166- try :
160+ if hasattr ( node , '__dict__' ) :
167161 iterable_attrs = [
168162 k
169163 for k , v in node .__dict__ .items ()
170- if isinstance (v , docarray .DocumentArray )
171- or isinstance (v , docarray .BaseDocument )
164+ if isinstance (v , docarray .DocumentArray | docarray .BaseDocument )
172165 ]
173166 for attr in iterable_attrs :
174- _icon = ':diamond_with_a_dot:'
175167 value = getattr (node , attr )
168+ attr_type = value .__class__ .__name__
169+ icon = ':diamond_with_a_dot:'
170+
176171 if isinstance (value , docarray .BaseDocument ):
177- _icon = ':large_orange_diamond:'
178- _match_tree = tree .add (
179- f'{ _icon } [b]{ attr } : ' f'{ value .__class__ .__name__ } [/b]'
180- )
181- if isinstance (value , docarray .BaseDocument ):
172+ icon = ':large_orange_diamond:'
182173 value = [value ]
174+
175+ _match_tree = tree .add (f'{ icon } [b]{ attr } : ' f'{ attr_type } [/b]' )
183176 for i , d in enumerate (value ):
184177 if i == 2 :
178+ doc_cls = d .__class__ .__name__
185179 _plot_recursion (
186- f'... { len (value ) - 2 } more { d .__class__ .__name__ } documents\n ' ,
187- _match_tree ,
180+ f'... { len (value ) - 2 } more { doc_cls } documents\n ' , _match_tree
188181 )
189182 break
190183 _plot_recursion (d , _match_tree )
191184
192- except Exception :
193- pass
194-
195185 return tree
196186
197187
0 commit comments