Skip to content

Commit 7792477

Browse files
authored
test: enable evaluate mixin tests for backends (#174)
1 parent 4539a2c commit 7792477

File tree

1 file changed

+98
-20
lines changed

1 file changed

+98
-20
lines changed

tests/unit/array/mixins/test_eval_class.py

Lines changed: 98 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from docarray import DocumentArray, Document
77

88

9+
@pytest.mark.parametrize(
10+
'storage, config',
11+
[
12+
('memory', {}),
13+
('weaviate', {}),
14+
('sqlite', {}),
15+
('annlite', {'n_dim': 256}),
16+
('qdrant', {'n_dim': 256}),
17+
],
18+
)
919
@pytest.mark.parametrize(
1020
'metric_fn, kwargs',
1121
[
@@ -19,17 +29,28 @@
1929
('ndcg_at_k', {}),
2030
],
2131
)
22-
def test_eval_mixin_perfect_match(metric_fn, kwargs):
23-
da = DocumentArray.empty(10)
24-
da.embeddings = np.random.random([10, 256])
25-
da.match(da, exclude_self=True)
26-
r = da.evaluate(da, metric=metric_fn, **kwargs)
32+
def test_eval_mixin_perfect_match(metric_fn, kwargs, storage, config, start_storage):
33+
da1 = DocumentArray.empty(10)
34+
da1.embeddings = np.random.random([10, 256])
35+
da1_index = DocumentArray(da1, storage=storage, config=config)
36+
da1.match(da1_index, exclude_self=True)
37+
r = da1.evaluate(da1, metric=metric_fn, strict=False, **kwargs)
2738
assert isinstance(r, float)
2839
assert r == 1.0
29-
for d in da:
40+
for d in da1:
3041
assert d.evaluations[metric_fn].value == 1.0
3142

3243

44+
@pytest.mark.parametrize(
45+
'storage, config',
46+
[
47+
('memory', {}),
48+
('weaviate', {}),
49+
('sqlite', {}),
50+
('annlite', {'n_dim': 256}),
51+
('qdrant', {'n_dim': 256}),
52+
],
53+
)
3354
@pytest.mark.parametrize(
3455
'metric_fn, kwargs',
3556
[
@@ -43,14 +64,16 @@ def test_eval_mixin_perfect_match(metric_fn, kwargs):
4364
('ndcg_at_k', {}),
4465
],
4566
)
46-
def test_eval_mixin_zero_match(metric_fn, kwargs):
67+
def test_eval_mixin_zero_match(storage, config, metric_fn, kwargs):
4768
da1 = DocumentArray.empty(10)
4869
da1.embeddings = np.random.random([10, 256])
49-
da1.match(da1, exclude_self=True)
70+
da1_index = DocumentArray(da1, storage=storage, config=config)
71+
da1.match(da1_index, exclude_self=True)
5072

5173
da2 = copy.deepcopy(da1)
5274
da2.embeddings = np.random.random([10, 256])
53-
da2.match(da2, exclude_self=True)
75+
da2_index = DocumentArray(da2, storage=storage, config=config)
76+
da2.match(da2_index, exclude_self=True)
5477

5578
r = da1.evaluate(da2, metric=metric_fn, **kwargs)
5679
assert isinstance(r, float)
@@ -60,27 +83,59 @@ def test_eval_mixin_zero_match(metric_fn, kwargs):
6083
assert d.evaluations[metric_fn].value == 1.0
6184

6285

63-
def test_diff_len_should_raise():
86+
@pytest.mark.parametrize(
87+
'storage, config',
88+
[
89+
('memory', {}),
90+
('weaviate', {}),
91+
('sqlite', {}),
92+
('annlite', {'n_dim': 256}),
93+
('qdrant', {'n_dim': 256}),
94+
],
95+
)
96+
def test_diff_len_should_raise(storage, config):
6497
da1 = DocumentArray.empty(10)
65-
da2 = DocumentArray.empty(5)
98+
da2 = DocumentArray.empty(5, storage=storage, config=config)
6699
with pytest.raises(ValueError):
67100
da1.evaluate(da2, metric='precision_at_k')
68101

69102

70-
def test_diff_hash_fun_should_raise():
103+
@pytest.mark.parametrize(
104+
'storage, config',
105+
[
106+
('memory', {}),
107+
('weaviate', {}),
108+
('sqlite', {}),
109+
('annlite', {'n_dim': 256}),
110+
('qdrant', {'n_dim': 256}),
111+
],
112+
)
113+
def test_diff_hash_fun_should_raise(storage, config):
71114
da1 = DocumentArray.empty(10)
72-
da2 = DocumentArray.empty(10)
115+
da2 = DocumentArray.empty(10, storage=storage, config=config)
73116
with pytest.raises(ValueError):
74117
da1.evaluate(da2, metric='precision_at_k')
75118

76119

77-
def test_same_hash_same_len_fun_should_work():
120+
@pytest.mark.parametrize(
121+
'storage, config',
122+
[
123+
('memory', {}),
124+
('weaviate', {}),
125+
('sqlite', {}),
126+
('annlite', {'n_dim': 3}),
127+
('qdrant', {'n_dim': 3}),
128+
],
129+
)
130+
def test_same_hash_same_len_fun_should_work(storage, config):
78131
da1 = DocumentArray.empty(10)
79132
da1.embeddings = np.random.random([10, 3])
80-
da1.match(da1)
133+
da1_index = DocumentArray(da1, storage=storage, config=config)
134+
da1.match(da1_index)
81135
da2 = DocumentArray.empty(10)
82136
da2.embeddings = np.random.random([10, 3])
83-
da2.match(da2)
137+
da2_index = DocumentArray(da1, storage=storage, config=config)
138+
da2.match(da2_index)
84139
with pytest.raises(ValueError):
85140
da1.evaluate(da2, metric='precision_at_k')
86141
for d1, d2 in zip(da1, da2):
@@ -89,11 +144,22 @@ def test_same_hash_same_len_fun_should_work():
89144
da1.evaluate(da2, metric='precision_at_k')
90145

91146

92-
def test_adding_noise():
147+
@pytest.mark.parametrize(
148+
'storage, config',
149+
[
150+
('memory', {}),
151+
('weaviate', {}),
152+
('sqlite', {}),
153+
('annlite', {'n_dim': 3}),
154+
('qdrant', {'n_dim': 3}),
155+
],
156+
)
157+
def test_adding_noise(storage, config):
93158
da = DocumentArray.empty(10)
94159

95160
da.embeddings = np.random.random([10, 3])
96-
da.match(da, exclude_self=True)
161+
da_index = DocumentArray(da, storage=storage, config=config)
162+
da.match(da_index, exclude_self=True)
97163

98164
da2 = copy.deepcopy(da)
99165

@@ -107,21 +173,33 @@ def test_adding_noise():
107173
assert 0.0 < d.evaluations['precision_at_k'].value < 1.0
108174

109175

176+
@pytest.mark.parametrize(
177+
'storage, config',
178+
[
179+
('memory', {}),
180+
('weaviate', {}),
181+
('sqlite', {}),
182+
('annlite', {'n_dim': 128}),
183+
('qdrant', {'n_dim': 128}),
184+
],
185+
)
110186
@pytest.mark.parametrize(
111187
'metric_fn, kwargs',
112188
[
113189
('recall_at_k', {}),
114190
('f1_score_at_k', {}),
115191
],
116192
)
117-
def test_diff_match_len_in_gd(metric_fn, kwargs):
193+
def test_diff_match_len_in_gd(storage, config, metric_fn, kwargs):
118194
da1 = DocumentArray.empty(10)
119195
da1.embeddings = np.random.random([10, 128])
196+
da1_index = DocumentArray(da1, storage=storage, config=config)
120197
da1.match(da1, exclude_self=True)
121198

122199
da2 = copy.deepcopy(da1)
123200
da2.embeddings = np.random.random([10, 128])
124-
da2.match(da2, exclude_self=True)
201+
da2_index = DocumentArray(da2, storage=storage, config=config)
202+
da2.match(da2_index, exclude_self=True)
125203
# pop some matches from first document
126204
da2[0].matches.pop(8)
127205

0 commit comments

Comments
 (0)