Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: validate file formats in url
Signed-off-by: Mohammad Kalim Akram <[email protected]>
  • Loading branch information
makram93 committed May 31, 2023
commit 5243e2364c0c8171a9ee479dd31dfa957904a2dd
38 changes: 37 additions & 1 deletion docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import mimetypes
import os
import urllib
import urllib.parse
import urllib.request
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union

import numpy as np
from pydantic import AnyUrl as BaseAnyUrl
Expand All @@ -27,6 +28,14 @@ class AnyUrl(BaseAnyUrl, AbstractType):
False # turn off host requirement to allow passing of local paths as URL
)

@classmethod
def mime_type(cls) -> str:
raise NotImplementedError

@classmethod
def allowed_extensions(cls) -> List[str]:
raise NotImplementedError

def _to_node_protobuf(self) -> 'NodeProto':
"""Convert Document into a NodeProto protobuf message. This function should
be called when the Document is nested into another Document that need to
Expand Down Expand Up @@ -61,6 +70,33 @@ def validate(

url = super().validate(abs_path, field, config) # basic url validation

# Use mimetypes to validate file formats
mimetype, encoding = mimetypes.guess_type(value.split("?")[0])
if not mimetype:
# try reading from the request headers if mimetypes failed - could be slow
try:
r = urllib.request.urlopen(value)
except Exception: # noqa
pass # should we raise an error/warning here, since url is not reachable(invalid)?
else:
mimetype = r.headers.get_content_maintype()

skip_check = False
if not mimetype: # not able to automatically detect mimetype
# check if the file extension is among one of the allowed extensions
if not any(
value.endswith(ext) or value.split("?")[0].endswith(ext)
for ext in cls.allowed_extensions()
):
raise ValueError(
f'file {value} is not a valid file format for class {cls}'
)
else:
skip_check = True # one of the allowed extensions, skip the check

if not skip_check and not mimetype.startswith(cls.mime_type()):
raise ValueError(f'file {value} is not a {cls.mime_type()} file format')

if input_is_relative_path:
return cls(str(value), scheme=None)
else:
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/audio_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar

from docarray.typing import AudioNdArray
from docarray.typing.bytes.audio_bytes import AudioBytes
Expand All @@ -17,6 +17,15 @@ class AudioUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'audio'

@classmethod
def allowed_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load(self: T) -> Tuple[AudioNdArray, int]:
"""
Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray]
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/image_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar

from docarray.typing import ImageBytes
from docarray.typing.proto_register import _register_proto
Expand All @@ -20,6 +20,15 @@ class ImageUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'image'

@classmethod
def allowed_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
"""
Load the image from the bytes into a `PIL.Image.Image` instance
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/text_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar

from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
Expand All @@ -13,6 +13,15 @@ class TextUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'text'

@classmethod
def allowed_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return ['.md']

def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Load the text file into a string.
Expand Down
15 changes: 14 additions & 1 deletion docarray/typing/url/url_3d/mesh_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar

import numpy as np
from pydantic import parse_obj_as
Expand All @@ -20,6 +20,19 @@ class Mesh3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def allowed_extensions(cls) -> List[str]:
# return list of allowed extensions to be used for mesh if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada']
list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo']
list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off']
list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd']
list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml']
list_f = ['xyz', 'zgl', 'vta']
return list_a + list_b + list_c + list_d + list_e + list_f

def load(
self: T,
skip_materials: bool = True,
Expand Down
13 changes: 12 additions & 1 deletion docarray/typing/url/url_3d/point_cloud_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar

import numpy as np
from pydantic import parse_obj_as
Expand All @@ -21,6 +21,17 @@ class PointCloud3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def allowed_extensions(cls) -> List[str]:
# return list of file format for point cloud if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls']
list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal']
list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts']
list_d = ['rcp', 'xyz', 'zfs']
return list_a + list_b + list_c + list_d

def load(
self: T,
samples: int,
Expand Down
4 changes: 4 additions & 0 deletions docarray/typing/url/url_3d/url_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class Url3D(AnyUrl, ABC):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'application'

def _load_trimesh_instance(
self: T,
force: Optional[str] = None,
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/video_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar

from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult
from docarray.typing.proto_register import _register_proto
Expand All @@ -16,6 +16,15 @@ class VideoUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'video'

@classmethod
def allowed_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load(self: T, **kwargs) -> VideoLoadResult:
"""
Load the data from the url into a `NamedTuple` of
Expand Down
23 changes: 23 additions & 0 deletions tests/units/typing/url/test_audio_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -123,3 +124,25 @@ def test_load_bytes():
assert isinstance(audio_bytes, bytes)
assert isinstance(audio_bytes, AudioBytes)
assert len(audio_bytes) > 0


