Skip to content

Commit 15b6e0a

Browse files
author
anna-charlotte
committed
fix: wip doc summary
Signed-off-by: anna-charlotte <[email protected]>
1 parent 1a6e1d3 commit 15b6e0a

File tree

1 file changed

+66
-76
lines changed

1 file changed

+66
-76
lines changed

docarray/base_document/document.py

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,9 @@
1414
from docarray.typing import ID
1515

1616
if TYPE_CHECKING:
17-
# import colorsys
18-
# from typing import Any, Optional, TypeVar
19-
# import numpy as np
20-
# from rich.color import Color
2117
from rich.console import Console, ConsoleOptions, RenderResult
2218
from rich.measure import Measurement
2319

24-
# from rich.segment import Segment
25-
# from rich.style import Style
26-
# from rich.tree import Tree
27-
#
28-
# import docarray
29-
# from docarray.math.helper import minmax_normalize
30-
# from docarray.typing import ID
31-
3220

3321
class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode):
3422
"""
@@ -54,49 +42,62 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']:
5442
"""
5543
return cls.__fields__[field].outer_type_
5644

57-
def _ipython_display_(self):
58-
"""Displays the object in IPython as a side effect"""
59-
self.summary()
60-
6145
def summary(self) -> None:
6246
"""Print non-empty fields and nested structure of this Document object."""
63-
from rich import print
47+
import rich
6448

6549
t = _plot_recursion(node=self)
66-
print(t)
50+
rich.print(t)
6751

68-
def schema_summary(self) -> None:
69-
from rich import print
70-
from rich.panel import Panel
52+
@classmethod
53+
def schema_summary(cls) -> None:
54+
"""Print a summary of the Documents schema."""
55+
import rich
7156

72-
panel = Panel(
73-
self.get_schema(), title='Document Schema', expand=False, padding=(1, 3)
57+
panel = rich.panel.Panel(
58+
cls.get_schema(), title='Document Schema', expand=False, padding=(1, 3)
7459
)
75-
print(panel)
60+
rich.print(panel)
61+
62+
def _ipython_display_(self):
63+
"""Displays the object in IPython as a side effect"""
64+
self.summary()
7665

7766
@classmethod
78-
def get_schema(cls, doc_name: str = None) -> Tree:
67+
def get_schema(cls, doc_name: Optional[str] = None) -> Tree:
7968
import re
8069

8170
from rich.tree import Tree
8271

83-
n = cls.__name__
72+
import docarray
73+
74+
name = cls.__name__
75+
tree = Tree(name) if doc_name is None else Tree(f'{doc_name}: {name}')
76+
77+
for k, v in cls.__annotations__.items():
78+
79+
field_type = cls._get_field_type(k)
8480

85-
tree = Tree(n) if doc_name is None else Tree(f'{doc_name}: {n}')
86-
annotations = cls.__annotations__
87-
for k, v in annotations.items():
88-
x = cls._get_field_type(k)
8981
t = str(v).replace('[', '\[')
9082
t = re.sub('[a-zA-Z_]*[.]', '', t)
9183

92-
if str(v).startswith('typing.Union'):
84+
if str(v).startswith('typing.Union') or str(v).startswith(
85+
'typing.Optional'
86+
):
9387
sub_tree = Tree(f'{k}: {t}')
9488
for arg in v.__args__:
9589
if issubclass(arg, BaseDocument):
9690
sub_tree.add(arg.get_schema())
91+
elif issubclass(arg, docarray.DocumentArray):
92+
sub_tree.add(arg.document_type.get_schema())
93+
tree.add(sub_tree)
94+
elif issubclass(field_type, BaseDocument):
95+
tree.add(field_type.get_schema(doc_name=k))
96+
elif issubclass(field_type, docarray.DocumentArray):
97+
name = v.__name__.replace('[', '\[')
98+
sub_tree = Tree(f'{k}: {name}')
99+
sub_tree.add(field_type.document_type.get_schema())
97100
tree.add(sub_tree)
98-
elif issubclass(x, BaseDocument):
99-
tree.add(x.get_schema(doc_name=k))
100101
else:
101102
tree.add(f'{k}: {t}')
102103
return tree
@@ -106,92 +107,81 @@ def __rich_console__(self, console, options):
106107
id_abbrv = getattr(self, 'id')[:7]
107108
yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]"
108109

109-
from collections.abc import Iterable
110-
111110
import torch
112111
from rich import box, text
113112
from rich.table import Table
114113

