Add some extra coverage.
This commit is contained in:
parent
52a18e0c2d
commit
2026787ce4
2 changed files with 64 additions and 11 deletions
|
@ -40,7 +40,6 @@ var now = func() time.Time {
|
|||
type keyType struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
KeySize int
|
||||
}
|
||||
|
||||
func (k keyType) KeyType(pl apiv1.ProtectionLevel) keyvault.JSONWebKeyType {
|
||||
|
|
|
@ -42,17 +42,19 @@ func mockClient(t *testing.T) *mock.KeyVaultClient {
|
|||
return mock.NewKeyVaultClient(ctrl)
|
||||
}
|
||||
|
||||
func mockCreateClient(t *testing.T, ctrl *gomock.Controller) {
|
||||
func mockCreateClient(t *testing.T) *mock.KeyVaultClient {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock.NewKeyVaultClient(ctrl)
|
||||
old := createClient
|
||||
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return mock.NewKeyVaultClient(ctrl), nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
createClient = old
|
||||
ctrl.Finish()
|
||||
})
|
||||
return client
|
||||
}
|
||||
|
||||
func createJWK(t *testing.T, pub crypto.PublicKey) *keyvault.JSONWebKey {
|
||||
|
@ -78,8 +80,11 @@ func Test_now(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
mockCreateClient(t, ctrl)
|
||||
client := mockClient(t)
|
||||
old := createClient
|
||||
t.Cleanup(func() {
|
||||
createClient = old
|
||||
})
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
@ -87,16 +92,27 @@ func TestNew(t *testing.T) {
|
|||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func()
|
||||
args args
|
||||
want *KeyVault
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{context.Background(), apiv1.Options{}}, &KeyVault{
|
||||
baseClient: mock.NewKeyVaultClient(ctrl),
|
||||
{"ok", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, &KeyVault{
|
||||
baseClient: client,
|
||||
}, false},
|
||||
{"fail", func() {
|
||||
createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) {
|
||||
return nil, errTest
|
||||
}
|
||||
}, args{context.Background(), apiv1.Options{}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.setup()
|
||||
got, err := New(tt.args.ctx, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@ -148,6 +164,9 @@ func TestKeyVault_GetPublicKey(t *testing.T) {
|
|||
{"fail GetKey", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=my-vault;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
{"fail empty", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "",
|
||||
}}, nil, true},
|
||||
{"fail vault", fields{client}, args{&apiv1.GetPublicKeyRequest{
|
||||
Name: "azurekms:vault=;name=not-found?version=my-version",
|
||||
}}, nil, true},
|
||||
|
@ -490,3 +509,38 @@ func TestKeyVault_Close(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyType_KeyType(t *testing.T) {
|
||||
type fields struct {
|
||||
Kty keyvault.JSONWebKeyType
|
||||
Curve keyvault.JSONWebKeyCurveName
|
||||
}
|
||||
type args struct {
|
||||
pl apiv1.ProtectionLevel
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want keyvault.JSONWebKeyType
|
||||
}{
|
||||
{"ec", fields{keyvault.EC, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.EC},
|
||||
{"ec software", fields{keyvault.EC, keyvault.P384}, args{apiv1.Software}, keyvault.EC},
|
||||
{"ec hsm", fields{keyvault.EC, keyvault.P521}, args{apiv1.HSM}, keyvault.ECHSM},
|
||||
{"rsa", fields{keyvault.RSA, keyvault.P256}, args{apiv1.UnspecifiedProtectionLevel}, keyvault.RSA},
|
||||
{"rsa software", fields{keyvault.RSA, ""}, args{apiv1.Software}, keyvault.RSA},
|
||||
{"rsa hsm", fields{keyvault.RSA, ""}, args{apiv1.HSM}, keyvault.RSAHSM},
|
||||
{"empty", fields{"FOO", ""}, args{apiv1.UnspecifiedProtectionLevel}, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := keyType{
|
||||
Kty: tt.fields.Kty,
|
||||
Curve: tt.fields.Curve,
|
||||
}
|
||||
if got := k.KeyType(tt.args.pl); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("keyType.KeyType() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue