Skip to content

Commit 49fd592

Browse files
authored
feat: validate file formats in url (#1606)
Signed-off-by: Mohammad Kalim Akram <[email protected]>
1 parent 3fc6ecb commit 49fd592

File tree

16 files changed

+288
-15
lines changed

16 files changed

+288
-15
lines changed

docarray/typing/url/any_url.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import mimetypes
12
import os
23
import urllib
34
import urllib.parse
45
import urllib.request
5-
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
6+
from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union
67

78
import numpy as np
89
from pydantic import AnyUrl as BaseAnyUrl
@@ -27,6 +28,17 @@ class AnyUrl(BaseAnyUrl, AbstractType):
2728
False # turn off host requirement to allow passing of local paths as URL
2829
)
2930

31+
@classmethod
32+
def mime_type(cls) -> str:
33+
"""Returns the mime type this class deals with."""
34+
raise NotImplementedError
35+
36+
@classmethod
37+
def extra_extensions(cls) -> List[str]:
38+
"""Returns a list of allowed file extensions for this class which
39+
falls outside the scope of mimetypes library."""
40+
raise NotImplementedError
41+
3042
def _to_node_protobuf(self) -> 'NodeProto':
3143
"""Convert Document into a NodeProto protobuf message. This function should
3244
be called when the Document is nested into another Document that need to
@@ -38,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto':
3850

3951
return NodeProto(text=str(self), type=self._proto_type_name)
4052

53+
@classmethod
54+
def is_extension_allowed(cls, value: Any) -> bool:
55+
"""
56+
Check if the file extension of the url is allowed for that class.
57+
First read the mime type of the file, if it fails, then check the file extension.
58+
59+
:param value: url to the file
60+
:return: True if the extension is allowed, False otherwise
61+
"""
62+
if cls == AnyUrl: # no check for AnyUrl class
63+
return True
64+
mimetype, _ = mimetypes.guess_type(value.split("?")[0])
65+
if mimetype:
66+
return mimetype.startswith(cls.mime_type())
67+
else:
68+
# check if the extension is among the extra extensions of that class
69+
return any(
70+
value.endswith(ext) or value.split("?")[0].endswith(ext)
71+
for ext in cls.extra_extensions()
72+
)
73+
74+
@classmethod
75+
def is_special_case(cls, value: Any) -> bool:
76+
"""
77+
Check if the url is a special case.
78+
79+
:param value: url to the file
80+
:return: True if the url is a special case, False otherwise
81+
"""
82+
return False
83+
4184
@classmethod
4285
def validate(
4386
cls: Type[T],
@@ -61,10 +104,14 @@ def validate(
61104

62105
url = super().validate(abs_path, field, config) # basic url validation
63106

64-
if input_is_relative_path:
65-
return cls(str(value), scheme=None)
66-
else:
67-
return cls(str(url), scheme=None)
107+
# perform check only for subclasses of AnyUrl
108+
if not cls.is_extension_allowed(value):
109+
if not cls.is_special_case(value): # check for special cases
110+
raise ValueError(
111+
f'file {value} is not a valid file format for class {cls}'
112+
)
113+
114+
return cls(str(value if input_is_relative_path else url), scheme=None)
68115

69116
@classmethod
70117
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':

docarray/typing/url/audio_url.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Tuple, TypeVar
2+
from typing import List, Optional, Tuple, TypeVar
33

44
from docarray.typing import AudioNdArray
55
from docarray.typing.bytes.audio_bytes import AudioBytes
@@ -17,6 +17,15 @@ class AudioUrl(AnyUrl):
1717
Can be remote (web) URL, or a local file path.
1818
"""
1919

20+
@classmethod
21+
def mime_type(cls) -> str:
22+
return 'audio'
23+
24+
@classmethod
25+
def extra_extensions(cls) -> List[str]:
26+
# add only those extensions that can not be identified by the mimetypes library but are valid
27+
return []
28+
2029
def load(self: T) -> Tuple[AudioNdArray, int]:
2130
"""
2231
Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray]

docarray/typing/url/image_url.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
2+
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar
33

44
from docarray.typing import ImageBytes
55
from docarray.typing.proto_register import _register_proto
@@ -20,6 +20,15 @@ class ImageUrl(AnyUrl):
2020
Can be remote (web) URL, or a local file path.
2121
"""
2222

23+
@classmethod
24+
def mime_type(cls) -> str:
25+
return 'image'
26+
27+
@classmethod
28+
def extra_extensions(cls) -> List[str]:
29+
# add only those extensions that can not be identified by the mimetypes library but are valid
30+
return []
31+
2332
def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
2433
"""
2534
Load the image from the bytes into a `PIL.Image.Image` instance

docarray/typing/url/text_url.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, TypeVar
1+
from typing import List, Optional, TypeVar
22

33
from docarray.typing.proto_register import _register_proto
44
from docarray.typing.url.any_url import AnyUrl
@@ -13,6 +13,32 @@ class TextUrl(AnyUrl):
1313
Can be remote (web) URL, or a local file path.
1414
"""
1515

16+
@classmethod
17+
def mime_type(cls) -> str:
18+
return 'text'
19+
20+
@classmethod
21+
def extra_extensions(cls) -> List[str]:
22+
"""
23+
List of extra file extensions for this type of URL (outside the scope of mimetype library).
24+
"""
25+
return ['.md']
26+
27+
@classmethod
28+
def is_special_case(cls, value: 'AnyUrl') -> bool:
29+
"""
30+
Check if the url is a special case that needs to be handled differently.
31+
32+
:param value: url to the file
33+
:return: True if the url is a special case, False otherwise
34+
"""
35+
if value.startswith('http') or value.startswith('https'):
36+
if len(value.split('/')[-1].split('.')) == 1:
37+
# This handles the case where the value is a URL without a file extension
38+
# for e.g. https://de.wikipedia.org/wiki/Brixen
39+
return True
40+
return False
41+
1642
def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
1743
"""
1844
Load the text file into a string.

docarray/typing/url/url_3d/mesh_url.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
22

33
import numpy as np
44
from pydantic import parse_obj_as
@@ -20,6 +20,19 @@ class Mesh3DUrl(Url3D):
2020
Can be remote (web) URL, or a local file path.
2121
"""
2222

23+
@classmethod
24+
def extra_extensions(cls) -> List[str]:
25+
# return list of allowed extensions to be used for mesh if mimetypes fail to detect
26+
# generated with the help of chatGPT and definitely this list is not exhaustive
27+
# bit hacky because of black formatting, making it a long vertical list
28+
list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada']
29+
list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo']
30+
list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off']
31+
list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd']
32+
list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml']
33+
list_f = ['xyz', 'zgl', 'vta']
34+
return list_a + list_b + list_c + list_d + list_e + list_f
35+
2336
def load(
2437
self: T,
2538
skip_materials: bool = True,

docarray/typing/url/url_3d/point_cloud_url.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
22

33
import numpy as np
44
from pydantic import parse_obj_as
@@ -21,6 +21,17 @@ class PointCloud3DUrl(Url3D):
2121
Can be remote (web) URL, or a local file path.
2222
"""
2323

24+
@classmethod
25+
def extra_extensions(cls) -> List[str]:
26+
# return list of file format for point cloud if mimetypes fail to detect
27+
# generated with the help of chatGPT and definitely this list is not exhaustive
28+
# bit hacky because of black formatting, making it a long vertical list
29+
list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls']
30+
list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal']
31+
list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts']
32+
list_d = ['rcp', 'xyz', 'zfs']
33+
return list_a + list_b + list_c + list_d
34+
2435
def load(
2536
self: T,
2637
samples: int,

docarray/typing/url/url_3d/url_3d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class Url3D(AnyUrl, ABC):
1818
Can be remote (web) URL, or a local file path.
1919
"""
2020

21+
@classmethod
22+
def mime_type(cls) -> str:
23+
return 'application'
24+
2125
def _load_trimesh_instance(
2226
self: T,
2327
force: Optional[str] = None,

docarray/typing/url/video_url.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, TypeVar
2+
from typing import List, Optional, TypeVar
33

44
from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult
55
from docarray.typing.proto_register import _register_proto
@@ -16,6 +16,15 @@ class VideoUrl(AnyUrl):
1616
Can be remote (web) URL, or a local file path.
1717
"""
1818

19+
@classmethod
20+
def mime_type(cls) -> str:
21+
return 'video'
22+
23+
@classmethod
24+
def extra_extensions(cls) -> List[str]:
25+
# add only those extensions that can not be identified by the mimetypes library but are valid
26+
return []
27+
1928
def load(self: T, **kwargs) -> VideoLoadResult:
2029
"""
2130
Load the data from the url into a `NamedTuple` of

tests/index/weaviate/test_index_get_del_weaviate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc):
403403

404404

405405
def test_index_document_with_bytes(weaviate_client):
406-
doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo")
406+
doc = ImageDoc(id="1", url="www.foo.com/test.png", bytes_=b"foo")
407407

408408
index = WeaviateDocumentIndex[ImageDoc]()
409409
index.index([doc])

tests/integrations/predefined_document/test_audio.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
str(TOYDATA_DIR / 'hello.ogg'),
3030
str(TOYDATA_DIR / 'hello.wma'),
3131
str(TOYDATA_DIR / 'hello.aac'),
32-
str(TOYDATA_DIR / 'hello'),
3332
]
3433

3534
LOCAL_AUDIO_FILES_AND_FORMAT = [
@@ -40,7 +39,6 @@
4039
(str(TOYDATA_DIR / 'hello.ogg'), 'ogg'),
4140
(str(TOYDATA_DIR / 'hello.wma'), 'asf'),
4241
(str(TOYDATA_DIR / 'hello.aac'), 'adts'),
43-
(str(TOYDATA_DIR / 'hello'), 'wav'),
4442
]
4543

4644
NON_AUDIO_FILES = [

0 commit comments

Comments
 (0)