Skip to content

Commit

Permalink
downloader: Big clean up
Browse files Browse the repository at this point in the history
The Download and Downloader were unnecessarily complicated with more
than one way to do things and confused aborting of downloads.

- Check req.Verify closer to where it's used to minimise the possibility
  of mistakes.
- Fixed/removed racy tests for stopping downloads. This also shaved 10s
  of the package test runtime.
- Removed the use of both a tomb and a channel to abort.
  downloads. Downloads are now only aborted via a channel passed in at
  download start time.
- The downloaded file is now removed if verification fails.
- The download filename is passed around instead of using a *os.File
  which needs to be re-seeked back to the start of file all the time.
- Removed Downloader most tests which were actually functionality which
  is already being tested on Download.
  • Loading branch information
Menno Smits committed Sep 26, 2016
1 parent 477fe7a commit 4a88d70
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 219 deletions.
161 changes: 65 additions & 96 deletions downloader/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/juju/errors"
"github.com/juju/utils"
"gopkg.in/tomb.v1"
)

// Request holds a single download request.
Expand All @@ -27,125 +26,88 @@ type Request struct {
// the download is invalid then the func must return errors.NotValid.
// If no func is provided then no verification happens.
Verify func(*os.File) error

// Abort is a channel that will cancel the download when it is closed.
Abort <-chan struct{}
}

// Status represents the status of a completed download.
type Status struct {
// File holds the downloaded data on success.
File *os.File
// Filename is the name of the file which holds the downloaded
// data on success.
Filename string

// Err describes any error encountered while downloading.
Err error
}

// Download can download a file from the network.
type Download struct {
tomb tomb.Tomb
done chan Status
openBlob func(*url.URL) (io.ReadCloser, error)
}

// StartDownload returns a new Download instance based on the provided
// request. openBlob is used to gain access to the blob, whether through
// an HTTP request or some other means.
// StartDownload starts a new download as specified by `req` using
// `openBlob` to actually pull the remote data.
func StartDownload(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
dl := newDownload(openBlob)
go dl.run(req)
return dl
}

