Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
6e9817c
feat: add rich display for doc and da
Jan 17, 2023
2941269
fix: wip plot
Jan 18, 2023
5949e7c
fix: wip plot
Jan 18, 2023
05fa0fa
fix: wip plot
Jan 19, 2023
718fe52
feat: add math package and minmax normalize
Jan 19, 2023
3669de1
fix: summary for document
Jan 19, 2023
c56e975
chore: update poetry lock after rebase
Jan 19, 2023
b0ba3f3
fix: move all from plotmixin to base document
Jan 19, 2023
bd8cf3b
feat: add docs schema summary
Jan 20, 2023
25be9cc
feat: add document array summary
Jan 20, 2023
b7a915b
fix: display doc within doc
Jan 20, 2023
c6ee8ec
fix: in notebook print docs summary
Jan 20, 2023
d45988a
fix: move summary from da to abstract da
Jan 23, 2023
40c8eea
fix: get schema for doc
Jan 23, 2023
3bdb9d0
fix: wip doc summary
Jan 23, 2023
ea12600
fix: wip clean up
Jan 23, 2023
9321c0b
test: add test for da pretty print
Jan 23, 2023
189c33c
docs: update note
Jan 23, 2023
93046af
docs: add some documentation
Jan 23, 2023
fc0deec
fix: apply samis suggestion
Jan 23, 2023
c8f3849
fix: mypy checks
Jan 23, 2023
15b94fc
fix: move to plot mixin
Jan 23, 2023
58229aa
fix: remove redundant line
Jan 24, 2023
e55ba3b
fix: remove comments
Jan 24, 2023
147742d
feat: add schema highlighter
Jan 24, 2023
59bd3a6
fix: add plotmixin to mixin init
Jan 24, 2023
fd26a43
fix: adjust da summary
Jan 24, 2023
675b5c5
fix: move minmaxnormalize to comp backend
Jan 24, 2023
a375d19
fix: remove redundant lines
Jan 24, 2023
c3b44bd
fix: add squeeze and detach to comp backend
Jan 24, 2023
0d5653c
fix: apply suggestion from code review
Jan 24, 2023
6d479ab
refactor: rename iterable attrs
Jan 24, 2023
a1c4678
fix: clean up
Jan 24, 2023
3aac1c9
fix: import
Jan 24, 2023
eb75060
fix: iterate over fields instead of annotations
Jan 24, 2023
3cc1b55
fix: remove math package since moved to comp backends
Jan 24, 2023
ab585eb
refactor: use single quotes
Jan 24, 2023
b838ec9
fix: apply suggestions from code review
Jan 24, 2023
c56aa6e
fix: extract summary to doc summary class
Jan 24, 2023
b2b5bdd
fix: add pretty print for base document
Jan 25, 2023
7aa7e58
fix: use rich capture instead of string io
Jan 25, 2023
2ae8d6a
fix: add colors for optional and union and use only single quotes
Jan 25, 2023
0b881b1
fix: extract display classes to display package
Jan 25, 2023
6ba4eff
fix: make da not optional in da summary
Jan 25, 2023
a70142a
fix: set _console instead of initializing new one everytime in __str__
Jan 25, 2023
2a6bd5c
fix: put console at module level
Jan 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: move minmaxnormalize to comp backend
Signed-off-by: anna-charlotte <[email protected]>
  • Loading branch information
anna-charlotte committed Jan 25, 2023
commit 675b5c5f5038441520a3c929c1506ad826402743
6 changes: 3 additions & 3 deletions docarray/base_document/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing_inspect import is_optional_type, is_union_type

from docarray.base_document.abstract_document import AbstractDocument
from docarray.math.helper import minmax_normalize
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor

if TYPE_CHECKING:
from rich.console import Console, ConsoleOptions, RenderResult
Expand Down Expand Up @@ -176,8 +176,8 @@ class ColorBoxArray:
Rich representation of an array as coloured blocks.
"""

def __init__(self, array):
self._array = minmax_normalize(array, (0, 5))
def __init__(self, array: AbstractTensor):
self._array = array.get_comp_backend().minmax_normalize(array, (0, 5))

def __rich_console__(
self, console: 'Console', options: 'ConsoleOptions'
Expand Down
27 changes: 27 additions & 0 deletions docarray/computation/abstract_comp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,33 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor':
"""
...

