frostfs-s3-gw/api/auth/signer/v4asdk2/middleware_test.go

150 lines
4.2 KiB
Go
Raw Normal View History

package v4a
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"strings"
"testing"
"time"
awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/logging"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
type stubCredentialsProviderFunc func(context.Context) (Credentials, error)
func (f stubCredentialsProviderFunc) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
return f(ctx)
}
type httpSignerFunc func(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error
func (f httpSignerFunc) SignHTTP(ctx context.Context, credentials Credentials, r *http.Request, payloadHash string, service string, regionSet []string, signingTime time.Time, optFns ...func(*SignerOptions)) error {
return f(ctx, credentials, r, payloadHash, service, regionSet, signingTime, optFns...)
}
func TestSignHTTPRequestMiddleware(t *testing.T) {
cases := map[string]struct {
creds CredentialsProvider
hash string
logSigning bool
expectedErr interface{}
}{
"success": {
creds: stubCredentials,
hash: "0123456789abcdef",
},
"error": {
creds: stubCredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
return Credentials{}, fmt.Errorf("credential error")
}),
hash: "",
expectedErr: &SigningError{},
},
"nil creds": {
creds: nil,
},
"with log signing": {
creds: stubCredentials,
hash: "0123456789abcdef",
logSigning: true,
},
}
const (
signingName = "serviceId"
signingRegion = "regionName"
)
for name, tt := range cases {
t.Run(name, func(t *testing.T) {
c := &SignHTTPRequestMiddleware{
credentials: tt.creds,
signer: httpSignerFunc(
func(ctx context.Context,
credentials Credentials, r *http.Request, payloadHash string,
service string, regionSet []string, signingTime time.Time,
optFns ...func(*SignerOptions),
) error {
var options SignerOptions
for _, fn := range optFns {
fn(&options)
}
if options.Logger == nil {
t.Errorf("expect logger, got none")
}
if options.LogSigning {
options.Logger.Logf(logging.Debug, t.Name())
}
expectCreds, _ := tt.creds.RetrievePrivateKey(ctx)
if diff := cmpDiff(expectCreds, credentials); len(diff) > 0 {
t.Error(diff)
}
if e, a := tt.hash, payloadHash; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if e, a := signingName, service; e != a {
t.Errorf("expected %v, got %v", e, a)
}
if diff := cmpDiff([]string{signingRegion}, regionSet); len(diff) > 0 {
t.Error(diff)
}
return nil
}),
logSigning: tt.logSigning,
}
next := middleware.FinalizeHandlerFunc(func(ctx context.Context, in middleware.FinalizeInput) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
return out, metadata, err
})
ctx := awsmiddleware.SetSigningRegion(
awsmiddleware.SetSigningName(context.Background(), signingName),
signingRegion)
var loggerBuf bytes.Buffer
logger := logging.NewStandardLogger(&loggerBuf)
ctx = middleware.SetLogger(ctx, logger)
if len(tt.hash) != 0 {
ctx = v4.SetPayloadHash(ctx, tt.hash)
}
_, _, err := c.HandleFinalize(ctx, middleware.FinalizeInput{
Request: &smithyhttp.Request{Request: &http.Request{}},
}, next)
if err != nil && tt.expectedErr == nil {
t.Errorf("expected no error, got %v", err)
} else if err != nil && tt.expectedErr != nil {
e, a := tt.expectedErr, err
if !errors.As(a, &e) {
t.Errorf("expected error type %T, got %T", e, a)
}
} else if err == nil && tt.expectedErr != nil {
t.Errorf("expected error, got nil")
}
if tt.logSigning {
if e, a := t.Name(), loggerBuf.String(); !strings.Contains(a, e) {
t.Errorf("expect %v logged in %v", e, a)
}
} else {
if loggerBuf.Len() != 0 {
t.Errorf("expect no log, got %v", loggerBuf.String())
}
}
})
}
}
var (
_ middleware.FinalizeMiddleware = &SignHTTPRequestMiddleware{}
)