Skip to content

Commit ea98df6

Browse files
azayzalaeddine-13
andauthored
test: test when device is cuda (docarray#293)
Co-authored-by: AlaeddineAbdessalem <[email protected]>
1 parent 5cf5462 commit ea98df6

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
import paddle
3+
import pytest
4+
import tensorflow as tf
5+
import torch
6+
7+
from docarray.math.distance import cdist, pdist
8+
9+
10+
def test_pdist():
11+
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
12+
np.testing.assert_almost_equal(
13+
pdist(tensor, 'cosine'),
14+
cdist(tensor, tensor, 'cosine'),
15+
decimal=3,
16+
)
17+
18+
19+
def test_cdist_raise_error():
20+
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
21+
y = np.array([[1, 2, 3], [4, 5, 6]])
22+
with pytest.raises(ValueError):
23+
cdist(x, y, 'cosine')
24+
25+
26+
def test_not_supported_metric():
27+
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
28+
y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
29+
with pytest.raises(NotImplementedError):
30+
cdist(x, y, 'fake_metric')
31+
32+
33+
@pytest.mark.parametrize(
34+
'x_mat, y_mat, result',
35+
(
36+
(
37+
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
38+
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
39+
np.array([[0, 27], [27, 0]]),
40+
),
41+
(
42+
tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32),
43+
tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32),
44+
np.array([[0, 27], [27, 0]]),
45+
),
46+
(
47+
paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
48+
paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
49+
np.array([[0, 27], [27, 0]]),
50+
),
51+
),
52+
)
53+
def test_seqeuclidean(x_mat, y_mat, result):
54+
np.testing.assert_almost_equal(
55+
cdist(x_mat, y_mat, metric='sqeuclidean'), result, decimal=3
56+
)
57+
58+
59+
@pytest.mark.parametrize(
60+
'x_mat, y_mat, result',
61+
(
62+
(
63+
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
64+
torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
65+
np.array([[0, 5.196], [5.196, 0]]),
66+
),
67+
(
68+
tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32),
69+
tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32),
70+
np.array([[0, 5.196], [5.196, 0]]),
71+
),
72+
(
73+
paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
74+
paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype='float32'),
75+
np.array([[0, 5.196], [5.196, 0]]),
76+
),
77+
),
78+
)
79+
def test_euclidean(x_mat, y_mat, result):
80+
np.testing.assert_almost_equal(
81+
cdist(x_mat, y_mat, metric='euclidean'), result, decimal=3
82+
)

0 commit comments

Comments
 (0)