Skip to content

Commit 181b5a9

Browse files
author
anna-charlotte
committed
fix: move minmaxnormalize to comp backend
Signed-off-by: anna-charlotte <[email protected]>
1 parent bf24a1e commit 181b5a9

File tree

8 files changed

+151
-20
lines changed

8 files changed

+151
-20
lines changed

docarray/base_document/mixins/plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing_inspect import is_optional_type, is_union_type
88

99
from docarray.base_document.abstract_document import AbstractDocument
10-
from docarray.math.helper import minmax_normalize
1110
from docarray.typing import ID
11+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
1212

1313
if TYPE_CHECKING:
1414
from rich.console import Console, ConsoleOptions, RenderResult
@@ -176,8 +176,8 @@ class ColorBoxArray:
176176
Rich representation of an array as coloured blocks.
177177
"""
178178

179-
def __init__(self, array):
180-
self._array = minmax_normalize(array, (0, 5))
179+
def __init__(self, array: AbstractTensor):
180+
self._array = array.get_comp_backend().minmax_normalize(array, (0, 5))
181181

182182
def __rich_console__(
183183
self, console: 'Console', options: 'ConsoleOptions'

docarray/computation/abstract_comp_backend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,33 @@ def reshape(tensor: 'TTensor', shape: Tuple[int, ...]) -> 'TTensor':
8585
"""
8686
...
8787

88+
@staticmethod
89+
@abstractmethod
90+
def minmax_normalize(
91+
tensor: 'TTensor',
92+
t_range: Tuple = (0, 1),
93+
x_range: Optional[Tuple] = None,
94+
eps: float = 1e-7,
95+
):
96+
"""
97+
Normalize values in `tensor` into `t_range`.
98+
99+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
100+
normalization is row-based.
101+
102+
.. note::
103+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
104+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
105+
of the data to 0.
106+
107+
:param tensor: the data to be normalized
108+
:param t_range: a tuple represents the target range.
109+
:param x_range: a tuple represents tensors range.
110+
:param eps: a small jitter to avoid divide by zero
111+
:return: normalized data in `t_range`
112+
"""
113+
...
114+
88115
class Retrieval(ABC, typing.Generic[TTensorRetrieval]):
89116
"""
90117
Abstract class for retrieval and ranking functionalities

docarray/computation/numpy_backend.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,38 @@ def reshape(array: 'np.ndarray', shape: Tuple[int, ...]) -> 'np.ndarray':
8585
"""
8686
return array.reshape(shape)
8787

88+
@staticmethod
89+
def minmax_normalize(
90+
tensor: 'np.ndarray',
91+
t_range: Tuple = (0, 1),
92+
x_range: Optional[Tuple] = None,
93+
eps: float = 1e-7,
94+
):
95+
"""
96+
Normalize values in `tensor` into `t_range`.
97+
98+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
99+
normalization is row-based.
100+
101+
.. note::
102+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
103+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
104+
of the data to 0.
105+
106+
:param tensor: the data to be normalized
107+
:param t_range: a tuple represents the target range.
108+
:param x_range: a tuple represents tensors range.
109+
:param eps: a small jitter to avoid divide by zero
110+
:return: normalized data in `t_range`
111+
"""
112+
a, b = t_range
113+
114+
min_d = x_range[0] if x_range else np.min(tensor, axis=-1, keepdims=True)
115+
max_d = x_range[1] if x_range else np.max(tensor, axis=-1, keepdims=True)
116+
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a
117+
118+
return np.clip(r, *((a, b) if a < b else (b, a)))
119+
88120
class Retrieval(AbstractComputationalBackend.Retrieval[np.ndarray]):
89121
"""
90122
Abstract class for retrieval and ranking functionalities

docarray/computation/torch_backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,45 @@ def reshape(tensor: 'torch.Tensor', shape: Tuple[int, ...]) -> 'torch.Tensor':
8989
"""
9090
return tensor.reshape(shape)
9191

