1+ import colorsys
12from typing import Any , Optional , TypeVar
23
34import numpy as np
5+ from rich .color import Color
6+ from rich .console import Console , ConsoleOptions , RenderResult
7+ from rich .measure import Measurement
8+ from rich .segment import Segment
9+ from rich .style import Style
410from rich .tree import Tree
511
612import docarray
13+ from docarray .math .helper import minmax_normalize
714from docarray .typing import ID
815
916T = TypeVar ('T' , bound = Any )
@@ -37,7 +44,8 @@ def _plot_recursion(node: T, tree: Optional[Tree] = None) -> Tree:
3744 for i , d in enumerate (value ):
3845 if i == 2 :
3946 PlotMixin ._plot_recursion (
40- f' ... { len (value ) - 2 } more Docs' , _match_tree
47+ f' ... { len (value ) - 2 } more { d .__class__ } documents' ,
48+ _match_tree ,
4149 )
4250 break
4351 PlotMixin ._plot_recursion (d , _match_tree )
@@ -51,6 +59,7 @@ def __rich_console__(self, console, options):
5159 kls = self .__class__ .__name__
5260 id_abbrv = getattr (self , 'id' )[:7 ]
5361 yield f":page_facing_up: [b]{ kls } " f"[/b]: [cyan]{ id_abbrv } ...[cyan]"
62+
5463 from collections .abc import Iterable
5564
5665 import torch
@@ -64,7 +73,7 @@ def __rich_console__(self, console, options):
6473 for k , v in self .__dict__ .items ():
6574 col_1 , col_2 = '' , ''
6675
67- if k .startswith ('_' ) or isinstance ( v , ID ) or v is None :
76+ if isinstance ( v , ID ) or k .startswith ('_' ) or v is None :
6877 continue
6978 elif isinstance (v , str ):
7079 col_1 = f'{ k } : { v .__class__ .__name__ } '
@@ -73,7 +82,14 @@ def __rich_console__(self, console, options):
7382 col_2 += f' ... (length: { len (v )} )'
7483 elif isinstance (v , np .ndarray ) or isinstance (v , torch .Tensor ):
7584 col_1 = f'{ k } : { v .__class__ .__name__ } '
76- col_2 = f'{ type (v )} in shape { v .shape } , dtype: { v .dtype } '
85+
86+ if isinstance (v , torch .Tensor ):
87+ v = v .detach ().cou ().numpy ()
88+ if v .squeeze ().ndim == 1 and len (v ) < 1000 :
89+ col_2 = ColorBoxArray (v .squeeze ())
90+ else :
91+ col_2 = f'{ type (v )} of shape { v .shape } , dtype: { v .dtype } '
92+
7793 elif isinstance (v , tuple ) or isinstance (v , list ):
7894 col_1 = f'{ k } : { v .__class__ .__name__ } '
7995 for i , x in enumerate (v ):
@@ -88,7 +104,31 @@ def __rich_console__(self, console, options):
88104 else :
89105 continue
90106
91- my_table .add_row (col_1 , text .Text (col_2 ))
107+ if not isinstance (col_2 , ColorBoxArray ):
108+ col_2 = text .Text (col_2 )
109+ my_table .add_row (col_1 , col_2 )
92110
93111 if my_table .rows :
94112 yield my_table
113+
114+
115+ class ColorBoxArray :
116+ def __init__ (self , array ):
117+ self ._array = minmax_normalize (array , (0 , 5 ))
118+
119+ def __rich_console__ (
120+ self , console : Console , options : ConsoleOptions
121+ ) -> RenderResult :
122+ h = 0.75
123+ for idx , y in enumerate (self ._array ):
124+ lightness = 0.1 + ((y / 5 ) * 0.7 )
125+ r , g , b = colorsys .hls_to_rgb (h , lightness + 0.7 / 10 , 1.0 )
126+ color = Color .from_rgb (r * 255 , g * 255 , b * 255 )
127+ yield Segment ('▄' , Style (color = color , bgcolor = color ))
128+ if idx != 0 and idx % options .max_width == 0 :
129+ yield Segment .line ()
130+
131+ def __rich_measure__ (
132+ self , console : "Console" , options : ConsoleOptions
133+ ) -> Measurement :
134+ return Measurement (1 , options .max_width )
0 commit comments