Skip to content

Commit

Permalink
api: refactor client HTTP interface
Browse files Browse the repository at this point in the history
  • Loading branch information
rogpeppe committed Oct 8, 2015
1 parent 4fcb9d1 commit 91f1a4b
Show file tree
Hide file tree
Showing 45 changed files with 810 additions and 1,188 deletions.
119 changes: 71 additions & 48 deletions api/apiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ type state struct {
// serverScheme is the URI scheme of the API Server
serverScheme string

// tlsConfig holds the TLS config appropriate for making SSL
// connections to the API endpoints.
tlsConfig *tls.Config

// certPool holds the cert pool that is used to authenticate the tls
// connections to the API.
certPool *x509.CertPool
Expand Down Expand Up @@ -119,22 +123,36 @@ func open(info *Info, opts DialOpts, loginFunc func(st *state, tag names.Tag, pw
return nil, errors.New("open should specifiy UseMacaroons or a username & password. Not both")
}
}
conn, err := connectWebsocket(info, opts)
conn, tlsConfig, err := connectWebsocket(info, opts)
if err != nil {
return nil, errors.Trace(err)
}

client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil)
client.Start()

bakeryClient := opts.BakeryClient
if bakeryClient == nil {
bakeryClient = httpbakery.NewClient()
} else {
// Make a copy of the bakery client and its
// HTTP client
c := *opts.BakeryClient
bakeryClient = &c
httpc := *bakeryClient.Client
bakeryClient.Client = &httpc
}
apiHost := conn.Config().Location.Host
bakeryClient.Client.Transport = &hostSwitchingTransport{
primaryHost: apiHost,
primary: utils.NewHttpTLSTransport(tlsConfig),
fallback: http.DefaultTransport,
}

