Skip to content

Commit 658c247

Browse files
author
anna-charlotte
committed
feat: add video torch tensor and tests
Signed-off-by: anna-charlotte <[email protected]>
1 parent 1e9e631 commit 658c247

File tree

11 files changed

+323
-44
lines changed

11 files changed

+323
-44
lines changed

docarray/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from docarray.array.array import DocumentArray
44
from docarray.document.document import BaseDocument
5-
from docarray.predefined_document import Audio, Image, Mesh3D, PointCloud3D, Text
5+
from docarray.predefined_document import Audio, Image, Mesh3D, PointCloud3D, Text, Video
66

77
__all__ = [
88
'BaseDocument',
@@ -12,4 +12,5 @@
1212
'Text',
1313
'Mesh3D',
1414
'PointCloud3D',
15+
'Video',
1516
]

docarray/predefined_document/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from docarray.predefined_document.mesh import Mesh3D
44
from docarray.predefined_document.point_cloud import PointCloud3D
55
from docarray.predefined_document.text import Text
6+
from docarray.predefined_document.video import Video
67

7-
__all__ = ['Text', 'Image', 'Audio', 'Mesh3D', 'PointCloud3D']
8+
__all__ = ['Text', 'Image', 'Audio', 'Mesh3D', 'PointCloud3D', 'Video']
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Optional, TypeVar
2+
3+
from docarray.document import BaseDocument
4+
from docarray.typing import AnyTensor, Embedding
5+
from docarray.typing.tensor.video.video_tensor import VideoTensor
6+
from docarray.typing.url.video_url import VideoUrl
7+
8+
T = TypeVar('T', bound='Video')
9+
10+
11+
class Video(BaseDocument):
12+
"""
13+
Document for handling video.
14+
The Video Document can contain a VideoUrl (`Video.url`), a VideoTensor
15+
(`Video.tensor`), an AnyTensor ('Video.key_frame_indices), and an Embedding
16+
(`Video.embedding`).
17+
18+
EXAMPLE USAGE:
19+
20+
You can use this Document directly:
21+
22+
You can extend this Document:
23+
24+
You can use this Document for composition:
25+
26+
"""
27+
28+
url: Optional[VideoUrl]
29+
tensor: Optional[VideoTensor]
30+
key_frame_indices: Optional[AnyTensor]
31+
embedding: Optional[Embedding]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from docarray.typing.tensor.video.video_ndarray import VideoNdArray
2+
3+
__all__ = ['VideoNdArray']
4+
5+
try:
6+
import torch # noqa: F401
7+
except ImportError:
8+
pass
9+
else:
10+
from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor # noqa
11+
12+
__all__.extend(['VideoTorchTensor'])
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from abc import ABC, abstractmethod
2+
from typing import BinaryIO, Dict, Generator, Optional, Tuple, Type, TypeVar, Union
3+
4+
import numpy as np
5+
6+
from docarray.typing.tensor.abstract_tensor import AbstractTensor
7+
8+
T = TypeVar('T', bound='AbstractVideoTensor')
9+
10+
11+
class AbstractVideoTensor(AbstractTensor, ABC):
12+
@abstractmethod
13+
def to_numpy(self) -> np.ndarray:
14+
"""
15+
Convert video tensor to numpy.ndarray.
16+
"""
17+
...
18+
19+
def save_to_file(
20+
self: 'T',
21+
file_path: Union[str, BinaryIO],
22+
frame_rate: int = 30,
23+
codec: str = 'h264',
24+
) -> None:
25+
"""
26+
Save video tensor to a .wav file. Mono/stereo is preserved.
27+
28+
29+
:param file_path: path to a .wav file. If file is a string, open the file by
30+
that name, otherwise treat it as a file-like object.
31+
:param frame_rate: frames per second.
32+
:param codec: the name of a decoder/encoder.
33+
"""
34+
np_tensor = self.to_numpy()
35+
36+
video_tensor = np.moveaxis(np.clip(np_tensor, 0, 255), 1, 2).astype('uint8')
37+
38+
import av
39+
40+
with av.open(file_path, mode='w') as container:
41+
stream = container.add_stream(codec, rate=frame_rate)
42+
stream.width = np_tensor.shape[1]
43+
stream.height = np_tensor.shape[2]
44+
stream.pix_fmt = 'yuv420p'
45+
46+
for b in video_tensor:
47+
frame = av.VideoFrame.from_ndarray(b, format='rgb24')
48+
for packet in stream.encode(frame):
49+
container.mux(packet)
50+
51+
for packet in stream.encode():
52+
container.mux(packet)
53+
54+
@classmethod
55+
def generator_from_webcam(
56+
cls: Type['T'],
57+
height_width: Optional[Tuple[int, int]] = None,
58+
show_window: bool = True,
59+
window_title: str = 'webcam',
60+
fps: int = 30,
61+
exit_key: int = 27,
62+
exit_event=None,
63+
tags: Optional[Dict] = None,
64+
) -> Generator['T', None, None]:
65+
...
Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TypeVar
1+
from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
22

