Add strategy to retry the sign operation if the key is not yet ready
This commit is contained in:
parent
48efd94994
commit
ead394fba7
2 changed files with 170 additions and 7 deletions
|
@ -7,8 +7,10 @@ import (
|
|||
"encoding/base64"
|
||||
"io"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/cryptobyte"
|
||||
"golang.org/x/crypto/cryptobyte/asn1"
|
||||
|
@ -69,15 +71,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
||||
|
||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
||||
Algorithm: alg,
|
||||
Value: &b64,
|
||||
})
|
||||
// Sign with retry if the key is not ready
|
||||
resp, err := s.signWithRetry(alg, b64, 3)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "keyVault Sign failed")
|
||||
}
|
||||
|
@ -111,6 +108,31 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
|
|||
return b.Bytes()
|
||||
}
|
||||
|
||||
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttemps int) (keyvault.KeyOperationResult, error) {
|
||||
retry:
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
||||
Algorithm: alg,
|
||||
Value: &b64,
|
||||
})
|
||||
if err != nil && retryAttemps > 0 {
|
||||
var requestError *azure.RequestError
|
||||
if errors.As(err, &requestError) {
|
||||
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
|
||||
code, ok := se.InnerError["code"].(string)
|
||||
if ok && code == "KeyNotYetValid" {
|
||||
time.Sleep(time.Second / time.Duration(retryAttemps))
|
||||
retryAttemps--
|
||||
goto retry
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
||||
switch key.(type) {
|
||||
case *rsa.PublicKey:
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/smallstep/certificates/kms/apiv1"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
@ -350,3 +352,142 @@ func TestSigner_Sign(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSigner_Sign_signWithRetry(t *testing.T) {
|
||||
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
||||
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := opts.HashFunc().New()
|
||||
h.Write([]byte("random-data"))
|
||||
sum := h.Sum(nil)
|
||||
|
||||
var sig, resultSig []byte
|
||||
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
||||
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
curveBits := priv.Params().BitSize
|
||||
keyBytes := curveBits / 8
|
||||
if curveBits%8 > 0 {
|
||||
keyBytes++
|
||||
}
|
||||
rBytes := r.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := s.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
// nolint:gocritic
|
||||
resultSig = append(rBytesPadded, sBytesPadded...)
|
||||
|
||||
var b cryptobyte.Builder
|
||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||
b.AddASN1BigInt(r)
|
||||
b.AddASN1BigInt(s)
|
||||
})
|
||||
sig, err = b.Bytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
sig, err = key.Sign(rand.Reader, sum, opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resultSig = sig
|
||||
}
|
||||
|
||||
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
||||
}
|
||||
|
||||
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
||||
okResult := keyvault.KeyOperationResult{
|
||||
Result: &p256ResultSig,
|
||||
}
|
||||
failResult := keyvault.KeyOperationResult{}
|
||||
retryError := autorest.DetailedError{
|
||||
Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
InnerError: map[string]interface{}{
|
||||
"code": "KeyNotYetValid",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := mockClient(t)
|
||||
expects := []struct {
|
||||
name string
|
||||
keyVersion string
|
||||
alg keyvault.JSONWebKeySignatureAlgorithm
|
||||
digest []byte
|
||||
result keyvault.KeyOperationResult
|
||||
err error
|
||||
}{
|
||||
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||
}
|
||||
for _, e := range expects {
|
||||
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
||||
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
||||
Algorithm: e.alg,
|
||||
Value: &value,
|
||||
}).Return(e.result, e.err)
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
client KeyVaultClient
|
||||
vaultBaseURL string
|
||||
name string
|
||||
version string
|
||||
publicKey crypto.PublicKey
|
||||
}
|
||||
type args struct {
|
||||
rand io.Reader
|
||||
digest []byte
|
||||
opts crypto.SignerOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, p256Sig, false},
|
||||
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
|
||||
rand.Reader, p256Digest, crypto.SHA256,
|
||||
}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &Signer{
|
||||
client: tt.fields.client,
|
||||
vaultBaseURL: tt.fields.vaultBaseURL,
|
||||
name: tt.fields.name,
|
||||
version: tt.fields.version,
|
||||
publicKey: tt.fields.publicKey,
|
||||
}
|
||||
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue