From 2026787ce4a8360d8b3e19a08edfd756e90d2d10 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 7 Oct 2021 15:01:11 -0700 Subject: [PATCH] Add some extra coverage. --- kms/azurekms/key_vault.go | 5 +-- kms/azurekms/key_vault_test.go | 70 ++++++++++++++++++++++++++++++---- 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go index f7b00ec7..c5dc56bf 100644 --- a/kms/azurekms/key_vault.go +++ b/kms/azurekms/key_vault.go @@ -38,9 +38,8 @@ var now = func() time.Time { } type keyType struct { - Kty keyvault.JSONWebKeyType - Curve keyvault.JSONWebKeyCurveName - KeySize int + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName } func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType { diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go index 1f09b2d5..4f98d274 100644 --- a/kms/azurekms/key_vault_test.go +++ b/kms/azurekms/key_vault_test.go @@ -42,17 +42,19 @@ func mockClient(t *testing.T) *mock.KeyVaultClient { return mock.NewKeyVaultClient(ctrl) } -func mockCreateClient(t *testing.T, ctrl *gomock.Controller) { +func mockCreateClient(t *testing.T) *mock.KeyVaultClient { t.Helper() + ctrl := gomock.NewController(t) + client := mock.NewKeyVaultClient(ctrl) old := createClient - createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { - return mock.NewKeyVaultClient(ctrl), nil + return client, nil } - t.Cleanup(func() { createClient = old + ctrl.Finish() }) + return client } func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey { @@ -78,8 +80,11 @@ func Test_now(t *testing.T) { } func TestNew(t *testing.T) { - ctrl := gomock.NewController(t) - mockCreateClient(t, ctrl) + client := mockClient(t) + old := createClient + t.Cleanup(func() { + createClient = old + }) type args struct { ctx context.Context @@ -87,16 +92,27 @@ func TestNew(t *testing.T) { } tests := []struct { name string + setup func() args args want *KeyVault wantErr bool }{ - {"ok", args{context.Background(), apiv1.Options{}}, &KeyVault{ - baseClient: mock.NewKeyVaultClient(ctrl), + {"ok", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{}}, &KeyVault{ + baseClient: client, }, false}, + {"fail", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return nil, errTest + } + }, args{context.Background(), apiv1.Options{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tt.setup() got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) @@ -148,6 +164,9 @@ func TestKeyVault_GetPublicKey(t *testing.T) { {"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{ Name: "azurekms:vault=my-vault;name=not-found?version=my-version", }}, nil, true}, + {"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{ + Name: "", + }}, nil, true}, {"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{ Name: "azurekms:vault=;name=not-found?version=my-version", }}, nil, true}, @@ -490,3 +509,38 @@ func TestKeyVault_Close(t *testing.T) { }) } } + +func Test_keyType_KeyType(t *testing.T) { + type fields struct { + Kty keyvault.JSONWebKeyType + Curve keyvault.JSONWebKeyCurveName + } + type args struct { + pl apiv1.ProtectionLevel + } + tests := []struct { + name string + fields fields + args args + want keyvault.JSONWebKeyType + }{ + {"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC}, + {"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC}, + {"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM}, + {"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA}, + {"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA}, + {"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM}, + {"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := keyType{ + Kty: tt.fields.Kty, + Curve: tt.fields.Curve, + } + if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) { + t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want) + } + }) + } +}