@staticmethod
@abstractmethod
def minmax_normalize(
tensor: 'TTensor',
t_range: Tuple = (0, 1),
x_range: Optional[Tuple] = None,
eps: float = 1e-7,
):
"""
Normalize values in `tensor` into `t_range`.

`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
normalization is row-based.

.. note::
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
of the data to 0.

:param tensor: the data to be normalized
:param t_range: a tuple represents the target range.
:param x_range: a tuple represents tensors range.
:param eps: a small jitter to avoid divide by zero
:return: normalized data in `t_range`
"""
...

class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
32 changes: 32 additions & 0 deletions docarray/computation/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,38 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray':
"""
return array.reshape(shape)

@staticmethod
def minmax_normalize(
tensor: 'np.ndarray',
t_range: Tuple = (0, 1),
x_range: Optional[Tuple] = None,
eps: float = 1e-7,
):
"""
Normalize values in `tensor` into `t_range`.

`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
normalization is row-based.

.. note::
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
of the data to 0.

:param tensor: the data to be normalized
:param t_range: a tuple represents the target range.
:param x_range: a tuple represents tensors range.
:param eps: a small jitter to avoid divide by zero
:return: normalized data in `t_range`
"""
a, b = t_range

min_d = x_range[0] if x_range else np.min(tensor, axis=-1, keepdims=True)
max_d = x_range[1] if x_range else np.max(tensor, axis=-1, keepdims=True)
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a

return np.clip(r, *((a, b) if a < b else (b, a)))

class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
39 changes: 39 additions & 0 deletions docarray/computation/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,45 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
"""
return tensor.reshape(shape)

@staticmethod
def minmax_normalize(
tensor: 'torch.Tensor',
t_range: Tuple = (0, 1),
x_range: Optional[Tuple] = None,
eps: float = 1e-7,
):
"""
Normalize values in `tensor` into `t_range`.

`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
normalization is row-based.

.. note::
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
of the data to 0.

:param tensor: the data to be normalized
:param t_range: a tuple represents the target range.
:param x_range: a tuple represents tensors range.
:param eps: a small jitter to avoid divide by zero
:return: normalized data in `t_range`
"""
a, b = t_range

min_d = (
x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values
)
max_d = (
x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values
)
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a

dtype = tensor.dtype
x = torch.clip(r, *((a, b) if a < b else (b, a)))
z = x.to(dtype)
return z

class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]):
"""
Abstract class for retrieval and ranking functionalities
Expand Down
20 changes: 20 additions & 0 deletions tests/units/computation_backends/numpy_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,23 @@ def test_empty_dtype():
def test_empty_device():
with pytest.raises(NotImplementedError):
NumpyCompBackend.empty((10, 3), device='meta')


@pytest.mark.parametrize(
'array,t_range,x_range,result',
[
(np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])),
(np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])),
(
np.array([[0.0, 1.0], [0.0, 1.0]]),
(0, 10),
None,
np.array([[0.0, 10.0], [0.0, 10.0]]),
),
],
)
def test_minmax_normalize(array, t_range, x_range, result):
output = NumpyCompBackend.minmax_normalize(
tensor=array, t_range=t_range, x_range=x_range
)
assert np.allclose(output, result)
30 changes: 30 additions & 0 deletions tests/units/computation_backends/torch_backend/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,33 @@ def test_empty_device():
tensor = TorchCompBackend.empty((10, 3), device='meta')
assert tensor.shape == (10, 3)
assert tensor.device == torch.device('meta')


@pytest.mark.parametrize(
'array,t_range,x_range,result',
[
(
torch.tensor([0, 1, 2, 3, 4, 5]),
(0, 10),
None,
torch.tensor([0, 2, 4, 6, 8, 10]),
),
(
torch.tensor([0, 1, 2, 3, 4, 5]),
(0, 10),
(0, 10),
torch.tensor([0, 1, 2, 3, 4, 5]),
),
(
torch.tensor([[0.0, 1.0], [0.0, 1.0]]),
(0, 10),
None,
torch.tensor([[0.0, 10.0], [0.0, 10.0]]),
),
],
)
def test_minmax_normalize(array, t_range, x_range, result):
output = TorchCompBackend.minmax_normalize(
tensor=array, t_range=t_range, x_range=x_range
)
assert torch.allclose(output, result)
Empty file removed tests/units/math/__init__.py
Empty file.
17 changes: 0 additions & 17 deletions tests/units/math/test_helper.py

This file was deleted.