Skip to content

Commit 97725ed

Browse files
davidbphanxiao
andauthored
fix: single doc set in docarray (#42)
Co-authored-by: Han Xiao <[email protected]>
1 parent edf836d commit 97725ed

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

docarray/array/document.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __setitem__(
211211
index: 'DocumentArrayIndexType',
212212
value: Union['Document', Sequence['Document']],
213213
):
214+
214215
if isinstance(index, (int, np.generic)) and not isinstance(index, bool):
215216
index = int(index)
216217
self._data[index] = value
@@ -274,8 +275,11 @@ def __setitem__(
274275
elif _a == 'embedding':
275276
_docs.embeddings = _v
276277
else:
277-
for _d, _vv in zip(_docs, _v):
278-
setattr(_d, _a, _vv)
278+
if len(_docs) == 1:
279+
setattr(_docs[0], _a, _v)
280+
else:
281+
for _d, _vv in zip(_docs, _v):
282+
setattr(_d, _a, _vv)
279283
elif isinstance(index[0], bool):
280284
if len(index) != len(self._data):
281285
raise IndexError(

docarray/types.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@
4848
DocumentArrayMultipleIndexType = Union[
4949
slice, Sequence[int], Sequence[str], Sequence[bool], Ellipsis
5050
]
51-
DocumentArraySingleAttributeType = Tuple[slice, str]
52-
DocumentArrayMultipleAttributeType = Tuple[slice, Sequence[str]]
51+
DocumentArraySingleAttributeType = Tuple[
52+
Union[DocumentArraySingletonIndexType, DocumentArrayMultipleIndexType], str
53+
]
54+
DocumentArrayMultipleAttributeType = Tuple[
55+
Union[DocumentArraySingletonIndexType, DocumentArrayMultipleIndexType],
56+
Sequence[str],
57+
]
5358
DocumentArrayIndexType = Union[
5459
DocumentArraySingletonIndexType,
5560
DocumentArrayMultipleIndexType,

tests/unit/array/mixins/test_getset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ def test_texts_getter_da(da):
9494
assert not da.texts
9595

9696

97+
@pytest.mark.parametrize('da', da_and_dam())
98+
def test_setter_by_sequences_in_selected_docs_da(da):
99+
100+
da[[0], 'text'] = 'jina'
101+
assert ['jina'] == da[[0], 'text']
102+
103+
da[[0, 1], 'text'] = ['jina', 'jana']
104+
assert ['jina', 'jana'] == da[[0, 1], 'text']
105+
106+
da[[0], 'id'] = '12'
107+
assert ['12'] == da[[0], 'id']
108+
109+
da[[0, 1], 'id'] = ['12', '34']
110+
assert ['12', '34'] == da[[0, 1], 'id']
111+
112+
97113
@pytest.mark.parametrize('da', da_and_dam())
98114
def test_texts_wrong_len(da):
99115
texts = ['hello']

0 commit comments

Comments
 (0)