Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
python-version: [3.9]
pydantic-version: ["pydantic-v2", "pydantic-v1"]
test-path: [tests/integrations, tests/units, tests/documentation]
steps:
Expand All @@ -112,6 +112,7 @@ jobs:
./scripts/install_pydantic_v2.sh ${{ matrix.pydantic-version }}
poetry run pip uninstall -y torch
poetry run pip install torch
poetry run pip install numpy==1.26.1
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down
10 changes: 5 additions & 5 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import typing_extensions
from pydantic import BaseModel, Field
from pydantic.fields import FieldInfo
from typing_inspect import is_optional_type
from typing_inspect import get_args, is_optional_type

from docarray.utils._internal.pydantic import is_pydantic_v2

Expand Down Expand Up @@ -185,7 +185,7 @@ def _get_field_annotation(cls, field: str) -> Type:
if is_optional_type(
annotation
): # this is equivalent to `outer_type_` in pydantic v1
return annotation.__args__[0]
return get_args(annotation)[0]
else:
return annotation
else:
Expand All @@ -205,12 +205,12 @@ def _get_field_inner_type(cls, field: str) -> Type:
if is_optional_type(
annotation
): # this is equivalent to `outer_type_` in pydantic v1
return annotation.__args__[0]
return get_args(annotation)[0]
elif annotation == Tuple:
if len(annotation.__args__) == 0:
if len(get_args(annotation)) == 0:
return Any
else:
annotation.__args__[0]
get_args(annotation)[0]
else:
return annotation
else:
Expand Down
2 changes: 1 addition & 1 deletion docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def _get_content_from_node_proto(
field_type = None

if isinstance(field_type, GenericAlias):
field_type = field_type.__args__[0]
field_type = get_args(field_type)[0]

return_field = arg_to_container[content_key](
cls._get_content_from_node_proto(node, field_type=field_type)
Expand Down
4 changes: 2 additions & 2 deletions docarray/display/document_summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union, get_args

from rich.highlighter import RegexHighlighter
from rich.theme import Theme
Expand Down Expand Up @@ -83,7 +83,7 @@ def _get_schema(

if is_union_type(field_type) or is_optional_type(field_type):
sub_tree = Tree(node_name, highlight=True)
for arg in field_type.__args__:
for arg in get_args(field_type):
if safe_issubclass(arg, BaseDoc):
sub_tree.add(
DocumentSummary._get_schema(
Expand Down
47 changes: 46 additions & 1 deletion docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@
Union,
)

import numpy as np

from docarray.utils._internal._typing import safe_issubclass
from docarray.utils._internal.misc import (
is_jax_available,
is_tf_available,
is_torch_available,
)

if is_torch_available():
import torch

if is_jax_available():
import jax

if is_tf_available():
import tensorflow as tf

if TYPE_CHECKING:
from docarray import BaseDoc
Expand Down Expand Up @@ -54,6 +70,35 @@ def _access_path_to_dict(access_path: str, value) -> Dict[str, Any]:
return result


def _is_none_like(val: Any) -> bool:
"""
:param val: any value
:return: true iff `val` equals to `None`, `'None'` or `''`
"""
# Convoluted implementation, but fixes https://github.com/docarray/docarray/issues/1821

# tensor-like types can have unexpected (= broadcast) `==`/`in` semantics,
# so treat separately
is_np_arr = isinstance(val, np.ndarray)
if is_np_arr:
return False

is_torch_tens = is_torch_available() and isinstance(val, torch.Tensor)
if is_torch_tens:
return False

is_tf_tens = is_tf_available() and isinstance(val, tf.Tensor)
if is_tf_tens:
return False

is_jax_arr = is_jax_available() and isinstance(val, jax.numpy.ndarray)
if is_jax_arr:
return False

# "normal" case
return val in ['', 'None', None]


def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[Any, Any]:
"""
Convert a dict, where the keys are access paths ("__"-separated) to a nested dictionary.
Expand All @@ -76,7 +121,7 @@ def _access_path_dict_to_nested_dict(access_path2val: Dict[str, Any]) -> Dict[An
for access_path, value in access_path2val.items():
field2val = _access_path_to_dict(
access_path=access_path,
value=value if value not in ['', 'None'] else None,
value=None if _is_none_like(value) else value,
)
_update_nested_dicts(to_update=nested_dict, update_with=field2val)
return nested_dict
Expand Down
4 changes: 2 additions & 2 deletions tests/units/typing/tensor/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def test_parametrized_instance():
def test_parametrized_equality():
t1 = parse_obj_as(NdArray[128], np.zeros(128))
t2 = parse_obj_as(NdArray[128], np.zeros(128))
t3 = parse_obj_as(NdArray[256], np.zeros(256))
t3 = parse_obj_as(NdArray[128], np.ones(128))
assert (t1 == t2).all()
assert not t1 == t3
assert not (t1 == t3).any()


def test_parametrized_operations():
Expand Down