Skip to content

Commit ac28cf3

Browse files
author
Charlotte Gerhaher
authored
feat(v2): rich display for doc and da (#1043)
* feat: add rich display for doc and da Signed-off-by: anna-charlotte <[email protected]> * fix: wip plot Signed-off-by: anna-charlotte <[email protected]> * fix: wip plot Signed-off-by: anna-charlotte <[email protected]> * fix: wip plot Signed-off-by: anna-charlotte <[email protected]> * feat: add math package and minmax normalize Signed-off-by: anna-charlotte <[email protected]> * fix: summary for document Signed-off-by: anna-charlotte <[email protected]> * chore: update poetry lock after rebase Signed-off-by: anna-charlotte <[email protected]> * fix: move all from plotmixin to base document Signed-off-by: anna-charlotte <[email protected]> * feat: add docs schema summary Signed-off-by: anna-charlotte <[email protected]> * feat: add document array summary Signed-off-by: anna-charlotte <[email protected]> * fix: display doc within doc Signed-off-by: anna-charlotte <[email protected]> * fix: in notebook print docs summary Signed-off-by: anna-charlotte <[email protected]> * fix: move summary from da to abstract da Signed-off-by: anna-charlotte <[email protected]> * fix: get schema for doc Signed-off-by: anna-charlotte <[email protected]> * fix: wip doc summary Signed-off-by: anna-charlotte <[email protected]> * fix: wip clean up Signed-off-by: anna-charlotte <[email protected]> * test: add test for da pretty print Signed-off-by: anna-charlotte <[email protected]> * docs: update note Signed-off-by: anna-charlotte <[email protected]> * docs: add some documentation Signed-off-by: anna-charlotte <[email protected]> * fix: apply samis suggestion Signed-off-by: anna-charlotte <[email protected]> * fix: mypy checks Signed-off-by: anna-charlotte <[email protected]> * fix: move to plot mixin Signed-off-by: anna-charlotte <[email protected]> * fix: remove redundant line Signed-off-by: anna-charlotte <[email protected]> * fix: remove comments Signed-off-by: anna-charlotte <[email protected]> * feat: add schema highlighter Signed-off-by: anna-charlotte <[email protected]> * fix: add plotmixin to mixin init Signed-off-by: anna-charlotte <[email protected]> * fix: adjust da summary Signed-off-by: anna-charlotte <[email protected]> * fix: move minmaxnormalize to comp backend Signed-off-by: anna-charlotte <[email protected]> * fix: remove redundant lines Signed-off-by: anna-charlotte <[email protected]> * fix: add squeeze and detach to comp backend Signed-off-by: anna-charlotte <[email protected]> * fix: apply suggestion from code review Signed-off-by: anna-charlotte <[email protected]> * refactor: rename iterable attrs Signed-off-by: anna-charlotte <[email protected]> * fix: clean up Signed-off-by: anna-charlotte <[email protected]> * fix: import Signed-off-by: anna-charlotte <[email protected]> * fix: iterate over fields instead of annotations Signed-off-by: anna-charlotte <[email protected]> * fix: remove math package since moved to comp backends Signed-off-by: anna-charlotte <[email protected]> * refactor: use single quotes Signed-off-by: anna-charlotte <[email protected]> * fix: apply suggestions from code review Signed-off-by: anna-charlotte <[email protected]> * fix: extract summary to doc summary class Signed-off-by: anna-charlotte <[email protected]> * fix: add pretty print for base document Signed-off-by: anna-charlotte <[email protected]> * fix: use rich capture instead of string io Signed-off-by: anna-charlotte <[email protected]> * fix: add colors for optional and union and use only single quotes Signed-off-by: anna-charlotte <[email protected]> * fix: extract display classes to display package Signed-off-by: anna-charlotte <[email protected]> * fix: make da not optional in da summary Signed-off-by: anna-charlotte <[email protected]> * fix: set _console instead of initializing new one everytime in __str__ Signed-off-by: anna-charlotte <[email protected]> * fix: put console at module level Signed-off-by: anna-charlotte <[email protected]> Signed-off-by: anna-charlotte <[email protected]>
1 parent ec241e0 commit ac28cf3

File tree

16 files changed

+589
-6
lines changed

16 files changed

+589
-6
lines changed

docarray/array/abstract_array.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, Any, Generic, List, Sequence, Type, TypeVar, Union
33

44
from docarray.base_document import BaseDocument
5+
from docarray.display.document_array_summary import DocumentArraySummary
56
from docarray.typing import NdArray
67
from docarray.typing.abstract_type import AbstractType
78

@@ -17,6 +18,9 @@ class AnyDocumentArray(Sequence[BaseDocument], Generic[T_doc], AbstractType):
1718
document_type: Type[BaseDocument]
1819
tensor_type: Type['AbstractTensor'] = NdArray
1920

21+
def __repr__(self):
22+
return f'<{self.__class__.__name__} (length={len(self)})>'
23+
2024
def __class_getitem__(cls, item: Type[BaseDocument]):
2125
if not issubclass(item, BaseDocument):
2226
raise ValueError(
@@ -209,3 +213,10 @@ def _flatten_one_level(sequence: List[Any]) -> List[Any]:
209213
return sequence
210214
else:
211215
return [item for sublist in sequence for item in sublist]
216+
217+
def summary(self):
218+
"""
219+
Print a summary of this DocumentArray object and a summary of the schema of its
220+
Document type.
221+
"""
222+
DocumentArraySummary(self).summary()

docarray/base_document/document.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33

44
import orjson
55
from pydantic import BaseModel, Field, parse_obj_as
6+
from rich.console import Console
67

78
from docarray.base_document.abstract_document import AbstractDocument
89
from docarray.base_document.base_node import BaseNode
910
from docarray.base_document.io.json import orjson_dumps, orjson_dumps_and_decode
10-
from docarray.base_document.mixins import ProtoMixin
11+
from docarray.base_document.mixins import PlotMixin, ProtoMixin
1112
from docarray.typing import ID
1213

14+
_console: Console = Console()
1315

14-
class BaseDocument(BaseModel, ProtoMixin, AbstractDocument, BaseNode):
16+
17+
class BaseDocument(BaseModel, PlotMixin, ProtoMixin, AbstractDocument, BaseNode):
1518
"""
1619
The base class for Document
1720
"""
@@ -34,3 +37,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDocument']:
3437
:return:
3538
"""
3639
return cls.__fields__[field].outer_type_
40+
41+
def __str__(self):
42+
with _console.capture() as capture:
43+
_console.print(self)
44+
45+
return capture.get().strip()
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from docarray.base_document.mixins.plot import PlotMixin
12
from docarray.base_document.mixins.proto import ProtoMixin
23

3-
__all__ = ['ProtoMixin']
4+
__all__ = ['PlotMixin', 'ProtoMixin']
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from docarray.base_document.abstract_document import AbstractDocument
2+
from docarray.display.document_summary import DocumentSummary
3+
4+
5+
class PlotMixin(AbstractDocument):
6+
def summary(self) -> None:
7+
"""Print non-empty fields and nested structure of this Document object."""
8+
DocumentSummary(doc=self).summary()
9+
10+
@classmethod
11+
def schema_summary(cls) -> None:
12+
"""Print a summary of the Documents schema."""
13+
DocumentSummary.schema_summary(cls)
14+
15+
def _ipython_display_(self):
16+
"""Displays the object in IPython as a summary"""
17+
self.summary()

docarray/computation/abstract_comp_backend.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def n_dim(array: 'TTensor') -> int:
3737
"""
3838
...
3939

40+
@staticmethod
41+
@abstractmethod
42+
def squeeze(tensor: 'TTensor') -> 'TTensor':
43+
"""
44+
Returns a tensor with all the dimensions of tensor of size 1 removed.
45+
"""
46+
...
47+
4048
@staticmethod
4149
@abstractmethod
4250
def to_numpy(array: 'TTensor') -> 'np.ndarray':
@@ -85,6 +93,44 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor':
8593
"""
8694
...
8795

96+
@staticmethod
97+
@abstractmethod
98+
def detach(tensor: 'TTensor') -> 'TTensor':
99+
"""
100+
Returns the tensor detached from its current graph.
101+
102+
:param tensor: tensor to be detached
103+
:return: a detached tensor with the same data.
104+
"""
105+
...
106+
107+
@staticmethod
108+
@abstractmethod
109+
def minmax_normalize(
110+
tensor: 'TTensor',
111+
t_range: Tuple = (0, 1),
112+
x_range: Optional[Tuple] = None,
113+
eps: float = 1e-7,
114+
):
115+
"""
116+
Normalize values in `tensor` into `t_range`.
117+
118+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
119+
normalization is row-based.
120+
121+
.. note::
122+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
123+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
124+
of the data to 0.
125+
126+
:param tensor: the data to be normalized
127+
:param t_range: a tuple represents the target range.
128+
:param x_range: a tuple represents tensors range.
129+
:param eps: a small jitter to avoid divide by zero
130+
:return: normalized data in `t_range`
131+
"""
132+
...
133+
88134
class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
89135
"""
90136
Abstract class for retrieval and ranking functionalities

docarray/computation/numpy_backend.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ def to_device(tensor: 'np.ndarray', device: str) -> 'np.ndarray':
4949
def n_dim(array: 'np.ndarray') -> int:
5050
return array.ndim
5151

52+
@staticmethod
53+
def squeeze(tensor: 'np.ndarray') -> 'np.ndarray':
54+
"""
55+
Returns a tensor with all the dimensions of tensor of size 1 removed.
56+
"""
57+
return tensor.squeeze()
58+
5259
@staticmethod
5360
def to_numpy(array: 'np.ndarray') -> 'np.ndarray':
5461
return array
@@ -85,6 +92,48 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray':
8592
"""
8693
return array.reshape(shape)
8794

95+
@staticmethod
96+
def detach(tensor: 'np.ndarray') -> 'np.ndarray':
97+
"""
98+
Returns the tensor detached from its current graph.
99+
100+
:param tensor: tensor to be detached
101+
:return: a detached tensor with the same data.
102+
"""
103+
return tensor
104+
105+
@staticmethod
106+
def minmax_normalize(
107+
tensor: 'np.ndarray',
108+
t_range: Tuple = (0, 1),
109+
x_range: Optional[Tuple] = None,
110+
eps: float = 1e-7,
111+
):
112+
"""
113+
Normalize values in `tensor` into `t_range`.
114+
115+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
116+
normalization is row-based.
117+
118+
.. note::
119+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
120+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
121+
of the data to 0.
122+
123+
:param tensor: the data to be normalized
124+
:param t_range: a tuple represents the target range.
125+
:param x_range: a tuple represents tensors range.
126+
:param eps: a small jitter to avoid divide by zero
127+
:return: normalized data in `t_range`
128+
"""
129+
a, b = t_range
130+
131+
min_d = x_range[0] if x_range else np.min(tensor, axis=-1, keepdims=True)
132+
max_d = x_range[1] if x_range else np.max(tensor, axis=-1, keepdims=True)
133+
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a
134+
135+
return np.clip(r, *((a, b) if a < b else (b, a)))
136+
88137
class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
89138
"""
90139
Abstract class for retrieval and ranking functionalities

docarray/computation/torch_backend.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ def empty(
6363
def n_dim(array: 'torch.Tensor') -> int:
6464
return array.ndim
6565

66+
@staticmethod
67+
def squeeze(tensor: 'torch.Tensor') -> 'torch.Tensor':
68+
"""
69+
Returns a tensor with all the dimensions of tensor of size 1 removed.
70+
"""
71+
return torch.squeeze(tensor)
72+
6673
@staticmethod
6774
def to_numpy(array: 'torch.Tensor') -> 'np.ndarray':
6875
return array.cpu().detach().numpy()
@@ -89,6 +96,53 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
8996
"""
9097
return tensor.reshape(shape)
9198

99+
@staticmethod
100+
def detach(tensor: 'torch.Tensor') -> 'torch.Tensor':
101+
"""
102+
Returns the tensor detached from its current graph.
103+
104+
:param tensor: tensor to be detached
105+
:return: a detached tensor with the same data.
106+
"""
107+
return tensor.detach()
108+
109+
@staticmethod
110+
def minmax_normalize(
111+
tensor: 'torch.Tensor',
112+
t_range: Tuple = (0, 1),
113+
x_range: Optional[Tuple] = None,
114+
eps: float = 1e-7,
115+
):
116+
"""
117+
Normalize values in `tensor` into `t_range`.
118+
119+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
120+
normalization is row-based.
121+
122+
.. note::
123+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
124+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
125+
of the data to 0.
126+
127+
:param tensor: the data to be normalized
128+
:param t_range: a tuple represents the target range.
129+
:param x_range: a tuple represents tensors range.
130+
:param eps: a small jitter to avoid divide by zero
131+
:return: normalized data in `t_range`
132+
"""
133+
a, b = t_range
134+
135+
min_d = (
136+
x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values
137+
)
138+
max_d = (
139+
x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values
140+
)
141+
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a
142+
143+
normalized = torch.clip(r, *((a, b) if a < b else (b, a)))
144+
return normalized.to(tensor.dtype)
145+
92146
class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]):
93147
"""
94148
Abstract class for retrieval and ranking functionalities

docarray/display/__init__.py

Whitespace-only changes.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import TYPE_CHECKING
2+
3+
if TYPE_CHECKING:
4+
from docarray.array.abstract_array import AnyDocumentArray
5+
6+
7+
class DocumentArraySummary:
8+
def __init__(self, da: 'AnyDocumentArray'):
9+
self.da = da
10+
11+
def summary(self) -> None:
12+
"""
13+
Print a summary of this DocumentArray object and a summary of the schema of its
14+
Document type.
15+
"""
16+
from rich import box
17+
from rich.console import Console
18+
from rich.panel import Panel
19+
from rich.table import Table
20+
21+
table = Table(box=box.SIMPLE, highlight=True)
22+
table.show_header = False
23+
table.add_row('Type', self.da.__class__.__name__)
24+
table.add_row('Length', str(len(self.da)))
25+
26+
Console().print(Panel(table, title='DocumentArray Summary', expand=False))
27+
self.da.document_type.schema_summary()

0 commit comments

Comments
 (0)