Skip to content

Commit 62ad22a

Browse files
author
Joan Fontanals
authored
fix: use safe_issubclass everywhere (#1691)
Signed-off-by: Joan Fontanals Martinez <[email protected]>
1 parent e0afb5e commit 62ad22a

File tree

16 files changed

+69
-57
lines changed

16 files changed

+69
-57
lines changed

docarray/array/any_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from docarray.display.document_array_summary import DocArraySummary
2424
from docarray.exceptions.exceptions import UnusableObjectError
2525
from docarray.typing.abstract_type import AbstractType
26-
from docarray.utils._internal._typing import change_cls_name
26+
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
2727

2828
if TYPE_CHECKING:
2929
from docarray.proto import DocListProto, NodeProto
@@ -53,7 +53,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
5353
if not isinstance(item, type):
5454
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
5555
# this do nothing that checking that item is valid type var or str
56-
if not issubclass(item, BaseDoc):
56+
if not safe_issubclass(item, BaseDoc):
5757
raise ValueError(
5858
f'{cls.__name__}[item] item should be a Document not a {item} '
5959
)

docarray/array/doc_list/doc_list.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
2525
from docarray.base_doc import AnyDoc, BaseDoc
2626
from docarray.typing import NdArray
27+
from docarray.utils._internal._typing import safe_issubclass
2728

2829
if TYPE_CHECKING:
2930
from pydantic import BaseConfig
@@ -158,7 +159,9 @@ def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
158159

159160
def _validate_one_doc(self, doc: T_doc) -> T_doc:
160161
"""Validate if a Document is compatible with this `DocList`"""
161-
if not issubclass(self.doc_type, AnyDoc) and not isinstance(doc, self.doc_type):
162+
if not safe_issubclass(self.doc_type, AnyDoc) and not isinstance(
163+
doc, self.doc_type
164+
):
162165
raise ValueError(f'{doc} is not a {self.doc_type}')
163166
return doc
164167

@@ -218,7 +221,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
218221
not is_union_type(field_type)
219222
and self.__class__.doc_type.__fields__[field].required
220223
and isinstance(field_type, type)
221-
and issubclass(field_type, BaseDoc)
224+
and safe_issubclass(field_type, BaseDoc)
222225
):
223226
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
224227
# most likely a bug in mypy though
@@ -272,7 +275,7 @@ def validate(
272275
return value
273276
elif isinstance(value, DocVec):
274277
if (
275-
issubclass(value.doc_type, cls.doc_type)
278+
safe_issubclass(value.doc_type, cls.doc_type)
276279
or value.doc_type == cls.doc_type
277280
):
278281
return cast(T, value.to_doc_list())
@@ -326,7 +329,7 @@ def __getitem__(self, item):
326329
@classmethod
327330
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
328331

329-
if isinstance(item, type) and issubclass(item, BaseDoc):
332+
if isinstance(item, type) and safe_issubclass(item, BaseDoc):
330333
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
331334
else:
332335
return super().__class_getitem__(item)

docarray/array/doc_vec/doc_vec.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from docarray.base_doc.mixins.io import _type_to_protobuf
3232
from docarray.typing import NdArray
3333
from docarray.typing.tensor.abstract_tensor import AbstractTensor
34-
from docarray.utils._internal._typing import is_tensor_union
34+
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
3535
from docarray.utils._internal.misc import is_tf_available, is_torch_available
3636

3737
if TYPE_CHECKING:
@@ -198,7 +198,7 @@ def __init__(
198198
for field_name, field in self.doc_type.__fields__.items():
199199
# here we iterate over the field of the docs schema, and we collect the data
200200
# from each document and put them in the corresponding column
201-
field_type = self.doc_type._get_field_type(field_name)
201+
field_type: Type = self.doc_type._get_field_type(field_name)
202202

203203
is_field_required = self.doc_type.__fields__[field_name].required
204204

@@ -231,19 +231,19 @@ def _check_doc_field_not_none(field_name, doc):
231231
field_type = tensor_type
232232
# all generic tensor types such as AnyTensor, ImageTensor, etc. are subclasses of AbstractTensor.
233233
# Perform check only if the field_type is not an alias and is a subclass of AbstractTensor
234-
elif not isinstance(field_type, typingGenericAlias) and issubclass(
234+
elif not isinstance(field_type, typingGenericAlias) and safe_issubclass(
235235
field_type, AbstractTensor
236236
):
237237
# check if the tensor associated with the field_name in the document is a subclass of the tensor_type
238238
# e.g. if the field_type is AnyTensor but the type(docs[0][field_name]) is ImageTensor,
239239
# then we change the field_type to ImageTensor, since AnyTensor is a union of all the tensor types
240240
# and does not override any methods of specific tensor types
241241
tensor = getattr(docs[0], field_name)
242-
if issubclass(tensor.__class__, tensor_type):
242+
if safe_issubclass(tensor.__class__, tensor_type):
243243
field_type = tensor_type
244244

245245
if isinstance(field_type, type):
246-
if tf_available and issubclass(field_type, TensorFlowTensor):
246+
if tf_available and safe_issubclass(field_type, TensorFlowTensor):
247247
# tf.Tensor does not allow item assignment, therefore the
248248
# optimized way
249249
# of initializing an empty array and assigning values to it
@@ -263,7 +263,7 @@ def _check_doc_field_not_none(field_name, doc):
263263
stacked: tf.Tensor = tf.stack(tf_stack)
264264
tensor_columns[field_name] = TensorFlowTensor(stacked)
265265

266-
elif issubclass(field_type, AbstractTensor):
266+
elif safe_issubclass(field_type, AbstractTensor):
267267
if first_doc_is_none:
268268
_verify_optional_field_of_docs(docs)
269269
tensor_columns[field_name] = None
@@ -291,7 +291,7 @@ def _check_doc_field_not_none(field_name, doc):
291291
val = getattr(doc, field_name)
292292
cast(AbstractTensor, tensor_columns[field_name])[i] = val
293293

294-
elif issubclass(field_type, BaseDoc):
294+
elif safe_issubclass(field_type, BaseDoc):
295295
if first_doc_is_none:
296296
_verify_optional_field_of_docs(docs)
297297
doc_columns[field_name] = None
@@ -307,7 +307,7 @@ def _check_doc_field_not_none(field_name, doc):
307307
tensor_type=self.tensor_type
308308
)
309309

310-
elif issubclass(field_type, AnyDocArray):
310+
elif safe_issubclass(field_type, AnyDocArray):
311311
if first_doc_is_none:
312312
_verify_optional_field_of_docs(docs)
313313
docs_vec_columns[field_name] = None
@@ -362,7 +362,7 @@ def validate(
362362
return value
363363
elif isinstance(value, DocList):
364364
if (
365-
issubclass(value.doc_type, cls.doc_type)
365+
safe_issubclass(value.doc_type, cls.doc_type)
366366
or value.doc_type == cls.doc_type
367367
):
368368
return cast(T, value.to_doc_vec())
@@ -481,7 +481,7 @@ def _set_data_and_columns(
481481
# set data and prepare columns
482482
processed_value: T
483483
if isinstance(value, DocList):
484-
if not issubclass(value.doc_type, self.doc_type):
484+
if not safe_issubclass(value.doc_type, self.doc_type):
485485
raise TypeError(
486486
f'{value} schema : {value.doc_type} is not compatible with '
487487
f'this DocVec schema : {self.doc_type}'
@@ -491,7 +491,7 @@ def _set_data_and_columns(
491491
) # we need to copy data here
492492

493493
elif isinstance(value, DocVec):
494-
if not issubclass(value.doc_type, self.doc_type):
494+
if not safe_issubclass(value.doc_type, self.doc_type):
495495
raise TypeError(
496496
f'{value} schema : {value.doc_type} is not compatible with '
497497
f'this DocVec schema : {self.doc_type}'

docarray/base_doc/doc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from docarray.base_doc.mixins import IOMixin, UpdateMixin
2828
from docarray.typing import ID
2929
from docarray.typing.tensor.abstract_tensor import AbstractTensor
30+
from docarray.utils._internal._typing import safe_issubclass
3031

3132
if TYPE_CHECKING:
3233
from pydantic import Protocol
@@ -351,7 +352,7 @@ def _exclude_docarray(
351352

352353
type_ = self._get_field_type(field)
353354
if isinstance(type_, type) and (
354-
issubclass(type_, DocList) or issubclass(type_, DocVec)
355+
safe_issubclass(type_, DocList) or safe_issubclass(type_, DocVec)
355356
):
356357
docarray_exclude_fields.append(field)
357358

docarray/base_doc/mixins/update.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar
33

44
from typing_inspect import get_origin
5+
from docarray.utils._internal._typing import safe_issubclass
56

67
T = TypeVar('T', bound='UpdateMixin')
78

@@ -108,7 +109,9 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
108109
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
109110
field_type = doc._get_field_type(field_name)
110111

111-
if isinstance(field_type, type) and issubclass(field_type, DocList):
112+
if isinstance(field_type, type) and safe_issubclass(
113+
field_type, DocList
114+
):
112115
nested_docarray_fields.append(field_name)
113116
else:
114117
origin = get_origin(field_type)

docarray/data/torch_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from docarray import BaseDoc, DocList, DocVec
66
from docarray.typing import TorchTensor
7-
from docarray.utils._internal._typing import change_cls_name
7+
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
88

99
T_doc = TypeVar('T_doc', bound=BaseDoc)
1010

@@ -141,7 +141,7 @@ def collate_fn(cls, batch: List[T_doc]):
141141

142142
@classmethod
143143
def __class_getitem__(cls, item: Type[BaseDoc]) -> Type['MultiModalDataset']:
144-
if not issubclass(item, BaseDoc):
144+
if not safe_issubclass(item, BaseDoc):
145145
raise ValueError(
146146
f'{cls.__name__}[item] item should be a Document not a {item} '
147147
)

docarray/documents/helper.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic import create_model, create_model_from_typeddict
44
from pydantic.config import BaseConfig
55
from typing_extensions import TypedDict
6+
from docarray.utils._internal._typing import safe_issubclass
67

78
from docarray import BaseDoc
89

@@ -38,8 +39,8 @@ def create_doc(
3839
tensor=(AudioNdArray, ...),
3940
)
4041
41-
assert issubclass(MyAudio, BaseDoc)
42-
assert issubclass(MyAudio, Audio)
42+
assert safe_issubclass(MyAudio, BaseDoc)
43+
assert safe_issubclass(MyAudio, Audio)
4344
```
4445
4546
:param __model_name: name of the created model
@@ -54,7 +55,7 @@ def create_doc(
5455
:return: the new Document class
5556
"""
5657

57-
if not issubclass(__base__, BaseDoc):
58+
if not safe_issubclass(__base__, BaseDoc):
5859
raise ValueError(f'{type(__base__)} is not a BaseDoc or its subclass')
5960

6061
doc = create_model(
@@ -96,8 +97,8 @@ class MyAudio(TypedDict):
9697
9798
Doc = create_doc_from_typeddict(MyAudio, __base__=Audio)
9899
99-
assert issubclass(Doc, BaseDoc)
100-
assert issubclass(Doc, Audio)
100+
assert safe_issubclass(Doc, BaseDoc)
101+
assert safe_issubclass(Doc, Audio)
101102
```
102103
103104
---
@@ -108,7 +109,7 @@ class MyAudio(TypedDict):
108109
"""
109110

110111
if '__base__' in kwargs:
111-
if not issubclass(kwargs['__base__'], BaseDoc):
112+
if not safe_issubclass(kwargs['__base__'], BaseDoc):
112113
raise ValueError(f'{kwargs["__base__"]} is not a BaseDoc or its subclass')
113114
else:
114115
kwargs['__base__'] = BaseDoc
@@ -136,7 +137,7 @@ def create_doc_from_dict(model_name: str, data_dict: Dict[str, Any]) -> Type['T_
136137
137138
MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)
138139
139-
assert issubclass(MyDoc, BaseDoc)
140+
assert safe_issubclass(MyDoc, BaseDoc)
140141
```
141142
142143
---

docarray/index/abstract.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ def __getitem__(
362362
for field_name, type_, _ in self._flatten_schema(
363363
cast(Type[BaseDoc], self._schema)
364364
):
365-
if issubclass(type_, AnyDocArray) and isinstance(doc_sequence[0], Dict):
365+
if safe_issubclass(type_, AnyDocArray) and isinstance(
366+
doc_sequence[0], Dict
367+
):
366368
for doc in doc_sequence:
367369
self._get_subindex_doclist(doc, field_name) # type: ignore
368370

@@ -534,7 +536,7 @@ def find_batched(
534536
if search_field:
535537
if '__' in search_field:
536538
fields = search_field.split('__')
537-
if issubclass(self._schema._get_field_type(fields[0]), AnyDocArray): # type: ignore
539+
if safe_issubclass(self._schema._get_field_type(fields[0]), AnyDocArray): # type: ignore
538540
return self._subindices[fields[0]].find_batched(
539541
queries,
540542
search_field='__'.join(fields[1:]),
@@ -799,7 +801,7 @@ def __class_getitem__(cls, item: Type[TSchema]):
799801
# do nothing
800802
# enables use in static contexts with type vars, e.g. as type annotation
801803
return Generic.__class_getitem__.__func__(cls, item)
802-
if not issubclass(item, BaseDoc):
804+
if not safe_issubclass(item, BaseDoc):
803805
raise ValueError(
804806
f'{cls.__name__}[item] `item` should be a Document not a {item} '
805807
)
@@ -849,7 +851,7 @@ def _flatten_schema(
849851
# treat as if it was a single non-optional type
850852
for t_arg in union_args:
851853
if t_arg is not type(None):
852-
if issubclass(t_arg, BaseDoc):
854+
if safe_issubclass(t_arg, BaseDoc):
853855
names_types_fields.extend(
854856
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
855857
)
@@ -1044,15 +1046,15 @@ def _convert_dict_to_doc(
10441046
for field_name, _ in schema.__fields__.items():
10451047
t_ = schema._get_field_type(field_name)
10461048

1047-
if not is_union_type(t_) and issubclass(t_, AnyDocArray):
1049+
if not is_union_type(t_) and safe_issubclass(t_, AnyDocArray):
10481050
self._get_subindex_doclist(doc_dict, field_name)
10491051

10501052
if is_optional_type(t_):
10511053
for t_arg in get_args(t_):
10521054
if t_arg is not type(None):
10531055
t_ = t_arg
10541056

1055-
if not is_union_type(t_) and issubclass(t_, BaseDoc):
1057+
if not is_union_type(t_) and safe_issubclass(t_, BaseDoc):
10561058
inner_dict = {}
10571059

10581060
fields = [
@@ -1125,7 +1127,7 @@ def _find_subdocs(
11251127
) -> FindResult:
11261128
"""Find documents in the subindex and return subindex docs and scores."""
11271129
fields = subindex.split('__')
1128-
if not subindex or not issubclass(
1130+
if not subindex or not safe_issubclass(
11291131
self._schema._get_field_type(fields[0]), AnyDocArray # type: ignore
11301132
):
11311133
raise ValueError(f'subindex {subindex} is not valid')

docarray/index/backends/elastic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(self, db_config=None, **kwargs):
9595
self._logger.debug('Mappings have been updated with db_config.index_mappings')
9696

9797
for col_name, col in self._column_infos.items():
98-
if issubclass(col.docarray_type, AnyDocArray):
98+
if safe_issubclass(col.docarray_type, AnyDocArray):
9999
continue
100100
if col.db_type == 'dense_vector' and (
101101
not col.n_dim and col.config['dims'] < 0
@@ -336,7 +336,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
336336
self._logger.debug(f'Mapping Python type {python_type} to database type')
337337

338338
for allowed_type in ELASTIC_PY_VEC_TYPES:
339-
if issubclass(python_type, allowed_type):
339+
if safe_issubclass(python_type, allowed_type):
340340
self._logger.info(
341341
f'Mapped Python type {python_type} to database type "dense_vector"'
342342
)
@@ -354,7 +354,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
354354
}
355355

356356
for type in elastic_py_types.keys():
357-
if issubclass(python_type, type):
357+
if safe_issubclass(python_type, type):
358358
self._logger.info(
359359
f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"'
360360
)
@@ -381,7 +381,7 @@ def _index(
381381
'_id': row['id'],
382382
}
383383
for col_name, col in self._column_infos.items():
384-
if issubclass(col.docarray_type, AnyDocArray):
384+
if safe_issubclass(col.docarray_type, AnyDocArray):
385385
continue
386386
if col.db_type == 'dense_vector' and np.all(row[col_name] == 0):
387387
row[col_name] = row[col_name] + 1.0e-9

docarray/index/backends/qdrant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
265265
if any(issubclass(python_type, vt) for vt in QDRANT_PY_VECTOR_TYPES):
266266
return 'vector'
267267

268-
if issubclass(python_type, docarray.typing.id.ID):
268+
if safe_issubclass(python_type, docarray.typing.id.ID):
269269
return 'id'
270270

271271
return 'payload'
@@ -587,7 +587,7 @@ def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:
587587
vectors: Dict[str, List[float]] = {}
588588
payload: Dict[str, Any] = {'__generated_vectors': []}
589589
for column_name, column_info in self._column_infos.items():
590-
if issubclass(column_info.docarray_type, AnyDocArray):
590+
if safe_issubclass(column_info.docarray_type, AnyDocArray):
591591
continue
592592
if column_info.db_type in ['id', 'payload']:
593593
payload[column_name] = row.get(column_name)

0 commit comments

Comments
 (0)