Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
python-version: [3.9]
pydantic-version: ["pydantic-v2", "pydantic-v1"]
test-path: [tests/integrations, tests/units, tests/documentation]
steps:
Expand All @@ -112,6 +112,7 @@ jobs:
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip uninstall -y torch
poetry run pip install torch
poetry run pip install numpy==1.26.1
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down
10 changes: 5 additions & 5 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import typing_extensions
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo
from typing_inspect import is_optional_type
from typing_inspect import get_args, is_optional_type

from docarray.utils._internal.pydantic import is_pydantic_v2

Expand Down Expand Up @@ -185,7 +185,7 @@ def _get_field_annotation(cls, field: str) -> Type:
if is_optional_type(
annotation
): # this is equivalent to `outer_type_` in pydantic v1
return annotation.__args__[0]
return get_args(annotation)[0]
else:
return annotation
else:
Expand All @@ -205,12 +205,12 @@ def _get_field_inner_type(cls, field: str) -> Type:
if is_optional_type(
annotation
): # this is equivalent to `outer_type_` in pydantic v1
return annotation.__args__[0]
return get_args(annotation)[0]
elif annotation == Tuple:
if len(annotation.__args__) == 0:
if len(get_args(annotation)) == 0:
return Any
else:
annotation.__args__[0]
get_args(annotation)[0]
else:
return annotation
else:
Expand Down
2 changes: 1 addition & 1 deletion docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _get_content_from_node_proto(
field_type = None

if isinstance(field_type, GenericAlias):
field_type = field_type.__args__[0]
field_type = get_args(field_type)[0]

return_field = arg_to_container[content_key](
cls._get_content_from_node_proto(node, field_type=field_type)
Expand Down
4 changes: 2 additions & 2 deletions docarray/display/document_summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union, get_args

from rich.highlighter import RegexHighlighter
from rich.theme import Theme
Expand Down Expand Up @@ -83,7 +83,7 @@ def _get_schema(

if is_union_type(field_type) or is_optional_type(field_type):
sub_tree = Tree(node_name, highlight=True)
for arg in field_type.__args__:
for arg in get_args(field_type):
if safe_issubclass(arg, BaseDoc):
sub_tree.add(
DocumentSummary._get_schema(
Expand Down
47 changes: 46 additions & 1 deletion docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@
Union,
)

import numpy as np

from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import (
is_jax_available,
is_tf_available,
is_torch_available,
)

if is_torch_available():
import torch

if is_jax_available():
import jax

if is_tf_available():
import tensorflow as tf

if TYPE_CHECKING:
from docarray import BaseDoc
Expand Down Expand Up @@ -54,6 +70,35 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
return result


def _is_none_like(val: Any) -> bool:
"""
:param val: any value
:return: true iff `val` equals to `None`, `'None'` or `''`
"""
# Convoluted implementation, but fixes https://github.com/docarray/docarray/issues/1821

# tensor-like types can have unexpected (= broadcast) `==`/`in` semantics,
# so treat separately
is_np_arr = isinstance(val, np.ndarray)
if is_np_arr:
return False

is_torch_tens = is_torch_available() and isinstance(val, torch.Tensor)
if is_torch_tens:
return False

is_tf_tens = is_tf_available() and isinstance(val, tf.Tensor)
if is_tf_tens:
return False

is_jax_arr = is_jax_available() and isinstance(val, jax.numpy.ndarray)
if is_jax_arr:
return False

# "normal" case
return val in ['', 'None', None]


def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[Any, Any]:
"""
Convert a dict, where the keys are access paths ("__"-separated) to a nested dictionary.
Expand All @@ -76,7 +121,7 @@ def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[An
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
value=value if value not in ['', 'None'] else None,
value=None if _is_none_like(value) else value,
)
_update_nested_dicts(to_update=nested_dict, update_with=field2val)
return nested_dict
Expand Down
11 changes: 8 additions & 3 deletions tests/units/array/test_array_from_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ class BasisUnion(BaseDoc):


@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor])
def test_from_to_pandas_tensor_type(tensor_type):
@pytest.mark.parametrize('tensor_len', [0, 5])
def test_from_to_pandas_tensor_type(tensor_type, tensor_len):
class MyDoc(BaseDoc):
embedding: tensor_type
text: str
Expand All @@ -145,9 +146,13 @@ class MyDoc(BaseDoc):
da = DocVec[MyDoc](
[
MyDoc(
embedding=[1, 2, 3, 4, 5], text='hello', image=ImageDoc(url='aux.png')
embedding=list(range(tensor_len)),
text='hello',
image=ImageDoc(url='aux.png'),
),
MyDoc(
embedding=list(range(tensor_len)), text='hello world', image=ImageDoc()
),
MyDoc(embedding=[5, 4, 3, 2, 1], text='hello world', image=ImageDoc()),
],
tensor_type=tensor_type,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/units/typing/tensor/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def test_parametrized_instance():
def test_parametrized_equality():
t1 = parse_obj_as(NdArray[128], np.zeros(128))
t2 = parse_obj_as(NdArray[128], np.zeros(128))
t3 = parse_obj_as(NdArray[256], np.zeros(256))
t3 = parse_obj_as(NdArray[128], np.ones(128))
assert (t1 == t2).all()
assert not t1 == t3
assert not (t1 == t3).any()


def test_parametrized_operations():
Expand Down