1+ import asyncio
12import collections
2- from typing import Any , Callable , Dict , List , Optional , Tuple
3+ from typing import Any , Awaitable , Callable , Dict , List , Optional , Tuple , Union
34
45import grpc
56
@@ -64,7 +65,7 @@ def __init__(self, interceptor_function: Callable):
6465 async def intercept_unary_unary (
6566 self , continuation : Any , client_call_details : Any , request : Any
6667 ) -> Any :
67- new_details , new_request_iterator , postprocess = self ._fn (
68+ new_details , new_request_iterator , postprocess = await self ._fn (
6869 client_call_details , iter ((request ,)), False , False
6970 )
7071 next_request = next (new_request_iterator )
@@ -74,7 +75,7 @@ async def intercept_unary_unary(
7475 async def intercept_unary_stream (
7576 self , continuation : Any , client_call_details : Any , request : Any
7677 ) -> Any :
77- new_details , new_request_iterator , postprocess = self ._fn (
78+ new_details , new_request_iterator , postprocess = await self ._fn (
7879 client_call_details , iter ((request ,)), False , True
7980 )
8081 response_it = await continuation (new_details , next (new_request_iterator ))
@@ -83,7 +84,7 @@ async def intercept_unary_stream(
8384 async def intercept_stream_unary (
8485 self , continuation : Any , client_call_details : Any , request_iterator : Any
8586 ) -> Any :
86- new_details , new_request_iterator , postprocess = self ._fn (
87+ new_details , new_request_iterator , postprocess = await self ._fn (
8788 client_call_details , request_iterator , True , False
8889 )
8990 response = await continuation (new_details , new_request_iterator )
@@ -92,7 +93,7 @@ async def intercept_stream_unary(
9293 async def intercept_stream_stream (
9394 self , continuation : Any , client_call_details : Any , request_iterator : Any
9495 ) -> Any :
95- new_details , new_request_iterator , postprocess = self ._fn (
96+ new_details , new_request_iterator , postprocess = await self ._fn (
9697 client_call_details , request_iterator , True , True
9798 )
9899 response_it = await continuation (new_details , new_request_iterator )
@@ -125,14 +126,18 @@ class _ClientAsyncCallDetails(
125126 pass
126127
127128
128- def header_adder_interceptor (new_metadata : List [Tuple [str , str ]]) -> _GenericClientInterceptor :
129+ def header_adder_interceptor (
130+ new_metadata : List [Tuple [str , str ]],
131+ auth_token_provider : Optional [Callable [[], str ]] = None ,
132+ ) -> _GenericClientInterceptor :
129133 def intercept_call (
130134 client_call_details : _ClientCallDetails ,
131135 request_iterator : Any ,
132136 _request_streaming : Any ,
133137 _response_streaming : Any ,
134138 ) -> Tuple [_ClientCallDetails , Any , Any ]:
135139 metadata = []
140+
136141 if client_call_details .metadata is not None :
137142 metadata = list (client_call_details .metadata )
138143 for header , value in new_metadata :
@@ -142,6 +147,13 @@ def intercept_call(
142147 value ,
143148 )
144149 )
150+
151+ if auth_token_provider :
152+ if not asyncio .iscoroutinefunction (auth_token_provider ):
153+ metadata .append (("authorization" , f"Bearer { auth_token_provider ()} " ))
154+ else :
155+ raise ValueError ("Synchronous channel requires synchronous auth token provider." )
156+
145157 client_call_details = _ClientCallDetails (
146158 client_call_details .method ,
147159 client_call_details .timeout ,
@@ -154,9 +166,10 @@ def intercept_call(
154166
155167
156168def header_adder_async_interceptor (
157- new_metadata : List [Tuple [str , str ]]
169+ new_metadata : List [Tuple [str , str ]],
170+ auth_token_provider : Optional [Union [Callable [[], str ], Callable [[], Awaitable [str ]]]] = None ,
158171) -> _GenericAsyncClientInterceptor :
159- def intercept_call (
172+ async def intercept_call (
160173 client_call_details : grpc .aio .ClientCallDetails ,
161174 request_iterator : Any ,
162175 _request_streaming : Any ,
@@ -172,6 +185,14 @@ def intercept_call(
172185 value ,
173186 )
174187 )
188+
189+ if auth_token_provider :
190+ if asyncio .iscoroutinefunction (auth_token_provider ):
191+ token = await auth_token_provider ()
192+ else :
193+ token = auth_token_provider ()
194+ metadata .append (("authorization" , f"Bearer { token } " ))
195+
175196 client_call_details = client_call_details ._replace (metadata = metadata )
176197 return client_call_details , request_iterator , None
177198
@@ -200,38 +221,21 @@ def get_channel(
200221 metadata : Optional [List [Tuple [str , str ]]] = None ,
201222 options : Optional [Dict [str , Any ]] = None ,
202223 compression : Optional [grpc .Compression ] = None ,
224+ auth_token_provider : Optional [Callable [[], str ]] = None ,
203225) -> grpc .Channel :
204- # gRPC client options
226+ # Parse gRPC client options
205227 _options = parse_channel_options (options )
228+ metadata_interceptor = header_adder_interceptor (
229+ new_metadata = metadata or [], auth_token_provider = auth_token_provider
230+ )
206231
207232 if ssl :
208- if metadata :
209-
210- def metadata_callback (context : Any , callback : Any ) -> None :
211- # for more info see grpc docs
212- callback (metadata , None )
213-
214- # build ssl credentials using the cert the same as before
215- cert_creds = grpc .ssl_channel_credentials ()
216-
217- # now build meta data credentials
218- auth_creds = grpc .metadata_call_credentials (metadata_callback )
219-
220- # combine the cert credentials and the macaroon auth credentials
221- # such that every call is properly encrypted and authenticated
222- creds = grpc .composite_channel_credentials (cert_creds , auth_creds )
223- else :
224- creds = grpc .ssl_channel_credentials ()
225-
226- # finally pass in the combined credentials when creating a channel
227- return grpc .secure_channel (f"{ host } :{ port } " , creds , _options , compression )
233+ ssl_creds = grpc .ssl_channel_credentials ()
234+ channel = grpc .secure_channel (f"{ host } :{ port } " , ssl_creds , _options , compression )
235+ return grpc .intercept_channel (channel , metadata_interceptor )
228236 else :
229- if metadata :
230- metadata_interceptor = header_adder_interceptor (metadata )
231- channel = grpc .insecure_channel (f"{ host } :{ port } " , _options , compression )
232- return grpc .intercept_channel (channel , metadata_interceptor )
233- else :
234- return grpc .insecure_channel (f"{ host } :{ port } " , _options , compression )
237+ channel = grpc .insecure_channel (f"{ host } :{ port } " , _options , compression )
238+ return grpc .intercept_channel (channel , metadata_interceptor )
235239
236240
237241def get_async_channel (
@@ -241,36 +245,26 @@ def get_async_channel(
241245 metadata : Optional [List [Tuple [str , str ]]] = None ,
242246 options : Optional [Dict [str , Any ]] = None ,
243247 compression : Optional [grpc .Compression ] = None ,
248+ auth_token_provider : Optional [Union [Callable [[], str ], Callable [[], Awaitable [str ]]]] = None ,
244249) -> grpc .aio .Channel :
245- # gRPC client options
250+ # Parse gRPC client options
246251 _options = parse_channel_options (options )
247252
248- if ssl :
249- if metadata :
250-
251- def metadata_callback (context : Any , callback : Any ) -> None :
252- # for more info see grpc docs
253- callback (metadata , None )
254-
255- # build ssl credentials using the cert the same as before
256- cert_creds = grpc .ssl_channel_credentials ()
253+ # Create metadata interceptor
254+ metadata_interceptor = header_adder_async_interceptor (
255+ new_metadata = metadata or [], auth_token_provider = auth_token_provider
256+ )
257257
258- # now build meta data credentials
259- auth_creds = grpc .metadata_call_credentials (metadata_callback )
260-
261- # combine the cert credentials and the macaroon auth credentials
262- # such that every call is properly encrypted and authenticated
263- creds = grpc .composite_channel_credentials (cert_creds , auth_creds )
264- else :
265- creds = grpc .ssl_channel_credentials ()
266-
267- # finally pass in the combined credentials when creating a channel
268- return grpc .aio .secure_channel (f"{ host } :{ port } " , creds , _options , compression )
258+ if ssl :
259+ ssl_creds = grpc .ssl_channel_credentials ()
260+ return grpc .aio .secure_channel (
261+ f"{ host } :{ port } " ,
262+ ssl_creds ,
263+ _options ,
264+ compression ,
265+ interceptors = [metadata_interceptor ],
266+ )
269267 else :
270- if metadata :
271- metadata_interceptor = header_adder_async_interceptor (metadata )
272- return grpc .aio .insecure_channel (
273- f"{ host } :{ port } " , _options , compression , interceptors = [metadata_interceptor ]
274- )
275- else :
276- return grpc .aio .insecure_channel (f"{ host } :{ port } " , _options , compression )
268+ return grpc .aio .insecure_channel (
269+ f"{ host } :{ port } " , _options , compression , interceptors = [metadata_interceptor ]
270+ )
0 commit comments