Skip to content

Commit e0e5cd8

Browse files
feat: validate file formats in url (#1606) (#1669)
Signed-off-by: Mohammad Kalim Akram <[email protected]> Signed-off-by: jupyterjazz <[email protected]> Co-authored-by: Mohammad Kalim Akram <[email protected]>
1 parent 3fc6ecb commit e0e5cd8

File tree

19 files changed

+450
-17
lines changed

19 files changed

+450
-17
lines changed

docarray/documents/text.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TextDoc(BaseDoc):
2424
from docarray.documents import TextDoc
2525
2626
# use it directly
27-
txt_doc = TextDoc(url='http://www.jina.ai/')
27+
txt_doc = TextDoc(url='https://www.gutenberg.org/files/1065/1065-0.txt')
2828
txt_doc.text = txt_doc.url.load()
2929
# model = MyEmbeddingModel()
3030
# txt_doc.embedding = model(txt_doc.text)
@@ -51,7 +51,7 @@ class MyText(TextDoc):
5151
second_embedding: Optional[AnyEmbedding]
5252
5353
54-
txt_doc = MyText(url='http://www.jina.ai/')
54+
txt_doc = MyText(url='https://www.gutenberg.org/files/1065/1065-0.txt')
5555
txt_doc.text = txt_doc.url.load()
5656
# model = MyEmbeddingModel()
5757
# txt_doc.embedding = model(txt_doc.text)
@@ -93,8 +93,8 @@ class MultiModalDoc(BaseDoc):
9393
```python
9494
from docarray.documents import TextDoc
9595
96-
doc = TextDoc(text='This is the main text', url='exampleurl.com')
97-
doc2 = TextDoc(text='This is the main text', url='exampleurl.com')
96+
doc = TextDoc(text='This is the main text', url='exampleurl.com/file')
97+
doc2 = TextDoc(text='This is the main text', url='exampleurl.com/file')
9898
9999
doc == 'This is the main text' # True
100100
doc == doc2 # True

docarray/typing/url/any_url.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import mimetypes
12
import os
23
import urllib
34
import urllib.parse
45
import urllib.request
5-
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
6+
from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union
67

78
import numpy as np
89
from pydantic import AnyUrl as BaseAnyUrl
@@ -20,13 +21,26 @@
2021

2122
T = TypeVar('T', bound='AnyUrl')
2223

24+
mimetypes.init([])
25+
2326

2427
@_register_proto(proto_type_name='any_url')
2528
class AnyUrl(BaseAnyUrl, AbstractType):
2629
host_required = (
2730
False # turn off host requirement to allow passing of local paths as URL
2831
)
2932

33+
@classmethod
34+
def mime_type(cls) -> str:
35+
"""Returns the mime type associated with the class."""
36+
raise NotImplementedError
37+
38+
@classmethod
39+
def extra_extensions(cls) -> List[str]:
40+
"""Returns a list of allowed file extensions for the class
41+
that are not covered by the mimetypes library."""
42+
raise NotImplementedError
43+
3044
def _to_node_protobuf(self) -> 'NodeProto':
3145
"""Convert Document into a NodeProto protobuf message. This function should
3246
be called when the Document is nested into another Document that need to
@@ -38,6 +52,48 @@ def _to_node_protobuf(self) -> 'NodeProto':
3852

3953
return NodeProto(text=str(self), type=self._proto_type_name)
4054

55+
@staticmethod
56+
def _get_url_extension(url: str) -> str:
57+
"""
58+
Extracts and returns the file extension from a given URL.
59+
If no file extension is present, the function returns an empty string.
60+
61+
62+
:param url: The URL to extract the file extension from.
63+
:return: The file extension without the period, if one exists,
64+
otherwise an empty string.
65+
"""
66+
67+
parsed_url = urllib.parse.urlparse(url)
68+
ext = os.path.splitext(parsed_url.path)[1]
69+
ext = ext[1:] if ext.startswith('.') else ext
70+
return ext
71+
72+
@classmethod
73+
def is_extension_allowed(cls, value: Any) -> bool:
74+
"""
75+
Check if the file extension of the URL is allowed for this class.
76+
First, it guesses the mime type of the file. If it fails to detect the
77+
mime type, it then checks the extra file extensions.
78+
Note: This method assumes that any URL without an extension is valid.
79+
80+
:param value: The URL or file path.
81+
:return: True if the extension is allowed, False otherwise
82+
"""
83+
if cls is AnyUrl:
84+
return True
85+
86+
url_parts = value.split('?')
87+
extension = cls._get_url_extension(value)
88+
if not extension:
89+
return True
90+
91+
mimetype, _ = mimetypes.guess_type(url_parts[0])
92+
if mimetype and mimetype.startswith(cls.mime_type()):
93+
return True
94+
95+
return extension in cls.extra_extensions()
96+
4197
@classmethod
4298
def validate(
4399
cls: Type[T],
@@ -61,10 +117,12 @@ def validate(
61117

62118
url = super().validate(abs_path, field, config) # basic url validation
63119

64-
if input_is_relative_path:
65-
return cls(str(value), scheme=None)
66-
else:
67-
return cls(str(url), scheme=None)
120+
if not cls.is_extension_allowed(value):
121+
raise ValueError(
122+
f"The file '{value}' is not in a valid format for class '{cls.__name__}'."
123+
)
124+
125+
return cls(str(value if input_is_relative_path else url), scheme=None)
68126

69127
@classmethod
70128
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':

docarray/typing/url/audio_url.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import warnings
2-
from typing import Optional, Tuple, TypeVar
2+
from typing import List, Optional, Tuple, TypeVar
33

44
from docarray.typing import AudioNdArray
55
from docarray.typing.bytes.audio_bytes import AudioBytes
66
from docarray.typing.proto_register import _register_proto
77
from docarray.typing.url.any_url import AnyUrl
8+
from docarray.typing.url.mimetypes import AUDIO_MIMETYPE
89
from docarray.utils._internal.misc import is_notebook
910

1011
T = TypeVar('T', bound='AudioUrl')
@@ -17,6 +18,18 @@ class AudioUrl(AnyUrl):
1718
Can be remote (web) URL, or a local file path.
1819
"""
1920

21+
@classmethod
22+
def mime_type(cls) -> str:
23+
return AUDIO_MIMETYPE
24+
25+
@classmethod
26+
def extra_extensions(cls) -> List[str]:
27+
"""
28+
Returns a list of additional file extensions that are valid for this class
29+
but cannot be identified by the mimetypes library.
30+
"""
31+
return []
32+
2033
def load(self: T) -> Tuple[AudioNdArray, int]:
2134
"""
2235
Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray]

docarray/typing/url/image_url.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import warnings
2-
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
2+
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar
33

44
from docarray.typing import ImageBytes
55
from docarray.typing.proto_register import _register_proto
66
from docarray.typing.tensor.image import ImageNdArray
77
from docarray.typing.url.any_url import AnyUrl
8+
from docarray.typing.url.mimetypes import IMAGE_MIMETYPE
89
from docarray.utils._internal.misc import is_notebook
910

1011
if TYPE_CHECKING:
@@ -20,6 +21,18 @@ class ImageUrl(AnyUrl):
2021
Can be remote (web) URL, or a local file path.
2122
"""
2223

24+
@classmethod
25+
def mime_type(cls) -> str:
26+
return IMAGE_MIMETYPE
27+
28+
@classmethod
29+
def extra_extensions(cls) -> List[str]:
30+
"""
31+
Returns a list of additional file extensions that are valid for this class
32+
but cannot be identified by the mimetypes library.
33+
"""
34+
return []
35+
2336
def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
2437
"""
2538
Load the image from the bytes into a `PIL.Image.Image` instance

docarray/typing/url/mimetypes.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
TEXT_MIMETYPE = 'text'
2+
AUDIO_MIMETYPE = 'audio'
3+
IMAGE_MIMETYPE = 'image'
4+
OBJ_MIMETYPE = 'application/x-tgif'
5+
VIDEO_MIMETYPE = 'video'
6+
7+
MESH_EXTRA_EXTENSIONS = [
8+
'3ds',
9+
'3mf',
10+
'ac',
11+
'ac3d',
12+
'amf',
13+
'assimp',
14+
'bvh',
15+
'cob',
16+
'collada',
17+
'ctm',
18+
'dxf',
19+
'e57',
20+
'fbx',
21+
'gltf',
22+
'glb',
23+
'ifc',
24+
'lwo',
25+
'lws',
26+
'lxo',
27+
'md2',
28+
'md3',
29+
'md5',
30+
'mdc',
31+
'm3d',
32+
'mdl',
33+
'ms3d',
34+
'nff',
35+
'obj',
36+
'off',
37+
'pcd',
38+
'pod',
39+
'pmd',
40+
'pmx',
41+
'ply',
42+
'q3o',
43+
'q3s',
44+
'raw',
45+
'sib',
46+
'smd',
47+
'stl',
48+
'ter',
49+
'terragen',
50+
'vtk',
51+
'vrml',
52+
'x3d',
53+
'xaml',
54+
'xgl',
55+
'xml',
56+
'xyz',
57+
'zgl',
58+
'vta',
59+
]
60+
61+
TEXT_EXTRA_EXTENSIONS = ['md', 'log']
62+
63+
POINT_CLOUD_EXTRA_EXTENSIONS = [
64+
'ascii',
65+
'bin',
66+
'b3dm',
67+
'bpf',
68+
'dp',
69+
'dxf',
70+
'e57',
71+
'fls',
72+
'fls',
73+
'glb',
74+
'ply',
75+
'gpf',
76+
'las',
77+
'obj',
78+
'osgb',
79+
'pcap',
80+
'pcd',
81+
'pdal',
82+
'pfm',
83+
'ply',
84+
'ply2',
85+
'pod',
86+
'pods',
87+
'pnts',
88+
'ptg',
89+
'ptx',
90+
'pts',
91+
'rcp',
92+
'xyz',
93+
'zfs',
94+
]

docarray/typing/url/text_url.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Optional, TypeVar
1+
from typing import List, Optional, TypeVar
22

33
from docarray.typing.proto_register import _register_proto
44
from docarray.typing.url.any_url import AnyUrl
5+
from docarray.typing.url.mimetypes import TEXT_EXTRA_EXTENSIONS, TEXT_MIMETYPE
56

67
T = TypeVar('T', bound='TextUrl')
78

@@ -13,6 +14,18 @@ class TextUrl(AnyUrl):
1314
Can be remote (web) URL, or a local file path.
1415
"""
1516

17+
@classmethod
18+
def mime_type(cls) -> str:
19+
return TEXT_MIMETYPE
20+
21+
@classmethod
22+
def extra_extensions(cls) -> List[str]:
23+
"""
24+
Returns a list of additional file extensions that are valid for this class
25+
but cannot be identified by the mimetypes library.
26+
"""
27+
return TEXT_EXTRA_EXTENSIONS
28+
1629
def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
1730
"""
1831
Load the text file into a string.

docarray/typing/url/url_3d/mesh_url.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
22

33
import numpy as np
44
from pydantic import parse_obj_as
55

66
from docarray.typing.proto_register import _register_proto
77
from docarray.typing.tensor.ndarray import NdArray
8+
from docarray.typing.url.mimetypes import MESH_EXTRA_EXTENSIONS
89
from docarray.typing.url.url_3d.url_3d import Url3D
910

1011
if TYPE_CHECKING:
@@ -20,6 +21,14 @@ class Mesh3DUrl(Url3D):
2021
Can be remote (web) URL, or a local file path.
2122
"""
2223

24+
@classmethod
25+
def extra_extensions(cls) -> List[str]:
26+
"""
27+
Returns a list of additional file extensions that are valid for this class
28+
but cannot be identified by the mimetypes library.
29+
"""
30+
return MESH_EXTRA_EXTENSIONS
31+
2332
def load(
2433
self: T,
2534
skip_materials: bool = True,

docarray/typing/url/url_3d/point_cloud_url.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar
22

33
import numpy as np
44
from pydantic import parse_obj_as
55

66
from docarray.typing.proto_register import _register_proto
77
from docarray.typing.tensor.ndarray import NdArray
8+
from docarray.typing.url.mimetypes import POINT_CLOUD_EXTRA_EXTENSIONS
89
from docarray.typing.url.url_3d.url_3d import Url3D
910

1011
if TYPE_CHECKING:
@@ -21,6 +22,14 @@ class PointCloud3DUrl(Url3D):
2122
Can be remote (web) URL, or a local file path.
2223
"""
2324

25+
@classmethod
26+
def extra_extensions(cls) -> List[str]:
27+
"""
28+
Returns a list of additional file extensions that are valid for this class
29+
but cannot be identified by the mimetypes library.
30+
"""
31+
return POINT_CLOUD_EXTRA_EXTENSIONS
32+
2433
def load(
2534
self: T,
2635
samples: int,

docarray/typing/url/url_3d/url_3d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from docarray.typing.proto_register import _register_proto
55
from docarray.typing.url.any_url import AnyUrl
6+
from docarray.typing.url.mimetypes import OBJ_MIMETYPE
67
from docarray.utils._internal.misc import import_library
78

89
if TYPE_CHECKING:
@@ -18,6 +19,10 @@ class Url3D(AnyUrl, ABC):
1819
Can be remote (web) URL, or a local file path.
1920
"""
2021

22+
@classmethod
23+
def mime_type(cls) -> str:
24+
return OBJ_MIMETYPE
25+
2126
def _load_trimesh_instance(
2227
self: T,
2328
force: Optional[str] = None,

0 commit comments

Comments
 (0)