115-
my_table = Table(
116-
'Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True
117-
)
114+
import docarray
118115

119-
for k, v in self.__dict__.items():
120-
col_1, col_2 = '', ''
116+
table = Table('Attribute', 'Value', width=80, box=box.ROUNDED, highlight=True)
121117

122-
if isinstance(v, ID) or k.startswith('_') or v is None:
118+
for k, v in self.__dict__.items():
119+
col_1 = f'{k}: {v.__class__.__name__}'
120+
if (
121+
isinstance(v, ID | docarray.DocumentArray | docarray.BaseDocument)
122+
or k.startswith('_')
123+
or v is None
124+
):
123125
continue
124126
elif isinstance(v, str):
125-
col_1 = f'{k}: {v.__class__.__name__}'
126127
col_2 = str(v)[:50]
127128
if len(v) > 50:
128129
col_2 += f' ... (length: {len(v)})'
129-
elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor):
130-
col_1 = f'{k}: {v.__class__.__name__}'
131-
130+
table.add_row(col_1, text.Text(col_2))
131+
elif isinstance(v, np.ndarray | torch.Tensor):
132132
if isinstance(v, torch.Tensor):
133133
v = v.detach().cpu().numpy()
134-
if v.squeeze().ndim == 1 and len(v) < 1000:
135-
col_2 = ColorBoxArray(v.squeeze())
134+
if v.squeeze().ndim == 1 and len(v) < 50:
135+
table.add_row(col_1, ColorBoxArray(v.squeeze()))
136136
else:
137-
col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'
138-
139-
elif isinstance(v, tuple) or isinstance(v, list):
140-
col_1 = f'{k}: {v.__class__.__name__}'
137+
table.add_row(
138+
col_1,
139+
text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'),
140+
)
141+
elif isinstance(v, tuple | list):
142+
col_2 = ''
141143
for i, x in enumerate(v):
142144
if len(col_2) + len(str(x)) < 50:
143145
col_2 = str(v[:i])
144146
else:
145147
col_2 = f'{col_2[:-1]}, ...] (length: {len(v)})'
146148
break
147-
elif not isinstance(v, Iterable):
148-
col_1 = f'{k}: {v.__class__.__name__}'
149-
col_2 = str(v)
150-
else:
151-
continue
149+
table.add_row(col_1, text.Text(col_2))
152150

153-
if not isinstance(col_2, ColorBoxArray):
154-
col_2 = text.Text(col_2)
155-
my_table.add_row(col_1, col_2)
156-
157-
if my_table.rows:
158-
yield my_table
151+
if table.rows:
152+
yield table
159153

160154

161155
def _plot_recursion(node: Any, tree: Optional[Tree] = None) -> Tree:
162156
import docarray
163157

164158
tree = Tree(node) if tree is None else tree.add(node)
165159

166-
try:
160+
if hasattr(node, '__dict__'):
167161
iterable_attrs = [
168162
k
169163
for k, v in node.__dict__.items()
170-
if isinstance(v, docarray.DocumentArray)
171-
or isinstance(v, docarray.BaseDocument)
164+
if isinstance(v, docarray.DocumentArray | docarray.BaseDocument)
172165
]
173166
for attr in iterable_attrs:
174-
_icon = ':diamond_with_a_dot:'
175167
value = getattr(node, attr)
168+
attr_type = value.__class__.__name__
169+
icon = ':diamond_with_a_dot:'
170+
176171
if isinstance(value, docarray.BaseDocument):
177-
_icon = ':large_orange_diamond:'
178-
_match_tree = tree.add(
179-
f'{_icon} [b]{attr}: ' f'{value.__class__.__name__}[/b]'
180-
)
181-
if isinstance(value, docarray.BaseDocument):
172+
icon = ':large_orange_diamond:'
182173
value = [value]
174+
175+
_match_tree = tree.add(f'{icon} [b]{attr}: ' f'{attr_type}[/b]')
183176
for i, d in enumerate(value):
184177
if i == 2:
178+
doc_cls = d.__class__.__name__
185179
_plot_recursion(
186-
f'... {len(value) - 2} more {d.__class__.__name__} documents\n',
187-
_match_tree,
180+
f'... {len(value) - 2} more {doc_cls} documents\n', _match_tree
188181
)
189182
break
190183
_plot_recursion(d, _match_tree)
191184

192-
except Exception:
193-
pass
194-
195185
return tree
196186

197187

0 commit comments

Comments
 (0)