From abdb56065d646dc77f5b45f406a309dc83f3c22a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 7 Oct 2021 16:18:36 -0700 Subject: [PATCH] Allow o specify an hsm using the uri. --- kms/azurekms/key_vault.go | 17 +++++++++++++---- kms/azurekms/key_vault_test.go | 23 +++++++++++++++++++++++ kms/azurekms/signer.go | 2 +- kms/azurekms/signer_test.go | 9 +++++++++ kms/azurekms/utils.go | 9 +++++++-- kms/azurekms/utils_test.go | 26 ++++++++++++++++---------- 6 files changed, 69 insertions(+), 17 deletions(-) diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go index c5dc56bf..4a927d4f 100644 --- a/kms/azurekms/key_vault.go +++ b/kms/azurekms/key_vault.go @@ -113,11 +113,13 @@ type KeyVaultClient interface { // // - azurekms:name=key-name;vault=vault-name // - azurekms:name=key-name;vault=vault-name?version=key-version +// - azurekms:name=key-name;vault=vault-name?hsm=true // // The scheme is "azurekms"; "name" is the key name; "vault" is the key vault // name where the key is located; "version" is an optional parameter that // defines the version of they key, if version is not given, the latest one will -// be used. +// be used; "hsm" defines if an HSM want to be used for this key, this is +// specially useful when this is used from `step`. // // TODO(mariano): The implementation is using /services/keyvault/v7.1/keyvault // package, at some point Azure might create a keyvault client with all the @@ -165,7 +167,7 @@ func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKe return nil, errors.New("getPublicKeyRequest 'name' cannot be empty") } - vault, name, version, err := parseKeyName(req.Name) + vault, name, version, _, err := parseKeyName(req.Name) if err != nil { return nil, err } @@ -187,11 +189,18 @@ func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo return nil, errors.New("createKeyRequest 'name' cannot be empty") } - vault, name, _, err := parseKeyName(req.Name) + vault, name, _, hsm, err := parseKeyName(req.Name) if err != nil { return nil, err } + // Override protection level to HSM only if it's not specified, and is given + // in the uri. + protectionLevel := req.ProtectionLevel + if protectionLevel == apiv1.UnspecifiedProtectionLevel && hsm { + protectionLevel = apiv1.HSM + } + kt, ok := signatureAlgorithmMapping[req.SignatureAlgorithm] if !ok { return nil, errors.Errorf("keyVault does not support signature algorithm '%s'", req.SignatureAlgorithm) @@ -216,7 +225,7 @@ func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo defer cancel() resp, err := k.baseClient.CreateKey(ctx, vaultBaseURL(vault), name, keyvault.KeyCreateParameters{ - Kty: kt.KeyType(req.ProtectionLevel), + Kty: kt.KeyType(protectionLevel), KeySize: keySize, Curve: kt.Curve, KeyOps: &[]keyvault.JSONWebKeyOperation{ diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go index f9446b0f..0f6d7e0e 100644 --- a/kms/azurekms/key_vault_test.go +++ b/kms/azurekms/key_vault_test.go @@ -202,11 +202,13 @@ func TestKeyVault_CreateKey(t *testing.T) { }{ {"P-256", keyvault.EC, nil, keyvault.P256, ecJWK}, {"P-256 HSM", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, + {"P-256 HSM (uri)", keyvault.ECHSM, nil, keyvault.P256, ecJWK}, {"P-256 Default", keyvault.EC, nil, keyvault.P256, ecJWK}, {"P-384", keyvault.EC, nil, keyvault.P384, ecJWK}, {"P-521", keyvault.EC, nil, keyvault.P521, ecJWK}, {"RSA 0", keyvault.RSA, &value3072, "", rsaJWK}, {"RSA 0 HSM", keyvault.RSAHSM, &value3072, "", rsaJWK}, + {"RSA 0 HSM (uri)", keyvault.RSAHSM, &value3072, "", rsaJWK}, {"RSA 2048", keyvault.RSA, &value2048, "", rsaJWK}, {"RSA 3072", keyvault.RSA, &value3072, "", rsaJWK}, {"RSA 4096", keyvault.RSA, &value4096, "", rsaJWK}, @@ -269,6 +271,16 @@ func TestKeyVault_CreateKey(t *testing.T) { SigningKey: "azurekms:name=my-key;vault=my-vault", }, }, false}, + {"ok P-256 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key?hsm=true", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: ecPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, {"ok P-256 Default", fields{client}, args{&apiv1.CreateKeyRequest{ Name: "azurekms:vault=my-vault;name=my-key", }}, &apiv1.CreateKeyResponse{ @@ -322,6 +334,17 @@ func TestKeyVault_CreateKey(t *testing.T) { SigningKey: "azurekms:name=my-key;vault=my-vault", }, }, false}, + {"ok RSA 0 HSM (uri)", fields{client}, args{&apiv1.CreateKeyRequest{ + Name: "azurekms:vault=my-vault;name=my-key;hsm=true", + Bits: 0, + SignatureAlgorithm: apiv1.SHA256WithRSAPSS, + }}, &apiv1.CreateKeyResponse{ + Name: "azurekms:name=my-key;vault=my-vault", + PublicKey: rsaPub, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "azurekms:name=my-key;vault=my-vault", + }, + }, false}, {"ok RSA 2048", fields{client}, args{&apiv1.CreateKeyRequest{ Name: "azurekms:vault=my-vault;name=my-key", Bits: 2048, diff --git a/kms/azurekms/signer.go b/kms/azurekms/signer.go index cb844bdf..cf0197fb 100644 --- a/kms/azurekms/signer.go +++ b/kms/azurekms/signer.go @@ -25,7 +25,7 @@ type Signer struct { // NewSigner creates a new signer using a key in the AWS KMS. func NewSigner(client KeyVaultClient, signingKey string) (crypto.Signer, error) { - vault, name, version, err := parseKeyName(signingKey) + vault, name, version, _, err := parseKeyName(signingKey) if err != nil { return nil, err } diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go index 389f65b3..90740b9f 100644 --- a/kms/azurekms/signer_test.go +++ b/kms/azurekms/signer_test.go @@ -221,6 +221,12 @@ func TestSigner_Sign(t *testing.T) { {"fail sign length", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ Result: &rsaSHA256ResultSig, }, nil}, + {"fail base64", "", keyvault.ES256, p256Digest, keyvault.KeyOperationResult{ + Result: func() *string { + v := "😎" + return &v + }(), + }, nil}, } for _, e := range expects { value := base64.RawURLEncoding.EncodeToString(e.digest) @@ -291,6 +297,9 @@ func TestSigner_Sign(t *testing.T) { {"fail sign length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ rand.Reader, p256Digest[:], crypto.SHA256, }, nil, true}, + {"fail base64", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest[:], crypto.SHA256, + }, nil, true}, {"fail RSA-PSS salt length", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", rsaPSSSHA256}, args{ rand.Reader, rsaPSSSHA256Digest[:], &rsa.PSSOptions{ SaltLength: 64, diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go index 6b6d1511..52bed868 100644 --- a/kms/azurekms/utils.go +++ b/kms/azurekms/utils.go @@ -42,11 +42,15 @@ func getKeyName(vault, name string, bundle keyvault.KeyBundle) string { // parseKeyName returns the key vault, name and version from URIs like: // // - azurekms:vault=key-vault;name=key-name -// - azurekms:vault=key-vault;name=key-name;id=key-id +// - azurekms:vault=key-vault;name=key-name?version=key-id +// - azurekms:vault=key-vault;name=key-name?version=key-id&hsm=true // // The key-id defines the version of the key, if it is not passed the latest // version will be used. -func parseKeyName(rawURI string) (vault, name, version string, err error) { +// +// HSM can also be passed to define the protection level if this is not given in +// CreateQuery. +func parseKeyName(rawURI string) (vault, name, version string, hsm bool, err error) { var u *uri.URI u, err = uri.ParseWithScheme("azurekms", rawURI) @@ -63,6 +67,7 @@ func parseKeyName(rawURI string) (vault, name, version string, err error) { return } version = u.Get("version") + hsm = u.GetBool("hsm") return } diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go index 915ee74d..03d3f6e2 100644 --- a/kms/azurekms/utils_test.go +++ b/kms/azurekms/utils_test.go @@ -51,21 +51,24 @@ func Test_parseKeyName(t *testing.T) { wantVault string wantName string wantVersion string + wantHsm bool wantErr bool }{ - {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version"}, "my-vault", "my-key", "my-version", false}, - {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version"}, "my-vault", "my-key", "my-version", false}, - {"ok no version", args{"azurekms:name=my-key;vault=my-vault"}, "my-vault", "my-key", "", false}, - {"fail scheme", args{"azure:name=my-key;vault=my-vault"}, "", "", "", true}, - {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, "", "", "", true}, - {"fail no name", args{"azurekms:vault=my-vault"}, "", "", "", true}, - {"fail empty name", args{"azurekms:name=;vault=my-vault"}, "", "", "", true}, - {"fail no vault", args{"azurekms:name=my-key"}, "", "", "", true}, - {"fail empty vault", args{"azurekms:name=my-key;vault="}, "", "", "", true}, + {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version"}, "my-vault", "my-key", "my-version", false, false}, + {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version"}, "my-vault", "my-key", "my-version", false, false}, + {"ok no version", args{"azurekms:name=my-key;vault=my-vault"}, "my-vault", "my-key", "", false, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, "my-vault", "my-key", "", true, false}, + {"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false"}, "my-vault", "my-key", "", false, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault"}, "", "", "", false, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, "", "", "", false, true}, + {"fail no name", args{"azurekms:vault=my-vault"}, "", "", "", false, true}, + {"fail empty name", args{"azurekms:name=;vault=my-vault"}, "", "", "", false, true}, + {"fail no vault", args{"azurekms:name=my-key"}, "", "", "", false, true}, + {"fail empty vault", args{"azurekms:name=my-key;vault="}, "", "", "", false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotVault, gotName, gotVersion, err := parseKeyName(tt.args.rawURI) + gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI) if (err != nil) != tt.wantErr { t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr) return @@ -79,6 +82,9 @@ func Test_parseKeyName(t *testing.T) { if gotVersion != tt.wantVersion { t.Errorf("parseKeyName() gotVersion = %v, want %v", gotVersion, tt.wantVersion) } + if gotHsm != tt.wantHsm { + t.Errorf("parseKeyName() gotHsm = %v, want %v", gotHsm, tt.wantHsm) + } }) } }