Skip to content

Commit c02a97a

Browse files
authored
perf: improve doc init speed (#179)
1 parent 5078e1e commit c02a97a

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

docarray/base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy as cp
22
from dataclasses import fields
3+
from functools import lru_cache
34
from typing import TYPE_CHECKING, Optional, Tuple, Dict
45

56
from .helper import typename
@@ -8,6 +9,11 @@
89
from .types import T
910

1011

12+
@lru_cache()
13+
def _get_fields(dc):
14+
return [f.name for f in fields(dc)]
15+
16+
1117
class BaseDCType:
1218
_data_class = None
1319

@@ -32,9 +38,9 @@ def __init__(
3238
if field_resolver:
3339
kwargs = {field_resolver.get(k, k): v for k, v in kwargs.items()}
3440

35-
_fields = fields(self._data_class)
41+
_fields = _get_fields(self._data_class)
3642
_unknown_kwargs = None
37-
_unresolved = set(kwargs.keys()).difference({f.name for f in _fields})
43+
_unresolved = set(kwargs.keys()).difference(_fields)
3844

3945
if _unresolved:
4046
if unknown_fields_handler == 'raise':
@@ -46,9 +52,9 @@ def __init__(
4652

4753
self._data = self._data_class(self)
4854

49-
for field in _fields:
50-
if field.name in kwargs:
51-
setattr(self._data, field.name, kwargs[field.name])
55+
for f in _fields:
56+
if f in kwargs:
57+
setattr(self._data, f, kwargs[f])
5258

5359
if _unknown_kwargs and unknown_fields_handler == 'catch':
5460
getattr(self, self._unresolved_fields_dest).update(_unknown_kwargs)

docarray/document/data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ def __setattr__(self, key, value):
6565
self.text = None
6666
self.tensor = None
6767
self.blob = None
68-
if key == 'text':
69-
self.mime_type = 'text/plain'
7068
elif key == 'uri':
7169
mime_type = mimetypes.guess_type(value)[0]
7270

tests/unit/document/test_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_convert_blob_to_uri(converter):
146146
'converter', ['convert_text_to_datauri', 'convert_content_to_datauri']
147147
)
148148
def test_convert_text_to_uri(converter):
149-
d = Document(content=open(__file__).read())
149+
d = Document(content=open(__file__).read(), mime_type='text/plain')
150150
assert d.text
151151
getattr(d, converter)()
152152
assert d.uri.startswith('data:text/plain;')
@@ -183,7 +183,7 @@ def test_convert_text_to_uri_and_back():
183183
text_from_file = open(__file__).read()
184184
doc = Document(content=text_from_file)
185185
assert doc.text
186-
assert doc.mime_type == 'text/plain'
186+
assert not doc.mime_type
187187
doc.convert_text_to_datauri()
188188
doc.load_uri_to_text()
189189
assert doc.mime_type == 'text/plain'

tests/unit/document/test_docdata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def test_doc_hash_complicate_content():
4040

4141
def test_pop_field():
4242
d1 = Document(text='hello', embedding=np.array([1, 2, 3]), id=1)
43-
assert d1.non_empty_fields == ('id', 'mime_type', 'text', 'embedding')
43+
assert d1.non_empty_fields == ('id', 'text', 'embedding')
4444
d1.pop('text')
45-
assert d1.non_empty_fields == ('id', 'mime_type', 'embedding')
46-
d1.pop('id', 'embedding', 'mime_type')
45+
assert d1.non_empty_fields == ('id', 'embedding')
46+
d1.pop('id', 'embedding')
4747
assert d1.non_empty_fields == tuple()
4848

4949
d1.pop('foobar')
@@ -138,7 +138,7 @@ def test_offset():
138138

139139
def test_exclusive_content_2():
140140
d = Document(text='hello', blob=b'sda')
141-
assert len(d.non_empty_fields) == 3
141+
assert len(d.non_empty_fields) == 2
142142
d.content = b'sda'
143143
assert d.content == b'sda'
144144
assert 'blob' in d.non_empty_fields

0 commit comments

Comments
 (0)