Skip to content

Commit

Permalink
Check ctx.Done() during BootstrapInstance and WaitSSH
Browse files Browse the repository at this point in the history
  • Loading branch information
benhoyt committed Dec 17, 2020
1 parent c0ded3e commit d566acd
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 85 deletions.
43 changes: 26 additions & 17 deletions provider/common/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package common

import (
"bufio"
"context"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -34,7 +35,7 @@ import (
"github.com/juju/juju/core/status"
"github.com/juju/juju/environs"
"github.com/juju/juju/environs/config"
"github.com/juju/juju/environs/context"
envcontext "github.com/juju/juju/environs/context"
"github.com/juju/juju/environs/imagemetadata"
"github.com/juju/juju/environs/instances"
"github.com/juju/juju/environs/simplestreams"
Expand All @@ -49,7 +50,7 @@ var logger = loggo.GetLogger("juju.provider.common")
func Bootstrap(
ctx environs.BootstrapContext,
env environs.Environ,
callCtx context.ProviderCallContext,
callCtx envcontext.ProviderCallContext,
args environs.BootstrapParams,
) (*environs.BootstrapResult, error) {
result, series, finalizer, err := BootstrapInstance(ctx, env, callCtx, args)
Expand All @@ -75,7 +76,7 @@ func Bootstrap(
func BootstrapInstance(
ctx environs.BootstrapContext,
env environs.Environ,
callCtx context.ProviderCallContext,
callCtx envcontext.ProviderCallContext,
args environs.BootstrapParams,
) (_ *environs.StartInstanceResult, selectedSeries string, _ environs.CloudBootstrapFinalizer, err error) {
// TODO make safe in the case of racing Bootstraps
Expand Down Expand Up @@ -250,6 +251,13 @@ func BootstrapInstance(
if zone == "" || environs.IsAvailabilityZoneIndependent(err) {
return nil, "", nil, errors.Annotate(err, "cannot start bootstrap instance")
}

select {
case <-ctx.Context().Done():
return nil, "", nil, errors.Annotatef(err, "starting controller (cancelled)")
default:
}

if i < len(zones)-1 {
// Try the next zone.
logger.Debugf("failed to start instance in availability zone %q: %s", zone, err)
Expand Down Expand Up @@ -294,7 +302,7 @@ func BootstrapInstance(
return result, selectedSeries, finalizer, nil
}

func startInstanceZones(env environs.Environ, ctx context.ProviderCallContext, args environs.StartInstanceParams) ([]string, error) {
func startInstanceZones(env environs.Environ, ctx envcontext.ProviderCallContext, args environs.StartInstanceParams) ([]string, error) {
zonedEnviron, ok := env.(ZonedEnviron)
if !ok {
return nil, errors.NotImplementedf("ZonedEnviron")
Expand Down Expand Up @@ -358,7 +366,7 @@ var FinishBootstrap = func(
ctx environs.BootstrapContext,
client ssh.Client,
env environs.Environ,
callCtx context.ProviderCallContext,
callCtx envcontext.ProviderCallContext,
inst instances.Instance,
instanceConfig *instancecfg.InstanceConfig,
opts environs.BootstrapDialOpts,
Expand All @@ -369,8 +377,8 @@ var FinishBootstrap = func(

hostSSHOptions := bootstrapSSHOptionsFunc(instanceConfig)
addr, err := WaitSSH(
ctx.Context(),
ctx.GetStderr(),
interrupted,
client,
GetCheckNonceCommand(instanceConfig),
&RefreshableInstance{inst, env},
Expand Down Expand Up @@ -452,6 +460,7 @@ func ConfigureMachine(
}
script := shell.DumpFileOnErrorScript(instanceConfig.CloudInitOutputLog) + configScript
ctx.Infof("Running machine configuration script...")
// TODO(benhoyt) - plumb context through juju/utils/ssh?
return sshinit.RunConfigureScript(script, sshinit.ConfigureParams{
Host: "ubuntu@" + host,
Client: client,
Expand Down Expand Up @@ -534,16 +543,16 @@ func hostBootstrapSSHOptions(
// for waiting for SSH access to become available.
type InstanceRefresher interface {
// Refresh refreshes the addresses for the instance.
Refresh(ctx context.ProviderCallContext) error
Refresh(ctx envcontext.ProviderCallContext) error

// Addresses returns the addresses for the instance.
// To ensure that the results are up to date, call
// Refresh first.
Addresses(ctx context.ProviderCallContext) (network.ProviderAddresses, error)
Addresses(ctx envcontext.ProviderCallContext) (network.ProviderAddresses, error)

// Status returns the provider-specific status for the
// instance.
Status(ctx context.ProviderCallContext) instance.Status
Status(ctx envcontext.ProviderCallContext) instance.Status
}

type RefreshableInstance struct {
Expand All @@ -552,7 +561,7 @@ type RefreshableInstance struct {
}

// Refresh refreshes the addresses for the instance.
func (i *RefreshableInstance) Refresh(ctx context.ProviderCallContext) error {
func (i *RefreshableInstance) Refresh(ctx envcontext.ProviderCallContext) error {
instances, err := i.Env.Instances(ctx, []instance.Id{i.Id()})
if err != nil {
return errors.Trace(err)
Expand Down Expand Up @@ -702,12 +711,12 @@ var connectSSH = func(client ssh.Client, host, checkHostScript string, options *
// machine's nonce. The "checkHostScript" is a bash script
// that performs this file check.
func WaitSSH(
ctx context.Context,
stdErr io.Writer,
interrupted <-chan os.Signal,
client ssh.Client,
checkHostScript string,
inst InstanceRefresher,
ctx context.ProviderCallContext,
callCtx envcontext.ProviderCallContext,
opts environs.BootstrapDialOpts,
hostSSHOptions HostSSHOptionsFunc,
) (addr string, err error) {
Expand All @@ -734,17 +743,17 @@ func WaitSSH(
select {
case <-pollAddresses.C:
pollAddresses.Reset(opts.AddressesDelay)
if err := inst.Refresh(ctx); err != nil {
if err := inst.Refresh(callCtx); err != nil {
return "", fmt.Errorf("refreshing addresses: %v", err)
}
instanceStatus := inst.Status(ctx)
instanceStatus := inst.Status(callCtx)
if instanceStatus.Status == status.ProvisioningError {
if instanceStatus.Message != "" {
return "", errors.Errorf("instance provisioning failed (%v)", instanceStatus.Message)
}
return "", errors.Errorf("instance provisioning failed")
}
addresses, err := inst.Addresses(ctx)
addresses, err := inst.Addresses(callCtx)
if err != nil {
return "", fmt.Errorf("getting addresses: %v", err)
}
Expand All @@ -764,8 +773,8 @@ func WaitSSH(
args = append(args, lastErr)
}
return "", fmt.Errorf(format, args...)
case <-interrupted:
return "", fmt.Errorf("interrupted")
case <-ctx.Done():
return "", fmt.Errorf("cancelled")
case <-checker.Dead():
result, err := checker.Result()
if err != nil {
Expand Down
57 changes: 28 additions & 29 deletions provider/common/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"crypto/rsa"
"fmt"
"io/ioutil"
"os"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -719,34 +718,34 @@ var testSSHTimeout = environs.BootstrapDialOpts{
func (s *BootstrapSuite) TestWaitSSHTimesOutWaitingForAddresses(c *gc.C) {
ctx := cmdtesting.Context(c)
_, err := common.WaitSSH(
ctx.Stderr, nil, ssh.DefaultClient, "/bin/true", neverAddresses{}, s.callCtx, testSSHTimeout,
context.Background(), ctx.Stderr, ssh.DefaultClient, "/bin/true", neverAddresses{}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, `waited for `+testSSHTimeout.Timeout.String()+` without getting any addresses`)
c.Check(cmdtesting.Stderr(ctx), gc.Matches, "Waiting for address\n")
}

func (s *BootstrapSuite) TestWaitSSHKilledWaitingForAddresses(c *gc.C) {
ctx := cmdtesting.Context(c)
interrupted := make(chan os.Signal)
close(interrupted)
cmdCtx := cmdtesting.Context(c)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := common.WaitSSH(
ctx.Stderr, interrupted, ssh.DefaultClient, "/bin/true", neverAddresses{}, s.callCtx, testSSHTimeout,
ctx, cmdCtx.Stderr, ssh.DefaultClient, "/bin/true", neverAddresses{}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, "interrupted")
c.Check(cmdtesting.Stderr(ctx), gc.Matches, "Waiting for address\n")
c.Check(err, gc.ErrorMatches, "cancelled")
c.Check(cmdtesting.Stderr(cmdCtx), gc.Matches, "Waiting for address\n")
}

func (s *BootstrapSuite) TestWaitSSHNoticesProvisioningFailures(c *gc.C) {
ctx := cmdtesting.Context(c)
_, err := common.WaitSSH(
ctx.Stderr, nil, ssh.DefaultClient, "/bin/true", failsProvisioning{}, s.callCtx, testSSHTimeout,
context.Background(), ctx.Stderr, ssh.DefaultClient, "/bin/true", failsProvisioning{}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, `instance provisioning failed`)
_, err = common.WaitSSH(
ctx.Stderr, nil, ssh.DefaultClient, "/bin/true", failsProvisioning{message: "blargh"}, s.callCtx, testSSHTimeout,
context.Background(), ctx.Stderr, ssh.DefaultClient, "/bin/true", failsProvisioning{message: "blargh"}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, `instance provisioning failed \(blargh\)`)
Expand All @@ -763,7 +762,7 @@ func (brokenAddresses) Addresses(ctx envcontext.ProviderCallContext) (corenetwor
func (s *BootstrapSuite) TestWaitSSHStopsOnBadError(c *gc.C) {
ctx := cmdtesting.Context(c)
_, err := common.WaitSSH(
ctx.Stderr, nil, ssh.DefaultClient, "/bin/true", brokenAddresses{}, s.callCtx, testSSHTimeout,
context.Background(), ctx.Stderr, ssh.DefaultClient, "/bin/true", brokenAddresses{}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, "getting addresses: Addresses will never work")
Expand All @@ -783,7 +782,7 @@ func (s *BootstrapSuite) TestWaitSSHTimesOutWaitingForDial(c *gc.C) {
ctx := cmdtesting.Context(c)
// 0.x.y.z addresses are always invalid
_, err := common.WaitSSH(
ctx.Stderr, nil, ssh.DefaultClient, "/bin/true", &neverOpensPort{addr: "0.1.2.3"}, s.callCtx, testSSHTimeout,
context.Background(), ctx.Stderr, ssh.DefaultClient, "/bin/true", &neverOpensPort{addr: "0.1.2.3"}, s.callCtx, testSSHTimeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches,
Expand All @@ -793,38 +792,38 @@ func (s *BootstrapSuite) TestWaitSSHTimesOutWaitingForDial(c *gc.C) {
"(Attempting to connect to 0.1.2.3:22\n)+")
}

type interruptOnDial struct {
type cancelOnDial struct {
neverRefreshes
name string
interrupted chan os.Signal
returned bool
name string
cancel context.CancelFunc
returned bool
}

func (i *interruptOnDial) Addresses(ctx envcontext.ProviderCallContext) (corenetwork.ProviderAddresses, error) {
func (c *cancelOnDial) Addresses(ctx envcontext.ProviderCallContext) (corenetwork.ProviderAddresses, error) {
// kill the tomb the second time Addresses is called
if !i.returned {
i.returned = true
if !c.returned {
c.returned = true
} else {
if i.interrupted != nil {
close(i.interrupted)
i.interrupted = nil
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
}
return corenetwork.NewProviderAddresses(i.name), nil
return corenetwork.NewProviderAddresses(c.name), nil
}

func (s *BootstrapSuite) TestWaitSSHKilledWaitingForDial(c *gc.C) {
ctx := cmdtesting.Context(c)
cmdCtx := cmdtesting.Context(c)
timeout := testSSHTimeout
timeout.Timeout = 1 * time.Minute
interrupted := make(chan os.Signal)
ctx, cancel := context.WithCancel(context.Background())
_, err := common.WaitSSH(
ctx.Stderr, interrupted, ssh.DefaultClient, "", &interruptOnDial{name: "0.1.2.3", interrupted: interrupted}, s.callCtx, timeout,
ctx, cmdCtx.Stderr, ssh.DefaultClient, "", &cancelOnDial{name: "0.1.2.3", cancel: cancel}, s.callCtx, timeout,
common.DefaultHostSSHOptions,
)
c.Check(err, gc.ErrorMatches, "interrupted")
c.Check(err, gc.ErrorMatches, "cancelled")
// Exact timing is imprecise but it should have tried a few times before being killed
c.Check(cmdtesting.Stderr(ctx), gc.Matches,
c.Check(cmdtesting.Stderr(cmdCtx), gc.Matches,
"Waiting for address\n"+
"(Attempting to connect to 0.1.2.3:22\n)+")
}
Expand All @@ -850,7 +849,7 @@ func (ac *addressesChange) Addresses(ctx envcontext.ProviderCallContext) (corene

func (s *BootstrapSuite) TestWaitSSHRefreshAddresses(c *gc.C) {
ctx := cmdtesting.Context(c)
_, err := common.WaitSSH(ctx.Stderr, nil, ssh.DefaultClient, "", &addressesChange{addrs: [][]string{
_, err := common.WaitSSH(context.Background(), ctx.Stderr, ssh.DefaultClient, "", &addressesChange{addrs: [][]string{
nil,
nil,
{"0.1.2.3"},
Expand Down
11 changes: 5 additions & 6 deletions provider/rackspace/environ.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
package rackspace

import (
"context"
"io/ioutil"
"os"
"time"

"github.com/juju/errors"
Expand All @@ -15,7 +15,7 @@ import (

"github.com/juju/juju/environs"
"github.com/juju/juju/environs/config"
"github.com/juju/juju/environs/context"
envcontext "github.com/juju/juju/environs/context"
"github.com/juju/juju/provider/common"
)

Expand All @@ -26,15 +26,15 @@ type environ struct {
var bootstrap = common.Bootstrap

// Bootstrap implements environs.Environ.
func (e environ) Bootstrap(ctx environs.BootstrapContext, callCtx context.ProviderCallContext, params environs.BootstrapParams) (*environs.BootstrapResult, error) {
func (e environ) Bootstrap(ctx environs.BootstrapContext, callCtx envcontext.ProviderCallContext, params environs.BootstrapParams) (*environs.BootstrapResult, error) {
// can't redirect to openstack provider as usually, because correct environ should be passed for common.Bootstrap
return bootstrap(ctx, e, callCtx, params)
}

var waitSSH = common.WaitSSH

// StartInstance implements environs.Environ.
func (e environ) StartInstance(ctx context.ProviderCallContext, args environs.StartInstanceParams) (*environs.StartInstanceResult, error) {
func (e environ) StartInstance(ctx envcontext.ProviderCallContext, args environs.StartInstanceParams) (*environs.StartInstanceResult, error) {
osString, err := series.GetOSFromSeries(args.Tools.OneSeries())
if err != nil {
return nil, errors.Trace(err)
Expand All @@ -49,15 +49,14 @@ func (e environ) StartInstance(ctx context.ProviderCallContext, args environs.St
return nil, errors.Trace(err)
}
if fwmode != config.FwNone {
interrupted := make(chan os.Signal, 1)
timeout := environs.BootstrapDialOpts{
Timeout: time.Minute * 5,
RetryDelay: time.Second * 5,
AddressesDelay: time.Second * 20,
}
addr, err := waitSSH(
context.Background(),
ioutil.Discard,
interrupted,
ssh.DefaultClient,
common.GetCheckNonceCommand(args.InstanceConfig),
&common.RefreshableInstance{r.Instance, e},
Expand Down
Loading

0 comments on commit d566acd

Please sign in to comment.