|
1 | | -from typing import TypeVar |
| 1 | +from typing import TYPE_CHECKING, Any, List, Tuple, Type, TypeVar, Union |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 |
|
|
7 | 7 |
|
8 | 8 | T = TypeVar('T', bound='VideoNdArray') |
9 | 9 |
|
| 10 | +if TYPE_CHECKING: |
| 11 | + from pydantic import BaseConfig |
| 12 | + from pydantic.fields import ModelField |
| 13 | + |
10 | 14 |
|
11 | 15 | class VideoNdArray(AbstractVideoTensor, NdArray): |
12 | 16 | """ |
13 | 17 | Subclass of NdArray, to represent a video tensor. |
14 | | -
|
15 | | - Additionally, this allows storing such a tensor as a .wav audio file. |
| 18 | + Adds video-specific features to the tensor. |
16 | 19 |
|
17 | 20 | EXAMPLE USAGE |
18 | 21 |
|
19 | | - .. code-block:: python |
20 | | -
|
21 | | - from typing import Optional |
22 | | - from pydantic import parse_obj_as |
23 | | - from docarray import Document |
24 | | - from docarray.typing import AudioNdArray, AudioUrl |
25 | | - import numpy as np |
26 | | -
|
27 | | -
|
28 | | - class MyAudioDoc(Document): |
29 | | - title: str |
30 | | - audio_tensor: Optional[AudioNdArray] |
31 | | - url: Optional[AudioUrl] |
32 | | -
|
33 | | -
|
34 | | - # from tensor |
35 | | - doc_1 = MyAudioDoc( |
36 | | - title='my_first_audio_doc', |
37 | | - audio_tensor=np.random.rand(1000, 2), |
38 | | - ) |
39 | | - doc_1.audio_tensor.save_to_wav_file(file_path='path/to/file_1.wav') |
40 | | - # from url |
41 | | - doc_2 = MyAudioDoc( |
42 | | - title='my_second_audio_doc', |
43 | | - url='https://github.com/docarray/docarray/tree/feat-add-audio-v2/tests/toydata/hello.wav', |
44 | | - ) |
45 | | - doc_2.audio_tensor = parse_obj_as(AudioNdArray, doc_2.url.load()) |
46 | | - doc_2.audio_tensor.save_to_wav_file(file_path='path/to/file_2.wav') |
47 | 22 | """ |
48 | 23 |
|
49 | 24 | _PROTO_FIELD_NAME = 'video_ndarray' |
50 | 25 |
|
51 | | - def check_shape(self) -> None: |
52 | | - if self.ndim != 4 or self.shape[-1] != 3 or self.dtype != np.uint8: |
| 26 | + @classmethod |
| 27 | + def validate( |
| 28 | + cls: Type[T], |
| 29 | + value: Union[T, np.ndarray, List[Any], Tuple[Any], Any], |
| 30 | + field: 'ModelField', |
| 31 | + config: 'BaseConfig', |
| 32 | + ) -> T: |
| 33 | + array = super().validate(value=value, field=field, config=config) |
| 34 | + if array.ndim not in [3, 4] or array.shape[-1] != 3: |
53 | 35 | raise ValueError( |
54 | | - f'expects `` with dtype=uint8 and ndim=4 and the last dimension is 3, ' |
55 | | - f'but receiving {self.shape} in {self.dtype}' |
| 36 | + f'Expects tensor with 3 or 4 dimensions and the last dimension equal' |
| 37 | + f' to 3, but received {array.shape} in {array.dtype}' |
56 | 38 | ) |
| 39 | + else: |
| 40 | + return array |
57 | 41 |
|
58 | 42 | def to_numpy(self) -> np.ndarray: |
59 | 43 | return self |
0 commit comments