client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil)
client.Start()
st := &state{
client: client,
conn: conn,
addr: conn.Config().Location.Host,
addr: apiHost,
cookieURL: &url.URL{
Scheme: "https",
Host: conn.Config().Location.Host,
Expand All @@ -147,7 +165,7 @@ func open(info *Info, opts DialOpts, loginFunc func(st *state, tag names.Tag, pw
tag: tagToString(info.Tag),
password: info.Password,
nonce: info.Nonce,
certPool: conn.Config().TlsConfig.RootCAs,
tlsConfig: tlsConfig,
bakeryClient: bakeryClient,
}
if info.Tag != nil || info.Password != "" || info.UseMacaroons {
Expand All @@ -162,6 +180,26 @@ func open(info *Info, opts DialOpts, loginFunc func(st *state, tag names.Tag, pw
return st, nil
}

// hostSwitchingTransport provides an http.RoundTripper
// that chooses an actual RoundTripper to use
// depending on the destination host.
//
// This makes it possible to use a different set of root
// CAs for the API and all other hosts.
type hostSwitchingTransport struct {
primaryHost string
primary http.RoundTripper
fallback http.RoundTripper
}

// RoundTrip implements http.RoundTripper.RoundTrip.
func (t *hostSwitchingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Host == t.primaryHost {
return t.primary.RoundTrip(req)
}
return t.fallback.RoundTrip(req)
}

// OpenWithVersion uses an explicit version of the Admin facade to call Login
// on. This allows the caller to pretend to be an older client, and is used
// only in testing.
Expand All @@ -184,15 +222,16 @@ func OpenWithVersion(info *Info, opts DialOpts, loginVersion int) (Connection, e
// API websocket on the API server using Info. If multiple API addresses
// are provided in Info they will be tried concurrently - the first successful
// connection wins.
func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, error) {
//
// It also returns the TLS configuration that it has derived from the Info.
func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
if len(info.Addrs) == 0 {
return nil, errors.New("no API addresses to connect to")
return nil, nil, errors.New("no API addresses to connect to")
}
pool, err := CreateCertPool(info.CACert)
tlsConfig, err := tlsConfigForCACert(info.CACert)
if err != nil {
return nil, errors.Annotate(err, "cert pool creation failed")
return nil, nil, errors.Annotatef(err, "cannot make TLS configuration")
}

path := "/"
if info.EnvironTag.Id() != "" {
path = apiPath(info.EnvironTag, "/api")
Expand All @@ -202,12 +241,12 @@ func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, error) {
try := parallel.NewTry(0, nil)
defer try.Kill()
for _, addr := range info.Addrs {
err := dialWebsocket(addr, path, opts, pool, try)
err := dialWebsocket(addr, path, opts, tlsConfig, try)
if err == parallel.ErrStopped {
break
}
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
select {
case <-time.After(opts.DialAddressInterval):
Expand All @@ -217,11 +256,24 @@ func connectWebsocket(info *Info, opts DialOpts) (*websocket.Conn, error) {
try.Close()
result, err := try.Result()
if err != nil {
return nil, errors.Trace(err)
return nil, nil, errors.Trace(err)
}
conn := result.(*websocket.Conn)
logger.Infof("connection established to %q", conn.RemoteAddr())
return conn, nil
return conn, tlsConfig, nil
}

func tlsConfigForCACert(caCert string) (*tls.Config, error) {
certPool, err := CreateCertPool(caCert)
if err != nil {
return nil, errors.Annotate(err, "cert pool creation failed")
}
return &tls.Config{
RootCAs: certPool,
// We want to be specific here (rather than just using "anything".
// See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899).
ServerName: "juju-apiserver",
}, nil
}

// ConnectStream implements Connection.ConnectStream.
Expand Down Expand Up @@ -284,10 +336,7 @@ func (st *state) connectStream(path string, attrs url.Values) (base.Stream, erro
// connections by default.
st.addCookiesToHeader(cfg.Header)

cfg.TlsConfig = &tls.Config{
RootCAs: st.certPool,
ServerName: "juju-apiserver",
}
cfg.TlsConfig = st.tlsConfig
connection, err := websocketDialConfig(cfg)
if err != nil {
return nil, err
Expand All @@ -298,29 +347,6 @@ func (st *state) connectStream(path string, attrs url.Values) (base.Stream, erro
return connection, nil
}

// bakeryError translates any discharge-required error into
// an error value that the httpbakery package will recognize.
// Other errors are returned unchanged.
func bakeryError(err error) error {
if params.ErrCode(err) != params.CodeDischargeRequired {
return err
}
errResp := errors.Cause(err).(*params.Error)
if errResp.Info == nil {
return errors.Annotatef(err, "no error info found in discharge-required response error")
}
// It's a discharge-required error, so make an appropriate httpbakery
// error from it.
return &httpbakery.Error{
Message: err.Error(),
Code: httpbakery.ErrDischargeRequired,
Info: &httpbakery.ErrorInfo{
Macaroon: errResp.Info.Macaroon,
MacaroonPath: errResp.Info.MacaroonPath,
},
}
}

// readInitialStreamError reads the initial error response
// from a stream connection and returns it.
func readInitialStreamError(conn io.Reader) error {
Expand Down Expand Up @@ -408,15 +434,15 @@ func tagToString(tag names.Tag) string {
return tag.String()
}

func dialWebsocket(addr, path string, opts DialOpts, rootCAs *x509.CertPool, try *parallel.Try) error {
cfg, err := setUpWebsocket(addr, path, rootCAs)
func dialWebsocket(addr, path string, opts DialOpts, tlsConfig *tls.Config, try *parallel.Try) error {
cfg, err := setUpWebsocket(addr, path, tlsConfig)
if err != nil {
return err
}
return try.Start(newWebsocketDialer(cfg, opts))
}

func setUpWebsocket(addr, path string, rootCAs *x509.CertPool) (*websocket.Config, error) {
func setUpWebsocket(addr, path string, tlsConfig *tls.Config) (*websocket.Config, error) {
// origin is required by the WebSocket API, used for "origin policy"
// in websockets. We pass localhost to satisfy the API; it is
// inconsequential to us.
Expand All @@ -425,10 +451,7 @@ func setUpWebsocket(addr, path string, rootCAs *x509.CertPool) (*websocket.Confi
if err != nil {
return nil, errors.Trace(err)
}
cfg.TlsConfig = &tls.Config{
RootCAs: rootCAs,
ServerName: "juju-apiserver",
}
cfg.TlsConfig = tlsConfig
return cfg, nil
}

Expand Down
12 changes: 6 additions & 6 deletions api/apiclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (s *apiclientSuite) TestOpenFailsIfUsernameAndUseMacaroon(c *gc.C) {

func (s *apiclientSuite) TestConnectWebsocketToEnv(c *gc.C) {
info := s.APIInfo(c)
conn, err := api.ConnectWebsocket(info, api.DialOpts{})
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
defer conn.Close()
assertConnAddrForEnv(c, conn, info.Addrs[0], s.State.EnvironUUID(), "/api")
Expand All @@ -46,7 +46,7 @@ func (s *apiclientSuite) TestConnectWebsocketToEnv(c *gc.C) {
func (s *apiclientSuite) TestConnectWebsocketToRoot(c *gc.C) {
info := s.APIInfo(c)
info.EnvironTag = names.NewEnvironTag("")
conn, err := api.ConnectWebsocket(info, api.DialOpts{})
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
defer conn.Close()
assertConnAddrForRoot(c, conn, info.Addrs[0])
Expand Down Expand Up @@ -83,7 +83,7 @@ func (s *apiclientSuite) TestConnectWebsocketPrefersLocalhostIfPresent(c *gc.C)
c.Check(err, jc.ErrorIsNil)
expectedHostPort := fmt.Sprintf("localhost:%d", portNum)
info.Addrs = []string{"fakeAddress:1", "fakeAddress:1", expectedHostPort}
conn, err := api.ConnectWebsocket(info, api.DialOpts{})
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
defer conn.Close()
assertConnAddrForEnv(c, conn, expectedHostPort, s.State.EnvironUUID(), "/api")
Expand Down Expand Up @@ -113,7 +113,7 @@ func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
// Check that we can use the proxy to connect.
proxyAddr := listener.Addr().String()
info.Addrs = []string{proxyAddr}
conn, err := api.ConnectWebsocket(info, api.DialOpts{})
conn, _, err := api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
conn.Close()
assertConnAddrForEnv(c, conn, proxyAddr, s.State.EnvironUUID(), "/api")
Expand All @@ -122,7 +122,7 @@ func (s *apiclientSuite) TestConnectWebsocketMultiple(c *gc.C) {
// is successfully connected to.
info.Addrs = []string{proxyAddr, serverAddr}
listener.Close()
conn, err = api.ConnectWebsocket(info, api.DialOpts{})
conn, _, err = api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
conn.Close()
assertConnAddrForEnv(c, conn, serverAddr, s.State.EnvironUUID(), "/api")
Expand All @@ -144,7 +144,7 @@ func (s *apiclientSuite) TestConnectWebsocketMultipleError(c *gc.C) {
info := s.APIInfo(c)
addr := listener.Addr().String()
info.Addrs = []string{addr, addr, addr}
_, err = api.ConnectWebsocket(info, api.DialOpts{})
_, _, err = api.ConnectWebsocket(info, api.DialOpts{})
c.Assert(err, gc.ErrorMatches, `unable to connect to API: websocket.Dial wss://.*/environment/[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}/api: .*`)
}

Expand Down
60 changes: 60 additions & 0 deletions api/backups/base_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package backups_test

import (
"time"

gc "gopkg.in/check.v1"

"github.com/juju/juju/api/backups"
apiserverbackups "github.com/juju/juju/apiserver/backups"
"github.com/juju/juju/apiserver/params"
jujutesting "github.com/juju/juju/juju/testing"
stbackups "github.com/juju/juju/state/backups"
backupstesting "github.com/juju/juju/state/backups/testing"
)

type baseSuite struct {
jujutesting.JujuConnSuite
backupstesting.BaseSuite
client *backups.Client
}

func (s *baseSuite) SetUpTest(c *gc.C) {
s.BaseSuite.SetUpTest(c)
s.JujuConnSuite.SetUpTest(c)
client, err := backups.NewClient(s.APIState)
c.Assert(err, gc.IsNil)
s.client = client
}

func (s *baseSuite) metadataResult() *params.BackupsMetadataResult {
result := apiserverbackups.ResultFromMetadata(s.Meta)
return &result
}

func (s *baseSuite) checkMetadataResult(c *gc.C, result *params.BackupsMetadataResult, meta *stbackups.Metadata) {
var finished, stored time.Time
if meta.Finished != nil {
finished = *meta.Finished
}
if meta.Stored() != nil {
stored = *(meta.Stored())
}

c.Check(result.ID, gc.Equals, meta.ID())
c.Check(result.Started, gc.Equals, meta.Started)
c.Check(result.Finished, gc.Equals, finished)
c.Check(result.Checksum, gc.Equals, meta.Checksum())
c.Check(result.ChecksumFormat, gc.Equals, meta.ChecksumFormat())
c.Check(result.Size, gc.Equals, meta.Size())
c.Check(result.Stored, gc.Equals, stored)
c.Check(result.Notes, gc.Equals, meta.Notes)

c.Check(result.Environment, gc.Equals, meta.Origin.Environment)
c.Check(result.Machine, gc.Equals, meta.Origin.Machine)
c.Check(result.Hostname, gc.Equals, meta.Origin.Hostname)
c.Check(result.Version, gc.Equals, meta.Origin.Version)
}
Loading

0 comments on commit 91f1a4b

Please sign in to comment.