Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: apply review suggestions
Signed-off-by: Mohammad Kalim Akram <[email protected]>
  • Loading branch information
makram93 committed Jun 23, 2023
commit aa073128bad51eeb66da000843e2e5b3068b9bbf
67 changes: 39 additions & 28 deletions docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ class AnyUrl(BaseAnyUrl, AbstractType):

@classmethod
def mime_type(cls) -> str:
"""Returns the mime type this class deals with."""
raise NotImplementedError

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
"""Returns a list of allowed file extensions for this class which
falls outside the scope of mimetypes library."""
raise NotImplementedError

def _to_node_protobuf(self) -> 'NodeProto':
Expand All @@ -47,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto':

return NodeProto(text=str(self), type=self._proto_type_name)

@classmethod
def is_extension_allowed(cls, value: 'AnyUrl') -> bool:
"""
Check if the file extension of the url is allowed for that class.
First read the mime type of the file, if it fails, then check the file extension.

:param value: url to the file
:return: True if the extension is allowed, False otherwise
"""
if not issubclass(cls, AnyUrl): # no check for AnyUrl class
return True
mimetype, _ = mimetypes.guess_type(value.split("?")[0])
if mimetype:
return mimetype.startswith(cls.mime_type())
else:
# check if the extension is among the extra extensions of that class
return any(
value.endswith(ext) or value.split("?")[0].endswith(ext)
for ext in cls.extra_extensions()
)

@classmethod
def is_special_case(cls, value: 'AnyUrl') -> bool:
"""
Check if the url is a special case.

:param value: url to the file
:return: True if the url is a special case, False otherwise
"""
return False

@classmethod
def validate(
cls: Type[T],
Expand All @@ -70,37 +104,14 @@ def validate(

url = super().validate(abs_path, field, config) # basic url validation

# Use mimetypes to validate file formats
mimetype, encoding = mimetypes.guess_type(value.split("?")[0])
if not mimetype:
# try reading from the request headers if mimetypes failed - could be slow
try:
r = urllib.request.urlopen(value)
except Exception: # noqa
pass # should we raise an error/warning here, since url is not reachable(invalid)?
else:
mimetype = r.headers.get_content_maintype()

skip_check = False
if not mimetype: # not able to automatically detect mimetype
# check if the file extension is among one of the allowed extensions
if not any(
value.endswith(ext) or value.split("?")[0].endswith(ext)
for ext in cls.allowed_extensions()
):
# perform check only for subclasses of AnyUrl
if not cls.is_extension_allowed(url):
if not cls.is_special_case(url): # check for special cases
raise ValueError(
f'file {value} is not a valid file format for class {cls}'
)
else:
skip_check = True # one of the allowed extensions, skip the check

if not skip_check and not mimetype.startswith(cls.mime_type()):
raise ValueError(f'file {value} is not a {cls.mime_type()} file format')

if input_is_relative_path:
return cls(str(value), scheme=None)
else:
return cls(str(url), scheme=None)
return cls(str(value if input_is_relative_path else url), scheme=None)

@classmethod
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/audio_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def mime_type(cls) -> str:
return 'audio'

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/image_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def mime_type(cls) -> str:
return 'image'

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

Expand Down
20 changes: 18 additions & 2 deletions docarray/typing/url/text_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,26 @@ def mime_type(cls) -> str:
return 'text'

@classmethod
def allowed_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
def extra_extensions(cls) -> List[str]:
"""
List of extra file extensions for this type of URL (outside the scope of mimetype library).
"""
return ['.md']

@classmethod
def is_special_case(cls, value: 'AnyUrl') -> bool:
"""
Check if the url is a special case that needs to be handled differently.

:param value: url to the file
:return: True if the url is a special case, False otherwise
"""
if value.startswith('http') or value.startswith('https'):
if len(value.split('/')[-1].split('.')) == 1:
# This handles the case where the value is a URL without a file extension
# for e.g. https://de.wikipedia.org/wiki/Brixen
return True

def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Load the text file into a string.
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/url_3d/mesh_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Mesh3DUrl(Url3D):
"""

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
# return list of allowed extensions to be used for mesh if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/url_3d/point_cloud_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PointCloud3DUrl(Url3D):
"""

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
# return list of file format for point cloud if mimetypes fail to detect
# generated with the help of chatGPT and definitely this list is not exhaustive
# bit hacky because of black formatting, making it a long vertical list
Expand Down
2 changes: 1 addition & 1 deletion docarray/typing/url/video_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def mime_type(cls) -> str:
return 'video'

@classmethod
def allowed_extensions(cls) -> List[str]:
def extra_extensions(cls) -> List[str]:
# add only those extensions that can not be identified by the mimetypes library but are valid
return []

Expand Down
2 changes: 1 addition & 1 deletion tests/units/typing/url/test_text_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_load_to_bytes(url):
@pytest.mark.proto
@pytest.mark.slow
@pytest.mark.internet
@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE, *LOCAL_TEXT_FILES])
@pytest.mark.parametrize('url', [REMOTE_TEXT_FILE])
def test_proto_text_url(url):
uri = parse_obj_as(TextUrl, url)

Expand Down