Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: polish up the code
Signed-off-by: jupyterjazz <[email protected]>
  • Loading branch information
jupyterjazz committed Jun 26, 2023
commit 4ee48f3c0d7e4fe0f793b6428b3a602b784b69eb
33 changes: 18 additions & 15 deletions docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class AnyUrl(BaseAnyUrl, AbstractType):

@classmethod
def mime_type(cls) -> str:
"""Returns the mime type this class deals with."""
"""Returns the mime type associated with the class."""
raise NotImplementedError

@classmethod
def extra_extensions(cls) -> List[str]:
"""Returns a list of allowed file extensions for this class which
falls outside the scope of mimetypes library."""
"""Returns a list of allowed file extensions for the class
that are not covered by the mimetypes library."""
raise NotImplementedError

def _to_node_protobuf(self) -> 'NodeProto':
Expand All @@ -55,24 +55,25 @@ def _to_node_protobuf(self) -> 'NodeProto':
@classmethod
def is_extension_allowed(cls, value: Any) -> bool:
"""
Check if the file extension of the url is allowed for that class.
First read the mime type of the file, if it fails, then check the file extension.
Check if the file extension of the URL is allowed for this class.
First, it guesses the mime type of the file. If it fails to detect the
mime type, it then checks the extra file extension.

:param value: url to the file
:param value: The URL or file path.
:return: True if the extension is allowed, False otherwise
"""
if cls == AnyUrl: # no check for AnyUrl class
if cls is AnyUrl:
return True
mimetype, _ = mimetypes.guess_type(value.split("?")[0])
print('mimetype for value', mimetype, value, value.split("?")[0])

url_parts = value.split("?")
mimetype, _ = mimetypes.guess_type(url_parts[0])
if mimetype and mimetype.startswith(cls.mime_type()):
return True
filename = value.split("?")[0].split('.')
if len(filename) > 1:
extension = filename[-1]
return extension in cls.extra_extensions()

return False
filename = url_parts[0].split('.')
extension = filename[-1] if len(filename) > 1 else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I am being overly cautious here, but do we know for a fact that there are no corner cases where this splitting into filename and extension could break? Is there some resource or standard that we can reference?
Alternatively, I think pydantic implements some of this internally, Maybe we could repurpose some of their logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah there are many edge cases indeed. I already changed that part, can you take a look again? here are unit tests
https://github.com/docarray/docarray/pull/1669/files#diff-f1502e8b25d6058d51f22b4de5d853aeba8e107952a8b597848f8a918cb055fd

I'll explore how pydantic's doing that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I think this is ok for now, wdyt?


return extension in cls.extra_extensions()

