Skip to content

Commit

Permalink
Revert "Merge pull request juju#7029 from sinzui/revert-gorilla-sockets"
Browse files Browse the repository at this point in the history
This reverts commit ca2498e, reversing
changes made to d7f0644.
  • Loading branch information
howbazaar committed Feb 26, 2017
1 parent b328000 commit 043d6b7
Show file tree
Hide file tree
Showing 30 changed files with 698 additions and 534 deletions.
192 changes: 128 additions & 64 deletions api/apiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import (
"sync/atomic"
"time"

"github.com/gorilla/websocket"
"github.com/juju/errors"
"github.com/juju/loggo"
"github.com/juju/retry"
"github.com/juju/utils"
"github.com/juju/utils/clock"
"github.com/juju/utils/parallel"
"github.com/juju/version"
"golang.org/x/net/websocket"
"gopkg.in/juju/names.v2"
"gopkg.in/macaroon-bakery.v1/httpbakery"
"gopkg.in/macaroon.v1"
Expand All @@ -47,6 +47,11 @@ const pingTimeout = 30 * time.Second
// modelRoot is the prefix that all model API paths begin with.
const modelRoot = "/model/"

// Use a 64k frame size for the websockets while we need to deal
// with x/net/websocket connections that don't deal with recieving
// fragmented messages.
const websocketFrameSize = 65536

var logger = loggo.GetLogger("juju.api")

type rpcConnection interface {
Expand Down Expand Up @@ -183,12 +188,12 @@ func open(
if clock == nil {
return nil, errors.NotValidf("nil clock")
}
conn, tlsConfig, err := dialAPI(info, opts)
dialResult, err := dialAPI(info, opts)
if err != nil {
return nil, errors.Trace(err)
}

client := rpc.NewConn(jsoncodec.NewWebsocket(conn), observer.None())
client := rpc.NewConn(jsoncodec.NewWebsocket(dialResult.conn), observer.None())
client.Start()

bakeryClient := opts.BakeryClient
Expand All @@ -201,29 +206,36 @@ func open(
httpc := *bakeryClient.Client
bakeryClient.Client = &httpc
}
apiHost := conn.Config().Location.Host
apiURL, err := url.Parse(dialResult.urlStr)
if err != nil {
// This should never happen as the url would have failed during dialAPI above.
// However the code paths don't allow capture of the url.URL used.
return nil, errors.Trace(err)
}
apiHost := apiURL.Host

// Technically when there's no CACert, we don't need this
// machinery, because we could just use http.DefaultTransport
// for everything, but it's easier just to leave it in place.
bakeryClient.Client.Transport = &hostSwitchingTransport{
primaryHost: apiHost,
primary: utils.NewHttpTLSTransport(tlsConfig),
primary: utils.NewHttpTLSTransport(dialResult.tlsConfig),
fallback: http.DefaultTransport,
}

st := &state{
client: client,
conn: conn,
conn: dialResult.conn,
clock: clock,
addr: apiHost,
cookieURL: &url.URL{
Scheme: "https",
Host: conn.Config().Location.Host,
Host: apiHost,
Path: "/",
},
pingerFacadeVersion: facadeVersions["Pinger"],
serverScheme: "https",
serverRootAddress: conn.Config().Location.Host,
serverRootAddress: apiHost,
// We populate the username and password before
// login because, when doing HTTP requests, we'll want
// to use the same username and password for authenticating
Expand All @@ -232,13 +244,13 @@ func open(
password: info.Password,
macaroons: info.Macaroons,
nonce: info.Nonce,
tlsConfig: tlsConfig,
tlsConfig: dialResult.tlsConfig,
bakeryClient: bakeryClient,
modelTag: info.ModelTag,
}
if !info.SkipLogin {
if err := st.Login(info.Tag, info.Password, info.Nonce, info.Macaroons); err != nil {
conn.Close()
dialResult.conn.Close()
return nil, errors.Trace(err)
}
}
Expand Down Expand Up @@ -367,49 +379,65 @@ func (st *state) connectStream(path string, attrs url.Values, extraHeaders http.
// TODO(macgreagoir) IPv6. Ubuntu still always provides IPv4 loopback,
// and when/if this changes localhost should resolve to IPv6 loopback
// in any case (lp:1644009). Review.
cfg, err := websocket.NewConfig(target.String(), "http://localhost/")
if err != nil {
return nil, errors.Trace(err)

dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: st.tlsConfig,
// In order to deal with the remote side not handling message
// fragmentation, we default to largeish frames.
ReadBufferSize: websocketFrameSize,
WriteBufferSize: websocketFrameSize,
}
var requestHeader http.Header
if st.tag != "" {
cfg.Header = utils.BasicAuthHeader(st.tag, st.password)
requestHeader = utils.BasicAuthHeader(st.tag, st.password)
} else {
requestHeader = make(http.Header)
}
requestHeader.Set("Origin", "http://localhost/")
if st.nonce != "" {
cfg.Header.Set(params.MachineNonceHeader, st.nonce)
requestHeader.Set(params.MachineNonceHeader, st.nonce)
}
// Add any cookies because they will not be sent to websocket
// connections by default.
err = st.addCookiesToHeader(cfg.Header)
err := st.addCookiesToHeader(requestHeader)
if err != nil {
return nil, errors.Trace(err)
}
for header, values := range extraHeaders {
for _, value := range values {
cfg.Header.Add(header, value)
requestHeader.Add(header, value)
}
}

cfg.TlsConfig = st.tlsConfig
connection, err := websocketDialConfig(cfg)
connection, err := websocketDial(dialer, target.String(), requestHeader)
if err != nil {
return nil, err
}
if err := readInitialStreamError(connection); err != nil {
connection.Close()
return nil, errors.Trace(err)
}
return connection, nil
}

// readInitialStreamError reads the initial error response
// from a stream connection and returns it.
func readInitialStreamError(conn io.Reader) error {
func readInitialStreamError(ws base.Stream) error {
// We can use bufio here because the websocket guarantees that a
// single read will not read more than a single frame; there is
// no guarantee that a single read might not read less than the
// whole frame though, so using a single Read call is not
// correct. By using ReadSlice rather than ReadBytes, we
// guarantee that the error can't be too big (>4096 bytes).
line, err := bufio.NewReader(conn).ReadSlice('\n')
messageType, reader, err := ws.NextReader()
if err != nil {
return errors.Annotate(err, "unable to get reader")
}
if messageType != websocket.TextMessage {
return errors.Errorf("unexpected message type %v", messageType)
}
line, err := bufio.NewReader(reader).ReadSlice('\n')
if err != nil {
return errors.Annotate(err, "unable to read initial response")
}
Expand Down Expand Up @@ -496,20 +524,21 @@ func tagToString(tag names.Tag) string {
return tag.String()
}

type dialResult struct {
conn *websocket.Conn
urlStr string
tlsConfig *tls.Config
}

// dialAPI establishes a websocket connection to the RPC
// API websocket on the API server using Info. If multiple API addresses
// are provided in Info they will be tried concurrently - the first successful
// connection wins.
//
// It also returns the TLS configuration that it has derived from the Info.
func dialAPI(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
// Set opts.DialWebsocket here rather than in open because
// some tests call dialAPI directly.
if opts.DialWebsocket == nil {
opts.DialWebsocket = websocket.DialConfig
}
func dialAPI(info *Info, opts DialOpts) (*dialResult, error) {
if len(info.Addrs) == 0 {
return nil, nil, errors.New("no API addresses to connect to")
return nil, errors.New("no API addresses to connect to")
}
tlsConfig := utils.SecureTLSConfig()
tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify
Expand All @@ -520,7 +549,7 @@ func dialAPI(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
tlsConfig.ServerName = "juju-apiserver"
certPool, err := CreateCertPool(info.CACert)
if err != nil {
return nil, nil, errors.Annotate(err, "cert pool creation failed")
return nil, errors.Annotate(err, "cert pool creation failed")
}
tlsConfig.RootCAs = certPool
} else {
Expand All @@ -529,33 +558,62 @@ func dialAPI(info *Info, opts DialOpts) (*websocket.Conn, *tls.Config, error) {
// name in the address will be used as usual).
tlsConfig.ServerName = info.SNIHostName
}

opts.tlsConfig = tlsConfig

// Set opts.DialWebsocket here rather than in open because
// some tests call dialAPI directly.
if opts.DialWebsocket == nil {
dialer := &websocketDialerAdapter{
&websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: tlsConfig,
// In order to deal with the remote side not handling message
// fragmentation, we default to largeish frames.
ReadBufferSize: websocketFrameSize,
WriteBufferSize: websocketFrameSize,
},
}
opts.DialWebsocket = dialer.Dial
}

path, err := apiPath(info.ModelTag, "/api")
if err != nil {
return nil, nil, errors.Trace(err)
return nil, errors.Trace(err)
}
conn, err := dialWebsocketMulti(info.Addrs, path, tlsConfig, opts)
conn, urlStr, err := dialWebsocketMulti(info.Addrs, path, opts)
if err != nil {
return nil, nil, errors.Trace(err)
return nil, errors.Trace(err)
}
logger.Infof("connection established to %q", conn.RemoteAddr())
return conn, tlsConfig, nil
logger.Infof("connection established to %q", urlStr)
return &dialResult{conn, urlStr, tlsConfig}, nil
}

type websocketDialerAdapter struct {
dialer *websocket.Dialer
}

func (a *websocketDialerAdapter) Dial(urlStr string, tlsConfig *tls.Config, requestHeader http.Header) (*websocket.Conn, *http.Response, error) {
// Ignore the tlsConfig because it is set on the dialer.
// The tls.Config is only passed through for the purpose of catpure in the tests.
return a.dialer.Dial(urlStr, requestHeader)
}

// dialWebsocketMulti dials a websocket with one of the provided addresses, the
// specified URL path, TLS configuration, and dial options. Each of the
// specified addresses will be attempted concurrently, and the first
// successful connection will be returned.
func dialWebsocketMulti(addrs []string, path string, tlsConfig *tls.Config, opts DialOpts) (*websocket.Conn, error) {
func dialWebsocketMulti(addrs []string, path string, opts DialOpts) (*websocket.Conn, string, error) {
// Dial all addresses at reasonable intervals.
try := parallel.NewTry(0, nil)
defer try.Kill()
for _, addr := range addrs {
err := startDialWebsocket(try, addr, path, opts, tlsConfig)
err := startDialWebsocket(try, addr, path, opts)
if err == parallel.ErrStopped {
break
}
if err != nil {
return nil, errors.Trace(err)
return nil, "", errors.Trace(err)
}
select {
case <-time.After(opts.DialAddressInterval):
Expand All @@ -565,30 +623,40 @@ func dialWebsocketMulti(addrs []string, path string, tlsConfig *tls.Config, opts
try.Close()
result, err := try.Result()
if err != nil {
return nil, errors.Trace(err)
return nil, "", errors.Trace(err)
}
return result.(*websocket.Conn), nil
wrapper := result.(*connWrapper)
return wrapper.conn, wrapper.urlStr, nil
}

// startDialWebsocket starts websocket connection to a single address
// on the given try instance.
func startDialWebsocket(try *parallel.Try, addr, path string, opts DialOpts, tlsConfig *tls.Config) error {
func startDialWebsocket(try *parallel.Try, addr, path string, opts DialOpts) error {
// origin is required by the WebSocket API, used for "origin policy"
// in websockets. We pass localhost to satisfy the API; it is
// inconsequential to us.
const origin = "http://localhost/"
cfg, err := websocket.NewConfig("wss://"+addr+path, origin)
if err != nil {
return errors.Trace(err)
}
cfg.TlsConfig = tlsConfig
return try.Start(newWebsocketDialer(cfg, opts))
urlStr := "wss://" + addr + path
return try.Start(newWebsocketDialer(urlStr, opts))
}

// connWrapper contains the *websocket.Conn and the urlStr that was used
// to connect to it. The gorilla/websocket code does not remember the URL
// that was used to connect to it, and many internal parts of Juju assume
// that it does.
type connWrapper struct {
conn *websocket.Conn
urlStr string
}

// This is defined for the parallel try to close other results.
func (c *connWrapper) Close() error {
return c.conn.Close()
}

// newWebsocketDialer0 returns a function that dials the websocket represented
// by the given configuration with the given dial options, suitable for passing
// to utils/parallel.Try.Start.
func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
func newWebsocketDialer(urlStr string, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
// TODO(katco): 2016-08-09: lp:1611427
openAttempt := utils.AttemptStrategy{
Total: opts.Timeout,
Expand All @@ -606,11 +674,12 @@ func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct
return nil, parallel.ErrStopped
default:
}
logger.Debugf("dialing %q", cfg.Location)
conn, err := opts.DialWebsocket(cfg)
logger.Debugf("dialing %q", urlStr)
// Not passing through any extra header information
conn, _, err := opts.DialWebsocket(urlStr, opts.tlsConfig, nil)
if err == nil {
logger.Debugf("successfully dialed %q", cfg.Location)
return conn, nil
logger.Debugf("successfully dialed %q", urlStr)
return &connWrapper{conn, urlStr}, nil
}
if isCertErr := isX509Error(err); !a.HasNext() || isCertErr {
// We won't reconnect when there's an X509
Expand All @@ -631,25 +700,20 @@ func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct
// isX509Error reports whether the given websocket error
// results from an X509 problem.
func isX509Error(err error) bool {
wsErr, ok := errors.Cause(err).(*websocket.DialError)
if !ok {
return false
}
switch wsErr.Err.(type) {
case x509.HostnameError,
switch errType := err.(type) {
case *websocket.CloseError:
return errType.Code == websocket.CloseTLSHandshake
case x509.CertificateInvalidError,
x509.HostnameError,
x509.InsecureAlgorithmError,
x509.UnhandledCriticalExtension,
x509.UnknownAuthorityError,
x509.ConstraintViolationError,
x509.SystemRootsError:
return true
default:
return false
}
switch err {
case x509.ErrUnsupportedAlgorithm,
x509.IncorrectPasswordError:
return true
}
return false
}

type hasErrorCode interface {
Expand Down
Loading

0 comments on commit 043d6b7

Please sign in to comment.