Skip to content

Commit 30d9f27

Browse files
authored
Merge pull request juju#6407 from rogpeppe/109-apiclient-snihostname
api: add SNIHostName to Info This means that we can have an arbitrary API addresses in api.Info, including resolved IP addresses, but still connect using the officially signed certificate to check. If there's a private Juju CA cert available, we use that by preference to make it possible to connect even if the server cannot obtain an officially signed certificate.
2 parents 6e02cec + f2fe2d0 commit 30d9f27

File tree

7 files changed

+206
-104
lines changed

7 files changed

+206
-104
lines changed

api/apiclient.go

Lines changed: 83 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func open(
180180
if clock == nil {
181181
return nil, errors.NotValidf("nil clock")
182182
}
183-
conn, tlsConfig, err := connectWebsocket(info, opts)
183+
conn, tlsConfig, err := dialAPI(info, opts)
184184
if err != nil {
185185
return nil, errors.Trace(err)
186186
}
@@ -275,70 +275,6 @@ func (t *hostSwitchingTransport) RoundTrip(req *http.Request) (*http.Response, e
275275
return t.fallback.RoundTrip(req)
276276
}
277277

278-
// connectWebsocket establishes a websocket connection to the RPC
279-
// API websocket on the API server using Info. If multiple API addresses
280-
// are provided in Info they will be tried concurrently - the first successful
281-
// connection wins.
282-
//
283-
// It also returns the TLS configuration that it has derived from the Info.
284-
func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
285-
if len(info.Addrs) == 0 {
286-
return nil, nil, errors.New("no API addresses to connect to")
287-
}
288-
tlsConfig := utils.SecureTLSConfig()
289-
tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify
290-
291-
if info.CACert != "" && !tlsConfig.InsecureSkipVerify {
292-
// We want to be specific here (rather than just using "anything".
293-
// See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899).
294-
tlsConfig.ServerName = "juju-apiserver"
295-
certPool, err := CreateCertPool(info.CACert)
296-
if err != nil {
297-
return nil, nil, errors.Annotate(err, "cert pool creation failed")
298-
}
299-
tlsConfig.RootCAs = certPool
300-
}
301-
path, err := apiPath(info.ModelTag, "/api")
302-
if err != nil {
303-
return nil, nil, errors.Trace(err)
304-
}
305-
conn, err := dialWebSocket(info.Addrs, path, tlsConfig, opts)
306-
if err != nil {
307-
return nil, nil, errors.Trace(err)
308-
}
309-
logger.Infof("connection established to %q", conn.RemoteAddr())
310-
return conn, tlsConfig, nil
311-
}
312-
313-
// dialWebSocket dials a websocket with one of the provided addresses, the
314-
// specified URL path, TLS configuration, and dial options. Each of the
315-
// specified addresses will be attempted concurrently, and the first
316-
// successful connection will be returned.
317-
func dialWebSocket(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) {
318-
// Dial all addresses at reasonable intervals.
319-
try := parallel.NewTry(0, nil)
320-
defer try.Kill()
321-
for _, addr := range addrs {
322-
err := dialWebsocket(addr, path, opts, tlsConfig, try)
323-
if err == parallel.ErrStopped {
324-
break
325-
}
326-
if err != nil {
327-
return nil, errors.Trace(err)
328-
}
329-
select {
330-
case <-time.After(opts.DialAddressInterval):
331-
case <-try.Dead():
332-
}
333-
}
334-
try.Close()
335-
result, err := try.Result()
336-
if err != nil {
337-
return nil, errors.Trace(err)
338-
}
339-
return result.(*websocket.Conn), nil
340-
}
341-
342278
// ConnectStream implements StreamConnector.ConnectStream.
343279
func (st *state) ConnectStream(path string, attrs url.Values) (base.Stream, error) {
344280
if !st.isLoggedIn() {
@@ -486,7 +422,83 @@ func tagToString(tag names.Tag) string {
486422
return tag.String()
487423
}
488424

489-
func dialWebsocket(addr, path string, opts DialOpts, tlsConfig *tls.Config, try *parallel.Try) error {
425+
// dialAPI establishes a websocket connection to the RPC
426+
// API websocket on the API server using Info. If multiple API addresses
427+
// are provided in Info they will be tried concurrently - the first successful
428+
// connection wins.
429+
//
430+
// It also returns the TLS configuration that it has derived from the Info.
431+
func dialAPI(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
432+
// Set opts.DialWebsocket here rather than in open because
433+
// some tests call dialAPI directly.
434+
if opts.DialWebsocket == nil {
435+
opts.DialWebsocket = websocket.DialConfig
436+
}
437+
if len(info.Addrs) == 0 {
438+
return nil, nil, errors.New("no API addresses to connect to")
439+
}
440+
tlsConfig := utils.SecureTLSConfig()
441+
tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify
442+
443+
if info.CACert != "" {
444+
// We want to be specific here (rather than just using "anything".
445+
// See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899).
446+
tlsConfig.ServerName = "juju-apiserver"
447+
certPool, err := CreateCertPool(info.CACert)
448+
if err != nil {
449+
return nil, nil, errors.Annotate(err, "cert pool creation failed")
450+
}
451+
tlsConfig.RootCAs = certPool
452+
} else {
453+
// No CA certificate so use the SNI host name for all
454+
// connections (if SNIHostName is empty, the host
455+
// name in the address will be used as usual).
456+
tlsConfig.ServerName = info.SNIHostName
457+
}
458+
path, err := apiPath(info.ModelTag, "/api")
459+
if err != nil {
460+
return nil, nil, errors.Trace(err)
461+
}
462+
conn, err := dialWebsocketMulti(info.Addrs, path, tlsConfig, opts)
463+
if err != nil {
464+
return nil, nil, errors.Trace(err)
465+
}
466+
logger.Infof("connection established to %q", conn.RemoteAddr())
467+
return conn, tlsConfig, nil
468+
}
469+
470+
// dialWebsocketMulti dials a websocket with one of the provided addresses, the
471+
// specified URL path, TLS configuration, and dial options. Each of the
472+
// specified addresses will be attempted concurrently, and the first
473+
// successful connection will be returned.
474+
func dialWebsocketMulti(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) {
475+
// Dial all addresses at reasonable intervals.
476+
try := parallel.NewTry(0, nil)
477+
defer try.Kill()
478+
for _, addr := range addrs {
479+
err := startDialWebsocket(try, addr, path, opts, tlsConfig)
480+
if err == parallel.ErrStopped {
481+
break
482+
}
483+
if err != nil {
484+
return nil, errors.Trace(err)
485+
}
486+
select {
487+
case <-time.After(opts.DialAddressInterval):
488+
case <-try.Dead():
489+
}
490+
}
491+
try.Close()
492+
result, err := try.Result()
493+
if err != nil {
494+
return nil, errors.Trace(err)
495+
}
496+
return result.(*websocket.Conn), nil
497+
}
498+
499+
// startDialWebsocket starts websocket connection to a single address
500+
// on the given try instance.
501+
func startDialWebsocket(try *parallel.Try, addr, path string, opts DialOpts, tlsConfig *tls.Config) error {
490502
// origin is required by the WebSocket API, used for "origin policy"
491503
// in websockets. We pass localhost to satisfy the API; it is
492504
// inconsequential to us.
@@ -499,11 +511,10 @@ func dialWebsocket(addr, path string, opts DialOpts, tlsConfig *tls.Config, try
499511
return try.Start(newWebsocketDialer(cfg, opts))
500512
}
501513

502-
// newWebsocketDialer returns a function that
503-
// can be passed to utils/parallel.Try.Start.
504-
var newWebsocketDialer = createWebsocketDialer
505-
506-
func createWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
514+
// newWebsocketDialer0 returns a function that dials the websocket represented
515+
// by the given configuration with the given dial options, suitable for passing
516+
// to utils/parallel.Try.Start.
517+
func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
507518
// TODO(katco): 2016-08-09: lp:1611427
508519
openAttempt := utils.AttemptStrategy{
509520
Total: opts.Timeout,
@@ -517,7 +528,7 @@ func createWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan str
517528
default:
518529
}
519530
logger.Infof("dialing %q", cfg.Location)
520-
conn, err := websocket.DialConfig(cfg)
531+
conn, err := opts.DialWebsocket(cfg)
521532
if err == nil {
522533
return conn, nil
523534
}

api/apiclient_test.go

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,24 @@ type apiclientSuite struct {
3434

3535
var _ = gc.Suite(&apiclientSuite{})
3636

37-
func (s *apiclientSuite) TestConnectWebsocketToEnv(c *gc.C) {
37+
func (s *apiclientSuite) TestDialAPIToEnv(c *gc.C) {
3838
info := s.APIInfo(c)
39-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
39+
conn, _, err := api.DialAPI(info, api.DialOpts{})
4040
c.Assert(err, jc.ErrorIsNil)
4141
defer conn.Close()
42-
assertConnAddrForEnv(c, conn, info.Addrs[0], s.State.ModelUUID(), "/api")
42+
assertConnAddrForModel(c, conn, info.Addrs[0], s.State.ModelUUID())
4343
}
4444

45-
func (s *apiclientSuite) TestConnectWebsocketToRoot(c *gc.C) {
45+
func (s *apiclientSuite) TestDialAPIToRoot(c *gc.C) {
4646
info := s.APIInfo(c)
4747
info.ModelTag = names.NewModelTag("")
48-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
48+
conn, _, err := api.DialAPI(info, api.DialOpts{})
4949
c.Assert(err, jc.ErrorIsNil)
5050
defer conn.Close()
5151
assertConnAddrForRoot(c, conn, info.Addrs[0])
5252
}
5353

54-
func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
54+
func (s *apiclientSuite) TestDialAPIMultiple(c *gc.C) {
5555
// Create a socket that proxies to the API server.
5656
info := s.APIInfo(c)
5757
serverAddr := info.Addrs[0]
@@ -60,22 +60,22 @@ func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
6060

6161
// Check that we can use the proxy to connect.
6262
info.Addrs = []string{proxy.Addr()}
63-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
63+
conn, _, err := api.DialAPI(info, api.DialOpts{})
6464
c.Assert(err, jc.ErrorIsNil)
6565
conn.Close()
66-
assertConnAddrForEnv(c, conn, proxy.Addr(), s.State.ModelUUID(), "/api")
66+
assertConnAddrForModel(c, conn, proxy.Addr(), s.State.ModelUUID())
6767

6868
// Now break Addrs[0], and ensure that Addrs[1]
6969
// is successfully connected to.
7070
proxy.Close()
7171
info.Addrs = []string{proxy.Addr(), serverAddr}
72-
conn, _, err = api.ConnectWebsocket(info, api.DialOpts{})
72+
conn, _, err = api.DialAPI(info, api.DialOpts{})
7373
c.Assert(err, jc.ErrorIsNil)
7474
conn.Close()
75-
assertConnAddrForEnv(c, conn, serverAddr, s.State.ModelUUID(), "/api")
75+
assertConnAddrForModel(c, conn, serverAddr, s.State.ModelUUID())
7676
}
7777

78-
func (s *apiclientSuite) TestConnectWebsocketMultipleError(c *gc.C) {
78+
func (s *apiclientSuite) TestDialAPIMultipleError(c *gc.C) {
7979
listener, err := net.Listen("tcp", "127.0.0.1:0")
8080
c.Assert(err, jc.ErrorIsNil)
8181
defer listener.Close()
@@ -94,7 +94,7 @@ func (s *apiclientSuite) TestConnectWebsocketMultipleError(c *gc.C) {
9494
info := s.APIInfo(c)
9595
addr := listener.Addr().String()
9696
info.Addrs = []string{addr, addr, addr}
97-
_, _, err = api.ConnectWebsocket(info, api.DialOpts{})
97+
_, _, err = api.DialAPI(info, api.DialOpts{})
9898
c.Assert(err, gc.ErrorMatches, `unable to connect to API: websocket.Dial wss://.*/model/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}/api: .*`)
9999
c.Assert(atomic.LoadInt32(&count), gc.Equals, int32(3))
100100
}
@@ -153,14 +153,95 @@ func (s *apiclientSuite) TestServerRoot(c *gc.C) {
153153
}
154154

155155
func (s *apiclientSuite) TestDialWebsocketStopped(c *gc.C) {
156-
stopped := make(chan struct{})
157156
f := api.NewWebsocketDialer(nil, api.DialOpts{})
157+
stopped := make(chan struct{})
158158
close(stopped)
159159
result, err := f(stopped)
160160
c.Assert(err, gc.Equals, parallel.ErrStopped)
161161
c.Assert(result, gc.IsNil)
162162
}
163163

164+
func (s *apiclientSuite) TestOpenWithSNIHostNameEmptyCert(c *gc.C) {
165+
info := &api.Info{
166+
Addrs: []string{"foo.com:1234", "0.1.2.3:1234"},
167+
SNIHostName: "foo.com",
168+
SkipLogin: true,
169+
}
170+
// When we dial with an SNI name and no CA certificate, the connection should
171+
// always be made with the provided SNI name.
172+
s.testSNIHostName(c, info, []apiDialInfo{{
173+
location: "wss://foo.com:1234/api",
174+
hasRootCAs: false,
175+
serverName: "foo.com",
176+
}, {
177+
location: "wss://0.1.2.3:1234/api",
178+
hasRootCAs: false,
179+
serverName: "foo.com",
180+
}})
181+
}
182+
183+
func (s *apiclientSuite) TestOpenWithSNIHostNameWithCACert(c *gc.C) {
184+
info := &api.Info{
185+
Addrs: []string{"foo.com:1234", "0.1.2.3:1234"},
186+
SNIHostName: "foo.com",
187+
SkipLogin: true,
188+
CACert: jtesting.CACert,
189+
}
190+
// When we dial with an SNI name and a CA cert, the SNI name
191+
// should be ignored.
192+
s.testSNIHostName(c, info, []apiDialInfo{{
193+
location: "wss://foo.com:1234/api",
194+
hasRootCAs: true,
195+
serverName: "juju-apiserver",
196+
}, {
197+
location: "wss://0.1.2.3:1234/api",
198+
hasRootCAs: true,
199+
serverName: "juju-apiserver",
200+
}})
201+
}
202+
203+
type apiDialInfo struct {
204+
location string
205+
hasRootCAs bool
206+
serverName string
207+
}
208+
209+
// testSNIHostName tests that when the API is dialed with the given info,
210+
// api.newWebsocketDialer is called with the expected information
211+
// (one element for each call to newWebsocketDialer)
212+
func (s *apiclientSuite) testSNIHostName(c *gc.C, info *api.Info, expectDials []apiDialInfo) {
213+
dialed := make(chan *websocket.Config)
214+
fakeDialer := func(cfg *websocket.Config) (*websocket.Conn, error) {
215+
dialed <- cfg
216+
return nil, errors.New("nope")
217+
}
218+
done := make(chan struct{})
219+
go func() {
220+
defer close(done)
221+
conn, err := api.Open(info, api.DialOpts{
222+
DialWebsocket: fakeDialer,
223+
})
224+
c.Check(conn, gc.Equals, nil)
225+
c.Check(err, gc.ErrorMatches, `unable to connect to API: nope`)
226+
}()
227+
for _, expect := range expectDials {
228+
select {
229+
case cfg := <-dialed:
230+
c.Check(cfg.Location.String(), gc.Equals, expect.location)
231+
c.Assert(cfg.TlsConfig, gc.NotNil)
232+
c.Check(cfg.TlsConfig.RootCAs != nil, gc.Equals, expect.hasRootCAs)
233+
c.Check(cfg.TlsConfig.ServerName, gc.Equals, expect.serverName)
234+
case <-time.After(jtesting.LongWait):
235+
c.Fatalf("timed out waiting for dial")
236+
}
237+
}
238+
select {
239+
case <-done:
240+
case <-time.After(jtesting.LongWait):
241+
c.Fatalf("timed out waiting for API open")
242+
}
243+
}
244+
164245
func (s *apiclientSuite) TestOpenWithNoCACert(c *gc.C) {
165246
// This is hard to test as we have no way of affecting the system roots,
166247
// so instead we check that the error that we get implies that
@@ -405,10 +486,10 @@ func (a *redirectAPIAdmin) RedirectInfo() (params.RedirectInfoResult, error) {
405486
}, nil
406487
}
407488

408-
func assertConnAddrForEnv(c *gc.C, conn *websocket.Conn, addr, modelUUID, tail string) {
409-
c.Assert(conn.RemoteAddr(), gc.Matches, "^wss://"+addr+"/model/"+modelUUID+tail+"$")
489+
func assertConnAddrForModel(c *gc.C, conn *websocket.Conn, addr, modelUUID string) {
490+
c.Assert(conn.RemoteAddr().String(), gc.Equals, "wss://"+addr+"/model/"+modelUUID+"/api")
410491
}
411492

412493
func assertConnAddrForRoot(c *gc.C, conn *websocket.Conn, addr string) {
413-
c.Assert(conn.RemoteAddr(), gc.Matches, "^wss://"+addr+"/api$")
494+
c.Assert(conn.RemoteAddr().String(), gc.Matches, "wss://"+addr+"/api")
414495
}

api/certpool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func CreateCertPool(caCert string) (*x509.CertPool, error) {
2727
if caCert != "" {
2828
xcert, err := cert.ParseCert(caCert)
2929
if err != nil {
30-
return nil, errors.Trace(err)
30+
return nil, errors.Annotatef(err, "cannot parse certificate %q", caCert)
3131
}
3232
pool.AddCert(xcert)
3333
}

api/export_test.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@ import (
1212
)
1313

1414
var (
15-
CertDir = &certDir
16-
NewWebsocketDialer = newWebsocketDialer
17-
NewWebsocketDialerPtr = &newWebsocketDialer
18-
WebsocketDialConfig = &websocketDialConfig
19-
SlideAddressToFront = slideAddressToFront
20-
BestVersion = bestVersion
21-
FacadeVersions = &facadeVersions
22-
ConnectWebsocket = connectWebsocket
15+
CertDir = &certDir
16+
NewWebsocketDialer = newWebsocketDialer
17+
WebsocketDialConfig = &websocketDialConfig
18+
SlideAddressToFront = slideAddressToFront
19+
BestVersion = bestVersion
20+
FacadeVersions = &facadeVersions
21+
DialAPI = dialAPI
2322
)
2423

2524
// RPCConnection defines the methods that are called on the rpc.Conn instance.

0 commit comments

Comments
 (0)