Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ jobs:
poetry install --without dev
poetry run pip install tensorflow==2.12.0
poetry run pip install jax
poetry run pip uninstall -y torch
poetry run pip install torch
- name: Test basic import
run: poetry run python -c 'from docarray import DocList, BaseDoc'

Expand Down Expand Up @@ -106,6 +108,8 @@ jobs:
python -m pip install poetry
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down Expand Up @@ -153,6 +157,8 @@ jobs:
python -m pip install poetry
poetry install --all-extras
poetry run pip install elasticsearch==8.6.2
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down Expand Up @@ -199,6 +205,8 @@ jobs:
python -m pip install poetry
poetry install --all-extras
poetry run pip install protobuf==3.20.0 # we check that we support 3.19
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg
- name: Test
Expand Down Expand Up @@ -244,6 +252,8 @@ jobs:
poetry install --all-extras
poetry run pip install protobuf==3.20.0
poetry run pip install tensorflow==2.12.0
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down Expand Up @@ -290,6 +300,8 @@ jobs:
poetry run pip install protobuf==3.20.0
poetry run pip install tensorflow==2.12.0
poetry run pip install elasticsearch==8.6.2
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down Expand Up @@ -334,6 +346,8 @@ jobs:
poetry install --all-extras
poetry run pip install protobuf==3.20.0
poetry run pip install tensorflow==2.12.0
poetry run pip uninstall -y torch
poetry run pip install torch
sudo apt-get update
sudo apt-get install --no-install-recommends ffmpeg

Expand Down Expand Up @@ -376,6 +390,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --all-extras
poetry run pip uninstall -y torch
poetry run pip install torch
poetry run pip install jaxlib
poetry run pip install jax

Expand Down Expand Up @@ -420,6 +436,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --all-extras
poetry run pip uninstall -y torch
poetry run pip install torch

- name: Test
id: test
Expand Down
41 changes: 40 additions & 1 deletion docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TorchTensor(
"""
Subclass of `torch.Tensor`, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
and coersion from compatible types like numpy.ndarray.
and coercion from compatible types like numpy.ndarray.

This type can also be used in a parametrized way,
specifying the shape of the tensor.
Expand Down Expand Up @@ -112,6 +112,45 @@ class MyDoc(BaseDoc):
```

---


## Compatibility with `torch.compile()`


PyTorch 2 [introduced compilation support](https://pytorch.org/blog/pytorch-2.0-release/) in the form of `torch.compile()`.

Currently, **`torch.compile()` does not properly support subclasses of `torch.Tensor` such as `TorchTensor`**.
The PyTorch team is currently working on a [fix for this issue](https://github.com/pytorch/pytorch/pull/105167#issuecomment-1678050808).

In the meantime, you can use the following workaround:

### Workaround: Convert `TorchTensor` to `torch.Tensor` before calling `torch.compile()`

Converting any `TorchTensor`s tor `torch.Tensor` before calling `torch.compile()` side-steps the issue:

```python
from docarray import BaseDoc
from docarray.typing import TorchTensor
import torch


class MyDoc(BaseDoc):
tensor: TorchTensor


doc = MyDoc(tensor=torch.zeros(128))


def foo(tensor: torch.Tensor):
return tensor @ tensor.t()


foo_compiled = torch.compile(foo)

# unwrap the tensor before passing it to torch.compile()
foo_compiled(doc.tensor.unwrap())
```

"""

__parametrized_meta__ = metaTorchAndNode
Expand Down
9 changes: 9 additions & 0 deletions docs/data_types/tensor/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,12 @@ assert isinstance(docs.tensor, NdArray)

- you don't specify the `tensor_type` parameter
- your tensor field is a Union of tensor or [`AnyTensor`][docarray.typing.tensor.AnyTensor]

## Compatibility of `TorchTensor` and `torch.compile()`

PyTorch 2 [introduced compilation support](https://pytorch.org/blog/pytorch-2.0-release/) in the form of `torch.compile()`.

Currently, **`torch.compile()` does not properly support subclasses of `torch.Tensor` such as [`TorchTensor`][docarray.typing.tensor.TorchTensor]**.
The PyTorch team is currently working on a [fix for this issue](https://github.com/pytorch/pytorch/pull/105167#issuecomment-1678050808).

For a workaround to this issue, see the [`TorchTensor` API reference][docarray.typing.tensor.TorchTensor].
141 changes: 30 additions & 111 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.