frostfs-s3-gw/api/auth/signer/v4asdk2/presign_middleware_test.go

223 lines
5.9 KiB
Go
Raw Normal View History

package v4a
import (
"bytes"
"context"
"net/http"
"net/url"
"strings"
"testing"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
type httpPresignerFunc func(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error)
func (f httpPresignerFunc) PresignHTTP(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (
url string, signedHeader http.Header, err error,
) {
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
}
func TestPresignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
Request *http.Request
Creds CredentialsProvider
PayloadHash string
LogSigning bool
ExpectResult *v4.PresignedHTTPRequest
ExpectErr string
}{
"success": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "0123456789abcdef",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"error": {
Request: func() *http.Request {
return &http.Request{}
}(),
Creds: stubCredentials,
PayloadHash: "",
ExpectErr: "failed to sign request",
},
"anonymous creds": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "",
ExpectErr: "failed to sign request",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"nil creds": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: nil,
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
},
"with log signing": {
Request: &http.Request{
URL: func() *url.URL {
u, _ := url.Parse("https://example.aws/path?query=foo")
return u
}(),
Header: http.Header{},
},
Creds: stubCredentials,
PayloadHash: "0123456789abcdef",
ExpectResult: &v4.PresignedHTTPRequest{
URL: "https://example.aws/path?query=foo",
SignedHeader: http.Header{},
},
LogSigning: true,
},
}
const (
signingName = "serviceId"
signingRegion = "regionName"
)
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
m := &PresignHTTPRequestMiddleware{
credentialsProvider: tt.Creds,
presigner: httpPresignerFunc(func(
ctx context.Context, credentials Credentials, r *http.Request,
payloadHash string, service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) (url string, signedHeader http.Header, err error) {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}
if !hasCredentialProvider(tt.Creds) {
t.Errorf("expect presigner not to be called for not credentials provider")
}
expectCreds, _ := tt.Creds.RetrievePrivateKey(context.Background())
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
t.Error(diff)
}
if e, a := tt.PayloadHash, payloadHash; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := signingName, service; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
t.Error(diff)
}
return tt.ExpectResult.URL, tt.ExpectResult.SignedHeader, nil
}),
logSigning: tt.LogSigning,
}
next := middleware.FinalizeHandlerFunc(
func(ctx context.Context, in middleware.FinalizeInput) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
t.Errorf("expect next handler not to be called")
return out, metadata, err
})
ctx := awsmiddleware.SetSigningRegion(
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)
var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)
if len(tt.PayloadHash) != 0 {
ctx = v4.SetPayloadHash(ctx, tt.PayloadHash)
}
result, _, err := m.HandleFinalize(ctx, middleware.FinalizeInput{
Request: &smithyhttp.Request{
Request: tt.Request,
},
}, next)
if len(tt.ExpectErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := tt.ExpectErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
return
}
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if diff := cmpDiff(tt.ExpectResult, result.Result); len(diff) != 0 {
t.Errorf("expect result match\n%v", diff)
}
if tt.LogSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
var (
_ middleware.FinalizeMiddleware = &PresignHTTPRequestMiddleware{}
)