@classmethod
def validate(
Expand All @@ -98,7 +99,9 @@ def validate(
url = super().validate(abs_path, field, config) # basic url validation

if not cls.is_extension_allowed(value):
raise ValueError(f'file {value} is not a valid file format for class {cls}')
raise ValueError(
f"The file '{value}' is not in a valid format for class '{cls.__name__}'."
)

return cls(str(value if input_is_relative_path else url), scheme=None)

Expand Down
5 changes: 4 additions & 1 deletion docarray/typing/url/audio_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def mime_type(cls) -> str:

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
"""
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return []

def load(self: T) -> Tuple[AudioNdArray, int]:
Expand Down
88 changes: 88 additions & 0 deletions docarray/typing/url/extra_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
MESH_EXTRA_EXTENSIONS = [
'3ds',
'3mf',
'ac',
'ac3d',
'amf',
'assimp',
'bvh',
'cob',
'collada',
'ctm',
'dxf',
'e57',
'fbx',
'gltf',
'glb',
'ifc',
'lwo',
'lws',
'lxo',
'md2',
'md3',
'md5',
'mdc',
'm3d',
'mdl',
'ms3d',
'nff',
'obj',
'off',
'pcd',
'pod',
'pmd',
'pmx',
'ply',
'q3o',
'q3s',
'raw',
'sib',
'smd',
'stl',
'ter',
'terragen',
'vtk',
'vrml',
'x3d',
'xaml',
'xgl',
'xml',
'xyz',
'zgl',
'vta',
]

TEXT_EXTRA_EXTENSIONS = ['md', 'log']

POINT_CLOUD_EXTRA_EXTENSIONS = [
'ascii',
'bin',
'b3dm',
'bpf',
'dp',
'dxf',
'e57',
'fls',
'fls',
'glb',
'ply',
'gpf',
'las',
'obj',
'osgb',
'pcap',
'pcd',
'pdal',
'pfm',
'ply',
'ply2',
'pod',
'pods',
'pnts',
'ptg',
'ptx',
'pts',
'rcp',
'xyz',
'zfs',
]
5 changes: 4 additions & 1 deletion docarray/typing/url/image_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def mime_type(cls) -> str:

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
"""
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return []

def load_pil(self, timeout: Optional[float] = None) -> 'PILImage.Image':
Expand Down
6 changes: 4 additions & 2 deletions docarray/typing/url/text_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from docarray.typing.proto_register import _register_proto
from docarray.typing.url.any_url import AnyUrl
from docarray.typing.url.extra_extensions import TEXT_EXTRA_EXTENSIONS

T = TypeVar('T', bound='TextUrl')

Expand All @@ -20,9 +21,10 @@ def mime_type(cls) -> str:
@classmethod
def extra_extensions(cls) -> List[str]:
"""
List of extra file extensions for this type of URL (outside the scope of mimetype library).
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return ['md', 'log']
return TEXT_EXTRA_EXTENSIONS

def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Expand Down
16 changes: 6 additions & 10 deletions docarray/typing/url/url_3d/mesh_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.ndarray import NdArray
from docarray.typing.url.extra_extensions import MESH_EXTRA_EXTENSIONS
from docarray.typing.url.url_3d.url_3d import Url3D

if TYPE_CHECKING:
Expand All @@ -22,16 +23,11 @@ class Mesh3DUrl(Url3D):

@classmethod
def extra_extensions(cls) -> List[str]:
# return list of allowed extensions to be used for mesh if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['3ds', '3mf', 'ac', 'ac3d', 'amf', 'assimp', 'bvh', 'cob', 'collada']
list_b = ['ctm', 'dxf', 'e57', 'fbx', 'gltf', 'glb', 'ifc', 'lwo', 'lws', 'lxo']
list_c = ['md2', 'md3', 'md5', 'mdc', 'm3d', 'mdl', 'ms3d', 'nff', 'obj', 'off']
list_d = ['pcd', 'pod', 'pmd', 'pmx', 'ply', 'q3o', 'q3s', 'raw', 'sib', 'smd']
list_e = ['stl', 'ter' 'terragen', 'vtk', 'vrml', 'x3d', 'xaml', 'xgl', 'xml']
list_f = ['xyz', 'zgl', 'vta']
return list_a + list_b + list_c + list_d + list_e + list_f
"""
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return MESH_EXTRA_EXTENSIONS

def load(
self: T,
Expand Down
20 changes: 10 additions & 10 deletions docarray/typing/url/url_3d/point_cloud_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from docarray.typing.proto_register import _register_proto
from docarray.typing.tensor.ndarray import NdArray
from docarray.typing.url.extra_extensions import POINT_CLOUD_EXTRA_EXTENSIONS
from docarray.typing.url.url_3d.url_3d import Url3D

if TYPE_CHECKING:
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
from docarray.documents.point_cloud.points_and_colors import \
PointsAndColors


T = TypeVar('T', bound='PointCloud3DUrl')
Expand All @@ -23,14 +25,11 @@ class PointCloud3DUrl(Url3D):

@classmethod
def extra_extensions(cls) -> List[str]:
# return list of file format for point cloud if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
list_a = ['ascii', 'bin', 'b3dm', 'bpf', 'dp', 'dxf', 'e57', 'fls', 'fls']
list_b = ['glb', 'ply', 'gpf', 'las', 'obj', 'osgb', 'pcap', 'pcd', 'pdal']
list_c = ['pfm', 'ply', 'ply2', 'pod', 'pods', 'pnts', 'ptg', 'ptx', 'pts']
list_d = ['rcp', 'xyz', 'zfs']
return list_a + list_b + list_c + list_d
"""
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return POINT_CLOUD_EXTRA_EXTENSIONS

def load(
self: T,
Expand Down Expand Up @@ -75,7 +74,8 @@ class MyDoc(BaseDoc):

:return: np.ndarray representing the point cloud
"""
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
from docarray.documents.point_cloud.points_and_colors import \
PointsAndColors

if not trimesh_args:
trimesh_args = {}
Expand Down
5 changes: 4 additions & 1 deletion docarray/typing/url/video_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def mime_type(cls) -> str:

@classmethod
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
"""
Returns a list of additional file extensions that are valid for this class
but cannot be identified by the mimetypes library.
"""
return []

def load(self: T, **kwargs) -> VideoLoadResult:
Expand Down