Skip to content

Commit 5d453d4

Browse files
author
anna-charlotte
committed
fix: summary for document
Signed-off-by: anna-charlotte <[email protected]>
1 parent d29c468 commit 5d453d4

File tree

1 file changed

+44
-4
lines changed
  • docarray/base_document/mixins

1 file changed

+44
-4
lines changed

docarray/base_document/mixins/plot.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
import colorsys
12
from typing import Any, Optional, TypeVar
23

34
import numpy as np
5+
from rich.color import Color
6+
from rich.console import Console, ConsoleOptions, RenderResult
7+
from rich.measure import Measurement
8+
from rich.segment import Segment
9+
from rich.style import Style
410
from rich.tree import Tree
511

612
import docarray
13+
from docarray.math.helper import minmax_normalize
714
from docarray.typing import ID
815

916
T = TypeVar('T', bound=Any)
@@ -37,7 +44,8 @@ def _plot_recursion(node: T, tree: Optional[Tree] = None) -> Tree:
3744
for i, d in enumerate(value):
3845
if i == 2:
3946
PlotMixin._plot_recursion(
40-
f' ... {len(value) - 2} more Docs', _match_tree
47+
f' ... {len(value) - 2} more {d.__class__} documents',
48+
_match_tree,
4149
)
4250
break
4351
PlotMixin._plot_recursion(d, _match_tree)
@@ -51,6 +59,7 @@ def __rich_console__(self, console, options):
5159
kls = self.__class__.__name__
5260
id_abbrv = getattr(self, 'id')[:7]
5361
yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]"
62+
5463
from collections.abc import Iterable
5564

5665
import torch
@@ -64,7 +73,7 @@ def __rich_console__(self, console, options):
6473
for k, v in self.__dict__.items():
6574
col_1, col_2 = '', ''
6675

67-
if k.startswith('_') or isinstance(v, ID) or v is None:
76+
if isinstance(v, ID) or k.startswith('_') or v is None:
6877
continue
6978
elif isinstance(v, str):
7079
col_1 = f'{k}: {v.__class__.__name__}'
@@ -73,7 +82,14 @@ def __rich_console__(self, console, options):
7382
col_2 += f' ... (length: {len(v)})'
7483
elif isinstance(v, np.ndarray) or isinstance(v, torch.Tensor):
7584
col_1 = f'{k}: {v.__class__.__name__}'
76-
col_2 = f'{type(v)} in shape {v.shape}, dtype: {v.dtype}'
85+
86+
if isinstance(v, torch.Tensor):
87+
v = v.detach().cou().numpy()
88+
if v.squeeze().ndim == 1 and len(v) < 1000:
89+
col_2 = ColorBoxArray(v.squeeze())
90+
else:
91+
col_2 = f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'
92+
7793
elif isinstance(v, tuple) or isinstance(v, list):
7894
col_1 = f'{k}: {v.__class__.__name__}'
7995
for i, x in enumerate(v):
@@ -88,7 +104,31 @@ def __rich_console__(self, console, options):
88104
else:
89105
continue
90106

91-
my_table.add_row(col_1, text.Text(col_2))
107+
if not isinstance(col_2, ColorBoxArray):
108+
col_2 = text.Text(col_2)
109+
my_table.add_row(col_1, col_2)
92110

93111
if my_table.rows:
94112
yield my_table
113+
114+
115+
class ColorBoxArray:
116+
def __init__(self, array):
117+
self._array = minmax_normalize(array, (0, 5))
118+
119+
def __rich_console__(
120+
self, console: Console, options: ConsoleOptions
121+
) -> RenderResult:
122+
h = 0.75
123+
for idx, y in enumerate(self._array):
124+
lightness = 0.1 + ((y / 5) * 0.7)
125+
r, g, b = colorsys.hls_to_rgb(h, lightness + 0.7 / 10, 1.0)
126+
color = Color.from_rgb(r * 255, g * 255, b * 255)
127+
yield Segment('▄', Style(color=color, bgcolor=color))
128+
if idx != 0 and idx % options.max_width == 0:
129+
yield Segment.line()
130+
131+
def __rich_measure__(
132+
self, console: "Console", options: ConsoleOptions
133+
) -> Measurement:
134+
return Measurement(1, options.max_width)

0 commit comments

Comments
 (0)