Skip to content

Commit ef8e772

Browse files
skvarkjoein
andcommitted
Feat: bearer token authentication support (#591)
* bearer token authentication provider support * add tests and checks, move auth file to separate dir * fix error message * remove locks * rename var * refactoring: refactor exceptions, fix mypy * fix: regen async * tests: extend token tests to check token updates * new: add warning when auth token provider is used with an insecure connection * fix: propagate auth token to rest client even with prefer_grpc set --------- Co-authored-by: George Panchuk <[email protected]>
1 parent efb0309 commit ef8e772

File tree

9 files changed

+363
-86
lines changed

9 files changed

+363
-86
lines changed

qdrant_client/async_qdrant_client.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@
1010
# ****** WARNING: THIS FILE IS AUTOGENERATED ******
1111

1212
import warnings
13-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
13+
from typing import (
14+
Any,
15+
Awaitable,
16+
Callable,
17+
Dict,
18+
Iterable,
19+
List,
20+
Mapping,
21+
Optional,
22+
Sequence,
23+
Tuple,
24+
Union,
25+
)
1426

1527
from qdrant_client import grpc as grpc
1628
from qdrant_client.async_client_base import AsyncQdrantBase
@@ -68,6 +80,7 @@ class AsyncQdrantClient(AsyncQdrantFastembedMixin):
6880
force_disable_check_same_thread:
6981
For QdrantLocal, force disable check_same_thread. Default: `False`
7082
Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient.
83+
auth_token_provider: Callback function to get Bearer access token. If given, the function will be called before each request to get the token.
7184
**kwargs: Additional arguments passed directly into REST client initialization
7285
7386
"""
@@ -87,6 +100,9 @@ def __init__(
87100
path: Optional[str] = None,
88101
force_disable_check_same_thread: bool = False,
89102
grpc_options: Optional[Dict[str, Any]] = None,
103+
auth_token_provider: Optional[
104+
Union[Callable[[], str], Callable[[], Awaitable[str]]]
105+
] = None,
90106
**kwargs: Any,
91107
):
92108
super().__init__(**kwargs)
@@ -117,6 +133,7 @@ def __init__(
117133
timeout=timeout,
118134
host=host,
119135
grpc_options=grpc_options,
136+
auth_token_provider=auth_token_provider,
120137
**kwargs,
121138
)
122139

qdrant_client/async_qdrant_remote.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from multiprocessing import get_all_start_methods
1616
from typing import (
1717
Any,
18+
Awaitable,
19+
Callable,
1820
Dict,
1921
Iterable,
2022
List,
@@ -35,6 +37,7 @@
3537
from qdrant_client import grpc as grpc
3638
from qdrant_client._pydantic_compat import construct
3739
from qdrant_client.async_client_base import AsyncQdrantBase
40+
from qdrant_client.auth import BearerAuth
3841
from qdrant_client.connection import get_async_channel as get_channel
3942
from qdrant_client.conversions import common_types as types
4043
from qdrant_client.conversions.common_types import get_args_subscribed
@@ -63,6 +66,9 @@ def __init__(
6366
timeout: Optional[int] = None,
6467
host: Optional[str] = None,
6568
grpc_options: Optional[Dict[str, Any]] = None,
69+
auth_token_provider: Optional[
70+
Union[Callable[[], str], Callable[[], Awaitable[str]]]
71+
] = None,
6672
**kwargs: Any,
6773
):
6874
super().__init__(**kwargs)
@@ -100,6 +106,7 @@ def __init__(
100106
self._port = port
101107
self._timeout = math.ceil(timeout) if timeout is not None else None
102108
self._api_key = api_key
109+
self._auth_token_provider = auth_token_provider
103110
limits = kwargs.pop("limits", None)
104111
if limits is None:
105112
if self._host in ["localhost", "127.0.0.1"]:
@@ -109,7 +116,7 @@ def __init__(
109116
self._rest_headers = kwargs.pop("metadata", {})
110117
if api_key is not None:
111118
if self._scheme == "http":
112-
warnings.warn("Api key is used with unsecure connection.")
119+
warnings.warn("Api key is used with an insecure connection.")
113120
self._rest_headers["api-key"] = api_key
114121
self._grpc_headers.append(("api-key", api_key))
115122
grpc_compression: Optional[Compression] = kwargs.pop("grpc_compression", None)
@@ -129,6 +136,11 @@ def __init__(
129136
self._rest_args["limits"] = limits
130137
if self._timeout is not None:
131138
self._rest_args["timeout"] = self._timeout
139+
if self._auth_token_provider is not None:
140+
if self._scheme == "http":
141+
warnings.warn("Auth token provider is used with an insecure connection.")
142+
bearer_auth = BearerAuth(self._auth_token_provider)
143+
self._rest_args["auth"] = bearer_auth
132144
self.openapi_client: AsyncApis[AsyncApiClient] = AsyncApis(
133145
host=self.rest_uri, **self._rest_args
134146
)
@@ -182,6 +194,7 @@ def _init_grpc_channel(self) -> None:
182194
metadata=self._grpc_headers,
183195
options=self._grpc_options,
184196
compression=self._grpc_compression,
197+
auth_token_provider=self._auth_token_provider,
185198
)
186199

187200
def _init_grpc_points_client(self) -> None:

qdrant_client/auth/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from qdrant_client.auth.bearer_auth import BearerAuth

qdrant_client/auth/bearer_auth.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import asyncio
2+
from typing import Awaitable, Callable, Optional, Union
3+
4+
import httpx
5+
6+
7+
class BearerAuth(httpx.Auth):
8+
def __init__(
9+
self,
10+
auth_token_provider: Union[Callable[[], str], Callable[[], Awaitable[str]]],
11+
):
12+
self.async_token: Optional[Callable[[], Awaitable[str]]] = None
13+
self.sync_token: Optional[Callable[[], str]] = None
14+
15+
if asyncio.iscoroutinefunction(auth_token_provider):
16+
self.async_token = auth_token_provider
17+
else:
18+
if callable(auth_token_provider):
19+
self.sync_token = auth_token_provider # type: ignore
20+
else:
21+
raise ValueError("auth_token_provider must be a callable or awaitable")
22+
23+
def _sync_get_token(self) -> str:
24+
if self.sync_token is None:
25+
raise ValueError("Synchronous token provider is not set.")
26+
return self.sync_token()
27+
28+
def sync_auth_flow(self, request: httpx.Request) -> httpx.Request:
29+
token = self._sync_get_token()
30+
request.headers["Authorization"] = f"Bearer {token}"
31+
yield request
32+
33+
async def _async_get_token(self) -> str:
34+
if self.async_token is not None:
35+
return await self.async_token() # type: ignore
36+
# Fallback to synchronous token if asynchronous token is not available
37+
return self._sync_get_token()
38+
39+
async def async_auth_flow(self, request: httpx.Request) -> httpx.Request:
40+
token = await self._async_get_token()
41+
request.headers["Authorization"] = f"Bearer {token}"
42+
yield request

qdrant_client/connection.py

Lines changed: 57 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import asyncio
12
import collections
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
34

45
import grpc
56

@@ -64,7 +65,7 @@ def __init__(self, interceptor_function: Callable):
6465
async def intercept_unary_unary(
6566
self, continuation: Any, client_call_details: Any, request: Any
6667
) -> Any:
67-
new_details, new_request_iterator, postprocess = self._fn(
68+
new_details, new_request_iterator, postprocess = await self._fn(
6869
client_call_details, iter((request,)), False, False
6970
)
7071
next_request = next(new_request_iterator)
@@ -74,7 +75,7 @@ async def intercept_unary_unary(
7475
async def intercept_unary_stream(
7576
self, continuation: Any, client_call_details: Any, request: Any
7677
) -> Any:
77-
new_details, new_request_iterator, postprocess = self._fn(
78+
new_details, new_request_iterator, postprocess = await self._fn(
7879
client_call_details, iter((request,)), False, True
7980
)
8081
response_it = await continuation(new_details, next(new_request_iterator))
@@ -83,7 +84,7 @@ async def intercept_unary_stream(
8384
async def intercept_stream_unary(
8485
self, continuation: Any, client_call_details: Any, request_iterator: Any
8586
) -> Any:
86-
new_details, new_request_iterator, postprocess = self._fn(
87+
new_details, new_request_iterator, postprocess = await self._fn(
8788
client_call_details, request_iterator, True, False
8889
)
8990
response = await continuation(new_details, new_request_iterator)
@@ -92,7 +93,7 @@ async def intercept_stream_unary(
9293
async def intercept_stream_stream(
9394
self, continuation: Any, client_call_details: Any, request_iterator: Any
9495
) -> Any:
95-
new_details, new_request_iterator, postprocess = self._fn(
96+
new_details, new_request_iterator, postprocess = await self._fn(
9697
client_call_details, request_iterator, True, True
9798
)
9899
response_it = await continuation(new_details, new_request_iterator)
@@ -125,14 +126,18 @@ class _ClientAsyncCallDetails(
125126
pass
126127

127128

128-
def header_adder_interceptor(new_metadata: List[Tuple[str, str]]) -> _GenericClientInterceptor:
129+
def header_adder_interceptor(
130+
new_metadata: List[Tuple[str, str]],
131+
auth_token_provider: Optional[Callable[[], str]] = None,
132+
) -> _GenericClientInterceptor:
129133
def intercept_call(
130134
client_call_details: _ClientCallDetails,
131135
request_iterator: Any,
132136
_request_streaming: Any,
133137
_response_streaming: Any,
134138
) -> Tuple[_ClientCallDetails, Any, Any]:
135139
metadata = []
140+
136141
if client_call_details.metadata is not None:
137142
metadata = list(client_call_details.metadata)
138143
for header, value in new_metadata:
@@ -142,6 +147,13 @@ def intercept_call(
142147
value,
143148
)
144149
)
150+
151+
if auth_token_provider:
152+
if not asyncio.iscoroutinefunction(auth_token_provider):
153+
metadata.append(("authorization", f"Bearer {auth_token_provider()}"))
154+
else:
155+
raise ValueError("Synchronous channel requires synchronous auth token provider.")
156+
145157
client_call_details = _ClientCallDetails(
146158
client_call_details.method,
147159
client_call_details.timeout,
@@ -154,9 +166,10 @@ def intercept_call(
154166

155167

156168
def header_adder_async_interceptor(
157-
new_metadata: List[Tuple[str, str]]
169+
new_metadata: List[Tuple[str, str]],
170+
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
158171
) -> _GenericAsyncClientInterceptor:
159-
def intercept_call(
172+
async def intercept_call(
160173
client_call_details: grpc.aio.ClientCallDetails,
161174
request_iterator: Any,
162175
_request_streaming: Any,
@@ -172,6 +185,14 @@ def intercept_call(
172185
value,
173186
)
174187
)
188+
189+
if auth_token_provider:
190+
if asyncio.iscoroutinefunction(auth_token_provider):
191+
token = await auth_token_provider()
192+
else:
193+
token = auth_token_provider()
194+
metadata.append(("authorization", f"Bearer {token}"))
195+
175196
client_call_details = client_call_details._replace(metadata=metadata)
176197
return client_call_details, request_iterator, None
177198

@@ -200,38 +221,21 @@ def get_channel(
200221
metadata: Optional[List[Tuple[str, str]]] = None,
201222
options: Optional[Dict[str, Any]] = None,
202223
compression: Optional[grpc.Compression] = None,
224+
auth_token_provider: Optional[Callable[[], str]] = None,
203225
) -> grpc.Channel:
204-
# gRPC client options
226+
# Parse gRPC client options
205227
_options = parse_channel_options(options)
228+
metadata_interceptor = header_adder_interceptor(
229+
new_metadata=metadata or [], auth_token_provider=auth_token_provider
230+
)
206231

207232
if ssl:
208-
if metadata:
209-
210-
def metadata_callback(context: Any, callback: Any) -> None:
211-
# for more info see grpc docs
212-
callback(metadata, None)
213-
214-
# build ssl credentials using the cert the same as before
215-
cert_creds = grpc.ssl_channel_credentials()
216-
217-
# now build meta data credentials
218-
auth_creds = grpc.metadata_call_credentials(metadata_callback)
219-
220-
# combine the cert credentials and the macaroon auth credentials
221-
# such that every call is properly encrypted and authenticated
222-
creds = grpc.composite_channel_credentials(cert_creds, auth_creds)
223-
else:
224-
creds = grpc.ssl_channel_credentials()
225-
226-
# finally pass in the combined credentials when creating a channel
227-
return grpc.secure_channel(f"{host}:{port}", creds, _options, compression)
233+
ssl_creds = grpc.ssl_channel_credentials()
234+
channel = grpc.secure_channel(f"{host}:{port}", ssl_creds, _options, compression)
235+
return grpc.intercept_channel(channel, metadata_interceptor)
228236
else:
229-
if metadata:
230-
metadata_interceptor = header_adder_interceptor(metadata)
231-
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
232-
return grpc.intercept_channel(channel, metadata_interceptor)
233-
else:
234-
return grpc.insecure_channel(f"{host}:{port}", _options, compression)
237+
channel = grpc.insecure_channel(f"{host}:{port}", _options, compression)
238+
return grpc.intercept_channel(channel, metadata_interceptor)
235239

236240

237241
def get_async_channel(
@@ -241,36 +245,26 @@ def get_async_channel(
241245
metadata: Optional[List[Tuple[str, str]]] = None,
242246
options: Optional[Dict[str, Any]] = None,
243247
compression: Optional[grpc.Compression] = None,
248+
auth_token_provider: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
244249
) -> grpc.aio.Channel:
245-
# gRPC client options
250+
# Parse gRPC client options
246251
_options = parse_channel_options(options)
247252

248-
if ssl:
249-
if metadata:
250-
251-
def metadata_callback(context: Any, callback: Any) -> None:
252-
# for more info see grpc docs
253-
callback(metadata, None)
254-
255-
# build ssl credentials using the cert the same as before
256-
cert_creds = grpc.ssl_channel_credentials()
253+
# Create metadata interceptor
254+
metadata_interceptor = header_adder_async_interceptor(
255+
new_metadata=metadata or [], auth_token_provider=auth_token_provider
256+
)
257257

258-
# now build meta data credentials
259-
auth_creds = grpc.metadata_call_credentials(metadata_callback)
260-
261-
# combine the cert credentials and the macaroon auth credentials
262-
# such that every call is properly encrypted and authenticated
263-
creds = grpc.composite_channel_credentials(cert_creds, auth_creds)
264-
else:
265-
creds = grpc.ssl_channel_credentials()
266-
267-
# finally pass in the combined credentials when creating a channel
268-
return grpc.aio.secure_channel(f"{host}:{port}", creds, _options, compression)
258+
if ssl:
259+
ssl_creds = grpc.ssl_channel_credentials()
260+
return grpc.aio.secure_channel(
261+
f"{host}:{port}",
262+
ssl_creds,
263+
_options,
264+
compression,
265+
interceptors=[metadata_interceptor],
266+
)
269267
else:
270-
if metadata:
271-
metadata_interceptor = header_adder_async_interceptor(metadata)
272-
return grpc.aio.insecure_channel(
273-
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
274-
)
275-
else:
276-
return grpc.aio.insecure_channel(f"{host}:{port}", _options, compression)
268+
return grpc.aio.insecure_channel(
269+
f"{host}:{port}", _options, compression, interceptors=[metadata_interceptor]
270+
)

0 commit comments

Comments
 (0)