Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
docs: add workaround for torch compile
Signed-off-by: Johannes Messner <[email protected]>
  • Loading branch information
JohannesMessner committed Aug 15, 2023
commit 49eb47b30c745c6813b91f8afde5702b5d882b8f
39 changes: 38 additions & 1 deletion docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TorchTensor(
"""
Subclass of `torch.Tensor`, intended for use in a Document.
This enables (de)serialization from/to protobuf and json, data validation,
and coersion from compatible types like numpy.ndarray.
and coercion from compatible types like numpy.ndarray.

This type can also be used in a parametrized way,
specifying the shape of the tensor.
Expand Down Expand Up @@ -112,6 +112,43 @@ class MyDoc(BaseDoc):
```

---

## Compatibility with `torch.compile()`

PyTorch 2 [introduced compilation support](https://pytorch.org/blog/pytorch-2.0-release/) in the form of `torch.compile()`.

Currently, **`torch.compile()` does not properly support subclasses of `torch.Tensor` such as `TorchTensor`**.
The PyTorch team is currently working on a [fix for this issue](https://github.com/pytorch/pytorch/pull/105167#issuecomment-1678050808).

In the meantime, you can use the following workaround:

### Workaround: Convert `TorchTensor` to `torch.Tensor` before calling `torch.compile()`

Converting any `TorchTensor`s tor `torch.Tensor` before calling `torch.compile()` side-steps the issue:

```python
from docarray import BaseDoc
from docarray.typing import TorchTensor
import torch


class MyDoc(BaseDoc):
tensor: TorchTensor


doc = MyDoc(tensor=torch.zeros(128))


def foo(tensor: torch.Tensor):
return tensor @ tensor.t()


foo_compiled = torch.compile(foo)

# unwrap the tensor before passing it to torch.compile()
foo_compiled(doc.tensor.unwrap())
```

"""

__parametrized_meta__ = metaTorchAndNode
Expand Down
9 changes: 9 additions & 0 deletions docs/data_types/tensor/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,12 @@ assert isinstance(docs.tensor, NdArray)

- you don't specify the `tensor_type` parameter
- your tensor field is a Union of tensor or [`AnyTensor`][docarray.typing.tensor.AnyTensor]

## Compatibility of `TorchTensor` and `torch.compile()`

PyTorch 2 [introduced compilation support](https://pytorch.org/blog/pytorch-2.0-release/) in the form of `torch.compile()`.

Currently, **`torch.compile()` does not properly support subclasses of `torch.Tensor` such as [`TorchTensor`][docarray.typing.tensor.TorchTensor]**.
The PyTorch team is currently working on a [fix for this issue](https://github.com/pytorch/pytorch/pull/105167#issuecomment-1678050808).

For a workaround to this issue, see [`TorchTensor`][docarray.typing.tensor.TorchTensor#Compatibility-with-torch-compile]