Skip to content

Commit 210114d

Browse files
committed
Ease use of JWT by passing URI to auth library
The URI no longer needs to be provided to the Credential explicitly, which prevents needing to know a magic string and allows using the same Credential with multiple services.
1 parent db0423c commit 210114d

12 files changed

Lines changed: 137 additions & 23 deletions

File tree

auth/src/main/java/io/grpc/auth/ClientAuthInterceptor.java

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@
4242
import io.grpc.Metadata;
4343
import io.grpc.MethodDescriptor;
4444
import io.grpc.Status;
45+
import io.grpc.StatusException;
4546

4647
import java.io.IOException;
48+
import java.net.URI;
49+
import java.net.URISyntaxException;
4750
import java.util.List;
4851
import java.util.Map;
4952
import java.util.concurrent.Executor;
@@ -70,22 +73,27 @@ public ClientAuthInterceptor(Credentials credentials, Executor executor) {
7073
}
7174

7275
@Override
73-
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
74-
CallOptions callOptions, Channel next) {
76+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
77+
final MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, final Channel next) {
7578
// TODO(ejona86): If the call fails for Auth reasons, this does not properly propagate info that
7679
// would be in WWW-Authenticate, because it does not yet have access to the header.
7780
return new CheckedForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
7881
@Override
7982
protected void checkedStart(Listener<RespT> responseListener, Metadata headers)
80-
throws Exception {
83+
throws StatusException {
8184
Metadata cachedSaved;
85+
URI uri = serviceUri(next, method);
8286
synchronized (ClientAuthInterceptor.this) {
8387
// TODO(louiscryan): This is icky but the current auth library stores the same
8488
// metadata map until the next refresh cycle. This will be fixed once
8589
// https://github.com/google/google-auth-library-java/issues/3
8690
// is resolved.
87-
if (lastMetadata == null || lastMetadata != getRequestMetadata()) {
88-
lastMetadata = getRequestMetadata();
91+
// getRequestMetadata() may return a different map based on the provided URI, i.e., for
92+
// JWT. However, today it does not cache JWT and so we won't bother tring to cache its
93+
// return value based on the URI.
94+
Map<String, List<String>> latestMetadata = getRequestMetadata(uri);
95+
if (lastMetadata == null || lastMetadata != latestMetadata) {
96+
lastMetadata = latestMetadata;
8997
cached = toHeaders(lastMetadata);
9098
}
9199
cachedSaved = cached;
@@ -96,11 +104,37 @@ protected void checkedStart(Listener<RespT> responseListener, Metadata headers)
96104
};
97105
}
98106

99-
private Map<String, List<String>> getRequestMetadata() {
107+
/**
108+
* Generate a JWT-specific service URI. The URI is simply an identifier with enough information
109+
* for a service to know that the JWT was intended for it. The URI will commonly be verified with
110+
* a simple string equality check.
111+
*/
112+
private URI serviceUri(Channel channel, MethodDescriptor<?, ?> method) throws StatusException {
113+
String authority = channel.authority();
114+
if (authority == null) {
115+
throw Status.UNAUTHENTICATED.withDescription("Channel has no authority").asException();
116+
}
117+
// Always use HTTPS, by definition.
118+
final String scheme = "https";
119+
// The default port must not be present. Alternative ports should be present.
120+
final String suffixToStrip = ":443";
121+
if (authority.endsWith(suffixToStrip)) {
122+
authority = authority.substring(0, authority.length() - suffixToStrip.length());
123+
}
124+
String path = "/" + MethodDescriptor.extractFullServiceName(method.getFullMethodName());
125+
try {
126+
return new URI(scheme, authority, path, null, null);
127+
} catch (URISyntaxException e) {
128+
throw Status.UNAUTHENTICATED.withDescription("Unable to construct service URI for auth")
129+
.withCause(e).asException();
130+
}
131+
}
132+
133+
private Map<String, List<String>> getRequestMetadata(URI uri) throws StatusException {
100134
try {
101-
return credentials.getRequestMetadata();
135+
return credentials.getRequestMetadata(uri);
102136
} catch (IOException e) {
103-
throw Status.UNAUTHENTICATED.withCause(e).asRuntimeException();
137+
throw Status.UNAUTHENTICATED.withCause(e).asException();
104138
}
105139
}
106140

auth/src/test/java/io/grpc/auth/ClientAuthInterceptorTests.java

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import static org.mockito.Matchers.any;
3535
import static org.mockito.Matchers.isA;
3636
import static org.mockito.Matchers.same;
37+
import static org.mockito.Mockito.doReturn;
3738
import static org.mockito.Mockito.never;
3839
import static org.mockito.Mockito.verify;
3940
import static org.mockito.Mockito.when;
@@ -51,6 +52,7 @@
5152
import io.grpc.ClientCall;
5253
import io.grpc.Metadata;
5354
import io.grpc.MethodDescriptor;
55+
import io.grpc.MethodDescriptor.Marshaller;
5456
import io.grpc.Status;
5557

5658
import org.junit.Assert;
@@ -64,6 +66,7 @@
6466
import org.mockito.MockitoAnnotations;
6567

6668
import java.io.IOException;
69+
import java.net.URI;
6770
import java.util.Date;
6871
import java.util.concurrent.Executors;
6972

@@ -82,6 +85,11 @@ public class ClientAuthInterceptorTests {
8285
Credentials credentials;
8386

8487
@Mock
88+
Marshaller<String> stringMarshaller;
89+
90+
@Mock
91+
Marshaller<Integer> intMarshaller;
92+
8593
MethodDescriptor<String, Integer> descriptor;
8694

8795
@Mock
@@ -99,7 +107,10 @@ public class ClientAuthInterceptorTests {
99107
@Before
100108
public void startUp() throws IOException {
101109
MockitoAnnotations.initMocks(this);
110+
descriptor = MethodDescriptor.create(
111+
MethodDescriptor.MethodType.UNKNOWN, "a.service/method", stringMarshaller, intMarshaller);
102112
when(channel.newCall(same(descriptor), any(CallOptions.class))).thenReturn(call);
113+
doReturn("localhost:443").when(channel).authority();
103114
interceptor = new ClientAuthInterceptor(credentials,
104115
Executors.newSingleThreadExecutor());
105116
}
@@ -111,7 +122,7 @@ public void testCopyCredentialToHeaders() throws IOException {
111122
values.put("Authorization", "token2");
112123
values.put("Extra-Authorization", "token3");
113124
values.put("Extra-Authorization", "token4");
114-
when(credentials.getRequestMetadata()).thenReturn(Multimaps.asMap(values));
125+
when(credentials.getRequestMetadata(any(URI.class))).thenReturn(Multimaps.asMap(values));
115126
ClientCall<String, Integer> interceptedCall =
116127
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
117128
Metadata headers = new Metadata();
@@ -128,7 +139,7 @@ public void testCopyCredentialToHeaders() throws IOException {
128139

129140
@Test
130141
public void testCredentialsThrows() throws IOException {
131-
when(credentials.getRequestMetadata()).thenThrow(new IOException("Broken"));
142+
when(credentials.getRequestMetadata(any(URI.class))).thenThrow(new IOException("Broken"));
132143
ClientCall<String, Integer> interceptedCall =
133144
interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
134145
Metadata headers = new Metadata();
@@ -160,4 +171,21 @@ public AccessToken refreshAccessToken() throws IOException {
160171
Assert.assertArrayEquals(new String[]{"Bearer allyourbase"},
161172
Iterables.toArray(authorization, String.class));
162173
}
174+
175+
@Test
176+
public void verifyServiceUri() throws IOException {
177+
ClientCall<String, Integer> interceptedCall;
178+
179+
doReturn("example.com:443").when(channel).authority();
180+
interceptedCall = interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
181+
interceptedCall.start(listener, new Metadata());
182+
verify(credentials).getRequestMetadata(URI.create("https://example.com/a.service"));
183+
interceptedCall.cancel();
184+
185+
doReturn("example.com:123").when(channel).authority();
186+
interceptedCall = interceptor.interceptCall(descriptor, CallOptions.DEFAULT, channel);
187+
interceptedCall.start(listener, new Metadata());
188+
verify(credentials).getRequestMetadata(URI.create("https://example.com:123/a.service"));
189+
interceptedCall.cancel();
190+
}
163191
}

core/src/main/java/io/grpc/Channel.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,12 @@ public abstract class Channel {
5959
*/
6060
public abstract <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
6161
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions);
62+
63+
/**
64+
* The authority of the destination this channel connects to. Typically this is in the format
65+
* {@code host:port}.
66+
*
67+
* @return authority of remote, or {@code null}
68+
*/
69+
public abstract String authority();
6270
}

core/src/main/java/io/grpc/ChannelImpl.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(MethodDescriptor<ReqT, Resp
283283
return interceptorChannel.newCall(method, callOptions);
284284
}
285285

286+
@Override
287+
public String authority() {
288+
return interceptorChannel.authority();
289+
}
290+
286291
private ClientTransport obtainActiveTransport() {
287292
ClientTransport savedActiveTransport = activeTransport;
288293
// If we know there is an active transport and we are not in backoff mode, return quickly.
@@ -344,6 +349,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(MethodDescriptor<ReqT, Resp
344349
scheduledExecutor)
345350
.setUserAgent(userAgent);
346351
}
352+
353+
@Override
354+
public String authority() {
355+
return transportFactory.authority();
356+
}
347357
}
348358

349359
private class TransportListener implements ClientTransport.Listener {

core/src/main/java/io/grpc/ClientInterceptors.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
8888
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
8989
return interceptor.interceptCall(method, callOptions, channel);
9090
}
91+
92+
@Override
93+
public String authority() {
94+
return channel.authority();
95+
}
9196
}
9297

9398
private static final ClientCall<Object, Object> NOOP_CALL = new ClientCall<Object, Object>() {

core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ public ClientTransport newClientTransport() {
7979
return new InProcessTransport(name);
8080
}
8181

82+
@Override
83+
public String authority() {
84+
return null;
85+
}
86+
8287
@Override
8388
protected void deallocate() {
8489
// Do nothing.

core/src/main/java/io/grpc/internal/ClientTransportFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,10 @@
3535
public interface ClientTransportFactory extends ReferenceCounted {
3636
/** Creates an unstarted transport for exclusive use. */
3737
ClientTransport newClientTransport();
38+
39+
/**
40+
* Returns the authority of the channel. Typically, this should be in the form {@code host:port}.
41+
* Note that since there is not a scheme, there can't be a default port.
42+
*/
43+
String authority();
3844
}

core/src/test/java/io/grpc/ClientInterceptorsTest.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
168168
order.add("channel");
169169
return (ClientCall<ReqT, RespT>) call;
170170
}
171+
172+
@Override
173+
public String authority() {
174+
return null;
175+
}
171176
};
172177
ClientInterceptor interceptor1 = new ClientInterceptor() {
173178
@Override

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ private static class NettyTransportFactory extends AbstractReferenceCounted
186186
private final int flowControlWindow;
187187
private final ProtocolNegotiator negotiator;
188188
private final int maxMessageSize;
189+
private final String authority;
189190

190191
private NettyTransportFactory(SocketAddress serverAddress,
191192
Class<? extends Channel> channelType,
@@ -198,6 +199,14 @@ private NettyTransportFactory(SocketAddress serverAddress,
198199
this.flowControlWindow = flowControlWindow;
199200
this.negotiator = negotiator;
200201
this.maxMessageSize = maxMessageSize;
202+
if (serverAddress instanceof InetSocketAddress) {
203+
InetSocketAddress address = (InetSocketAddress) serverAddress;
204+
this.authority = address.getHostString() + ":" + address.getPort();
205+
} else {
206+
// Specialized address types are allowed to support custom Channel types so just assume
207+
// their toString() values are valid :authority values
208+
this.authority = serverAddress.toString();
209+
}
201210

202211
usingSharedGroup = group == null;
203212
if (usingSharedGroup) {
@@ -211,7 +220,12 @@ private NettyTransportFactory(SocketAddress serverAddress,
211220
@Override
212221
public ClientTransport newClientTransport() {
213222
return new NettyClientTransport(serverAddress, channelType, group, negotiator,
214-
flowControlWindow, maxMessageSize);
223+
flowControlWindow, maxMessageSize, authority);
224+
}
225+
226+
@Override
227+
public String authority() {
228+
return authority;
215229
}
216230

217231
@Override

netty/src/main/java/io/grpc/netty/NettyClientTransport.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@
6464
import io.netty.handler.logging.LogLevel;
6565
import io.netty.util.AsciiString;
6666

67-
import java.net.InetSocketAddress;
6867
import java.net.SocketAddress;
6968
import java.util.concurrent.Executor;
7069

@@ -95,22 +94,14 @@ class NettyClientTransport implements ClientTransport {
9594

9695
NettyClientTransport(SocketAddress address, Class<? extends Channel> channelType,
9796
EventLoopGroup group, ProtocolNegotiator negotiator,
98-
int flowControlWindow, int maxMessageSize) {
97+
int flowControlWindow, int maxMessageSize, String authority) {
9998
Preconditions.checkNotNull(negotiator, "negotiator");
10099
this.address = Preconditions.checkNotNull(address, "address");
101100
this.group = Preconditions.checkNotNull(group, "group");
102101
this.channelType = Preconditions.checkNotNull(channelType, "channelType");
103102
this.flowControlWindow = flowControlWindow;
104103
this.maxMessageSize = maxMessageSize;
105-
106-
if (address instanceof InetSocketAddress) {
107-
InetSocketAddress inetAddress = (InetSocketAddress) address;
108-
authority = new AsciiString(inetAddress.getHostString() + ":" + inetAddress.getPort());
109-
} else {
110-
// Specialized address types are allowed to support custom Channel types so just assume their
111-
// toString() values are valid :authority values
112-
authority = new AsciiString(address.toString());
113-
}
104+
this.authority = new AsciiString(authority);
114105

115106
handler = newHandler();
116107
negotiationHandler = negotiator.newHandler(handler);

0 commit comments

Comments
 (0)