Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import mimetypes
import os
import urllib
import urllib.parse
import urllib.request
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, List, Optional, Type, TypeVar, Union

import numpy as np
from pydantic import AnyUrl as BaseAnyUrl
Expand All @@ -27,6 +28,17 @@ class AnyUrl(BaseAnyUrl, AbstractType):
False # turn off host requirement to allow passing of local paths as URL
)

@classmethod
def mime_type(cls) -> str:
"""Returns the mime type this class deals with."""
raise NotImplementedError

@classmethod
def extra_extensions(cls) -> List[str]:
"""Returns a list of allowed file extensions for this class which
falls outside the scope of mimetypes library."""
raise NotImplementedError

def _to_node_protobuf(self) -> 'NodeProto':
"""Convert Document into a NodeProto protobuf message. This function should
be called when the Document is nested into another Document that need to
Expand All @@ -38,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto':

return NodeProto(text=str(self), type=self._proto_type_name)

@classmethod
def is_extension_allowed(cls, value: Any) -> bool:
"""
Check if the file extension of the url is allowed for that class.
First read the mime type of the file, if it fails, then check the file extension.

:param value: url to the file
:return: True if the extension is allowed, False otherwise
"""
if cls == AnyUrl: # no check for AnyUrl class
return True
mimetype, _ = mimetypes.guess_type(value.split("?")[0])
if mimetype:
return mimetype.startswith(cls.mime_type())
else:
# check if the extension is among the extra extensions of that class
return any(
value.endswith(ext) or value.split("?")[0].endswith(ext)
for ext in cls.extra_extensions()
)

@classmethod
def is_special_case(cls, value: Any) -> bool:
"""
Check if the url is a special case.

:param value: url to the file
:return: True if the url is a special case, False otherwise
"""
return False

@classmethod
def validate(
cls: Type[T],
Expand All @@ -61,10 +104,14 @@ def validate(

url = super().validate(abs_path, field, config) # basic url validation

if input_is_relative_path:
return cls(str(value), scheme=None)
else:
return cls(str(url), scheme=None)
# perform check only for subclasses of AnyUrl
if not cls.is_extension_allowed(value):
if not cls.is_special_case(value): # check for special cases
raise ValueError(
f'file {value} is not a valid file format for class {cls}'
)

return cls(str(value if input_is_relative_path else url), scheme=None)

@classmethod
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/audio_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar

from docarray.typing import AudioNdArray
from docarray.typing.bytes.audio_bytes import AudioBytes
Expand All @@ -17,6 +17,15 @@ class AudioUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'audio'

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load(self: T) -> Tuple[AudioNdArray, int]:
"""
Load the data from the url into an [`AudioNdArray`][docarray.typing.AudioNdArray]
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/image_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar

from docarray.typing import ImageBytes
from docarray.typing.proto_register import _register_proto
Expand All @@ -20,6 +20,15 @@ class ImageUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'image'

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
"""
Load the image from the bytes into a `PIL.Image.Image` instance
Expand Down
28 changes: 27 additions & 1 deletion docarray/typing/url/text_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar

from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
Expand All @@ -13,6 +13,32 @@ class TextUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'text'

@classmethod
def extra_extensions(cls) -> List[str]:
"""
List of extra file extensions for this type of URL (outside the scope of mimetype library).
"""
return ['.md']

@classmethod
def is_special_case(cls, value: 'AnyUrl') -> bool:
"""
Check if the url is a special case that needs to be handled differently.

:param value: url to the file
:return: True if the url is a special case, False otherwise
"""
if value.startswith('http') or value.startswith('https'):
if len(value.split('/')[-1].split('.')) == 1:
# This handles the case where the value is a URL without a file extension
# for e.g. https://de.wikipedia.org/wiki/Brixen
return True
return False

def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Load the text file into a string.
Expand Down
15 changes: 14 additions & 1 deletion docarray/typing/url/url_3d/mesh_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar

import numpy as np
from pydantic import parse_obj_as
Expand All @@ -20,6 +20,19 @@ class Mesh3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def extra_extensions(cls) -> List[str]:
# return list of allowed extensions to be used for mesh if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada']
list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo']
list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off']
list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd']
list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml']
list_f = ['xyz', 'zgl', 'vta']
return list_a + list_b + list_c + list_d + list_e + list_f

def load(
self: T,
skip_materials: bool = True,
Expand Down
13 changes: 12 additions & 1 deletion docarray/typing/url/url_3d/point_cloud_url.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar

import numpy as np
from pydantic import parse_obj_as
Expand All @@ -21,6 +21,17 @@ class PointCloud3DUrl(Url3D):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def extra_extensions(cls) -> List[str]:
# return list of file format for point cloud if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls']
list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal']
list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts']
list_d = ['rcp', 'xyz', 'zfs']
return list_a + list_b + list_c + list_d

def load(
self: T,
samples: int,
Expand Down
4 changes: 4 additions & 0 deletions docarray/typing/url/url_3d/url_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class Url3D(AnyUrl, ABC):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'application'

def _load_trimesh_instance(
self: T,
force: Optional[str] = None,
Expand Down
11 changes: 10 additions & 1 deletion docarray/typing/url/video_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, TypeVar
from typing import List, Optional, TypeVar

from docarray.typing.bytes.video_bytes import VideoBytes, VideoLoadResult
from docarray.typing.proto_register import _register_proto
Expand All @@ -16,6 +16,15 @@ class VideoUrl(AnyUrl):
Can be remote (web) URL, or a local file path.
"""

