diff --git a/kms/azurekms/signer.go b/kms/azurekms/signer.go index e3aca5fe..2fb5951a 100644 --- a/kms/azurekms/signer.go +++ b/kms/azurekms/signer.go @@ -7,8 +7,10 @@ import ( "encoding/base64" "io" "math/big" + "time" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest/azure" "github.com/pkg/errors" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" @@ -69,15 +71,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([] return nil, err } - ctx, cancel := defaultContext() - defer cancel() - b64 := base64.RawURLEncoding.EncodeToString(digest) - resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{ - Algorithm: alg, - Value: &b64, - }) + // Sign with retry if the key is not ready + resp, err := s.signWithRetry(alg, b64, 3) if err != nil { return nil, errors.Wrap(err, "keyVault Sign failed") } @@ -111,6 +108,31 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([] return b.Bytes() } +func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttemps int) (keyvault.KeyOperationResult, error) { +retry: + ctx, cancel := defaultContext() + defer cancel() + + resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{ + Algorithm: alg, + Value: &b64, + }) + if err != nil && retryAttemps > 0 { + var requestError *azure.RequestError + if errors.As(err, &requestError) { + if se := requestError.ServiceError; se != nil && se.InnerError != nil { + code, ok := se.InnerError["code"].(string) + if ok && code == "KeyNotYetValid" { + time.Sleep(time.Second / time.Duration(retryAttemps)) + retryAttemps-- + goto retry + } + } + } + } + return resp, err +} + func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) { switch key.(type) { case *rsa.PublicKey: diff --git a/kms/azurekms/signer_test.go b/kms/azurekms/signer_test.go index 381c3577..bd072b25 100644 --- a/kms/azurekms/signer_test.go +++ b/kms/azurekms/signer_test.go @@ -11,6 +11,8 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" + "github.com/Azure/go-autorest/autorest" + "github.com/Azure/go-autorest/autorest/azure" "github.com/golang/mock/gomock" "github.com/smallstep/certificates/kms/apiv1" "go.step.sm/crypto/keyutil" @@ -350,3 +352,142 @@ func TestSigner_Sign(t *testing.T) { }) } } + +func TestSigner_Sign_signWithRetry(t *testing.T) { + sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) { + key, err := keyutil.GenerateSigner(kty, crv, bits) + if err != nil { + t.Fatal(err) + } + h := opts.HashFunc().New() + h.Write([]byte("random-data")) + sum := h.Sum(nil) + + var sig, resultSig []byte + if priv, ok := key.(*ecdsa.PrivateKey); ok { + r, s, err := ecdsa.Sign(rand.Reader, priv, sum) + if err != nil { + t.Fatal(err) + } + curveBits := priv.Params().BitSize + keyBytes := curveBits / 8 + if curveBits%8 > 0 { + keyBytes++ + } + rBytes := r.Bytes() + rBytesPadded := make([]byte, keyBytes) + copy(rBytesPadded[keyBytes-len(rBytes):], rBytes) + + sBytes := s.Bytes() + sBytesPadded := make([]byte, keyBytes) + copy(sBytesPadded[keyBytes-len(sBytes):], sBytes) + // nolint:gocritic + resultSig = append(rBytesPadded, sBytesPadded...) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + sig, err = b.Bytes() + if err != nil { + t.Fatal(err) + } + } else { + sig, err = key.Sign(rand.Reader, sum, opts) + if err != nil { + t.Fatal(err) + } + resultSig = sig + } + + return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig + } + + p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256) + okResult := keyvault.KeyOperationResult{ + Result: &p256ResultSig, + } + failResult := keyvault.KeyOperationResult{} + retryError := autorest.DetailedError{ + Original: &azure.RequestError{ + ServiceError: &azure.ServiceError{ + InnerError: map[string]interface{}{ + "code": "KeyNotYetValid", + }, + }, + }, + } + + client := mockClient(t) + expects := []struct { + name string + keyVersion string + alg keyvault.JSONWebKeySignatureAlgorithm + digest []byte + result keyvault.KeyOperationResult + err error + }{ + {"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError}, + {"ok 4", "", keyvault.ES256, p256Digest, okResult, nil}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + {"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError}, + } + for _, e := range expects { + value := base64.RawURLEncoding.EncodeToString(e.digest) + client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{ + Algorithm: e.alg, + Value: &value, + }).Return(e.result, e.err) + } + + type fields struct { + client KeyVaultClient + vaultBaseURL string + name string + version string + publicKey crypto.PublicKey + } + type args struct { + rand io.Reader + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, p256Sig, false}, + {"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{ + rand.Reader, p256Digest, crypto.SHA256, + }, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + client: tt.fields.client, + vaultBaseURL: tt.fields.vaultBaseURL, + name: tt.fields.name, + version: tt.fields.version, + publicKey: tt.fields.publicKey, + } + got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) + } + }) + } +}