33
import numpy as np
44

@@ -7,53 +7,37 @@
77

88
T = TypeVar('T', bound='VideoNdArray')
99

10+
if TYPE_CHECKING:
11+
from pydantic import BaseConfig
12+
from pydantic.fields import ModelField
13+
1014

1115
class VideoNdArray(AbstractVideoTensor, NdArray):
1216
"""
1317
Subclass of NdArray, to represent a video tensor.
14-
15-
Additionally, this allows storing such a tensor as a .wav audio file.
18+
Adds video-specific features to the tensor.
1619
1720
EXAMPLE USAGE
1821
19-
.. code-block:: python
20-
21-
from typing import Optional
22-
from pydantic import parse_obj_as
23-
from docarray import Document
24-
from docarray.typing import AudioNdArray, AudioUrl
25-
import numpy as np
26-
27-
28-
class MyAudioDoc(Document):
29-
title: str
30-
audio_tensor: Optional[AudioNdArray]
31-
url: Optional[AudioUrl]
32-
33-
34-
# from tensor
35-
doc_1 = MyAudioDoc(
36-
title='my_first_audio_doc',
37-
audio_tensor=np.random.rand(1000, 2),
38-
)
39-
doc_1.audio_tensor.save_to_wav_file(file_path='path/to/file_1.wav')
40-
# from url
41-
doc_2 = MyAudioDoc(
42-
title='my_second_audio_doc',
43-
url='https://github.com/docarray/docarray/tree/feat-add-audio-v2/tests/toydata/hello.wav',
44-
)
45-
doc_2.audio_tensor = parse_obj_as(AudioNdArray, doc_2.url.load())
46-
doc_2.audio_tensor.save_to_wav_file(file_path='path/to/file_2.wav')
4722
"""
4823

4924
_PROTO_FIELD_NAME = 'video_ndarray'
5025

51-
def check_shape(self) -> None:
52-
if self.ndim != 4 or self.shape[-1] != 3 or self.dtype != np.uint8:
26+
@classmethod
27+
def validate(
28+
cls: Type[T],
29+
value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
30+
field: 'ModelField',
31+
config: 'BaseConfig',
32+
) -> T:
33+
array = super().validate(value=value, field=field, config=config)
34+
if array.ndim not in [3, 4] or array.shape[-1] != 3:
5335
raise ValueError(
54-
f'expects `` with dtype=uint8 and ndim=4 and the last dimension is 3, '
55-
f'but receiving {self.shape} in {self.dtype}'
36+
f'Expects tensor with 3 or 4 dimensions and the last dimension equal'
37+
f' to 3, but received {array.shape} in {array.dtype}'
5638
)
39+
else:
40+
return array
5741

5842
def to_numpy(self) -> np.ndarray:
5943
return self
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Union
2+
3+
from docarray.typing.tensor.video.video_ndarray import VideoNdArray
4+
5+
try:
6+
import torch # noqa: F401
7+
except ImportError:
8+
VideoTensor = VideoNdArray
9+
10+
else:
11+
from docarray.typing.tensor.video.video_torch_tensor import VideoTorchTensor
12+
13+
VideoTensor = Union[VideoNdArray, VideoTorchTensor] # type: ignore
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union
2+
3+
import numpy as np
4+
5+
from docarray.typing.tensor.torch_tensor import TorchTensor, metaTorchAndNode
6+
from docarray.typing.tensor.video.abstract_video_tensor import AbstractVideoTensor
7+
8+
T = TypeVar('T', bound='VideoTorchTensor')
9+
10+
if TYPE_CHECKING:
11+
from pydantic import BaseConfig
12+
from pydantic.fields import ModelField
13+
14+
15+
class VideoTorchTensor(AbstractVideoTensor, TorchTensor, metaclass=metaTorchAndNode):
16+
"""
17+
Subclass of TorchTensor, to represent a video tensor.
18+
Adds video-specific features to the tensor.
19+
20+
EXAMPLE USAGE
21+
22+
"""
23+
24+
_PROTO_FIELD_NAME = 'video_torch_tensor'
25+
26+
@classmethod
27+
def validate(
28+
cls: Type[T],
29+
value: Union[T, np.ndarray, List[Any], Tuple[Any], Any],
30+
field: 'ModelField',
31+
config: 'BaseConfig',
32+
) -> T:
33+
tensor = super().validate(value=value, field=field, config=config)
34+
if tensor.ndim not in [3, 4] or tensor.shape[-1] != 3:
35+
raise ValueError(
36+
f'Expects tensor with 3 or 4 dimensions and the last dimension equal '
37+
f'to 3, but received {tensor.shape} in {tensor.dtype}'
38+
)
39+
else:
40+
return tensor
41+
42+
def to_numpy(self) -> np.ndarray:
43+
return self.cpu().detach().numpy()

