Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6e9817c
feat: add rich display for doc and da
Jan 17, 2023
2941269
fix: wip plot
Jan 18, 2023
5949e7c
fix: wip plot
Jan 18, 2023
05fa0fa
fix: wip plot
Jan 19, 2023
718fe52
feat: add math package and minmax normalize
Jan 19, 2023
3669de1
fix: summary for document
Jan 19, 2023
c56e975
chore: update poetry lock after rebase
Jan 19, 2023
b0ba3f3
fix: move all from plotmixin to base document
Jan 19, 2023
bd8cf3b
feat: add docs schema summary
Jan 20, 2023
25be9cc
feat: add document array summary
Jan 20, 2023
b7a915b
fix: display doc within doc
Jan 20, 2023
c6ee8ec
fix: in notebook print docs summary
Jan 20, 2023
d45988a
fix: move summary from da to abstract da
Jan 23, 2023
40c8eea
fix: get schema for doc
Jan 23, 2023
3bdb9d0
fix: wip doc summary
Jan 23, 2023
ea12600
fix: wip clean up
Jan 23, 2023
9321c0b
test: add test for da pretty print
Jan 23, 2023
189c33c
docs: update note
Jan 23, 2023
93046af
docs: add some documentation
Jan 23, 2023
fc0deec
fix: apply samis suggestion
Jan 23, 2023
c8f3849
fix: mypy checks
Jan 23, 2023
15b94fc
fix: move to plot mixin
Jan 23, 2023
58229aa
fix: remove redundant line
Jan 24, 2023
e55ba3b
fix: remove comments
Jan 24, 2023
147742d
feat: add schema highlighter
Jan 24, 2023
59bd3a6
fix: add plotmixin to mixin init
Jan 24, 2023
fd26a43
fix: adjust da summary
Jan 24, 2023
675b5c5
fix: move minmaxnormalize to comp backend
Jan 24, 2023
a375d19
fix: remove redundant lines
Jan 24, 2023
c3b44bd
fix: add squeeze and detach to comp backend
Jan 24, 2023
0d5653c
fix: apply suggestion from code review
Jan 24, 2023
6d479ab
refactor: rename iterable attrs
Jan 24, 2023
a1c4678
fix: clean up
Jan 24, 2023
3aac1c9
fix: import
Jan 24, 2023
eb75060
fix: iterate over fields instead of annotations
Jan 24, 2023
3cc1b55
fix: remove math package since moved to comp backends
Jan 24, 2023
ab585eb
refactor: use single quotes
Jan 24, 2023
b838ec9
fix: apply suggestions from code review
Jan 24, 2023
c56aa6e
fix: extract summary to doc summary class
Jan 24, 2023
b2b5bdd
fix: add pretty print for base document
Jan 25, 2023
7aa7e58
fix: use rich capture instead of string io
Jan 25, 2023
2ae8d6a
fix: add colors for optional and union and use only single quotes
Jan 25, 2023
0b881b1
fix: extract display classes to display package
Jan 25, 2023
6ba4eff
fix: make da not optional in da summary
Jan 25, 2023
a70142a
fix: set _console instead of initializing new one everytime in __str__
Jan 25, 2023
2a6bd5c
fix: put console at module level
Jan 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: add squeeze and detach to comp backend
Signed-off-by: anna-charlotte <[email protected]>
  • Loading branch information
anna-charlotte committed Jan 25, 2023
commit c3b44bd4da85285cb4b15bef4431a79507fa2d72
17 changes: 8 additions & 9 deletions docarray/base_document/mixins/plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Any, Optional

import numpy as np
import rich
from rich.highlighter import RegexHighlighter
from rich.tree import Tree
Expand Down Expand Up @@ -84,7 +83,6 @@ def __rich_console__(self, console, options):
id_abbrv = getattr(self, 'id')[:7]
yield f":page_facing_up: [b]{kls}" f"[/b]: [cyan]{id_abbrv} ...[cyan]"

import torch
from rich import box, text
from rich.table import Table

