Skip to content

Commit 3e4bebe

Browse files
authored
perf: optimize Document init (#184)
1 parent fc06231 commit 3e4bebe

File tree

5 files changed

+101
-71
lines changed

5 files changed

+101
-71
lines changed

docarray/base.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,36 @@ def __init__(
3535
kwargs.update(_obj)
3636

3737
if kwargs:
38-
if field_resolver:
39-
kwargs = {field_resolver.get(k, k): v for k, v in kwargs.items()}
40-
41-
_fields = _get_fields(self._data_class)
42-
_unknown_kwargs = None
43-
_unresolved = set(kwargs.keys()).difference(_fields)
44-
45-
if _unresolved:
38+
try:
39+
self._data = self._data_class(self, **kwargs)
40+
except TypeError as ex:
4641
if unknown_fields_handler == 'raise':
47-
raise AttributeError(f'unknown attributes: {_unresolved}')
48-
49-
_unknown_kwargs = {k: kwargs[k] for k in _unresolved}
50-
for k in _unresolved:
51-
kwargs.pop(k)
52-
53-
self._data = self._data_class(self)
54-
55-
for f in _fields:
56-
if f in kwargs:
57-
setattr(self._data, f, kwargs[f])
58-
59-
if _unknown_kwargs and unknown_fields_handler == 'catch':
60-
getattr(self, self._unresolved_fields_dest).update(_unknown_kwargs)
42+
raise AttributeError(f'unknown attributes') from ex
43+
else:
44+
if field_resolver:
45+
kwargs = {
46+
field_resolver.get(k, k): v for k, v in kwargs.items()
47+
}
48+
49+
_fields = _get_fields(self._data_class)
50+
_unknown_kwargs = None
51+
_unresolved = set(kwargs.keys()).difference(_fields)
52+
53+
if _unresolved:
54+
_unknown_kwargs = {k: kwargs[k] for k in _unresolved}
55+
for k in _unresolved:
56+
kwargs.pop(k)
57+
58+
self._data = self._data_class(self, **kwargs)
59+
60+
if _unknown_kwargs and unknown_fields_handler == 'catch':
61+
getattr(self, self._unresolved_fields_dest).update(
62+
_unknown_kwargs
63+
)
64+
65+
for k in self._post_init_fields:
66+
if k in kwargs:
67+
setattr(self, k, kwargs[k])
6168

6269
if not _obj and not kwargs and self._data is None:
6370
self._data = self._data_class(self)

docarray/document/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
class Document(AllMixins, BaseDCType):
1212
_data_class = DocumentData
1313
_unresolved_fields_dest = 'tags'
14+
_post_init_fields = (
15+
'text',
16+
'blob',
17+
'tensor',
18+
'content',
19+
'uri',
20+
'mime_type',
21+
'chunks',
22+
'matches',
23+
)
1424

1525
@overload
1626
def __init__(self):

docarray/document/data.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import mimetypes
2-
import uuid
2+
import os
33
from collections import defaultdict
44
from dataclasses import dataclass, field, fields
55
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@@ -35,7 +35,7 @@
3535
@dataclass(unsafe_hash=True)
3636
class DocumentData:
3737
_reference_doc: 'Document' = field(hash=False, compare=False)
38-
id: str = field(default_factory=lambda: uuid.uuid1().hex)
38+
id: str = field(default_factory=lambda: os.urandom(16).hex())
3939
parent_id: Optional[str] = None
4040
granularity: Optional[int] = None
4141
adjacency: Optional[int] = None
@@ -56,45 +56,6 @@ class DocumentData:
5656
chunks: Optional['DocumentArray'] = None
5757
matches: Optional['DocumentArray'] = None
5858

59-
def __setattr__(self, key, value):
60-
if value is not None:
61-
if key == 'text' or key == 'tensor' or key == 'blob':
62-
# enable mutual exclusivity for content field
63-
dv = default_values.get(key)
64-
if type(value) != type(dv) or value != dv:
65-
self.text = None
66-
self.tensor = None
67-
self.blob = None
68-
elif key == 'uri':
69-
mime_type = mimetypes.guess_type(value)[0]
70-
71-
if mime_type:
72-
self.mime_type = mime_type
73-
elif key == 'mime_type':
74-
if value not in _all_mime_types:
75-
# given but not recognizable, do best guess
76-
r = mimetypes.guess_type(f'*.{value}')[0]
77-
value = r or value
78-
elif key == 'content':
79-
if isinstance(value, bytes):
80-
self.blob = value
81-
elif isinstance(value, str):
82-
self.text = value
83-
else:
84-
self.tensor = value
85-
value = None
86-
elif key == 'chunks':
87-
from ..array.chunk import ChunkArray
88-
89-
if not isinstance(value, ChunkArray):
90-
value = ChunkArray(value, reference_doc=self._reference_doc)
91-
elif key == 'matches':
92-
from ..array.match import MatchArray
93-
94-
if not isinstance(value, MatchArray):
95-
value = MatchArray(value, reference_doc=self._reference_doc)
96-
self.__dict__[key] = value
97-
9859
@property
9960
def _non_empty_fields(self) -> Tuple[str]:
10061
r = []

docarray/document/mixins/property.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from ._property import _PropertyMixin
55

66
if TYPE_CHECKING:
7-
from ...types import DocumentContentType
7+
from ...types import DocumentContentType, ArrayType
8+
from ... import DocumentArray
89

910
_all_mime_types = set(mimetypes.types_map.values())
1011

@@ -22,16 +23,66 @@ def content(self) -> Optional['DocumentContentType']:
2223
if ct:
2324
return getattr(self, ct)
2425

26+
@_PropertyMixin.text.setter
27+
def text(self, value: str):
28+
self._clear_content()
29+
self._data.text = value
30+
31+
@_PropertyMixin.blob.setter
32+
def blob(self, value: bytes):
33+
self._clear_content()
34+
self._data.blob = value
35+
36+
@_PropertyMixin.tensor.setter
37+
def tensor(self, value: 'ArrayType'):
38+
self._clear_content()
39+
self._data.tensor = value
40+
2541
@content.setter
2642
def content(self, value: 'DocumentContentType'):
27-
if value is None:
28-
self._clear_content()
29-
elif isinstance(value, bytes):
30-
self.blob = value
43+
self._clear_content()
44+
if isinstance(value, bytes):
45+
self._data.blob = value
3146
elif isinstance(value, str):
32-
self.text = value
33-
else:
34-
self.tensor = value
47+
self._data.text = value
48+
elif value is not None:
49+
self._data.tensor = value
50+
51+
@_PropertyMixin.uri.setter
52+
def uri(self, value: str):
53+
if value:
54+
mime_type = mimetypes.guess_type(value)[0]
55+
56+
if mime_type:
57+
self._data.mime_type = mime_type
58+
self._data.uri = value
59+
60+
@_PropertyMixin.mime_type.setter
61+
def mime_type(self, value: str):
62+
if value and value not in _all_mime_types:
63+
# given but not recognizable, do best guess
64+
r = mimetypes.guess_type(f'*.{value}')[0]
65+
value = r or value
66+
67+
self._data.mime_type = value
68+
69+
@_PropertyMixin.chunks.setter
70+
def chunks(self, value: 'DocumentArray'):
71+
from ...array.chunk import ChunkArray
72+
73+
if not isinstance(value, ChunkArray):
74+
value = ChunkArray(value, reference_doc=self._data._reference_doc)
75+
76+
self._data.chunks = value
77+
78+
@_PropertyMixin.matches.setter
79+
def matches(self, value: 'DocumentArray'):
80+
from ...array.match import MatchArray
81+
82+
if not isinstance(value, MatchArray):
83+
value = MatchArray(value, reference_doc=self._data._reference_doc)
84+
85+
self._data.matches = value
3586

3687
@property
3788
def content_type(self) -> Optional[str]:

docarray/score/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55

66
class NamedScore(AllMixins, BaseDCType):
77
_data_class = NamedScoreData
8+
_post_init_fields = ()

0 commit comments

Comments
 (0)