Skip to content

Commit

Permalink
Add endpoint host validation
Browse files Browse the repository at this point in the history
  • Loading branch information
JordonPhillips committed Sep 23, 2020
1 parent 1fec7b6 commit b1bda39
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
7 changes: 6 additions & 1 deletion transport/http/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ func (c ClientHandler) Handle(ctx context.Context, input interface{}) (
return nil, metadata, fmt.Errorf("expect Smithy http.Request value as input, got unsupported type %T", input)
}

resp, err := c.client.Do(req.Build(ctx))
builtRequest := req.Build(ctx)
if err := ValidateEndpointHost(builtRequest.Host); err != nil {
return nil, metadata, err
}

resp, err := c.client.Do(builtRequest)
if err != nil {
err = &RequestSendError{Err: err}

Expand Down
53 changes: 53 additions & 0 deletions transport/http/host.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package http

import (
"fmt"
"strings"
)

// ValidateEndpointHost validates that the host string passed in is a valid RFC
// 3986 host. Returns error if the host is not valid.
func ValidateEndpointHost(host string) error {
var errors strings.Builder
labels := strings.Split(host, ".")

for i, label := range labels {
if i == len(labels)-1 && len(label) == 0 {
// Allow trailing dot for FQDN hosts.
continue
}

if !ValidHostLabel(label) {
errors.WriteString("\nendpoint host domain labels must match \"[a-zA-Z0-9-]{1,63}\", but found: ")
errors.WriteString(label)
}
}

if len(host) > 255 {
errors.WriteString(fmt.Sprintf("\nendpoint host must be less than 255 characters, but was %d", len(host)))
}

if len(errors.String()) > 0 {
return fmt.Errorf("invalid endpoint host%s", errors.String())
}
return nil
}

// ValidHostLabel returns if the label is a valid RFC 3986 host label.
func ValidHostLabel(label string) bool {
if l := len(label); l == 0 || l > 63 {
return false
}
for _, r := range label {
switch {
case r >= '0' && r <= '9':
case r >= 'A' && r <= 'Z':
case r >= 'a' && r <= 'z':
case r == '-':
default:
return false
}
}

return true
}
61 changes: 61 additions & 0 deletions transport/http/host_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package http

import (
"strconv"
"testing"
)

func TestValidHostLabel(t *testing.T) {
cases := []struct {
Input string
Valid bool
}{
{Input: "abc123", Valid: true},
{Input: "123", Valid: true},
{Input: "abc", Valid: true},
{Input: "123-abc", Valid: true},
{Input: "{thing}-abc", Valid: false},
{Input: "abc.123", Valid: false},
{Input: "abc/123", Valid: false},
{Input: "012345678901234567890123456789012345678901234567890123456789123", Valid: true},
{Input: "0123456789012345678901234567890123456789012345678901234567891234", Valid: false},
{Input: "", Valid: false},
}

for i, c := range cases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
valid := ValidHostLabel(c.Input)
if e, a := c.Valid, valid; e != a {
t.Errorf("expect valid %v, got %v", e, a)
}
})
}
}

func TestValidateEndpointHostHandler(t *testing.T) {
cases := map[string]struct {
Input string
Valid bool
}{
"valid host": {Input: "abc.123", Valid: true},
"fqdn host": {Input: "abc.123.", Valid: true},
"empty label": {Input: "abc..", Valid: false},
"max host len": {
Input: "123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.12345",
Valid: true,
},
"too long host": {
Input: "123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456789.123456",
Valid: false,
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
err := ValidateEndpointHost(c.Input)
if e, a := c.Valid, err == nil; e != a {
t.Errorf("expect valid %v, got %v, %v", e, a, err)
}
})
}
}

0 comments on commit b1bda39

Please sign in to comment.