92+
@staticmethod
93+
def minmax_normalize(
94+
tensor: 'torch.Tensor',
95+
t_range: Tuple = (0, 1),
96+
x_range: Optional[Tuple] = None,
97+
eps: float = 1e-7,
98+
):
99+
"""
100+
Normalize values in `tensor` into `t_range`.
101+
102+
`tensor` can be a 1D array or a 2D array. When `tensor` is a 2D array, then
103+
normalization is row-based.
104+
105+
.. note::
106+
- with `t_range=(0, 1)` will normalize the min-value of data to 0, max to 1;
107+
- with `t_range=(1, 0)` will normalize the min-value of data to 1, max value
108+
of the data to 0.
109+
110+
:param tensor: the data to be normalized
111+
:param t_range: a tuple represents the target range.
112+
:param x_range: a tuple represents tensors range.
113+
:param eps: a small jitter to avoid divide by zero
114+
:return: normalized data in `t_range`
115+
"""
116+
a, b = t_range
117+
118+
min_d = (
119+
x_range[0] if x_range else torch.min(tensor, dim=-1, keepdim=True).values
120+
)
121+
max_d = (
122+
x_range[1] if x_range else torch.max(tensor, dim=-1, keepdim=True).values
123+
)
124+
r = (b - a) * (tensor - min_d) / (max_d - min_d + eps) + a
125+
126+
dtype = tensor.dtype
127+
x = torch.clip(r, *((a, b) if a < b else (b, a)))
128+
z = x.to(dtype)
129+
return z
130+
92131
class Retrieval(AbstractComputationalBackend.Retrieval[torch.Tensor]):
93132
"""
94133
Abstract class for retrieval and ranking functionalities

tests/units/computation_backends/numpy_backend/test_basics.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,23 @@ def test_empty_dtype():
5050
def test_empty_device():
5151
with pytest.raises(NotImplementedError):
5252
NumpyCompBackend.empty((10, 3), device='meta')
53+
54+
55+
@pytest.mark.parametrize(
56+
'array,t_range,x_range,result',
57+
[
58+
(np.array([0, 1, 2, 3, 4, 5]), (0, 10), None, np.array([0, 2, 4, 6, 8, 10])),
59+
(np.array([0, 1, 2, 3, 4, 5]), (0, 10), (0, 10), np.array([0, 1, 2, 3, 4, 5])),
60+
(
61+
np.array([[0.0, 1.0], [0.0, 1.0]]),
62+
(0, 10),
63+
None,
64+
np.array([[0.0, 10.0], [0.0, 10.0]]),
65+
),
66+
],
67+
)
68+
def test_minmax_normalize(array, t_range, x_range, result):
69+
output = NumpyCompBackend.minmax_normalize(
70+
tensor=array, t_range=t_range, x_range=x_range
71+
)
72+
assert np.allclose(output, result)

tests/units/computation_backends/torch_backend/test_basics.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,33 @@ def test_empty_device():
5353
tensor = TorchCompBackend.empty((10, 3), device='meta')
5454
assert tensor.shape == (10, 3)
5555
assert tensor.device == torch.device('meta')
56+
57+
58+
@pytest.mark.parametrize(
59+
'array,t_range,x_range,result',
60+
[
61+
(
62+
torch.tensor([0, 1, 2, 3, 4, 5]),
63+
(0, 10),
64+
None,
65+
torch.tensor([0, 2, 4, 6, 8, 10]),
66+
),
67+
(
68+
torch.tensor([0, 1, 2, 3, 4, 5]),
69+
(0, 10),
70+
(0, 10),
71+
torch.tensor([0, 1, 2, 3, 4, 5]),
72+
),
73+
(
74+
torch.tensor([[0.0, 1.0], [0.0, 1.0]]),
75+
(0, 10),
76+
None,
77+
torch.tensor([[0.0, 10.0], [0.0, 10.0]]),
78+
),
79+
],
80+
)
81+
def test_minmax_normalize(array, t_range, x_range, result):
82+
output = TorchCompBackend.minmax_normalize(
83+
tensor=array, t_range=t_range, x_range=x_range
84+
)
85+
assert torch.allclose(output, result)

tests/units/math/__init__.py

Whitespace-only changes.

tests/units/math/test_helper.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)