forked from TrueCloudLab/certificates
358 lines
10 KiB
Go
358 lines
10 KiB
Go
package awskms
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/aws/aws-sdk-go/aws"
|
|
"github.com/aws/aws-sdk-go/aws/request"
|
|
"github.com/aws/aws-sdk-go/aws/session"
|
|
"github.com/aws/aws-sdk-go/service/kms"
|
|
"github.com/smallstep/certificates/kms/apiv1"
|
|
"go.step.sm/crypto/pemutil"
|
|
)
|
|
|
|
func TestNew(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
sess, err := session.NewSessionWithOptions(session.Options{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
expected := &KMS{
|
|
session: sess,
|
|
service: kms.New(sess),
|
|
}
|
|
|
|
// This will force an error in the session creation.
|
|
// It does not fail with missing credentials.
|
|
forceError := func(t *testing.T) {
|
|
key := "AWS_CA_BUNDLE"
|
|
value := os.Getenv(key)
|
|
os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt"))
|
|
t.Cleanup(func() {
|
|
if value == "" {
|
|
os.Unsetenv(key)
|
|
} else {
|
|
os.Setenv(key, value)
|
|
}
|
|
})
|
|
}
|
|
|
|
type args struct {
|
|
ctx context.Context
|
|
opts apiv1.Options
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *KMS
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{ctx, apiv1.Options{}}, expected, false},
|
|
{"ok with options", args{ctx, apiv1.Options{
|
|
Region: "us-east-1",
|
|
Profile: "smallstep",
|
|
CredentialsFile: "~/aws/credentials",
|
|
}}, expected, false},
|
|
{"fail", args{ctx, apiv1.Options{}}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Force an error in the session loading
|
|
if tt.wantErr {
|
|
forceError(t)
|
|
}
|
|
|
|
got, err := New(tt.args.ctx, tt.args.opts)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err != nil {
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
|
}
|
|
} else {
|
|
if got.session == nil || got.service == nil {
|
|
t.Errorf("New() = %#v, want %#v", got, tt.want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKMS_GetPublicKey(t *testing.T) {
|
|
okClient := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type fields struct {
|
|
session *session.Session
|
|
service KeyManagementClient
|
|
}
|
|
type args struct {
|
|
req *apiv1.GetPublicKeyRequest
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want crypto.PublicKey
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
|
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
}}, key, false},
|
|
{"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
|
|
{"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{
|
|
Name: "awskms:key-id=",
|
|
}}, nil, true},
|
|
{"fail getPublicKey", fields{nil, &MockClient{
|
|
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
}}, args{&apiv1.GetPublicKeyRequest{
|
|
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
}}, nil, true},
|
|
{"fail not der", fields{nil, &MockClient{
|
|
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
return &kms.GetPublicKeyOutput{
|
|
KeyId: input.KeyId,
|
|
PublicKey: []byte(publicKey),
|
|
}, nil
|
|
},
|
|
}}, args{&apiv1.GetPublicKeyRequest{
|
|
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
}}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
k := &KMS{
|
|
session: tt.fields.session,
|
|
service: tt.fields.service,
|
|
}
|
|
got, err := k.GetPublicKey(tt.args.req)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKMS_CreateKey(t *testing.T) {
|
|
okClient := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type fields struct {
|
|
session *session.Session
|
|
service KeyManagementClient
|
|
}
|
|
type args struct {
|
|
req *apiv1.CreateKeyRequest
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want *apiv1.CreateKeyResponse
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
}}, &apiv1.CreateKeyResponse{
|
|
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
PublicKey: key,
|
|
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
},
|
|
}, false},
|
|
{"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
Bits: 2048,
|
|
}}, &apiv1.CreateKeyResponse{
|
|
Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
PublicKey: key,
|
|
CreateSignerRequest: apiv1.CreateSignerRequest{
|
|
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
},
|
|
}, false},
|
|
{"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true},
|
|
{"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.PureEd25519,
|
|
}}, nil, true},
|
|
{"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.SHA256WithRSA,
|
|
Bits: 1234,
|
|
}}, nil, true},
|
|
{"fail createKey", fields{nil, &MockClient{
|
|
createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
createAliasWithContext: okClient.createAliasWithContext,
|
|
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
|
}}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
}}, nil, true},
|
|
{"fail createAlias", fields{nil, &MockClient{
|
|
createKeyWithContext: okClient.createKeyWithContext,
|
|
createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
getPublicKeyWithContext: okClient.getPublicKeyWithContext,
|
|
}}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
}}, nil, true},
|
|
{"fail getPublicKey", fields{nil, &MockClient{
|
|
createKeyWithContext: okClient.createKeyWithContext,
|
|
createAliasWithContext: okClient.createAliasWithContext,
|
|
getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
},
|
|
}}, args{&apiv1.CreateKeyRequest{
|
|
Name: "root",
|
|
SignatureAlgorithm: apiv1.ECDSAWithSHA256,
|
|
}}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
k := &KMS{
|
|
session: tt.fields.session,
|
|
service: tt.fields.service,
|
|
}
|
|
got, err := k.CreateKey(tt.args.req)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKMS_CreateSigner(t *testing.T) {
|
|
client := getOKClient()
|
|
key, err := pemutil.ParseKey([]byte(publicKey))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
type fields struct {
|
|
session *session.Session
|
|
service KeyManagementClient
|
|
}
|
|
type args struct {
|
|
req *apiv1.CreateSignerRequest
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want crypto.Signer
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{
|
|
SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
}}, &Signer{
|
|
service: client,
|
|
keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936",
|
|
publicKey: key,
|
|
}, false},
|
|
{"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
{"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
k := &KMS{
|
|
session: tt.fields.session,
|
|
service: tt.fields.service,
|
|
}
|
|
got, err := k.CreateSigner(tt.args.req)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestKMS_Close(t *testing.T) {
|
|
type fields struct {
|
|
session *session.Session
|
|
service KeyManagementClient
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{nil, getOKClient()}, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
k := &KMS{
|
|
session: tt.fields.session,
|
|
service: tt.fields.service,
|
|
}
|
|
if err := k.Close(); (err != nil) != tt.wantErr {
|
|
t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_parseKeyID(t *testing.T) {
|
|
type args struct {
|
|
name string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
{"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
{"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false},
|
|
{"fail parse", args{"awskms:key-id=%ZZ"}, "", true},
|
|
{"fail empty key", args{"awskms:key-id="}, "", true},
|
|
{"fail missing", args{"awskms:foo=bar"}, "", true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := parseKeyID(tt.args.name)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if got != tt.want {
|
|
t.Errorf("parseKeyID() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|