223 lines
5.9 KiB
Go
223 lines
5.9 KiB
Go
|
package v4a
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||
|
"github.com/aws/smithy-go/logging"
|
||
|
"github.com/aws/smithy-go/middleware"
|
||
|
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||
|
)
|
||
|
|
||
|
type httpPresignerFunc func(
|
||
|
ctx context.Context, credentials Credentials, r *http.Request,
|
||
|
payloadHash string, service string, regionSet []string, signingTime time.Time,
|
||
|
optFns ...func(*SignerOptions),
|
||
|
) (url string, signedHeader http.Header, err error)
|
||
|
|
||
|
func (f httpPresignerFunc) PresignHTTP(
|
||
|
ctx context.Context, credentials Credentials, r *http.Request,
|
||
|
payloadHash string, service string, regionSet []string, signingTime time.Time,
|
||
|
optFns ...func(*SignerOptions),
|
||
|
) (
|
||
|
url string, signedHeader http.Header, err error,
|
||
|
) {
|
||
|
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
|
||
|
}
|
||
|
|
||
|
func TestPresignHTTPRequestMiddleware(t *testing.T) {
|
||
|
cases := map[string]struct {
|
||
|
Request *http.Request
|
||
|
Creds CredentialsProvider
|
||
|
PayloadHash string
|
||
|
LogSigning bool
|
||
|
ExpectResult *v4.PresignedHTTPRequest
|
||
|
ExpectErr string
|
||
|
}{
|
||
|
"success": {
|
||
|
Request: &http.Request{
|
||
|
URL: func() *url.URL {
|
||
|
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||
|
return u
|
||
|
}(),
|
||
|
Header: http.Header{},
|
||
|
},
|
||
|
Creds: stubCredentials,
|
||
|
PayloadHash: "0123456789abcdef",
|
||
|
ExpectResult: &v4.PresignedHTTPRequest{
|
||
|
URL: "https://example.aws/path?query=foo",
|
||
|
SignedHeader: http.Header{},
|
||
|
},
|
||
|
},
|
||
|
"error": {
|
||
|
Request: func() *http.Request {
|
||
|
return &http.Request{}
|
||
|
}(),
|
||
|
Creds: stubCredentials,
|
||
|
PayloadHash: "",
|
||
|
ExpectErr: "failed to sign request",
|
||
|
},
|
||
|
"anonymous creds": {
|
||
|
Request: &http.Request{
|
||
|
URL: func() *url.URL {
|
||
|
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||
|
return u
|
||
|
}(),
|
||
|
Header: http.Header{},
|
||
|
},
|
||
|
Creds: stubCredentials,
|
||
|
PayloadHash: "",
|
||
|
ExpectErr: "failed to sign request",
|
||
|
ExpectResult: &v4.PresignedHTTPRequest{
|
||
|
URL: "https://example.aws/path?query=foo",
|
||
|
SignedHeader: http.Header{},
|
||
|
},
|
||
|
},
|
||
|
"nil creds": {
|
||
|
Request: &http.Request{
|
||
|
URL: func() *url.URL {
|
||
|
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||
|
return u
|
||
|
}(),
|
||
|
Header: http.Header{},
|
||
|
},
|
||
|
Creds: nil,
|
||
|
ExpectResult: &v4.PresignedHTTPRequest{
|
||
|
URL: "https://example.aws/path?query=foo",
|
||
|
SignedHeader: http.Header{},
|
||
|
},
|
||
|
},
|
||
|
"with log signing": {
|
||
|
Request: &http.Request{
|
||
|
URL: func() *url.URL {
|
||
|
u, _ := url.Parse("https://example.aws/path?query=foo")
|
||
|
return u
|
||
|
}(),
|
||
|
Header: http.Header{},
|
||
|
},
|
||
|
Creds: stubCredentials,
|
||
|
PayloadHash: "0123456789abcdef",
|
||
|
ExpectResult: &v4.PresignedHTTPRequest{
|
||
|
URL: "https://example.aws/path?query=foo",
|
||
|
SignedHeader: http.Header{},
|
||
|
},
|
||
|
|
||
|
LogSigning: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
signingName = "serviceId"
|
||
|
signingRegion = "regionName"
|
||
|
)
|
||
|
|
||
|
for name, tt := range cases {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
m := &PresignHTTPRequestMiddleware{
|
||
|
credentialsProvider: tt.Creds,
|
||
|
|
||
|
presigner: httpPresignerFunc(func(
|
||
|
ctx context.Context, credentials Credentials, r *http.Request,
|
||
|
payloadHash string, service string, regionSet []string, signingTime time.Time,
|
||
|
optFns ...func(*SignerOptions),
|
||
|
) (url string, signedHeader http.Header, err error) {
|
||
|
var options SignerOptions
|
||
|
for _, fn := range optFns {
|
||
|
fn(&options)
|
||
|
}
|
||
|
if options.Logger == nil {
|
||
|
t.Errorf("expect logger, got none")
|
||
|
}
|
||
|
if options.LogSigning {
|
||
|
options.Logger.Logf(logging.Debug, t.Name())
|
||
|
}
|
||
|
|
||
|
if !hasCredentialProvider(tt.Creds) {
|
||
|
t.Errorf("expect presigner not to be called for not credentials provider")
|
||
|
}
|
||
|
|
||
|
expectCreds, _ := tt.Creds.RetrievePrivateKey(context.Background())
|
||
|
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
|
||
|
t.Error(diff)
|
||
|
}
|
||
|
if e, a := tt.PayloadHash, payloadHash; e != a {
|
||
|
t.Errorf("expected %v, got %v", e, a)
|
||
|
}
|
||
|
if e, a := signingName, service; e != a {
|
||
|
t.Errorf("expected %v, got %v", e, a)
|
||
|
}
|
||
|
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
|
||
|
t.Error(diff)
|
||
|
}
|
||
|
|
||
|
return tt.ExpectResult.URL, tt.ExpectResult.SignedHeader, nil
|
||
|
}),
|
||
|
logSigning: tt.LogSigning,
|
||
|
}
|
||
|
|
||
|
next := middleware.FinalizeHandlerFunc(
|
||
|
func(ctx context.Context, in middleware.FinalizeInput) (
|
||
|
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||
|
) {
|
||
|
t.Errorf("expect next handler not to be called")
|
||
|
return out, metadata, err
|
||
|
})
|
||
|
|
||
|
ctx := awsmiddleware.SetSigningRegion(
|
||
|
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||
|
signingRegion)
|
||
|
|
||
|
var loggerBuf bytes.Buffer
|
||
|
logger := logging.NewStandardLogger(&loggerBuf)
|
||
|
ctx = middleware.SetLogger(ctx, logger)
|
||
|
|
||
|
if len(tt.PayloadHash) != 0 {
|
||
|
ctx = v4.SetPayloadHash(ctx, tt.PayloadHash)
|
||
|
}
|
||
|
|
||
|
result, _, err := m.HandleFinalize(ctx, middleware.FinalizeInput{
|
||
|
Request: &smithyhttp.Request{
|
||
|
Request: tt.Request,
|
||
|
},
|
||
|
}, next)
|
||
|
if len(tt.ExpectErr) != 0 {
|
||
|
if err == nil {
|
||
|
t.Fatalf("expect error, got none")
|
||
|
}
|
||
|
if e, a := tt.ExpectErr, err.Error(); !strings.Contains(a, e) {
|
||
|
t.Fatalf("expect error to contain %v, got %v", e, a)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
if err != nil {
|
||
|
t.Fatalf("expect no error, got %v", err)
|
||
|
}
|
||
|
|
||
|
if diff := cmpDiff(tt.ExpectResult, result.Result); len(diff) != 0 {
|
||
|
t.Errorf("expect result match\n%v", diff)
|
||
|
}
|
||
|
|
||
|
if tt.LogSigning {
|
||
|
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||
|
t.Errorf("expect %v logged in %v", e, a)
|
||
|
}
|
||
|
} else {
|
||
|
if loggerBuf.Len() != 0 {
|
||
|
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
_ middleware.FinalizeMiddleware = &PresignHTTPRequestMiddleware{}
|
||
|
)
|