diff --git a/go.mod b/go.mod index f3d2e358..791e0927 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.15 require ( cloud.google.com/go v0.83.0 github.com/Azure/azure-sdk-for-go v58.0.0+incompatible + github.com/Azure/go-autorest/autorest v0.11.17 github.com/Azure/go-autorest/autorest/azure/auth v0.5.8 github.com/Azure/go-autorest/autorest/date v0.3.0 github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect diff --git a/kms/azurekms/key_vault.go b/kms/azurekms/key_vault.go index 93be8241..34d9c3f1 100644 --- a/kms/azurekms/key_vault.go +++ b/kms/azurekms/key_vault.go @@ -7,10 +7,12 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/Azure/go-autorest/autorest/date" "github.com/pkg/errors" "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/certificates/kms/uri" ) func init() { @@ -126,9 +128,60 @@ type KeyVaultClient interface { // functionality in /sdk/keyvault, we should migrate to that once available. type KeyVault struct { baseClient KeyVaultClient + defaults DefaultOptions +} + +// DefaultOptions are custom options that can be passed as defaults using the +// URI in apiv1.Options. +type DefaultOptions struct { + Vault string + ProtectionLevel apiv1.ProtectionLevel } var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + baseClient := keyvault.New() + + // With an URI, try to log in only using client credentials in the URI. + // Client credentials requires: + // - client-id + // - client-secret + // - tenant-id + // And optionally the aad-endpoint to support custom clouds: + // - aad-endpoint (defaults to https://login.microsoftonline.com/) + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + + // Required options + clientID := u.Get("client-id") + clientSecret := u.Get("client-secret") + tenantID := u.Get("tenant-id") + // optional + aadEndpoint := u.Get("aad-endpoint") + + if clientID != "" && clientSecret != "" && tenantID != "" { + s := auth.EnvironmentSettings{ + Values: map[string]string{ + auth.ClientID: clientID, + auth.ClientSecret: clientSecret, + auth.TenantID: tenantID, + auth.Resource: vaultResource, + }, + Environment: azure.PublicCloud, + } + if aadEndpoint != "" { + s.Environment.ActiveDirectoryEndpoint = aadEndpoint + } + baseClient.Authorizer, err = s.GetAuthorizer() + if err != nil { + return nil, err + } + return baseClient, nil + } + } + // Attempt to authorize with the following methods: // 1. Environment variables. // - Client credentials @@ -143,8 +196,6 @@ var createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient return nil, errors.Wrap(err, "error getting authorizer for key vault") } } - - baseClient := keyvault.New() baseClient.Authorizer = authorizer return &baseClient, nil } @@ -155,8 +206,24 @@ func New(ctx context.Context, opts apiv1.Options) (*KeyVault, error) { if err != nil { return nil, err } + + // step and step-ca do not need and URI, but having a default vault and + // protection level is useful if this package is used as an api + var defaults DefaultOptions + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, err + } + defaults.Vault = u.Get("vault") + if u.GetBool("hsm") { + defaults.ProtectionLevel = apiv1.HSM + } + } + return &KeyVault{ baseClient: baseClient, + defaults: defaults, }, nil } @@ -166,7 +233,7 @@ func (k *KeyVault) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKe return nil, errors.New("getPublicKeyRequest 'name' cannot be empty") } - vault, name, version, _, err := parseKeyName(req.Name) + vault, name, version, _, err := parseKeyName(req.Name, k.defaults) if err != nil { return nil, err } @@ -188,7 +255,7 @@ func (k *KeyVault) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo return nil, errors.New("createKeyRequest 'name' cannot be empty") } - vault, name, _, hsm, err := parseKeyName(req.Name) + vault, name, _, hsm, err := parseKeyName(req.Name, k.defaults) if err != nil { return nil, err } @@ -260,7 +327,7 @@ func (k *KeyVault) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, if req.SigningKey == "" { return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") } - return NewSigner(k.baseClient, req.SigningKey) + return NewSigner(k.baseClient, req.SigningKey, k.defaults) } // Close closes the client connection to the Azure Key Vault. This is a noop. @@ -270,6 +337,6 @@ func (k *KeyVault) Close() error { // ValidateName validates that the given string is a valid URI. func (k *KeyVault) ValidateName(s string) error { - _, _, _, _, err := parseKeyName(s) + _, _, _, _, err := parseKeyName(s, k.defaults) return err } diff --git a/kms/azurekms/key_vault_test.go b/kms/azurekms/key_vault_test.go index 1f26e1ef..8f968189 100644 --- a/kms/azurekms/key_vault_test.go +++ b/kms/azurekms/key_vault_test.go @@ -89,11 +89,44 @@ func TestNew(t *testing.T) { }, args{context.Background(), apiv1.Options{}}, &KeyVault{ baseClient: client, }, false}, + {"ok with vault", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.UnspecifiedProtectionLevel, + }, + }, false}, + {"ok with vault + hsm", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "azurekms:vault=my-vault;hsm=true", + }}, &KeyVault{ + baseClient: client, + defaults: DefaultOptions{ + Vault: "my-vault", + ProtectionLevel: apiv1.HSM, + }, + }, false}, {"fail", func() { createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { return nil, errTest } }, args{context.Background(), apiv1.Options{}}, nil, true}, + {"fail uri", func() { + createClient = func(ctx context.Context, opts apiv1.Options) (KeyVaultClient, error) { + return client, nil + } + }, args{context.Background(), apiv1.Options{ + URI: "kms:vault=my-vault;hsm=true", + }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -110,6 +143,45 @@ func TestNew(t *testing.T) { } } +func TestKeyVault_createClient(t *testing.T) { + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + args args + skip bool + wantErr bool + }{ + {"ok", args{context.Background(), apiv1.Options{}}, true, false}, + {"ok with uri", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, false}, + {"ok with uri+aad", args{context.Background(), apiv1.Options{ + URI: "azurekms:client-id=id;client-secret=secret;tenant-id=id;aad-enpoint=https%3A%2F%2Flogin.microsoftonline.us%2F", + }}, false, false}, + {"ok with uri no config", args{context.Background(), apiv1.Options{ + URI: "azurekms:", + }}, true, false}, + {"fail uri", args{context.Background(), apiv1.Options{ + URI: "kms:client-id=id;client-secret=secret;tenant-id=id", + }}, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skip { + t.SkipNow() + } + _, err := createClient(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + func TestKeyVault_GetPublicKey(t *testing.T) { key, err := keyutil.GenerateDefaultSigner() if err != nil { diff --git a/kms/azurekms/signer.go b/kms/azurekms/signer.go index 405c625a..e3aca5fe 100644 --- a/kms/azurekms/signer.go +++ b/kms/azurekms/signer.go @@ -24,8 +24,8 @@ type Signer struct { } // NewSigner creates a new signer using a key in the AWS KMS. -func NewSigner(client KeyVaultClient, signingKey string) (crypto.Signer, error) { - vault, name, version, _, err := parseKeyName(signingKey) +func NewSigner(client KeyVaultClient, signingKey string, defaults DefaultOptions) (crypto.Signer, error) { + vault, name, version, _, err := parseKeyName(signingKey, defaults) if err != nil { return nil, err } diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go index 01921e2a..381c3577 100644 --- a/kms/azurekms/signer_test.go +++ b/kms/azurekms/signer_test.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/golang/mock/gomock" + "github.com/smallstep/certificates/kms/apiv1" "go.step.sm/crypto/keyutil" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" @@ -32,11 +33,16 @@ func TestNewSigner(t *testing.T) { client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ Key: jwk, }, nil) + client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", "my-version").Return(keyvault.KeyBundle{ + Key: jwk, + }, nil) client.EXPECT().GetKey(gomock.Any(), "https://my-vault.vault.azure.net/", "not-found", "my-version").Return(keyvault.KeyBundle{}, errTest) + var noOptions DefaultOptions type args struct { client KeyVaultClient signingKey string + defaults DefaultOptions } tests := []struct { name string @@ -44,28 +50,35 @@ func TestNewSigner(t *testing.T) { want crypto.Signer wantErr bool }{ - {"ok", args{client, "azurekms:vault=my-vault;name=my-key"}, &Signer{ + {"ok", args{client, "azurekms:vault=my-vault;name=my-key", noOptions}, &Signer{ client: client, vaultBaseURL: "https://my-vault.vault.azure.net/", name: "my-key", version: "", publicKey: pub, }, false}, - {"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version"}, &Signer{ + {"ok with version", args{client, "azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, &Signer{ client: client, vaultBaseURL: "https://my-vault.vault.azure.net/", name: "my-key", version: "my-version", publicKey: pub, }, false}, - {"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version"}, nil, true}, - {"fail vault", args{client, "azurekms:name=not-found;vault="}, nil, true}, - {"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version"}, nil, true}, - {"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version"}, nil, true}, + {"ok with options", args{client, "azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault", ProtectionLevel: apiv1.HSM}}, &Signer{ + client: client, + vaultBaseURL: "https://my-vault.vault.azure.net/", + name: "my-key", + version: "my-version", + publicKey: pub, + }, false}, + {"fail GetKey", args{client, "azurekms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail vault", args{client, "azurekms:name=not-found;vault=", noOptions}, nil, true}, + {"fail id", args{client, "azurekms:name=;vault=my-vault?version=my-version", noOptions}, nil, true}, + {"fail scheme", args{client, "kms:name=not-found;vault=my-vault?version=my-version", noOptions}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewSigner(tt.args.client, tt.args.signingKey) + got, err := NewSigner(tt.args.client, tt.args.signingKey, tt.args.defaults) if (err != nil) != tt.wantErr { t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/kms/azurekms/utils.go b/kms/azurekms/utils.go index 52bed868..d4201907 100644 --- a/kms/azurekms/utils.go +++ b/kms/azurekms/utils.go @@ -9,6 +9,7 @@ import ( "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/pkg/errors" + "github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/kms/uri" "go.step.sm/crypto/jose" ) @@ -50,10 +51,10 @@ func getKeyName(vault, name string, bundle keyvault.KeyBundle) string { // // HSM can also be passed to define the protection level if this is not given in // CreateQuery. -func parseKeyName(rawURI string) (vault, name, version string, hsm bool, err error) { +func parseKeyName(rawURI string, defaults DefaultOptions) (vault, name, version string, hsm bool, err error) { var u *uri.URI - u, err = uri.ParseWithScheme("azurekms", rawURI) + u, err = uri.ParseWithScheme(Scheme, rawURI) if err != nil { return } @@ -62,12 +63,21 @@ func parseKeyName(rawURI string) (vault, name, version string, hsm bool, err err return } if vault = u.Get("vault"); vault == "" { - err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI) - name = "" - return + if defaults.Vault == "" { + name = "" + err = errors.Errorf("key uri %s is not valid: vault is missing", rawURI) + return + } + vault = defaults.Vault } + if u.Get("hsm") == "" { + hsm = (defaults.ProtectionLevel == apiv1.HSM) + } else { + hsm = u.GetBool("hsm") + } + version = u.Get("version") - hsm = u.GetBool("hsm") + return } diff --git a/kms/azurekms/utils_test.go b/kms/azurekms/utils_test.go index 000a9d6b..cded50ea 100644 --- a/kms/azurekms/utils_test.go +++ b/kms/azurekms/utils_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/smallstep/certificates/kms/apiv1" ) func Test_getKeyName(t *testing.T) { @@ -42,8 +43,10 @@ func Test_getKeyName(t *testing.T) { } func Test_parseKeyName(t *testing.T) { + var noOptions DefaultOptions type args struct { - rawURI string + rawURI string + defaults DefaultOptions } tests := []struct { name string @@ -54,22 +57,24 @@ func Test_parseKeyName(t *testing.T) { wantHsm bool wantErr bool }{ - {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version"}, "my-vault", "my-key", "my-version", false, false}, - {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version"}, "my-vault", "my-key", "my-version", false, false}, - {"ok no version", args{"azurekms:name=my-key;vault=my-vault"}, "my-vault", "my-key", "", false, false}, - {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true"}, "my-vault", "my-key", "", true, false}, - {"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false"}, "my-vault", "my-key", "", false, false}, - {"fail scheme", args{"azure:name=my-key;vault=my-vault"}, "", "", "", false, true}, - {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault"}, "", "", "", false, true}, - {"fail no name", args{"azurekms:vault=my-vault"}, "", "", "", false, true}, - {"fail empty name", args{"azurekms:name=;vault=my-vault"}, "", "", "", false, true}, - {"fail no vault", args{"azurekms:name=my-key"}, "", "", "", false, true}, - {"fail empty vault", args{"azurekms:name=my-key;vault="}, "", "", "", false, true}, - {"fail empty", args{""}, "", "", "", false, true}, + {"ok", args{"azurekms:name=my-key;vault=my-vault?version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok opaque version", args{"azurekms:name=my-key;vault=my-vault;version=my-version", noOptions}, "my-vault", "my-key", "my-version", false, false}, + {"ok no version", args{"azurekms:name=my-key;vault=my-vault", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok hsm", args{"azurekms:name=my-key;vault=my-vault?hsm=true", noOptions}, "my-vault", "my-key", "", true, false}, + {"ok hsm false", args{"azurekms:name=my-key;vault=my-vault?hsm=false", noOptions}, "my-vault", "my-key", "", false, false}, + {"ok default vault", args{"azurekms:name=my-key?version=my-version", DefaultOptions{Vault: "my-vault"}}, "my-vault", "my-key", "my-version", false, false}, + {"ok default hsm", args{"azurekms:name=my-key;vault=my-vault?version=my-version", DefaultOptions{Vault: "other-vault", ProtectionLevel: apiv1.HSM}}, "my-vault", "my-key", "my-version", true, false}, + {"fail scheme", args{"azure:name=my-key;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail parse uri", args{"azurekms:name=%ZZ;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no name", args{"azurekms:vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail empty name", args{"azurekms:name=;vault=my-vault", noOptions}, "", "", "", false, true}, + {"fail no vault", args{"azurekms:name=my-key", noOptions}, "", "", "", false, true}, + {"fail empty vault", args{"azurekms:name=my-key;vault=", noOptions}, "", "", "", false, true}, + {"fail empty", args{"", noOptions}, "", "", "", false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI) + gotVault, gotName, gotVersion, gotHsm, err := parseKeyName(tt.args.rawURI, tt.args.defaults) if (err != nil) != tt.wantErr { t.Errorf("parseKeyName() error = %v, wantErr %v", err, tt.wantErr) return