Skip to content

Commit e6a078d

Browse files
authored
fix(document): allow eq to include ndarray comparison
1 parent 1e9e945 commit e6a078d

File tree

4 files changed

+154
-2
lines changed

4 files changed

+154
-2
lines changed

docarray/document/data.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass, field, fields
55
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
66

7+
from ..math.ndarray import check_arraylike_equality
8+
79
if TYPE_CHECKING:
810
from ..score import NamedScore
911
from .. import DocumentArray, Document
@@ -32,7 +34,7 @@
3234
_all_mime_types = set(mimetypes.types_map.values())
3335

3436

35-
@dataclass(unsafe_hash=True)
37+
@dataclass(unsafe_hash=True, eq=False)
3638
class DocumentData:
3739
_reference_doc: 'Document' = field(hash=False, compare=False)
3840
id: str = field(
@@ -111,3 +113,44 @@ def _set_default_value_if_none(self, key):
111113
setattr(self, key, defaultdict(NamedScore))
112114
else:
113115
setattr(self, key, v() if callable(v) else v)
116+
117+
@staticmethod
118+
def _embedding_eq(array1: 'ArrayType', array2: 'ArrayType'):
119+
120+
if array1 is None and array2 is None:
121+
return True
122+
123+
if type(array1) == type(array2):
124+
return check_arraylike_equality(array1, array2)
125+
else:
126+
return False
127+
128+
@staticmethod
129+
def _tensor_eq(array1: 'ArrayType', array2: 'ArrayType'):
130+
DocumentData._embedding_eq(array1, array2)
131+
132+
def __eq__(self, other):
133+
134+
self_non_empty_fields = self._non_empty_fields
135+
other_non_empty_fields = other._non_empty_fields
136+
137+
if other_non_empty_fields != self_non_empty_fields:
138+
return False
139+
140+
for key in self_non_empty_fields:
141+
142+
if hasattr(self, f'_{key}_eq'):
143+
144+
if hasattr(DocumentData, f'_{key}_eq'):
145+
are_equal = getattr(DocumentData, f'_{key}_eq')(
146+
getattr(self, key), getattr(other, key)
147+
)
148+
print(
149+
f'are_equal( {getattr(self, key)}, { getattr(other, key)}) ---> {are_equal}'
150+
)
151+
if are_equal == False:
152+
return False
153+
else:
154+
if getattr(self, key) != getattr(other, key):
155+
return False
156+
return True

docarray/math/ndarray.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,69 @@ def get_array_rows(array: 'ArrayType') -> Tuple[int, int]:
192192
raise ValueError
193193

194194
return num_rows, ndim
195+
196+
197+
def check_arraylike_equality(x: 'ArrayType', y: 'ArrayType'):
198+
"""Check if two array type objects are the same with the supported frameworks.
199+
200+
Examples
201+
202+
>>> import numpy as np
203+
x = np.array([[1,2,0,0,3],[1,2,0,0,3]])
204+
check_arraylike_equality(x,x)
205+
True
206+
207+
>>> from scipy import sparse as sp
208+
x = sp.csr_matrix([[1,2,0,0,3],[1,2,0,0,3]])
209+
check_arraylike_equality(x,x)
210+
True
211+
212+
>>> import torch
213+
x = torch.tensor([1,2,3])
214+
check_arraylike_equality(x,x)
215+
True
216+
"""
217+
x_type, x_is_sparse = get_array_type(x)
218+
y_type, y_is_sparse = get_array_type(y)
219+
220+
same_array = False
221+
if x_type == y_type and x_is_sparse == y_is_sparse:
222+
223+
if x_type == 'python':
224+
same_array = x == y
225+
226+
if x_type == 'numpy':
227+
# Numpy does not support sparse tensors
228+
import numpy as np
229+
230+
same_array = np.array_equal(x, y)
231+
elif x_type == 'torch':
232+
import torch
233+
234+
if x_is_sparse:
235+
# torch.equal NotImplementedError for sparse
236+
same_array = all((x - y).coalesce().values() == 0)
237+
else:
238+
same_array = torch.equal(x, y)
239+
elif x_type == 'scipy':
240+
# Not implemented in scipy this should work for all types
241+
# Note: you can't simply look at nonzero values because they can be in
242+
# different positions.
243+
if x.shape != y.shape:
244+
same_array = False
245+
else:
246+
same_array = (x != y).nnz == 0
247+
elif x_type == 'tensorflow':
248+
if x_is_sparse:
249+
same_array = x == y
250+
else:
251+
# Does not have equal implemented, only elementwise, therefore reduce .all is needed
252+
same_array = (x == y).numpy().all()
253+
elif x_type == 'paddle':
254+
# Paddle does not support sparse tensor on 11/8/2021
255+
# https://github.com/PaddlePaddle/Paddle/issues/36697
256+
# Does not have equal implemented, only elementwise, therefore reduce .all is needed
257+
same_array = (x == y).numpy().all()
258+
return same_array
259+
else:
260+
return same_array

tests/unit/document/test_docdata.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,29 @@ def test_doc_hash_complicate_content():
3535
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
3636
d2 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
3737
assert d1 == d2
38+
assert d2 == d1
3839
assert hash(d1) == hash(d2)
3940

4041

42+
def test_doc_difference_complicate_content():
43+
# Here we ensure != is symmetric therefore we put d1 != d2 and d2 != d1
44+
# The __eq__ at DocumentData level is implemented in docarray/document/data.py
45+
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
46+
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
47+
assert d1 != d2
48+
assert d2 != d1
49+
50+
d1 = Document(text='hello', embedding=np.array([1, 2, 3, 5]), id=1)
51+
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
52+
assert d1 != d2
53+
assert d2 != d1
54+
55+
d1 = Document(text='hello', id=1)
56+
d2 = Document(text='hello', embedding=np.array([1, 2, 4]), id=1)
57+
assert d1 != d2
58+
assert d2 != d1
59+
60+
4161
def test_pop_field():
4262
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
4363
assert d1.non_empty_fields == ('id', 'text', 'embedding')

tests/unit/math/test_ndarray.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from scipy.sparse import csr_matrix, coo_matrix, bsr_matrix, csc_matrix, issparse
77

8-
from docarray.math.ndarray import get_array_rows
8+
from docarray.math.ndarray import get_array_rows, check_arraylike_equality
99
from docarray.proto.docarray_pb2 import NdArrayProto
1010
from docarray.proto.io import flush_ndarray, read_ndarray
1111

@@ -51,3 +51,26 @@ def test_get_array_rows(data, expected_result, arraytype, ndarray_type):
5151
assert isinstance(r_data_array, list)
5252
elif ndarray_type == 'numpy':
5353
assert isinstance(r_data_array, np.ndarray)
54+
55+
56+
def get_ndarrays():
57+
a = np.random.random([10, 3])
58+
a[a > 0.5] = 0
59+
return [
60+
a,
61+
a.tolist(),
62+
torch.tensor(a),
63+
tf.constant(a),
64+
paddle.to_tensor(a),
65+
torch.tensor(a).to_sparse(),
66+
csr_matrix(a),
67+
bsr_matrix(a),
68+
coo_matrix(a),
69+
csc_matrix(a),
70+
]
71+
72+
73+
@pytest.mark.parametrize('ndarray_val', get_ndarrays())
74+
def test_check_arraylike_equality(ndarray_val):
75+
assert check_arraylike_equality(ndarray_val, ndarray_val) == True
76+
assert check_arraylike_equality(ndarray_val, ndarray_val + ndarray_val) == False

0 commit comments

Comments
 (0)