1+ import mimetypes
12import os
23import urllib
34import urllib .parse
45import urllib .request
5- from typing import TYPE_CHECKING , Any , Optional , Type , TypeVar , Union
6+ from typing import TYPE_CHECKING , Any , List , Optional , Type , TypeVar , Union
67
78import numpy as np
89from pydantic import AnyUrl as BaseAnyUrl
2021
2122T = TypeVar ('T' , bound = 'AnyUrl' )
2223
24+ mimetypes .init ([])
25+
2326
2427@_register_proto (proto_type_name = 'any_url' )
2528class AnyUrl (BaseAnyUrl , AbstractType ):
2629 host_required = (
2730 False # turn off host requirement to allow passing of local paths as URL
2831 )
2932
33+ @classmethod
34+ def mime_type (cls ) -> str :
35+ """Returns the mime type associated with the class."""
36+ raise NotImplementedError
37+
38+ @classmethod
39+ def extra_extensions (cls ) -> List [str ]:
40+ """Returns a list of allowed file extensions for the class
41+ that are not covered by the mimetypes library."""
42+ raise NotImplementedError
43+
3044 def _to_node_protobuf (self ) -> 'NodeProto' :
3145 """Convert Document into a NodeProto protobuf message. This function should
3246 be called when the Document is nested into another Document that need to
@@ -38,6 +52,48 @@ def _to_node_protobuf(self) -> 'NodeProto':
3852
3953 return NodeProto (text = str (self ), type = self ._proto_type_name )
4054
55+ @staticmethod
56+ def _get_url_extension (url : str ) -> str :
57+ """
58+ Extracts and returns the file extension from a given URL.
59+ If no file extension is present, the function returns an empty string.
60+
61+
62+ :param url: The URL to extract the file extension from.
63+ :return: The file extension without the period, if one exists,
64+ otherwise an empty string.
65+ """
66+
67+ parsed_url = urllib .parse .urlparse (url )
68+ ext = os .path .splitext (parsed_url .path )[1 ]
69+ ext = ext [1 :] if ext .startswith ('.' ) else ext
70+ return ext
71+
72+ @classmethod
73+ def is_extension_allowed (cls , value : Any ) -> bool :
74+ """
75+ Check if the file extension of the URL is allowed for this class.
76+ First, it guesses the mime type of the file. If it fails to detect the
77+ mime type, it then checks the extra file extensions.
78+ Note: This method assumes that any URL without an extension is valid.
79+
80+ :param value: The URL or file path.
81+ :return: True if the extension is allowed, False otherwise
82+ """
83+ if cls is AnyUrl :
84+ return True
85+
86+ url_parts = value .split ('?' )
87+ extension = cls ._get_url_extension (value )
88+ if not extension :
89+ return True
90+
91+ mimetype , _ = mimetypes .guess_type (url_parts [0 ])
92+ if mimetype and mimetype .startswith (cls .mime_type ()):
93+ return True
94+
95+ return extension in cls .extra_extensions ()
96+
4197 @classmethod
4298 def validate (
4399 cls : Type [T ],
@@ -61,10 +117,12 @@ def validate(
61117
62118 url = super ().validate (abs_path , field , config ) # basic url validation
63119
64- if input_is_relative_path :
65- return cls (str (value ), scheme = None )
66- else :
67- return cls (str (url ), scheme = None )
120+ if not cls .is_extension_allowed (value ):
121+ raise ValueError (
122+ f"The file '{ value } ' is not in a valid format for class '{ cls .__name__ } '."
123+ )
124+
125+ return cls (str (value if input_is_relative_path else url ), scheme = None )
68126
69127 @classmethod
70128 def validate_parts (cls , parts : 'Parts' , validate_port : bool = True ) -> 'Parts' :
0 commit comments