1+ import mimetypes
12import os
23import urllib
34import urllib .parse
45import urllib .request
5- from typing import TYPE_CHECKING , Any , Optional , Type , TypeVar , Union
6+ from typing import TYPE_CHECKING , Any , List , Optional , Type , TypeVar , Union
67
78import numpy as np
89from pydantic import AnyUrl as BaseAnyUrl
@@ -27,6 +28,17 @@ class AnyUrl(BaseAnyUrl, AbstractType):
2728 False # turn off host requirement to allow passing of local paths as URL
2829 )
2930
31+ @classmethod
32+ def mime_type (cls ) -> str :
33+ """Returns the mime type this class deals with."""
34+ raise NotImplementedError
35+
36+ @classmethod
37+ def extra_extensions (cls ) -> List [str ]:
38+ """Returns a list of allowed file extensions for this class which
39+ falls outside the scope of mimetypes library."""
40+ raise NotImplementedError
41+
3042 def _to_node_protobuf (self ) -> 'NodeProto' :
3143 """Convert Document into a NodeProto protobuf message. This function should
3244 be called when the Document is nested into another Document that need to
@@ -38,6 +50,37 @@ def _to_node_protobuf(self) -> 'NodeProto':
3850
3951 return NodeProto (text = str (self ), type = self ._proto_type_name )
4052
53+ @classmethod
54+ def is_extension_allowed (cls , value : Any ) -> bool :
55+ """
56+ Check if the file extension of the url is allowed for that class.
57+ First read the mime type of the file, if it fails, then check the file extension.
58+
59+ :param value: url to the file
60+ :return: True if the extension is allowed, False otherwise
61+ """
62+ if cls == AnyUrl : # no check for AnyUrl class
63+ return True
64+ mimetype , _ = mimetypes .guess_type (value .split ("?" )[0 ])
65+ if mimetype :
66+ return mimetype .startswith (cls .mime_type ())
67+ else :
68+ # check if the extension is among the extra extensions of that class
69+ return any (
70+ value .endswith (ext ) or value .split ("?" )[0 ].endswith (ext )
71+ for ext in cls .extra_extensions ()
72+ )
73+
74+ @classmethod
75+ def is_special_case (cls , value : Any ) -> bool :
76+ """
77+ Check if the url is a special case.
78+
79+ :param value: url to the file
80+ :return: True if the url is a special case, False otherwise
81+ """
82+ return False
83+
4184 @classmethod
4285 def validate (
4386 cls : Type [T ],
@@ -61,10 +104,14 @@ def validate(
61104
62105 url = super ().validate (abs_path , field , config ) # basic url validation
63106
64- if input_is_relative_path :
65- return cls (str (value ), scheme = None )
66- else :
67- return cls (str (url ), scheme = None )
107+ # perform check only for subclasses of AnyUrl
108+ if not cls .is_extension_allowed (value ):
109+ if not cls .is_special_case (value ): # check for special cases
110+ raise ValueError (
111+ f'file { value } is not a valid file format for class { cls } '
112+ )
113+
114+ return cls (str (value if input_is_relative_path else url ), scheme = None )
68115
69116 @classmethod
70117 def validate_parts (cls , parts : 'Parts' , validate_port : bool = True ) -> 'Parts' :
0 commit comments