Skip to content

Commit 6c5f0c9

Browse files
authored
Better DNS-over-HTTPS support. (#908)
This change: Allows resolution hostnames in URLs using dnspython's resolver or via a bootstrap address, without rewriting URLs. Adds full support for source addresses and ports to httpx, except for asyncio I/O where only the source address can be specified. Removes support for requests.
1 parent 1aeec72 commit 6c5f0c9

15 files changed

Lines changed: 411 additions & 225 deletions

.github/workflows/codeql-analysis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
sudo apt install -y gnome-keyring
6161
python -m pip install --upgrade pip
6262
python -m pip install poetry
63-
poetry install -E dnssec -E doh -E idna -E trio
63+
poetry install -E dnssec -E doh -E idna -E trio -E curio
6464
6565
- name: Perform CodeQL Analysis
6666
uses: github/codeql-action/analyze@v2

dns/_asyncbackend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ async def recv(self, size, timeout):
6161
raise NotImplementedError
6262

6363

64+
class NullTransport:
65+
async def connect_tcp(self, host, port, timeout, local_address):
66+
raise NotImplementedError
67+
68+
6469
class Backend: # pragma: no cover
6570
def name(self):
6671
return "unknown"
@@ -83,3 +88,6 @@ def datagram_connection_required(self):
8388

8489
async def sleep(self, interval):
8590
raise NotImplementedError
91+
92+
def get_transport_class(self):
93+
raise NotImplementedError

dns/_asyncio_backend.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,82 @@ async def getsockname(self):
113113
return self.writer.get_extra_info("sockname")
114114

115115

116+
try:
117+
import anyio
118+
import httpx
119+
120+
import httpcore
121+
import httpcore.backends.base
122+
import httpcore.backends.asyncio
123+
124+
from dns.query import _compute_times, _remaining, _expiration_for_this_attempt
125+
126+
class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend):
127+
def __init__(self, resolver, local_port, bootstrap_address, family):
128+
super().__init__()
129+
self._local_port = local_port
130+
self._resolver = resolver
131+
self._bootstrap_address = bootstrap_address
132+
self._family = family
133+
if local_port != 0:
134+
raise NotImplementedError(
135+
"the asyncio transport for HTTPX cannot set the local port"
136+
)
137+
138+
async def connect_tcp(self, host, port, timeout, local_address):
139+
addresses = []
140+
now, expiration = _compute_times(timeout)
141+
if dns.inet.is_address(host):
142+
addresses.append(host)
143+
elif self._bootstrap_address is not None:
144+
addresses.append(self._bootstrap_address)
145+
else:
146+
timeout = _remaining(expiration)
147+
family = self._family
148+
if local_address:
149+
family = dns.inet.af_for_address(local_address)
150+
answers = await self._resolver.resolve_name(
151+
host, family=family, lifetime=timeout
152+
)
153+
addresses = answers.addresses()
154+
for address in addresses:
155+
try:
156+
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
157+
timeout = _remaining(attempt_expiration)
158+
with anyio.fail_after(timeout):
159+
stream = await anyio.connect_tcp(
160+
remote_host=host,
161+
remote_port=port,
162+
local_host=local_address,
163+
)
164+
return httpcore.backends.asyncio.AsyncIOStream(stream)
165+
except Exception:
166+
pass
167+
raise httpcore.ConnectError
168+
169+
class _HTTPTransport(httpx.AsyncHTTPTransport):
170+
def __init__(
171+
self,
172+
*args,
173+
local_port=0,
174+
bootstrap_address=None,
175+
resolver=None,
176+
family=socket.AF_UNSPEC,
177+
**kwargs,
178+
):
179+
if resolver is None:
180+
import dns.asyncresolver
181+
182+
resolver = dns.asyncresolver.Resolver()
183+
super().__init__(*args, **kwargs)
184+
self._pool._network_backend = _NetworkBackend(
185+
resolver, local_port, bootstrap_address, family
186+
)
187+
188+
except ImportError:
189+
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
190+
191+
116192
class Backend(dns._asyncbackend.Backend):
117193
def name(self):
118194
return "asyncio"
@@ -171,3 +247,6 @@ async def sleep(self, interval):
171247

172248
def datagram_connection_required(self):
173249
return _is_win32
250+
251+
def get_transport_class(self):
252+
return _HTTPTransport

dns/_trio_backend.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,80 @@ async def getsockname(self):
8383
return self.stream.socket.getsockname()
8484

8585