Expand All @@ -105,15 +103,15 @@ def __rich_console__(self, console, options):
if len(v) > 50:
col_2 += f' ... (length: {len(v)})'
table.add_row(col_1, text.Text(col_2))
elif isinstance(v, (np.ndarray, torch.Tensor)):
if isinstance(v, torch.Tensor):
v = v.detach().cpu().numpy()
if v.squeeze().ndim == 1 and len(v) < 200:
table.add_row(col_1, ColorBoxArray(v.squeeze()))
elif isinstance(v, AbstractTensor):
comp = v.get_comp_backend()
v_squeezed = comp.squeeze(comp.detach(v))
if comp.n_dim(v_squeezed) == 1 and comp.shape(v_squeezed)[0] < 200:
table.add_row(col_1, ColorBoxArray(v_squeezed))
else:
table.add_row(
col_1,
text.Text(f'{type(v)} of shape {v.shape}, dtype: {v.dtype}'),
text.Text(f'{type(v)} of shape {comp.shape(v)}'),
)
elif isinstance(v, (tuple, list)):
col_2 = ''
Expand Down Expand Up @@ -177,7 +175,8 @@ class ColorBoxArray:
"""

def __init__(self, array: AbstractTensor):
self._array = array.get_comp_backend().minmax_normalize(array, (0, 5))
comp_be = array.get_comp_backend()
self._array = comp_be.minmax_normalize(comp_be.detach(array), (0, 5))

def __rich_console__(
self, console: 'Console', options: 'ConsoleOptions'
Expand Down
19 changes: 19 additions & 0 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def n_dim(array: 'TTensor') -> int:
"""
...

@staticmethod
@abstractmethod
def squeeze(tensor: 'TTensor') -> 'TTensor':
"""
Returns a tensor with all the dimensions of tensor of size 1 removed.
"""
...

@staticmethod
@abstractmethod
def to_numpy(array: 'TTensor') -> 'np.ndarray':
Expand Down Expand Up @@ -85,6 +93,17 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor':
"""
...

@staticmethod
@abstractmethod
def detach(tensor: 'TTensor') -> 'TTensor':
"""
Returns the tensor detached from its current graph.

:param tensor: tensor to be detached
:return: a detached tensor with the same data.
"""
...

@staticmethod
@abstractmethod
def minmax_normalize(
Expand Down
17 changes: 17 additions & 0 deletions docarray/computation/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def to_device(tensor: 'np.ndarray', device: str) -> 'np.ndarray':
def n_dim(array: 'np.ndarray') -> int:
return array.ndim

@staticmethod
def squeeze(tensor: 'np.ndarray') -> 'np.ndarray':
"""
Returns a tensor with all the dimensions of tensor of size 1 removed.
"""
return tensor.squeeze()

@staticmethod
def to_numpy(array: 'np.ndarray') -> 'np.ndarray':
return array
Expand Down Expand Up @@ -85,6 +92,16 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray':
"""
return array.reshape(shape)

@staticmethod
def detach(tensor: 'np.ndarray') -> 'np.ndarray':
"""
Returns the tensor detached from its current graph.

:param tensor: tensor to be detached
:return: a detached tensor with the same data.
"""
return tensor

@staticmethod
def minmax_normalize(
tensor: 'np.ndarray',
Expand Down
17 changes: 17 additions & 0 deletions docarray/computation/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def empty(
def n_dim(array: 'torch.Tensor') -> int:
return array.ndim

@staticmethod
def squeeze(tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
Returns a tensor with all the dimensions of tensor of size 1 removed.
"""
return torch.squeeze(tensor)

@staticmethod
def to_numpy(array: 'torch.Tensor') -> 'np.ndarray':
return array.cpu().detach().numpy()
Expand All @@ -89,6 +96,16 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
"""
return tensor.reshape(shape)

@staticmethod
def detach(tensor: 'torch.Tensor') -> 'torch.Tensor':
"""
Returns the tensor detached from its current graph.

:param tensor: tensor to be detached
:return: a detached tensor with the same data.
"""
return tensor.detach()

@staticmethod
def minmax_normalize(
tensor: 'torch.Tensor',
Expand Down
6 changes: 6 additions & 0 deletions tests/units/computation_backends/numpy_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def test_empty_device():
NumpyCompBackend.empty((10, 3), device='meta')


def test_squeeze():
tensor = np.zeros(shape=(1, 1, 3, 1))
squeezed = NumpyCompBackend.squeeze(tensor)
assert squeezed.shape == (3,)


@pytest.mark.parametrize(
'array,t_range,x_range,result',
[
Expand Down
6 changes: 6 additions & 0 deletions tests/units/computation_backends/torch_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def test_empty_device():
assert tensor.device == torch.device('meta')


def test_squeeze():
tensor = torch.zeros(size=(1, 1, 3, 1))
squeezed = TorchCompBackend.squeeze(tensor)
assert squeezed.shape == (3,)


@pytest.mark.parametrize(
'array,t_range,x_range,result',
[
Expand Down