Skip to content

Commit 7c21f4f

Browse files
author
Jenkins bot
committed
Merge commit '809dc26b1ecb5727cdcfbfa3511eb67bf9f02db5' into HEAD
2 parents df0effe + 809dc26 commit 7c21f4f

19 files changed

+426
-328
lines changed

api/apiclient.go

Lines changed: 83 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func open(
180180
if clock == nil {
181181
return nil, errors.NotValidf("nil clock")
182182
}
183-
conn, tlsConfig, err := connectWebsocket(info, opts)
183+
conn, tlsConfig, err := dialAPI(info, opts)
184184
if err != nil {
185185
return nil, errors.Trace(err)
186186
}
@@ -275,70 +275,6 @@ func (t *hostSwitchingTransport) RoundTrip(req *http.Request) (*http.Response, e
275275
return t.fallback.RoundTrip(req)
276276
}
277277

278-
// connectWebsocket establishes a websocket connection to the RPC
279-
// API websocket on the API server using Info. If multiple API addresses
280-
// are provided in Info they will be tried concurrently - the first successful
281-
// connection wins.
282-
//
283-
// It also returns the TLS configuration that it has derived from the Info.
284-
func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
285-
if len(info.Addrs) == 0 {
286-
return nil, nil, errors.New("no API addresses to connect to")
287-
}
288-
tlsConfig := utils.SecureTLSConfig()
289-
tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify
290-
291-
if info.CACert != "" && !tlsConfig.InsecureSkipVerify {
292-
// We want to be specific here (rather than just using "anything".
293-
// See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899).
294-
tlsConfig.ServerName = "juju-apiserver"
295-
certPool, err := CreateCertPool(info.CACert)
296-
if err != nil {
297-
return nil, nil, errors.Annotate(err, "cert pool creation failed")
298-
}
299-
tlsConfig.RootCAs = certPool
300-
}
301-
path, err := apiPath(info.ModelTag, "/api")
302-
if err != nil {
303-
return nil, nil, errors.Trace(err)
304-
}
305-
conn, err := dialWebSocket(info.Addrs, path, tlsConfig, opts)
306-
if err != nil {
307-
return nil, nil, errors.Trace(err)
308-
}
309-
logger.Infof("connection established to %q", conn.RemoteAddr())
310-
return conn, tlsConfig, nil
311-
}
312-
313-
// dialWebSocket dials a websocket with one of the provided addresses, the
314-
// specified URL path, TLS configuration, and dial options. Each of the
315-
// specified addresses will be attempted concurrently, and the first
316-
// successful connection will be returned.
317-
func dialWebSocket(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) {
318-
// Dial all addresses at reasonable intervals.
319-
try := parallel.NewTry(0, nil)
320-
defer try.Kill()
321-
for _, addr := range addrs {
322-
err := dialWebsocket(addr, path, opts, tlsConfig, try)
323-
if err == parallel.ErrStopped {
324-
break
325-
}
326-
if err != nil {
327-
return nil, errors.Trace(err)
328-
}
329-
select {
330-
case <-time.After(opts.DialAddressInterval):
331-
case <-try.Dead():
332-
}
333-
}
334-
try.Close()
335-
result, err := try.Result()
336-
if err != nil {
337-
return nil, errors.Trace(err)
338-
}
339-
return result.(*websocket.Conn), nil
340-
}
341-
342278
// ConnectStream implements StreamConnector.ConnectStream.
343279
func (st *state) ConnectStream(path string, attrs url.Values) (base.Stream, error) {
344280
if !st.isLoggedIn() {
@@ -486,7 +422,83 @@ func tagToString(tag names.Tag) string {
486422
return tag.String()
487423
}
488424

489-
func dialWebsocket(addr, path string, opts DialOpts, tlsConfig *tls.Config, try *parallel.Try) error {
425+
// dialAPI establishes a websocket connection to the RPC
426+
// API websocket on the API server using Info. If multiple API addresses
427+
// are provided in Info they will be tried concurrently - the first successful
428+
// connection wins.
429+
//
430+
// It also returns the TLS configuration that it has derived from the Info.
431+
func dialAPI(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
432+
// Set opts.DialWebsocket here rather than in open because
433+
// some tests call dialAPI directly.
434+
if opts.DialWebsocket == nil {
435+
opts.DialWebsocket = websocket.DialConfig
436+
}
437+
if len(info.Addrs) == 0 {
438+
return nil, nil, errors.New("no API addresses to connect to")
439+
}
440+
tlsConfig := utils.SecureTLSConfig()
441+
tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify
442+
443+
if info.CACert != "" {
444+
// We want to be specific here (rather than just using "anything".
445+
// See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899).
446+
tlsConfig.ServerName = "juju-apiserver"
447+
certPool, err := CreateCertPool(info.CACert)
448+
if err != nil {
449+
return nil, nil, errors.Annotate(err, "cert pool creation failed")
450+
}
451+
tlsConfig.RootCAs = certPool
452+
} else {
453+
// No CA certificate so use the SNI host name for all
454+
// connections (if SNIHostName is empty, the host
455+
// name in the address will be used as usual).
456+
tlsConfig.ServerName = info.SNIHostName
457+
}
458+
path, err := apiPath(info.ModelTag, "/api")
459+
if err != nil {
460+
return nil, nil, errors.Trace(err)
461+
}
462+
conn, err := dialWebsocketMulti(info.Addrs, path, tlsConfig, opts)
463+
if err != nil {
464+
return nil, nil, errors.Trace(err)
465+
}
466+
logger.Infof("connection established to %q", conn.RemoteAddr())
467+
return conn, tlsConfig, nil
468+
}
469+
470+
// dialWebsocketMulti dials a websocket with one of the provided addresses, the
471+
// specified URL path, TLS configuration, and dial options. Each of the
472+
// specified addresses will be attempted concurrently, and the first
473+
// successful connection will be returned.
474+
func dialWebsocketMulti(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) {
475+
// Dial all addresses at reasonable intervals.
476+
try := parallel.NewTry(0, nil)
477+
defer try.Kill()
478+
for _, addr := range addrs {
479+
err := startDialWebsocket(try, addr, path, opts, tlsConfig)
480+
if err == parallel.ErrStopped {
481+
break
482+
}
483+
if err != nil {
484+
return nil, errors.Trace(err)
485+
}
486+
select {
487+
case <-time.After(opts.DialAddressInterval):
488+
case <-try.Dead():
489+
}
490+
}
491+
try.Close()
492+
result, err := try.Result()
493+
if err != nil {
494+
return nil, errors.Trace(err)
495+
}
496+
return result.(*websocket.Conn), nil
497+
}
498+
499+
// startDialWebsocket starts websocket connection to a single address
500+
// on the given try instance.
501+
func startDialWebsocket(try *parallel.Try, addr, path string, opts DialOpts, tlsConfig *tls.Config) error {
490502
// origin is required by the WebSocket API, used for "origin policy"
491503
// in websockets. We pass localhost to satisfy the API; it is
492504
// inconsequential to us.
@@ -499,11 +511,10 @@ func dialWebsocket(addr, path string, opts DialOpts, tlsConfig *tls.Config, try
499511
return try.Start(newWebsocketDialer(cfg, opts))
500512
}
501513

502-
// newWebsocketDialer returns a function that
503-
// can be passed to utils/parallel.Try.Start.
504-
var newWebsocketDialer = createWebsocketDialer
505-
506-
func createWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
514+
// newWebsocketDialer0 returns a function that dials the websocket represented
515+
// by the given configuration with the given dial options, suitable for passing
516+
// to utils/parallel.Try.Start.
517+
func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
507518
// TODO(katco): 2016-08-09: lp:1611427
508519
openAttempt := utils.AttemptStrategy{
509520
Total: opts.Timeout,
@@ -517,7 +528,7 @@ func createWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan str
517528
default:
518529
}
519530
logger.Infof("dialing %q", cfg.Location)
520-
conn, err := websocket.DialConfig(cfg)
531+
conn, err := opts.DialWebsocket(cfg)
521532
if err == nil {
522533
return conn, nil
523534
}

api/apiclient_test.go

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,24 @@ type apiclientSuite struct {
3434

3535
var _ = gc.Suite(&apiclientSuite{})
3636

37-
func (s *apiclientSuite) TestConnectWebsocketToEnv(c *gc.C) {
37+
func (s *apiclientSuite) TestDialAPIToEnv(c *gc.C) {
3838
info := s.APIInfo(c)
39-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
39+
conn, _, err := api.DialAPI(info, api.DialOpts{})
4040
c.Assert(err, jc.ErrorIsNil)
4141
defer conn.Close()
42-
assertConnAddrForEnv(c, conn, info.Addrs[0], s.State.ModelUUID(), "/api")
42+
assertConnAddrForModel(c, conn, info.Addrs[0], s.State.ModelUUID())
4343
}
4444

45-
func (s *apiclientSuite) TestConnectWebsocketToRoot(c *gc.C) {
45+
func (s *apiclientSuite) TestDialAPIToRoot(c *gc.C) {
4646
info := s.APIInfo(c)
4747
info.ModelTag = names.NewModelTag("")
48-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
48+
conn, _, err := api.DialAPI(info, api.DialOpts{})
4949
c.Assert(err, jc.ErrorIsNil)
5050
defer conn.Close()
5151
assertConnAddrForRoot(c, conn, info.Addrs[0])
5252
}
5353

54-
func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
54+
func (s *apiclientSuite) TestDialAPIMultiple(c *gc.C) {
5555
// Create a socket that proxies to the API server.
5656
info := s.APIInfo(c)
5757
serverAddr := info.Addrs[0]
@@ -60,22 +60,22 @@ func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
6060

6161
// Check that we can use the proxy to connect.
6262
info.Addrs = []string{proxy.Addr()}
63-
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
63+
conn, _, err := api.DialAPI(info, api.DialOpts{})
6464
c.Assert(err, jc.ErrorIsNil)
6565
conn.Close()
66-
assertConnAddrForEnv(c, conn, proxy.Addr(), s.State.ModelUUID(), "/api")
66+
assertConnAddrForModel(c, conn, proxy.Addr(), s.State.ModelUUID())
6767

6868
// Now break Addrs[0], and ensure that Addrs[1]
6969
// is successfully connected to.
7070
proxy.Close()
7171
info.Addrs = []string{proxy.Addr(), serverAddr}
72-
conn, _, err = api.ConnectWebsocket(info, api.DialOpts{})
72+
conn, _, err = api.DialAPI(info, api.DialOpts{})
7373
c.Assert(err, jc.ErrorIsNil)
7474
conn.Close()
75-
assertConnAddrForEnv(c, conn, serverAddr, s.State.ModelUUID(), "/api")
75+
assertConnAddrForModel(c, conn, serverAddr, s.State.ModelUUID())
7676
}
7777

78-
func (s *apiclientSuite) TestConnectWebsocketMultipleError(c *gc.C) {
78+
func (s *apiclientSuite) TestDialAPIMultipleError(c *gc.C) {
7979
listener, err := net.Listen("tcp", "127.0.0.1:0")
8080
c.Assert(err, jc.ErrorIsNil)
8181
defer listener.Close()
@@ -94,7 +94,7 @@ func (s *apiclientSuite) TestConnectWebsocketMultipleError(c *gc.C) {
9494
info := s.APIInfo(c)
9595
addr := listener.Addr().String()
9696
info.Addrs = []string{addr, addr, addr}
97-
_, _, err = api.ConnectWebsocket(info, api.DialOpts{})
97+
_, _, err = api.DialAPI(info, api.DialOpts{})
9898
c.Assert(err, gc.ErrorMatches, `unable to connect to API: websocket.Dial wss://.*/model/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}/api: .*`)
9999
c.Assert(atomic.LoadInt32(&count), gc.Equals, int32(3))
100100
}
@@ -153,14 +153,116 @@ func (s *apiclientSuite) TestServerRoot(c *gc.C) {
153153
}
154154

155155
func (s *apiclientSuite) TestDialWebsocketStopped(c *gc.C) {
156-
stopped := make(chan struct{})
157156
f := api.NewWebsocketDialer(nil, api.DialOpts{})
157+
stopped := make(chan struct{})
158158
close(stopped)
159159
result, err := f(stopped)
160160
c.Assert(err, gc.Equals, parallel.ErrStopped)
161161
c.Assert(result, gc.IsNil)
162162
}
163163

164+
type apiDialInfo struct {
165+
location string
166+
hasRootCAs bool
167+
serverName string
168+
}
169+
170+
var openWithSNIHostnameTests = []struct {
171+
about string
172+
info *api.Info
173+
expectDial apiDialInfo
174+
}{{
175+
about: "no cert; DNS name - use SNI hostname",
176+
info: &api.Info{
177+
Addrs: []string{"foo.com:1234"},
178+
SNIHostName: "foo.com",
179+
SkipLogin: true,
180+
},
181+
expectDial: apiDialInfo{
182+
location: "wss://foo.com:1234/api",
183+
hasRootCAs: false,
184+
serverName: "foo.com",
185+
},
186+
}, {
187+
about: "no cert; numeric IP address - use SNI hostname",
188+
info: &api.Info{
189+
Addrs: []string{"0.1.2.3:1234"},
190+
SNIHostName: "foo.com",
191+
SkipLogin: true,
192+
},
193+
expectDial: apiDialInfo{
194+
location: "wss://0.1.2.3:1234/api",
195+
hasRootCAs: false,
196+
serverName: "foo.com",
197+
},
198+
}, {
199+
about: "with cert; DNS name - use cert",
200+
info: &api.Info{
201+
Addrs: []string{"foo.com:1234"},
202+
SNIHostName: "foo.com",
203+
SkipLogin: true,
204+
CACert: jtesting.CACert,
205+
},
206+
expectDial: apiDialInfo{
207+
location: "wss://foo.com:1234/api",
208+
hasRootCAs: true,
209+
serverName: "juju-apiserver",
210+
},
211+
}, {
212+
about: "with cert; numeric IP address - use cert",
213+
info: &api.Info{
214+
Addrs: []string{"0.1.2.3:1234"},
215+
SNIHostName: "foo.com",
216+
SkipLogin: true,
217+
CACert: jtesting.CACert,
218+
},
219+
expectDial: apiDialInfo{
220+
location: "wss://0.1.2.3:1234/api",
221+
hasRootCAs: true,
222+
serverName: "juju-apiserver",
223+
},
224+
}}
225+
226+
func (s *apiclientSuite) TestOpenWithSNIHostname(c *gc.C) {
227+
for i, test := range openWithSNIHostnameTests {
228+
c.Logf("test %d: %v", i, test.about)
229+
s.testSNIHostName(c, test.info, test.expectDial)
230+
}
231+
}
232+
233+
// testSNIHostName tests that when the API is dialed with the given info,
234+
// api.newWebsocketDialer is called with the expected information.
235+
func (s *apiclientSuite) testSNIHostName(c *gc.C, info *api.Info, expectDial apiDialInfo) {
236+
dialed := make(chan *websocket.Config)
237+
fakeDialer := func(cfg *websocket.Config) (*websocket.Conn, error) {
238+
dialed <- cfg
239+
return nil, errors.New("nope")
240+
}
241+
done := make(chan struct{})
242+
go func() {
243+
defer close(done)
244+
conn, err := api.Open(info, api.DialOpts{
245+
DialWebsocket: fakeDialer,
246+
})
247+
c.Check(conn, gc.Equals, nil)
248+
c.Check(err, gc.ErrorMatches, `unable to connect to API: nope`)
249+
}()
250+
select {
251+
case cfg := <-dialed:
252+
c.Check(cfg.Location.String(), gc.Equals, expectDial.location)
253+
c.Assert(cfg.TlsConfig, gc.NotNil)
254+
c.Check(cfg.TlsConfig.RootCAs != nil, gc.Equals, expectDial.hasRootCAs)
255+
c.Check(cfg.TlsConfig.ServerName, gc.Equals, expectDial.serverName)
256+
case <-time.After(jtesting.LongWait):
257+
c.Fatalf("timed out waiting for dial")
258+
}
259+
select {
260+
case <-done:
261+
case <-time.After(jtesting.LongWait):
262+
c.Fatalf("timed out waiting for API open")
263+
}
264+
}
265+
164266
func (s *apiclientSuite) TestOpenWithNoCACert(c *gc.C) {
165267
// This is hard to test as we have no way of affecting the system roots,
166268
// so instead we check that the error that we get implies that
@@ -405,10 +507,10 @@ func (a *redirectAPIAdmin) RedirectInfo() (params.RedirectInfoResult, error) {
405507
}, nil
406508
}
407509

408-
func assertConnAddrForEnv(c *gc.C, conn *websocket.Conn, addr, modelUUID, tail string) {
409-
c.Assert(conn.RemoteAddr(), gc.Matches, "^wss://"+addr+"/model/"+modelUUID+tail+"$")
510+
func assertConnAddrForModel(c *gc.C, conn *websocket.Conn, addr, modelUUID string) {
511+
c.Assert(conn.RemoteAddr().String(), gc.Equals, "wss://"+addr+"/model/"+modelUUID+"/api")
410512
}
411513

412514
func assertConnAddrForRoot(c *gc.C, conn *websocket.Conn, addr string) {
413-
c.Assert(conn.RemoteAddr(), gc.Matches, "^wss://"+addr+"/api$")
515+
c.Assert(conn.RemoteAddr().String(), gc.Matches, "wss://"+addr+"/api")
414516
}

api/certpool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func CreateCertPool(caCert string) (*x509.CertPool, error) {
2727
if caCert != "" {
2828
xcert, err := cert.ParseCert(caCert)
2929
if err != nil {
30-
return nil, errors.Trace(err)
30+
return nil, errors.Annotatef(err, "cannot parse certificate %q", caCert)
3131
}
3232
pool.AddCert(xcert)
3333
}

0 commit comments

Comments
 (0)