Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: mypy and unit tests
Signed-off-by: Mohammad Kalim Akram <[email protected]>
  • Loading branch information
makram93 committed Jun 23, 2023
commit a7f515d1658d389546f831c9632bf9d742b22601
10 changes: 5 additions & 5 deletions docarray/typing/url/any_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def _to_node_protobuf(self) -> 'NodeProto':
return NodeProto(text=str(self), type=self._proto_type_name)

@classmethod
def is_extension_allowed(cls, value: 'AnyUrl') -> bool:
def is_extension_allowed(cls, value: Any) -> bool:
"""
Check if the file extension of the url is allowed for that class.
First read the mime type of the file, if it fails, then check the file extension.

:param value: url to the file
:return: True if the extension is allowed, False otherwise
"""
if not issubclass(cls, AnyUrl): # no check for AnyUrl class
if cls == AnyUrl: # no check for AnyUrl class
return True
mimetype, _ = mimetypes.guess_type(value.split("?")[0])
if mimetype:
Expand All @@ -72,7 +72,7 @@ def is_extension_allowed(cls, value: 'AnyUrl') -> bool:
)

@classmethod
def is_special_case(cls, value: 'AnyUrl') -> bool:
def is_special_case(cls, value: Any) -> bool:
"""
Check if the url is a special case.

Expand Down Expand Up @@ -105,8 +105,8 @@ def validate(
url = super().validate(abs_path, field, config) # basic url validation

# perform check only for subclasses of AnyUrl
if not cls.is_extension_allowed(url):
if not cls.is_special_case(url): # check for special cases
if not cls.is_extension_allowed(value):
if not cls.is_special_case(value): # check for special cases
raise ValueError(
f'file {value} is not a valid file format for class {cls}'
)
Expand Down
1 change: 1 addition & 0 deletions docarray/typing/url/text_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def is_special_case(cls, value: 'AnyUrl') -> bool:
# This handles the case where the value is a URL without a file extension
# for e.g. https://de.wikipedia.org/wiki/Brixen
return True
return False

def load(self, charset: str = 'utf-8', timeout: Optional[float] = None) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/index/weaviate/test_index_get_del_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class MyMultiModalDoc(BaseDoc):


def test_index_document_with_bytes(weaviate_client):
doc = ImageDoc(id="1", url="www.foo.com", bytes_=b"foo")
doc = ImageDoc(id="1", url="www.foo.com/test.png", bytes_=b"foo")

index = WeaviateDocumentIndex[ImageDoc]()
index.index([doc])
Expand Down
2 changes: 0 additions & 2 deletions tests/integrations/predefined_document/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
str(TOYDATA_DIR / 'hello.ogg'),
str(TOYDATA_DIR / 'hello.wma'),
str(TOYDATA_DIR / 'hello.aac'),
str(TOYDATA_DIR / 'hello'),
]

LOCAL_AUDIO_FILES_AND_FORMAT = [
Expand All @@ -40,7 +39,6 @@
(str(TOYDATA_DIR / 'hello.ogg'), 'ogg'),
(str(TOYDATA_DIR / 'hello.wma'), 'asf'),
(str(TOYDATA_DIR / 'hello.aac'), 'adts'),
(str(TOYDATA_DIR / 'hello'), 'wav'),
]

NON_AUDIO_FILES = [
Expand Down