-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathmiddleware_content_length.go
84 lines (69 loc) · 2.64 KB
/
middleware_content_length.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package http
import (
"context"
"fmt"
"github.com/aws/smithy-go/middleware"
)
// ComputeContentLength provides a middleware to set the content-length
// header for the length of a serialize request body.
type ComputeContentLength struct {
}
// AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
// stack's Build step.
func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
return stack.Build.Add(&ComputeContentLength{}, middleware.After)
}
// ID returns the identifier for the ComputeContentLength.
func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *ComputeContentLength) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// do nothing if request content-length was set to 0 or above.
if req.ContentLength >= 0 {
return next.HandleBuild(ctx, in)
}
// attempt to compute stream length
if n, ok, err := req.StreamLength(); err != nil {
return out, metadata, fmt.Errorf(
"failed getting length of request stream, %w", err)
} else if ok {
req.ContentLength = n
}
return next.HandleBuild(ctx, in)
}
// validateContentLength provides a middleware to validate the content-length
// is valid (greater than zero), for the serialized request payload.
type validateContentLength struct{}
// ValidateContentLengthHeader adds middleware that validates request content-length
// is set to value greater than zero.
func ValidateContentLengthHeader(stack *middleware.Stack) error {
return stack.Build.Add(&validateContentLength{}, middleware.After)
}
// ID returns the identifier for the ComputeContentLength.
func (m *validateContentLength) ID() string { return "ValidateContentLength" }
// HandleBuild adds the length of the serialized request to the HTTP header
// if the length can be determined.
func (m *validateContentLength) HandleBuild(
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
) (
out middleware.BuildOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*Request)
if !ok {
return out, metadata, fmt.Errorf("unknown request type %T", req)
}
// if request content-length was set to less than 0, return an error
if req.ContentLength < 0 {
return out, metadata, fmt.Errorf(
"content length for payload is required and must be at least 0")
}
return next.HandleBuild(ctx, in)
}