docarray/typing/url/video_url.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,17 @@ def validate(
5050
return cls(str(url), scheme=None)
5151

5252
def load(
53-
self: T, only_keyframes: bool = False, **kwargs
54-
) -> Union[VideoNdArray, Tuple[VideoNdArray, VideoNdArray]]:
53+
self: T, only_keyframes: bool = False, dtype: str = 'int32', **kwargs
54+
) -> Union[VideoNdArray, Tuple[VideoNdArray, np.ndarray]]:
5555
"""
56-
Load the data from the url into a numpy.ndarray.
56+
Load the data from the url into a VideoNdArray or Tuple of VideoNdArray and
57+
np.ndarray.
5758
5859
5960
6061
:param only_keyframes: if True keep only the keyframes, if False keep all frames
6162
and store the indices of the keyframes in :attr:`.tags`
63+
:param dtype: Data-type of the returned array; default: int32.
6264
:param kwargs: supports all keyword arguments that are being supported by
6365
av.open() as described in:
6466
https://pyav.org/docs/stable/api/_globals.html?highlight=open#av.open
@@ -86,7 +88,4 @@ def load(
8688
if only_keyframes:
8789
return frames
8890
else:
89-
indices = parse_obj_as(
90-
VideoNdArray, np.ndarray(keyframe_indices, dtype=np.int32)
91-
)
92-
return frames, indices
91+
return frames, np.ndarray(keyframe_indices, dtype=dtype)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
3+
import numpy as np
4+
import pytest
5+
6+
from docarray import Video
7+
from docarray.typing import VideoNdArray
8+
from tests import TOYDATA_DIR
9+
10+
LOCAL_VIDEO_FILE = str(TOYDATA_DIR / 'mov_bbb.mp4')
11+
REMOTE_VIDEO_FILE = 'https://github.com/docarray/docarray/blob/feat-rewrite-v2/tests/toydata/mov_bbb.mp4?raw=true' # noqa: E501
12+
13+
14+
@pytest.mark.slow
15+
@pytest.mark.internet
16+
@pytest.mark.parametrize('file_url', [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE])
17+
def test_video(file_url):
18+
video = Video(url=file_url)
19+
video.tensor, video.key_frame_indices = video.url.load()
20+
21+
assert isinstance(video.tensor, np.ndarray)
22+
assert isinstance(video.tensor, VideoNdArray)
23+
assert isinstance(video.key_frame_indices, np.ndarray)
24+
25+
26+
@pytest.mark.slow
27+
@pytest.mark.internet
28+
@pytest.mark.parametrize('file_url', [LOCAL_VIDEO_FILE, REMOTE_VIDEO_FILE])
29+
def test_save_video_ndarray(file_url, tmpdir):
30+
tmp_file = str(tmpdir / 'tmp.mp4')
31+
32+
video = Video(url=file_url)
33+
video.tensor, _ = video.url.load()
34+
35+
assert isinstance(video.tensor, np.ndarray)
36+
assert isinstance(video.tensor, VideoNdArray)
37+
38+
video.tensor.save_to_file(tmp_file)
39+
assert os.path.isfile(tmp_file)
40+
41+
video_from_file = Video(url=tmp_file)
42+
video_from_file.tensor = video_from_file.url.load()
43+
assert np.allclose(video.tensor, video_from_file.tensor)

0 commit comments

Comments
 (0)