@@ -83,6 +83,80 @@ async def getsockname(self):
8383 return self .stream .socket .getsockname ()
8484
8585
86+ try :
87+ import httpx
88+
89+ import httpcore
90+ import httpcore .backends .base
91+ import httpcore .backends .trio
92+
93+ from dns .query import _compute_times , _remaining , _expiration_for_this_attempt
94+
95+ class _NetworkBackend (httpcore .backends .base .AsyncNetworkBackend ):
96+ def __init__ (self , resolver , local_port , bootstrap_address , family ):
97+ super ().__init__ ()
98+ self ._local_port = local_port
99+ self ._resolver = resolver
100+ self ._bootstrap_address = bootstrap_address
101+ self ._family = family
102+
103+ async def connect_tcp (self , host , port , timeout , local_address ):
104+ addresses = []
105+ now , expiration = _compute_times (timeout )
106+ if dns .inet .is_address (host ):
107+ addresses .append (host )
108+ elif self ._bootstrap_address is not None :
109+ addresses .append (self ._bootstrap_address )
110+ else :
111+ timeout = _remaining (expiration )
112+ family = self ._family
113+ if local_address :
114+ family = dns .inet .af_for_address (local_address )
115+ answers = await self ._resolver .resolve_name (
116+ host , family = family , lifetime = timeout
117+ )
118+ addresses = answers .addresses ()
119+ for address in addresses :
120+ try :
121+ af = dns .inet .af_for_address (address )
122+ if local_address is not None or self ._local_port != 0 :
123+ source = (local_address , self ._local_port )
124+ else :
125+ source = None
126+ destination = (address , port )
127+ attempt_expiration = _expiration_for_this_attempt (2.0 , expiration )
128+ timeout = _remaining (attempt_expiration )
129+ sock = await Backend ().make_socket (
130+ af , socket .SOCK_STREAM , 0 , source , destination , timeout
131+ )
132+ return httpcore .backends .trio .TrioStream (sock .stream )
133+ except Exception :
134+ continue
135+ raise httpcore .ConnectError
136+
137+ class _HTTPTransport (httpx .AsyncHTTPTransport ):
138+ def __init__ (
139+ self ,
140+ * args ,
141+ local_port = 0 ,
142+ bootstrap_address = None ,
143+ resolver = None ,
144+ family = socket .AF_UNSPEC ,
145+ ** kwargs ,
146+ ):
147+ if resolver is None :
148+ import dns .asyncresolver
149+
150+ resolver = dns .asyncresolver .Resolver ()
151+ super ().__init__ (* args , ** kwargs )
152+ self ._pool ._network_backend = _NetworkBackend (
153+ resolver , local_port , bootstrap_address , family
154+ )
155+
156+ except ImportError :
157+ _HTTPTransport = dns ._asyncbackend .NullTransport # type: ignore
158+
159+
86160class Backend (dns ._asyncbackend .Backend ):
87161 def name (self ):
88162 return "trio"
@@ -104,8 +178,14 @@ async def make_socket(
104178 if source :
105179 await s .bind (_lltuple (source , af ))
106180 if socktype == socket .SOCK_STREAM :
181+ connected = False
107182 with _maybe_timeout (timeout ):
108183 await s .connect (_lltuple (destination , af ))
184+ connected = True
185+ if not connected :
186+ raise dns .exception .Timeout (
187+ timeout = timeout
188+ ) # lgtm[py/unreachable-statement]
109189 except Exception : # pragma: no cover
110190 s .close ()
111191 raise
@@ -130,3 +210,6 @@ async def make_socket(
130210
131211 async def sleep (self , interval ):
132212 await trio .sleep (interval )
213+
214+ def get_transport_class (self ):
215+ return _HTTPTransport
0 commit comments