150 lines
4.2 KiB
Go
150 lines
4.2 KiB
Go
|
package v4a
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||
|
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||
|
"github.com/aws/smithy-go/logging"
|
||
|
"github.com/aws/smithy-go/middleware"
|
||
|
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||
|
)
|
||
|
|
||
|
type stubCredentialsProviderFunc func(context.Context) (Credentials, error)
|
||
|
|
||
|
func (f stubCredentialsProviderFunc) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
|
||
|
return f(ctx)
|
||
|
}
|
||
|
|
||
|
type httpSignerFunc func(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error
|
||
|
|
||
|
func (f httpSignerFunc) SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
|
||
|
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
|
||
|
}
|
||
|
|
||
|
func TestSignHTTPRequestMiddleware(t *testing.T) {
|
||
|
cases := map[string]struct {
|
||
|
creds CredentialsProvider
|
||
|
hash string
|
||
|
logSigning bool
|
||
|
expectedErr interface{}
|
||
|
}{
|
||
|
"success": {
|
||
|
creds: stubCredentials,
|
||
|
hash: "0123456789abcdef",
|
||
|
},
|
||
|
"error": {
|
||
|
creds: stubCredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
|
||
|
return Credentials{}, fmt.Errorf("credential error")
|
||
|
}),
|
||
|
hash: "",
|
||
|
expectedErr: &SigningError{},
|
||
|
},
|
||
|
"nil creds": {
|
||
|
creds: nil,
|
||
|
},
|
||
|
"with log signing": {
|
||
|
creds: stubCredentials,
|
||
|
hash: "0123456789abcdef",
|
||
|
logSigning: true,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
signingName = "serviceId"
|
||
|
signingRegion = "regionName"
|
||
|
)
|
||
|
|
||
|
for name, tt := range cases {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
c := &SignHTTPRequestMiddleware{
|
||
|
credentials: tt.creds,
|
||
|
signer: httpSignerFunc(
|
||
|
func(ctx context.Context,
|
||
|
credentials Credentials, r *http.Request, payloadHash string,
|
||
|
service string, regionSet []string, signingTime time.Time,
|
||
|
optFns ...func(*SignerOptions),
|
||
|
) error {
|
||
|
var options SignerOptions
|
||
|
for _, fn := range optFns {
|
||
|
fn(&options)
|
||
|
}
|
||
|
if options.Logger == nil {
|
||
|
t.Errorf("expect logger, got none")
|
||
|
}
|
||
|
if options.LogSigning {
|
||
|
options.Logger.Logf(logging.Debug, t.Name())
|
||
|
}
|
||
|
|
||
|
expectCreds, _ := tt.creds.RetrievePrivateKey(ctx)
|
||
|
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
|
||
|
t.Error(diff)
|
||
|
}
|
||
|
if e, a := tt.hash, payloadHash; e != a {
|
||
|
t.Errorf("expected %v, got %v", e, a)
|
||
|
}
|
||
|
if e, a := signingName, service; e != a {
|
||
|
t.Errorf("expected %v, got %v", e, a)
|
||
|
}
|
||
|
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
|
||
|
t.Error(diff)
|
||
|
}
|
||
|
return nil
|
||
|
}),
|
||
|
logSigning: tt.logSigning,
|
||
|
}
|
||
|
|
||
|
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
|
||
|
return out, metadata, err
|
||
|
})
|
||
|
|
||
|
ctx := awsmiddleware.SetSigningRegion(
|
||
|
awsmiddleware.SetSigningName(context.Background(), signingName),
|
||
|
signingRegion)
|
||
|
|
||
|
var loggerBuf bytes.Buffer
|
||
|
logger := logging.NewStandardLogger(&loggerBuf)
|
||
|
ctx = middleware.SetLogger(ctx, logger)
|
||
|
|
||
|
if len(tt.hash) != 0 {
|
||
|
ctx = v4.SetPayloadHash(ctx, tt.hash)
|
||
|
}
|
||
|
|
||
|
_, _, err := c.HandleFinalize(ctx, middleware.FinalizeInput{
|
||
|
Request: &smithyhttp.Request{Request: &http.Request{}},
|
||
|
}, next)
|
||
|
if err != nil && tt.expectedErr == nil {
|
||
|
t.Errorf("expected no error, got %v", err)
|
||
|
} else if err != nil && tt.expectedErr != nil {
|
||
|
e, a := tt.expectedErr, err
|
||
|
if !errors.As(a, &e) {
|
||
|
t.Errorf("expected error type %T, got %T", e, a)
|
||
|
}
|
||
|
} else if err == nil && tt.expectedErr != nil {
|
||
|
t.Errorf("expected error, got nil")
|
||
|
}
|
||
|
|
||
|
if tt.logSigning {
|
||
|
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
|
||
|
t.Errorf("expect %v logged in %v", e, a)
|
||
|
}
|
||
|
} else {
|
||
|
if loggerBuf.Len() != 0 {
|
||
|
t.Errorf("expect no log, got %v", loggerBuf.String())
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
_ middleware.FinalizeMiddleware = &SignHTTPRequestMiddleware{}
|
||
|
)
|