86+
try:
87+
import httpx
88+
89+
import httpcore
90+
import httpcore.backends.base
91+
import httpcore.backends.trio
92+
93+
from dns.query import _compute_times, _remaining, _expiration_for_this_attempt
94+
95+
class _NetworkBackend(httpcore.backends.base.AsyncNetworkBackend):
96+
def __init__(self, resolver, local_port, bootstrap_address, family):
97+
super().__init__()
98+
self._local_port = local_port
99+
self._resolver = resolver
100+
self._bootstrap_address = bootstrap_address
101+
self._family = family
102+
103+
async def connect_tcp(self, host, port, timeout, local_address):
104+
addresses = []
105+
now, expiration = _compute_times(timeout)
106+
if dns.inet.is_address(host):
107+
addresses.append(host)
108+
elif self._bootstrap_address is not None:
109+
addresses.append(self._bootstrap_address)
110+
else:
111+
timeout = _remaining(expiration)
112+
family = self._family
113+
if local_address:
114+
family = dns.inet.af_for_address(local_address)
115+
answers = await self._resolver.resolve_name(
116+
host, family=family, lifetime=timeout
117+
)
118+
addresses = answers.addresses()
119+
for address in addresses:
120+
try:
121+
af = dns.inet.af_for_address(address)
122+
if local_address is not None or self._local_port != 0:
123+
source = (local_address, self._local_port)
124+
else:
125+
source = None
126+
destination = (address, port)
127+
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
128+
timeout = _remaining(attempt_expiration)
129+
sock = await Backend().make_socket(
130+
af, socket.SOCK_STREAM, 0, source, destination, timeout
131+
)
132+
return httpcore.backends.trio.TrioStream(sock.stream)
133+
except Exception:
134+
continue
135+
raise httpcore.ConnectError
136+
137+
class _HTTPTransport(httpx.AsyncHTTPTransport):
138+
def __init__(
139+
self,
140+
*args,
141+
local_port=0,
142+
bootstrap_address=None,
143+
resolver=None,
144+
family=socket.AF_UNSPEC,
145+
**kwargs,
146+
):
147+
if resolver is None:
148+
import dns.asyncresolver
149+
150+
resolver = dns.asyncresolver.Resolver()
151+
super().__init__(*args, **kwargs)
152+
self._pool._network_backend = _NetworkBackend(
153+
resolver, local_port, bootstrap_address, family
154+
)
155+
156+
except ImportError:
157+
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
158+
159+
86160
class Backend(dns._asyncbackend.Backend):
87161
def name(self):
88162
return "trio"
@@ -104,8 +178,14 @@ async def make_socket(
104178
if source:
105179
await s.bind(_lltuple(source, af))
106180
if socktype == socket.SOCK_STREAM:
181+
connected = False
107182
with _maybe_timeout(timeout):
108183
await s.connect(_lltuple(destination, af))
184+
connected = True
185+
if not connected:
186+
raise dns.exception.Timeout(
187+
timeout=timeout
188+
) # lgtm[py/unreachable-statement]
109189
except Exception: # pragma: no cover
110190
s.close()
111191
raise
@@ -130,3 +210,6 @@ async def make_socket(
130210

131211
async def sleep(self, interval):
132212
await trio.sleep(interval)
213+
214+
def get_transport_class(self):
215+
return _HTTPTransport

dns/asyncquery.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
BadResponse,
4444
ssl,
4545
UDPMode,
46-
_have_httpx,
46+
have_doh,
4747
_have_http2,
4848
NoDOH,
4949
NoDOQ,
5050
)
5151

52-
if _have_httpx:
52+
if have_doh:
5353
import httpx
5454

5555
# for brevity
@@ -495,6 +495,9 @@ async def https(
495495
path: str = "/dns-query",
496496
post: bool = True,
497497
verify: Union[bool, str] = True,
498+
bootstrap_address: Optional[str] = None,
499+
resolver: Optional["dns.asyncresolver.Resolver"] = None,
500+
family: Optional[int] = socket.AF_UNSPEC,
498501
) -> dns.message.Message:
499502
"""Return the response obtained after sending a query via DNS-over-HTTPS.
500503
@@ -508,8 +511,10 @@ async def https(
508511
parameters, exceptions, and return type of this method.
509512
"""
510513

511-
if not _have_httpx:
514+
if not have_doh:
512515
raise NoDOH("httpx is not available.") # pragma: no cover
516+
if client and not isinstance(client, httpx.AsyncClient):
517+
raise ValueError("session parameter must be an httpx.AsyncClient")
513518

514519
wire = q.to_wire()
515520
try:
@@ -518,15 +523,30 @@ async def https(
518523
af = None
519524
transport = None
520525
headers = {"accept": "application/dns-message"}
521-
if af is not None:
526+
if af is not None and dns.inet.is_address(where):
522527
if af == socket.AF_INET:
523528
url = "https://{}:{}{}".format(where, port, path)
524529
elif af == socket.AF_INET6:
525530
url = "https://[{}]:{}{}".format(where, port, path)
526531
else:
527532
url = where
528-
if source is not None:
529-
transport = httpx.AsyncHTTPTransport(local_address=source[0])
533+
534+
backend = dns.asyncbackend.get_default_backend()
535+
536+
if source is None:
537+
local_address = None
538+
local_port = 0
539+
else:
540+
local_address = source
541+
local_port = source_port
542+
transport = backend.get_transport_class()(
543+
local_address=local_address,
544+
verify=verify,
545+
local_port=local_port,
546+
bootstrap_address=bootstrap_address,
547+
resolver=resolver,
548+
family=family,
549+
)
530550

531551
if client:
532552
cm: contextlib.AbstractAsyncContextManager = NullContext(client)

dns/inet.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,12 @@ def low_level_address_tuple(
171171
return tup
172172
else:
173173
raise NotImplementedError(f"unknown address family {af}")
174+
175+
176+
def any_for_af(af):
177+
"""Return the 'any' address for the specified address family."""
178+
if af == socket.AF_INET:
179+
return "0.0.0.0"
180+
elif af == socket.AF_INET6:
181+
return "::"
182+
raise NotImplementedError(f"unknown address family {af}")

0 commit comments

Comments
 (0)