@pytest.mark.parametrize(
'file_type, file_source',
[
('audio', AUDIO_FILES[0]),
('audio', AUDIO_FILES[1]),
('audio', REMOTE_AUDIO_FILE),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
('application', os.path.join(TOYDATA_DIR, 'test.glb')),
],
)
def test_file_validation(file_type, file_source):
if file_type != AudioUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(AudioUrl, file_source)
else:
parse_obj_as(AudioUrl, file_source)
25 changes: 25 additions & 0 deletions tests/units/typing/url/test_image_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from docarray.base_doc.io.json import orjson_dumps
from docarray.typing import ImageUrl
from tests import TOYDATA_DIR

CUR_DIR = os.path.dirname(os.path.abspath(__file__))
PATH_TO_IMAGE_DATA = os.path.join(CUR_DIR, '..', '..', '..', 'toydata', 'image-data')
Expand Down Expand Up @@ -174,3 +175,27 @@ def test_validation(path_to_img):
url = parse_obj_as(ImageUrl, path_to_img)
assert isinstance(url, ImageUrl)
assert isinstance(url, str)


@pytest.mark.parametrize(
'file_type, file_source',
[
('image', IMAGE_PATHS['png']),
('image', IMAGE_PATHS['jpg']),
('image', IMAGE_PATHS['jpeg']),
('image', REMOTE_JPG),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.wav')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
('application', os.path.join(TOYDATA_DIR, 'test.glb')),
],
)
def test_file_validation(file_type, file_source):
if file_type != ImageUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(ImageUrl, file_source)
else:
parse_obj_as(ImageUrl, file_source)
27 changes: 27 additions & 0 deletions tests/units/typing/url/test_mesh_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest
from pydantic.tools import parse_obj_as, schema_json_of
Expand Down Expand Up @@ -75,3 +77,28 @@ def test_validation(path_to_file):
def test_proto_mesh_url():
uri = parse_obj_as(Mesh3DUrl, REMOTE_OBJ_FILE)
uri._to_node_protobuf()


@pytest.mark.parametrize(
'file_type, file_source',
[
('application', MESH_FILES['obj']),
('application', MESH_FILES['glb']),
('application', MESH_FILES['ply']),
('application', REMOTE_OBJ_FILE),
('audio', os.path.join(TOYDATA_DIR, 'hello.aac')),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
],
)
def test_file_validation(file_type, file_source):
if file_type != Mesh3DUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(Mesh3DUrl, file_source)
else:
parse_obj_as(Mesh3DUrl, file_source)
27 changes: 27 additions & 0 deletions tests/units/typing/url/test_point_cloud_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest
from pydantic.tools import parse_obj_as, schema_json_of
Expand Down Expand Up @@ -79,3 +81,28 @@ def test_validation(path_to_file):
def test_proto_point_cloud_url():
uri = parse_obj_as(PointCloud3DUrl, REMOTE_OBJ_FILE)
uri._to_node_protobuf()


@pytest.mark.parametrize(
'file_type, file_source',
[
('application', MESH_FILES['obj']),
('application', MESH_FILES['glb']),
('application', MESH_FILES['ply']),
('application', REMOTE_OBJ_FILE),
('audio', os.path.join(TOYDATA_DIR, 'hello.aac')),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
],
)
def test_file_validation(file_type, file_source):
if file_type != PointCloud3DUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(PointCloud3DUrl, file_source)
else:
parse_obj_as(PointCloud3DUrl, file_source)
21 changes: 21 additions & 0 deletions tests/units/typing/url/test_text_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,24 @@ def test_validation(path_to_file):
url = parse_obj_as(TextUrl, path_to_file)
assert isinstance(url, TextUrl)
assert isinstance(url, str)


@pytest.mark.parametrize(
'file_type, file_source',
[
*[('text', file) for file in LOCAL_TEXT_FILES],
('text', REMOTE_TEXT_FILE),
('audio', os.path.join(TOYDATA_DIR, 'hello.aac')),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('application', os.path.join(TOYDATA_DIR, 'test.glb')),
],
)
def test_file_validation(file_type, file_source):
if file_type != TextUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(TextUrl, file_source)
else:
parse_obj_as(TextUrl, file_source)
Loading