3333
3434import com .google .protobuf .nano .MessageNano ;
3535
36+ import android .net .SSLCertificateSocketFactory ;
3637import android .os .AsyncTask ;
3738import android .support .annotation .Nullable ;
3839import android .util .Log ;
4748import java .io .InputStream ;
4849import java .io .PrintWriter ;
4950import java .io .StringWriter ;
51+ import java .lang .reflect .Method ;
5052import java .security .KeyStore ;
5153import java .security .cert .CertificateFactory ;
5254import java .security .cert .X509Certificate ;
@@ -79,6 +81,7 @@ public InteropTester(String testCase,
7981 @ Nullable String serverHostOverride ,
8082 boolean useTls ,
8183 @ Nullable InputStream testCa ,
84+ @ Nullable String androidSocketFactoryTls ,
8285 TestListener listener ) {
8386 this .testCase = testCase ;
8487 this .listener = listener ;
@@ -90,7 +93,13 @@ public InteropTester(String testCase,
9093 }
9194 if (useTls ) {
9295 try {
93- channelBuilder .sslSocketFactory (getSslSocketFactory (testCa ));
96+ SSLSocketFactory factory ;
97+ if (androidSocketFactoryTls != null ) {
98+ factory = getSslCertificateSocketFactory (testCa , androidSocketFactoryTls );
99+ } else {
100+ factory = getSslSocketFactory (testCa );
101+ }
102+ channelBuilder .sslSocketFactory (factory );
94103 } catch (Exception e ) {
95104 throw new RuntimeException (e );
96105 }
@@ -327,6 +336,39 @@ private SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) throw
327336 if (testCa == null ) {
328337 return (SSLSocketFactory ) SSLSocketFactory .getDefault ();
329338 }
339+
340+ SSLContext context = SSLContext .getInstance ("TLS" );
341+ context .init (null , getTrustManagers (testCa ) , null );
342+ return context .getSocketFactory ();
343+ }
344+
345+ private SSLCertificateSocketFactory getSslCertificateSocketFactory (
346+ @ Nullable InputStream testCa , String androidSocketFatoryTls ) throws Exception {
347+ SSLCertificateSocketFactory factory = (SSLCertificateSocketFactory )
348+ SSLCertificateSocketFactory .getDefault (5000 /* Timeout in ms*/ );
349+ // Use HTTP/2.0
350+ byte [] h2 = "h2" .getBytes ();
351+ byte [][] protocols = new byte [][]{h2 };
352+ if (androidSocketFatoryTls .equals ("alpn" )) {
353+ Method setAlpnProtocols =
354+ factory .getClass ().getDeclaredMethod ("setAlpnProtocols" , byte [][].class );
355+ setAlpnProtocols .invoke (factory , new Object [] { protocols });
356+ } else if (androidSocketFatoryTls .equals ("npn" )) {
357+ Method setNpnProtocols =
358+ factory .getClass ().getDeclaredMethod ("setNpnProtocols" , byte [][].class );
359+ setNpnProtocols .invoke (factory , new Object []{protocols });
360+ } else {
361+ throw new RuntimeException ("Unknown protocol: " + androidSocketFatoryTls );
362+ }
363+
364+ if (testCa != null ) {
365+ factory .setTrustManagers (getTrustManagers (testCa ));
366+ }
367+
368+ return factory ;
369+ }
370+
371+ private TrustManager [] getTrustManagers (InputStream testCa ) throws Exception {
330372 KeyStore ks = KeyStore .getInstance (KeyStore .getDefaultType ());
331373 ks .load (null );
332374 CertificateFactory cf = CertificateFactory .getInstance ("X.509" );
@@ -337,9 +379,7 @@ private SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) throw
337379 TrustManagerFactory trustManagerFactory =
338380 TrustManagerFactory .getInstance (TrustManagerFactory .getDefaultAlgorithm ());
339381 trustManagerFactory .init (ks );
340- SSLContext context = SSLContext .getInstance ("TLS" );
341- context .init (null , trustManagerFactory .getTrustManagers () , null );
342- return context .getSocketFactory ();
382+ return trustManagerFactory .getTrustManagers ();
343383 }
344384
345385 public interface TestListener {
0 commit comments