@classmethod
def mime_type(cls) -> str:
return 'video'

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

def load(self: T, **kwargs) -> VideoLoadResult:
"""
Load the data from the url into a `NamedTuple` of
Expand Down
2 changes: 1 addition & 1 deletion tests/index/weaviate/test_index_get_del_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc):


def test_index_document_with_bytes(weaviate_client):
doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo")
doc = ImageDoc(id="1", url="www.foo.com/test.png", bytes_=b"foo")

index = WeaviateDocumentIndex[ImageDoc]()
index.index([doc])
Expand Down
2 changes: 0 additions & 2 deletions tests/integrations/predefined_document/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
str(TOYDATA_DIR / 'hello.ogg'),
str(TOYDATA_DIR / 'hello.wma'),
str(TOYDATA_DIR / 'hello.aac'),
str(TOYDATA_DIR / 'hello'),
]

LOCAL_AUDIO_FILES_AND_FORMAT = [
Expand All @@ -40,7 +39,6 @@
(str(TOYDATA_DIR / 'hello.ogg'), 'ogg'),
(str(TOYDATA_DIR / 'hello.wma'), 'asf'),
(str(TOYDATA_DIR / 'hello.aac'), 'adts'),
(str(TOYDATA_DIR / 'hello'), 'wav'),
]

NON_AUDIO_FILES = [
Expand Down
23 changes: 23 additions & 0 deletions tests/units/typing/url/test_audio_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -123,3 +124,25 @@ def test_load_bytes():
assert isinstance(audio_bytes, bytes)
assert isinstance(audio_bytes, AudioBytes)
assert len(audio_bytes) > 0


@pytest.mark.parametrize(
'file_type, file_source',
[
('audio', AUDIO_FILES[0]),
('audio', AUDIO_FILES[1]),
('audio', REMOTE_AUDIO_FILE),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
('application', os.path.join(TOYDATA_DIR, 'test.glb')),
],
)
def test_file_validation(file_type, file_source):
if file_type != AudioUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(AudioUrl, file_source)
else:
parse_obj_as(AudioUrl, file_source)
25 changes: 25 additions & 0 deletions tests/units/typing/url/test_image_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from docarray.base_doc.io.json import orjson_dumps
from docarray.typing import ImageUrl
from tests import TOYDATA_DIR

CUR_DIR = os.path.dirname(os.path.abspath(__file__))
PATH_TO_IMAGE_DATA = os.path.join(CUR_DIR, '..', '..', '..', 'toydata', 'image-data')
Expand Down Expand Up @@ -174,3 +175,27 @@ def test_validation(path_to_img):
url = parse_obj_as(ImageUrl, path_to_img)
assert isinstance(url, ImageUrl)
assert isinstance(url, str)


@pytest.mark.parametrize(
'file_type, file_source',
[
('image', IMAGE_PATHS['png']),
('image', IMAGE_PATHS['jpg']),
('image', IMAGE_PATHS['jpeg']),
('image', REMOTE_JPG),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.wav')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
('application', os.path.join(TOYDATA_DIR, 'test.glb')),
],
)
def test_file_validation(file_type, file_source):
if file_type != ImageUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(ImageUrl, file_source)
else:
parse_obj_as(ImageUrl, file_source)
27 changes: 27 additions & 0 deletions tests/units/typing/url/test_mesh_url.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest
from pydantic.tools import parse_obj_as, schema_json_of
Expand Down Expand Up @@ -75,3 +77,28 @@ def test_validation(path_to_file):
def test_proto_mesh_url():
uri = parse_obj_as(Mesh3DUrl, REMOTE_OBJ_FILE)
uri._to_node_protobuf()


@pytest.mark.parametrize(
'file_type, file_source',
[
('application', MESH_FILES['obj']),
('application', MESH_FILES['glb']),
('application', MESH_FILES['ply']),
('application', REMOTE_OBJ_FILE),
('audio', os.path.join(TOYDATA_DIR, 'hello.aac')),
('audio', os.path.join(TOYDATA_DIR, 'hello.mp3')),
('audio', os.path.join(TOYDATA_DIR, 'hello.ogg')),
('video', os.path.join(TOYDATA_DIR, 'mov_bbb.mp4')),
('image', os.path.join(TOYDATA_DIR, 'test.png')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.html')),
('text', os.path.join(TOYDATA_DIR, 'test' 'test.md')),
('text', os.path.join(TOYDATA_DIR, 'penal_colony.txt')),
],
)
def test_file_validation(file_type, file_source):
if file_type != Mesh3DUrl.mime_type():
with pytest.raises(ValueError):
parse_obj_as(Mesh3DUrl, file_source)
else:
parse_obj_as(Mesh3DUrl, file_source)
Loading