Add unit tests for awskms.
This commit is contained in:
parent
d4cb9f4ac7
commit
aaf71ce66a
4 changed files with 631 additions and 9 deletions
358
kms/awskms/awskms_test.go
Normal file
358
kms/awskms/awskms_test.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
package awskms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
sess, err := session.NewSessionWithOptions(session.Options{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := &KMS{
|
||||
session: sess,
|
||||
service: kms.New(sess),
|
||||
}
|
||||
|
||||
// This will force an error in the session creation.
|
||||
// It does not fail with missing credentials.
|
||||
forceError := func(t *testing.T) {
|
||||
key := "AWS_CA_BUNDLE"
|
||||
value := os.Getenv(key)
|
||||
os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt"))
|
||||
t.Cleanup(func() {
|
||||
if value == "" {
|
||||
os.Unsetenv(key)
|
||||
} else {
|
||||
os.Setenv(key, value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
opts apiv1.Options
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *KMS
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, apiv1.Options{}}, expected, false},
|
||||
{"ok with options", args{ctx, apiv1.Options{
|
||||
Region: "us-east-1",
|
||||
Profile: "smallstep",
|
||||
CredentialsFile: "~/aws/credentials",
|
||||
}}, expected, false},
|
||||
{"fail", args{ctx, apiv1.Options{}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Force an error in the session loading
|
||||
if tt.wantErr {
|
||||
forceError(t)
|
||||
}
|
||||
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.session == nil || got.service == nil {
|
||||
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_GetPublicKey(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.GetPublicKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.PublicKey
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, key, false},
|
||||
{"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
||||
{"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=",
|
||||
}}, nil, true},
|
||||
{"fail getPublicKey", fields{nil, &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, nil, true},
|
||||
{"fail not der", fields{nil, &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: []byte(publicKey),
|
||||
}, nil
|
||||
},
|
||||
}}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.GetPublicKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_CreateKey(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateKeyRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *apiv1.CreateKeyResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
PublicKey: key,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
},
|
||||
}, false},
|
||||
{"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
Bits: 2048,
|
||||
}}, &apiv1.CreateKeyResponse{
|
||||
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
PublicKey: key,
|
||||
CreateSignerRequest: apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
},
|
||||
}, false},
|
||||
{"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
||||
{"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.PureEd25519,
|
||||
}}, nil, true},
|
||||
{"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
||||
Bits: 1234,
|
||||
}}, nil, true},
|
||||
{"fail createKey", fields{nil, &MockClient{
|
||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
createAliasWithContext: okClient.createAliasWithContext,
|
||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail createAlias", fields{nil, &MockClient{
|
||||
createKeyWithContext: okClient.createKeyWithContext,
|
||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
{"fail getPublicKey", fields{nil, &MockClient{
|
||||
createKeyWithContext: okClient.createKeyWithContext,
|
||||
createAliasWithContext: okClient.createAliasWithContext,
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}}, args{&apiv1.CreateKeyRequest{
|
||||
Name: "root",
|
||||
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.CreateKey(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_CreateSigner(t *testing.T) {
|
||||
client := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
type args struct {
|
||||
req *apiv1.CreateSignerRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want crypto.Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{
|
||||
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
}}, &Signer{
|
||||
service: client,
|
||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
publicKey: key,
|
||||
}, false},
|
||||
{"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
||||
{"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
got, err := k.CreateSigner(tt.args.req)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKMS_Close(t *testing.T) {
|
||||
type fields struct {
|
||||
session *session.Session
|
||||
service KeyManagementClient
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{nil, getOKClient()}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &KMS{
|
||||
session: tt.fields.session,
|
||||
service: tt.fields.service,
|
||||
}
|
||||
if err := k.Close(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_parseKeyID(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
||||
{"fail parse", args{"awskms:key-id=%ZZ"}, "", true},
|
||||
{"fail empty key", args{"awskms:key-id="}, "", true},
|
||||
{"fail missing", args{"awskms:foo=bar"}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseKeyID(tt.args.name)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("parseKeyID() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
72
kms/awskms/mock_test.go
Normal file
72
kms/awskms/mock_test.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package awskms
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error)
|
||||
createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error)
|
||||
createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error)
|
||||
signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return m.getPublicKeyWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
return m.createKeyWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return m.createAliasWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return m.signWithContext(ctx, input, opts...)
|
||||
}
|
||||
|
||||
const (
|
||||
publicKey = `-----BEGIN PUBLIC KEY-----
|
||||
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE8XWlIWkOThxNjGbZLYUgRHmsvCrW
|
||||
KF+HLktPfPTIK3lGd1k4849WQs59XIN+LXZQ6b2eRBEBKAHEyQus8UU7gw==
|
||||
-----END PUBLIC KEY-----`
|
||||
keyId = "be468355-ca7a-40d9-a28b-8ae1c4c7f936"
|
||||
)
|
||||
|
||||
var signature = []byte{
|
||||
0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
|
||||
0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
|
||||
}
|
||||
|
||||
func getOKClient() *MockClient {
|
||||
return &MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
block, _ := pem.Decode([]byte(publicKey))
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: block.Bytes,
|
||||
}, nil
|
||||
},
|
||||
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
||||
md := new(kms.KeyMetadata)
|
||||
md.SetKeyId(keyId)
|
||||
return &kms.CreateKeyOutput{
|
||||
KeyMetadata: md,
|
||||
}, nil
|
||||
},
|
||||
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
||||
return &kms.CreateAliasOutput{}, nil
|
||||
},
|
||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return &kms.SignOutput{
|
||||
Signature: signature,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
// Signer implements a crypto.Signer using the AWS KMS.
|
||||
type Signer struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
|
@ -88,30 +89,30 @@ func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string,
|
|||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
if isPSS {
|
||||
return "RSASSA_PSS_SHA_256", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha256, nil
|
||||
}
|
||||
return "RSASSA_PKCS1_V1_5_SHA_256", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil
|
||||
case crypto.SHA384:
|
||||
if isPSS {
|
||||
return "RSASSA_PSS_SHA_384", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha384, nil
|
||||
}
|
||||
return "RSASSA_PKCS1_V1_5_SHA_384", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil
|
||||
case crypto.SHA512:
|
||||
if isPSS {
|
||||
return "RSASSA_PSS_SHA_512", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPssSha512, nil
|
||||
}
|
||||
return "RSASSA_PKCS1_V1_5_SHA_512", nil
|
||||
return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
switch h := opts.HashFunc(); h {
|
||||
case crypto.SHA256:
|
||||
return "ECDSA_SHA_256", nil
|
||||
return kms.SigningAlgorithmSpecEcdsaSha256, nil
|
||||
case crypto.SHA384:
|
||||
return "ECDSA_SHA_384", nil
|
||||
return kms.SigningAlgorithmSpecEcdsaSha384, nil
|
||||
case crypto.SHA512:
|
||||
return "ECDSA_SHA_512", nil
|
||||
return kms.SigningAlgorithmSpecEcdsaSha512, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported hash function %v", h)
|
||||
}
|
||||
|
|
191
kms/awskms/signer_test.go
Normal file
191
kms/awskms/signer_test.go
Normal file
|
@ -0,0 +1,191 @@
|
|||
package awskms
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/request"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestNewSigner(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type args struct {
|
||||
svc KeyManagementClient
|
||||
signingKey string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Signer
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{
|
||||
service: okClient,
|
||||
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
||||
publicKey: key,
|
||||
}, false},
|
||||
{"fail parse", args{okClient, "awskms:key-id="}, nil, true},
|
||||
{"fail preload", args{&MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
||||
{"fail preload not der", args{&MockClient{
|
||||
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
||||
return &kms.GetPublicKeyOutput{
|
||||
KeyId: input.KeyId,
|
||||
PublicKey: []byte(publicKey),
|
||||
}, nil
|
||||
},
|
||||
}, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewSigner(tt.args.svc, tt.args.signingKey)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewSigner() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Public(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want crypto.PublicKey
|
||||
}{
|
||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
service: tt.fields.service,
|
||||
keyID: tt.fields.keyID,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Public() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign(t *testing.T) {
|
||||
okClient := getOKClient()
|
||||
key, err := pemutil.ParseKey([]byte(publicKey))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
service KeyManagementClient
|
||||
keyID string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false},
|
||||
{"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
|
||||
{"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
||||
{"fail sign", fields{&MockClient{
|
||||
signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
},
|
||||
}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
service: tt.fields.service,
|
||||
keyID: tt.fields.keyID,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getSigningAlgorithm(t *testing.T) {
|
||||
type args struct {
|
||||
key crypto.PublicKey
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},
|
||||
{"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false},
|
||||
{"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false},
|
||||
{"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false},
|
||||
{"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false},
|
||||
{"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false},
|
||||
{"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false},
|
||||
{"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false},
|
||||
{"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false},
|
||||
{"fail type", args{[]byte("key"), crypto.SHA256}, "", true},
|
||||
{"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true},
|
||||
{"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := getSigningAlgorithm(tt.args.key, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue