Add strategy to retry the sign operation if the key is not yet ready

This commit is contained in:
Mariano Cano 2021-10-20 18:09:50 -07:00
parent 48efd94994
commit ead394fba7
2 changed files with 170 additions and 7 deletions

View file

@ -7,8 +7,10 @@ import (
"encoding/base64" "encoding/base64"
"io" "io"
"math/big" "math/big"
"time"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1" "golang.org/x/crypto/cryptobyte/asn1"
@ -69,15 +71,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
return nil, err return nil, err
} }
ctx, cancel := defaultContext()
defer cancel()
b64 := base64.RawURLEncoding.EncodeToString(digest) b64 := base64.RawURLEncoding.EncodeToString(digest)
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{ // Sign with retry if the key is not ready
Algorithm: alg, resp, err := s.signWithRetry(alg, b64, 3)
Value: &b64,
})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "keyVault Sign failed") return nil, errors.Wrap(err, "keyVault Sign failed")
} }
@ -111,6 +108,31 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
return b.Bytes() return b.Bytes()
} }
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttemps int) (keyvault.KeyOperationResult, error) {
retry:
ctx, cancel := defaultContext()
defer cancel()
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
Algorithm: alg,
Value: &b64,
})
if err != nil && retryAttemps > 0 {
var requestError *azure.RequestError
if errors.As(err, &requestError) {
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
code, ok := se.InnerError["code"].(string)
if ok && code == "KeyNotYetValid" {
time.Sleep(time.Second / time.Duration(retryAttemps))
retryAttemps--
goto retry
}
}
}
}
return resp, err
}
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) { func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
switch key.(type) { switch key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:

View file

@ -11,6 +11,8 @@ import (
"testing" "testing"
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault" "github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/smallstep/certificates/kms/apiv1" "github.com/smallstep/certificates/kms/apiv1"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -350,3 +352,142 @@ func TestSigner_Sign(t *testing.T) {
}) })
} }
} }
func TestSigner_Sign_signWithRetry(t *testing.T) {
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
key, err := keyutil.GenerateSigner(kty, crv, bits)
if err != nil {
t.Fatal(err)
}
h := opts.HashFunc().New()
h.Write([]byte("random-data"))
sum := h.Sum(nil)
var sig, resultSig []byte
if priv, ok := key.(*ecdsa.PrivateKey); ok {
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
if err != nil {
t.Fatal(err)
}
curveBits := priv.Params().BitSize
keyBytes := curveBits / 8
if curveBits%8 > 0 {
keyBytes++
}
rBytes := r.Bytes()
rBytesPadded := make([]byte, keyBytes)
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
sBytes := s.Bytes()
sBytesPadded := make([]byte, keyBytes)
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
// nolint:gocritic
resultSig = append(rBytesPadded, sBytesPadded...)
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1BigInt(r)
b.AddASN1BigInt(s)
})
sig, err = b.Bytes()
if err != nil {
t.Fatal(err)
}
} else {
sig, err = key.Sign(rand.Reader, sum, opts)
if err != nil {
t.Fatal(err)
}
resultSig = sig
}
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
}
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
okResult := keyvault.KeyOperationResult{
Result: &p256ResultSig,
}
failResult := keyvault.KeyOperationResult{}
retryError := autorest.DetailedError{
Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
InnerError: map[string]interface{}{
"code": "KeyNotYetValid",
},
},
},
}
client := mockClient(t)
expects := []struct {
name string
keyVersion string
alg keyvault.JSONWebKeySignatureAlgorithm
digest []byte
result keyvault.KeyOperationResult
err error
}{
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
}
for _, e := range expects {
value := base64.RawURLEncoding.EncodeToString(e.digest)
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
Algorithm: e.alg,
Value: &value,
}).Return(e.result, e.err)
}
type fields struct {
client KeyVaultClient
vaultBaseURL string
name string
version string
publicKey crypto.PublicKey
}
type args struct {
rand io.Reader
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
}{
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, p256Sig, false},
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
rand.Reader, p256Digest, crypto.SHA256,
}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &Signer{
client: tt.fields.client,
vaultBaseURL: tt.fields.vaultBaseURL,
name: tt.fields.name,
version: tt.fields.version,
publicKey: tt.fields.publicKey,
}
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
if (err != nil) != tt.wantErr {
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
}
})
}
}