Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docarray/array/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from docarray.display.document_array_summary import DocArraySummary
from docarray.exceptions.exceptions import UnusableObjectError
from docarray.typing.abstract_type import AbstractType
from docarray.utils._internal._typing import change_cls_name
from docarray.utils._internal._typing import change_cls_name, safe_issubclass

if TYPE_CHECKING:
from docarray.proto import DocListProto, NodeProto
Expand Down Expand Up @@ -53,7 +53,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
if not isinstance(item, type):
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
# this do nothing that checking that item is valid type var or str
if not issubclass(item, BaseDoc):
if not safe_issubclass(item, BaseDoc):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
Expand Down
11 changes: 7 additions & 4 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray
from docarray.utils._internal._typing import safe_issubclass

if TYPE_CHECKING:
from pydantic import BaseConfig
Expand Down Expand Up @@ -158,7 +159,9 @@ def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:

def _validate_one_doc(self, doc: T_doc) -> T_doc:
"""Validate if a Document is compatible with this `DocList`"""
if not issubclass(self.doc_type, AnyDoc) and not isinstance(doc, self.doc_type):
if not safe_issubclass(self.doc_type, AnyDoc) and not isinstance(
doc, self.doc_type
):
raise ValueError(f'{doc} is not a {self.doc_type}')
return doc

Expand Down Expand Up @@ -218,7 +221,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
not is_union_type(field_type)
and self.__class__.doc_type.__fields__[field].required
and isinstance(field_type, type)
and issubclass(field_type, BaseDoc)
and safe_issubclass(field_type, BaseDoc)
):
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
# most likely a bug in mypy though
Expand Down Expand Up @@ -272,7 +275,7 @@ def validate(
return value
elif isinstance(value, DocVec):
if (
issubclass(value.doc_type, cls.doc_type)
safe_issubclass(value.doc_type, cls.doc_type)
or value.doc_type == cls.doc_type
):
return cast(T, value.to_doc_list())
Expand Down Expand Up @@ -326,7 +329,7 @@ def __getitem__(self, item):
@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):

if isinstance(item, type) and issubclass(item, BaseDoc):
if isinstance(item, type) and safe_issubclass(item, BaseDoc):
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
else:
return super().__class_getitem__(item)
Expand Down
22 changes: 11 additions & 11 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from docarray.base_doc.mixins.io import _type_to_protobuf
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
from docarray.utils._internal.misc import is_tf_available, is_torch_available

