Skip to content

Commit 5c1fb55

Browse files
author
Joan Fontanals
authored
fix: fix match score (#32)
1 parent 3a906db commit 5c1fb55

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

docarray/array/mixins/match.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def match(
121121
if only_id:
122122
d = Document(id=rhv[_id].id)
123123
else:
124-
d = rhv[int(_id)] # type: Document
124+
d = Document(rhv[int(_id)], copy=True) # type: Document
125125

126126
if d.id in lhv:
127127
d = Document(

tests/unit/array/mixins/test_match.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,34 @@ def test_diff_framework_match(ndarray_val):
515515
da = DocumentArray.empty(10)
516516
da.embeddings = ndarray_val
517517
da.match(da)
518+
519+
520+
def test_match_ensure_scores_unique():
521+
import numpy as np
522+
from docarray import DocumentArray
523+
524+
da1 = DocumentArray.empty(4)
525+
da1.embeddings = np.array(
526+
[[0, 0, 0, 0, 1], [1, 0, 0, 0, 0], [1, 1, 1, 1, 0], [1, 2, 2, 1, 0]]
527+
)
528+
529+
da2 = DocumentArray.empty(5)
530+
da2.embeddings = np.array(
531+
[
532+
[0.0, 0.1, 0.0, 0.0, 0.0],
533+
[1.0, 0.1, 0.0, 0.0, 0.0],
534+
[1.0, 1.2, 1.0, 1.0, 0.0],
535+
[1.0, 2.2, 2.0, 1.0, 0.0],
536+
[4.0, 5.2, 2.0, 1.0, 0.0],
537+
]
538+
)
539+
540+
da1.match(da2, metric='euclidean', only_id=False, limit=5)
541+
542+
assert len(da1) == 4
543+
for query in da1:
544+
previous_score = -10000
545+
assert len(query.matches) == 5
546+
for m in query.matches:
547+
assert m.scores['euclidean'].value >= previous_score
548+
previous_score = m.scores['euclidean'].value

0 commit comments

Comments
 (0)