func newDownload(openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
if openBlob == nil {
openBlob = NewHTTPBlobOpener(utils.NoVerifySSLHostnames)
}
return &Download{
done: make(chan Status),
dl := &Download{
done: make(chan Status, 1),
openBlob: openBlob,
}
go dl.run(req)
return dl
}

// Stop stops any download that's in progress.
func (dl *Download) Stop() {
dl.tomb.Kill(nil)
dl.tomb.Wait()
// Download can download a file from the network.
type Download struct {
done chan Status
openBlob func(*url.URL) (io.ReadCloser, error)
}

// Done returns a channel that receives a status when the download has
// completed. It is the receiver's responsibility to close and remove
// the received file.
// completed or is aborted. Exactly one Status value will be sent for
// each download once it finishes (successfully or otherwise) or is
// aborted.
//
// It is the receiver's responsibility to handle and remove the
// downloaded file.
func (dl *Download) Done() <-chan Status {
return dl.done
}

// Wait blocks until the download completes or the abort channel receives.
func (dl *Download) Wait(abort <-chan struct{}) (*os.File, error) {
defer dl.Stop()

select {
case <-abort:
logger.Infof("download aborted")
return nil, errors.New("aborted")
case status := <-dl.Done():
if status.Err != nil {
if status.File != nil {
if err := status.File.Close(); err != nil {
logger.Errorf("failed to close file: %v", err)
}
}
return nil, errors.Trace(status.Err)
}
return status.File, nil
}
// Wait blocks until the download finishes (successfully or
// otherwise), or the download is aborted. There will only be a
// filename if err is nil.
func (dl *Download) Wait() (string, error) {
// No select required here because each download will always
// return a value once it completes. Downloads can be aborted via
// the Abort channel provided a creation time.
status := <-dl.Done()
return status.Filename, errors.Trace(status.Err)
}

func (dl *Download) run(req Request) {
defer dl.tomb.Done()

// TODO(dimitern) 2013-10-03 bug #1234715
// Add a testing HTTPS storage to verify the
// disableSSLHostnameVerification behavior here.
file, err := download(req, dl.openBlob)
filename, err := dl.download(req)
if err != nil {
err = errors.Annotatef(err, "cannot download %q", req.URL)
}

if err == nil {
} else {
logger.Infof("download complete (%q)", req.URL)
if req.Verify != nil {
err = verifyDownload(file, req)
err = verifyDownload(filename, req)
if err != nil {
os.Remove(filename)
filename = ""
}
}

status := Status{
File: file,
Err: err,
}
select {
case dl.done <- status:
// no-op
case <-dl.tomb.Dying():
cleanTempFile(file)
// No select needed here because the channel has a size of 1 and
// will only be written to once.
dl.done <- Status{
Filename: filename,
Err: err,
}
}

func verifyDownload(file *os.File, req Request) error {
err := req.Verify(file)
if err != nil {
if errors.IsNotValid(err) {
logger.Errorf("download of %s invalid: %v", req.URL, err)
}
return errors.Trace(err)
}
logger.Infof("download verified (%q)", req.URL)

if _, err := file.Seek(0, os.SEEK_SET); err != nil {
logger.Errorf("failed to seek to beginning of file: %v", err)
return errors.Trace(err)
}
return nil
}

func download(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) (file *os.File, err error) {
func (dl *Download) download(req Request) (filename string, err error) {
logger.Infof("downloading from %s", req.URL)

dir := req.TargetDir
Expand All @@ -154,37 +116,44 @@ func download(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) (file
}
tempFile, err := ioutil.TempFile(dir, "inprogress-")
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
defer func() {
tempFile.Close()
if err != nil {
cleanTempFile(tempFile)
os.Remove(tempFile.Name())
}
}()

reader, err := openBlob(req.URL)
reader, err := dl.openBlob(req.URL)
if err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
defer reader.Close()

// XXX this should honor the Abort channel
_, err = io.Copy(tempFile, reader)
if err != nil {
return nil, errors.Trace(err)
}
if _, err := tempFile.Seek(0, 0); err != nil {
return nil, errors.Trace(err)
return "", errors.Trace(err)
}
return tempFile, nil

return tempFile.Name(), nil
}

func cleanTempFile(f *os.File) {
if f == nil {
return
func verifyDownload(filename string, req Request) error {
if req.Verify == nil {
return nil
}

file, err := os.Open(filename)
if err != nil {
return errors.Annotate(err, "opening for verify")
}
defer file.Close()

f.Close()
if err := os.Remove(f.Name()); err != nil {
logger.Errorf("cannot remove temp file %q: %v", f.Name(), err)
if err := req.Verify(file); err != nil {
return errors.Trace(err)
}
logger.Infof("download verified (%q)", req.URL)
return nil
}
66 changes: 23 additions & 43 deletions downloader/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"net/url"
"os"
"path/filepath"
"time"

"github.com/juju/errors"
gitjujutesting "github.com/juju/testing"
Expand Down Expand Up @@ -65,13 +64,11 @@ func (s *DownloadSuite) testDownload(c *gc.C, hostnameVerification utils.SSLHost
downloader.NewHTTPBlobOpener(hostnameVerification),
)
status := <-d.Done()
defer status.File.Close()
c.Assert(status.Err, gc.IsNil)
c.Assert(status.File, gc.NotNil)

dir, _ := filepath.Split(status.File.Name())
dir, _ := filepath.Split(status.Filename)
c.Assert(filepath.Clean(dir), gc.Equals, tmp)
assertFileContents(c, status.File, "archive")
assertFileContents(c, status.Filename, "archive")
}

func (s *DownloadSuite) TestDownloadWithoutDisablingSSLHostnameVerification(c *gc.C) {
Expand All @@ -84,36 +81,18 @@ func (s *DownloadSuite) TestDownloadWithDisablingSSLHostnameVerification(c *gc.C

func (s *DownloadSuite) TestDownloadError(c *gc.C) {
gitjujutesting.Server.Response(404, nil, nil)
d := downloader.StartDownload(
downloader.Request{
URL: s.URL(c, "/archive.tgz"),
TargetDir: c.MkDir(),
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-d.Done()
c.Assert(status.File, gc.IsNil)
c.Assert(status.Err, gc.ErrorMatches, `cannot download ".*": bad http response: 404 Not Found`)
}

func (s *DownloadSuite) TestStop(c *gc.C) {
tmp := c.MkDir()
d := downloader.StartDownload(
downloader.Request{
URL: s.URL(c, "/x.tgz"),
URL: s.URL(c, "/archive.tgz"),
TargetDir: tmp,
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
d.Stop()
select {
case status := <-d.Done():
c.Fatalf("received status %#v after stop", status)
case <-time.After(testing.ShortWait):
}
infos, err := ioutil.ReadDir(tmp)
c.Assert(err, jc.ErrorIsNil)
c.Assert(infos, gc.HasLen, 0)
filename, err := d.Wait()
c.Assert(filename, gc.Equals, "")
c.Assert(err, gc.ErrorMatches, `cannot download ".*": bad http response: 404 Not Found`)
checkDirEmpty(c, tmp)
}

func (s *DownloadSuite) TestVerifyValid(c *gc.C) {
Expand All @@ -131,11 +110,10 @@ func (s *DownloadSuite) TestVerifyValid(c *gc.C) {
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-dl.Done()
c.Assert(status.Err, jc.ErrorIsNil)

filename, err := dl.Wait()
c.Assert(err, jc.ErrorIsNil)
c.Check(filename, gc.Not(gc.Equals), "")
stub.CheckCallNames(c, "Verify")
stub.CheckCall(c, 0, "Verify", status.File)
}

func (s *DownloadSuite) TestVerifyInvalid(c *gc.C) {
Expand All @@ -154,19 +132,21 @@ func (s *DownloadSuite) TestVerifyInvalid(c *gc.C) {
},
downloader.NewHTTPBlobOpener(utils.VerifySSLHostnames),
)
status := <-dl.Done()

c.Check(errors.Cause(status.Err), gc.Equals, invalid)
filename, err := dl.Wait()
c.Check(filename, gc.Equals, "")
c.Check(errors.Cause(err), gc.Equals, invalid)
stub.CheckCallNames(c, "Verify")
stub.CheckCall(c, 0, "Verify", status.File)
checkDirEmpty(c, tmp)
}

func assertFileContents(c *gc.C, filename, expect string) {
got, err := ioutil.ReadFile(filename)
c.Assert(err, jc.ErrorIsNil)
c.Check(string(got), gc.Equals, expect)
}

func assertFileContents(c *gc.C, f *os.File, expect string) {
got, err := ioutil.ReadAll(f)
func checkDirEmpty(c *gc.C, dir string) {
files, err := ioutil.ReadDir(dir)
c.Assert(err, jc.ErrorIsNil)
if !c.Check(string(got), gc.Equals, expect) {
info, err := f.Stat()
c.Assert(err, jc.ErrorIsNil)
c.Logf("info %#v", info)
}
c.Check(files, gc.HasLen, 0)
}
Loading

0 comments on commit 4a88d70

Please sign in to comment.