if TYPE_CHECKING:
Expand Down Expand Up @@ -198,7 +198,7 @@ def __init__(
for field_name, field in self.doc_type.__fields__.items():
# here we iterate over the field of the docs schema, and we collect the data
# from each document and put them in the corresponding column
field_type = self.doc_type._get_field_type(field_name)
field_type: Type = self.doc_type._get_field_type(field_name)

is_field_required = self.doc_type.__fields__[field_name].required

Expand Down Expand Up @@ -231,19 +231,19 @@ def _check_doc_field_not_none(field_name, doc):
field_type = tensor_type
# all generic tensor types such as AnyTensor, ImageTensor, etc. are subclasses of AbstractTensor.
# Perform check only if the field_type is not an alias and is a subclass of AbstractTensor
elif not isinstance(field_type, typingGenericAlias) and issubclass(
elif not isinstance(field_type, typingGenericAlias) and safe_issubclass(
field_type, AbstractTensor
):
# check if the tensor associated with the field_name in the document is a subclass of the tensor_type
# e.g. if the field_type is AnyTensor but the type(docs[0][field_name]) is ImageTensor,
# then we change the field_type to ImageTensor, since AnyTensor is a union of all the tensor types
# and does not override any methods of specific tensor types
tensor = getattr(docs[0], field_name)
if issubclass(tensor.__class__, tensor_type):
if safe_issubclass(tensor.__class__, tensor_type):
field_type = tensor_type

if isinstance(field_type, type):
if tf_available and issubclass(field_type, TensorFlowTensor):
if tf_available and safe_issubclass(field_type, TensorFlowTensor):
# tf.Tensor does not allow item assignment, therefore the
# optimized way
# of initializing an empty array and assigning values to it
Expand All @@ -263,7 +263,7 @@ def _check_doc_field_not_none(field_name, doc):
stacked: tf.Tensor = tf.stack(tf_stack)
tensor_columns[field_name] = TensorFlowTensor(stacked)

elif issubclass(field_type, AbstractTensor):
elif safe_issubclass(field_type, AbstractTensor):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
tensor_columns[field_name] = None
Expand Down Expand Up @@ -291,7 +291,7 @@ def _check_doc_field_not_none(field_name, doc):
val = getattr(doc, field_name)
cast(AbstractTensor, tensor_columns[field_name])[i] = val

elif issubclass(field_type, BaseDoc):
elif safe_issubclass(field_type, BaseDoc):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
doc_columns[field_name] = None
Expand All @@ -307,7 +307,7 @@ def _check_doc_field_not_none(field_name, doc):
tensor_type=self.tensor_type
)

elif issubclass(field_type, AnyDocArray):
elif safe_issubclass(field_type, AnyDocArray):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
docs_vec_columns[field_name] = None
Expand Down Expand Up @@ -362,7 +362,7 @@ def validate(
return value
elif isinstance(value, DocList):
if (
issubclass(value.doc_type, cls.doc_type)
safe_issubclass(value.doc_type, cls.doc_type)
or value.doc_type == cls.doc_type
):
return cast(T, value.to_doc_vec())
Expand Down Expand Up @@ -481,7 +481,7 @@ def _set_data_and_columns(
# set data and prepare columns
processed_value: T
if isinstance(value, DocList):
if not issubclass(value.doc_type, self.doc_type):
if not safe_issubclass(value.doc_type, self.doc_type):
raise TypeError(
f'{value} schema : {value.doc_type} is not compatible with '
f'this DocVec schema : {self.doc_type}'
Expand All @@ -491,7 +491,7 @@ def _set_data_and_columns(
) # we need to copy data here

elif isinstance(value, DocVec):
if not issubclass(value.doc_type, self.doc_type):
if not safe_issubclass(value.doc_type, self.doc_type):
raise TypeError(
f'{value} schema : {value.doc_type} is not compatible with '
f'this DocVec schema : {self.doc_type}'
Expand Down
3 changes: 2 additions & 1 deletion docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from docarray.base_doc.mixins import IOMixin, UpdateMixin
from docarray.typing import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import safe_issubclass

if TYPE_CHECKING:
from pydantic import Protocol
Expand Down Expand Up @@ -351,7 +352,7 @@ def _exclude_docarray(

type_ = self._get_field_type(field)
if isinstance(type_, type) and (
issubclass(type_, DocList) or issubclass(type_, DocVec)
safe_issubclass(type_, DocList) or safe_issubclass(type_, DocVec)
):
docarray_exclude_fields.append(field)

Expand Down
5 changes: 4 additions & 1 deletion docarray/base_doc/mixins/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar

from typing_inspect import get_origin
from docarray.utils._internal._typing import safe_issubclass

T = TypeVar('T', bound='UpdateMixin')

Expand Down Expand Up @@ -108,7 +109,9 @@ def _group_fields(doc: 'UpdateMixin') -> _FieldGroups:
if field_name not in FORBIDDEN_FIELDS_TO_UPDATE:
field_type = doc._get_field_type(field_name)

if isinstance(field_type, type) and issubclass(field_type, DocList):
if isinstance(field_type, type) and safe_issubclass(
field_type, DocList
):
nested_docarray_fields.append(field_name)
else:
origin = get_origin(field_type)
Expand Down
4 changes: 2 additions & 2 deletions docarray/data/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from docarray import BaseDoc, DocList, DocVec
from docarray.typing import TorchTensor
from docarray.utils._internal._typing import change_cls_name
from docarray.utils._internal._typing import change_cls_name, safe_issubclass

T_doc = TypeVar('T_doc', bound=BaseDoc)

Expand Down Expand Up @@ -141,7 +141,7 @@ def collate_fn(cls, batch: List[T_doc]):

@classmethod
def __class_getitem__(cls, item: Type[BaseDoc]) -> Type['MultiModalDataset']:
if not issubclass(item, BaseDoc):
if not safe_issubclass(item, BaseDoc):
raise ValueError(
f'{cls.__name__}[item] item should be a Document not a {item} '
)
Expand Down
15 changes: 8 additions & 7 deletions docarray/documents/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import create_model, create_model_from_typeddict
from pydantic.config import BaseConfig
from typing_extensions import TypedDict
from docarray.utils._internal._typing import safe_issubclass

from docarray import BaseDoc

Expand Down Expand Up @@ -38,8 +39,8 @@ def create_doc(
tensor=(AudioNdArray, ...),
)

assert issubclass(MyAudio, BaseDoc)
assert issubclass(MyAudio, Audio)
assert safe_issubclass(MyAudio, BaseDoc)
assert safe_issubclass(MyAudio, Audio)
```

:param __model_name: name of the created model
Expand All @@ -54,7 +55,7 @@ def create_doc(
:return: the new Document class
"""

if not issubclass(__base__, BaseDoc):
if not safe_issubclass(__base__, BaseDoc):
raise ValueError(f'{type(__base__)} is not a BaseDoc or its subclass')

doc = create_model(
Expand Down Expand Up @@ -96,8 +97,8 @@ class MyAudio(TypedDict):

Doc = create_doc_from_typeddict(MyAudio, __base__=Audio)

assert issubclass(Doc, BaseDoc)
assert issubclass(Doc, Audio)
assert safe_issubclass(Doc, BaseDoc)
assert safe_issubclass(Doc, Audio)
```

---
Expand All @@ -108,7 +109,7 @@ class MyAudio(TypedDict):
"""

if '__base__' in kwargs:
if not issubclass(kwargs['__base__'], BaseDoc):
if not safe_issubclass(kwargs['__base__'], BaseDoc):
raise ValueError(f'{kwargs["__base__"]} is not a BaseDoc or its subclass')
else:
kwargs['__base__'] = BaseDoc
Expand Down Expand Up @@ -136,7 +137,7 @@ def create_doc_from_dict(model_name: str, data_dict: Dict[str, Any]) -> Type['T_

MyDoc = create_doc_from_dict(model_name='MyDoc', data_dict=data_dict)

assert issubclass(MyDoc, BaseDoc)
assert safe_issubclass(MyDoc, BaseDoc)
```

---
Expand Down
16 changes: 9 additions & 7 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def __getitem__(
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self._schema)
):
if issubclass(type_, AnyDocArray) and isinstance(doc_sequence[0], Dict):
if safe_issubclass(type_, AnyDocArray) and isinstance(
doc_sequence[0], Dict
):
for doc in doc_sequence:
self._get_subindex_doclist(doc, field_name) # type: ignore

Expand Down Expand Up @@ -534,7 +536,7 @@ def find_batched(
if search_field:
if '__' in search_field:
fields = search_field.split('__')
if issubclass(self._schema._get_field_type(fields[0]), AnyDocArray): # type: ignore
if safe_issubclass(self._schema._get_field_type(fields[0]), AnyDocArray): # type: ignore
return self._subindices[fields[0]].find_batched(
queries,
search_field='__'.join(fields[1:]),
Expand Down Expand Up @@ -799,7 +801,7 @@ def __class_getitem__(cls, item: Type[TSchema]):
# do nothing
# enables use in static contexts with type vars, e.g. as type annotation
return Generic.__class_getitem__.__func__(cls, item)
if not issubclass(item, BaseDoc):
if not safe_issubclass(item, BaseDoc):
raise ValueError(
f'{cls.__name__}[item] `item` should be a Document not a {item} '
)
Expand Down Expand Up @@ -849,7 +851,7 @@ def _flatten_schema(
# treat as if it was a single non-optional type
for t_arg in union_args:
if t_arg is not type(None):
if issubclass(t_arg, BaseDoc):
if safe_issubclass(t_arg, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_arg, name_prefix=inner_prefix)
)
Expand Down Expand Up @@ -1044,15 +1046,15 @@ def _convert_dict_to_doc(
for field_name, _ in schema.__fields__.items():
t_ = schema._get_field_type(field_name)

if not is_union_type(t_) and issubclass(t_, AnyDocArray):
if not is_union_type(t_) and safe_issubclass(t_, AnyDocArray):
self._get_subindex_doclist(doc_dict, field_name)

if is_optional_type(t_):
for t_arg in get_args(t_):
if t_arg is not type(None):
t_ = t_arg

if not is_union_type(t_) and issubclass(t_, BaseDoc):
if not is_union_type(t_) and safe_issubclass(t_, BaseDoc):
inner_dict = {}

fields = [
Expand Down Expand Up @@ -1125,7 +1127,7 @@ def _find_subdocs(
) -> FindResult:
"""Find documents in the subindex and return subindex docs and scores."""
fields = subindex.split('__')
if not subindex or not issubclass(
if not subindex or not safe_issubclass(
self._schema._get_field_type(fields[0]), AnyDocArray # type: ignore
):
raise ValueError(f'subindex {subindex} is not valid')
Expand Down
8 changes: 4 additions & 4 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, db_config=None, **kwargs):
self._logger.debug('Mappings have been updated with db_config.index_mappings')

for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
if safe_issubclass(col.docarray_type, AnyDocArray):
continue
if col.db_type == 'dense_vector' and (
not col.n_dim and col.config['dims'] < 0
Expand Down Expand Up @@ -336,7 +336,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
self._logger.debug(f'Mapping Python type {python_type} to database type')

for allowed_type in ELASTIC_PY_VEC_TYPES:
if issubclass(python_type, allowed_type):
if safe_issubclass(python_type, allowed_type):
self._logger.info(
f'Mapped Python type {python_type} to database type "dense_vector"'
)
Expand All @@ -354,7 +354,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
}

for type in elastic_py_types.keys():
if issubclass(python_type, type):
if safe_issubclass(python_type, type):
self._logger.info(
f'Mapped Python type {python_type} to database type "{elastic_py_types[type]}"'
)
Expand All @@ -381,7 +381,7 @@ def _index(
'_id': row['id'],
}
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
if safe_issubclass(col.docarray_type, AnyDocArray):
continue
if col.db_type == 'dense_vector' and np.all(row[col_name] == 0):
row[col_name] = row[col_name] + 1.0e-9
Expand Down
4 changes: 2 additions & 2 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
if any(issubclass(python_type, vt) for vt in QDRANT_PY_VECTOR_TYPES):
return 'vector'

if issubclass(python_type, docarray.typing.id.ID):
if safe_issubclass(python_type, docarray.typing.id.ID):
return 'id'

return 'payload'
Expand Down Expand Up @@ -587,7 +587,7 @@ def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:
vectors: Dict[str, List[float]] = {}
payload: Dict[str, Any] = {'__generated_vectors': []}
for column_name, column_info in self._column_infos.items():
if issubclass(column_info.docarray_type, AnyDocArray):
if safe_issubclass(column_info.docarray_type, AnyDocArray):
continue
if column_info.db_type in ['id', 'payload']:
payload[column_name] = row.get(column_name)
Expand Down
Loading