Add strategy to retry the sign operation if the key is not yet ready
This commit is contained in:
parent
48efd94994
commit
ead394fba7
2 changed files with 170 additions and 7 deletions
|
@ -7,8 +7,10 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||||
|
"github.com/Azure/go-autorest/autorest/azure"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
@ -69,15 +71,10 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := defaultContext()
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
b64 := base64.RawURLEncoding.EncodeToString(digest)
|
||||||
|
|
||||||
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
// Sign with retry if the key is not ready
|
||||||
Algorithm: alg,
|
resp, err := s.signWithRetry(alg, b64, 3)
|
||||||
Value: &b64,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "keyVault Sign failed")
|
return nil, errors.Wrap(err, "keyVault Sign failed")
|
||||||
}
|
}
|
||||||
|
@ -111,6 +108,31 @@ func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]
|
||||||
return b.Bytes()
|
return b.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Signer) signWithRetry(alg keyvault.JSONWebKeySignatureAlgorithm, b64 string, retryAttemps int) (keyvault.KeyOperationResult, error) {
|
||||||
|
retry:
|
||||||
|
ctx, cancel := defaultContext()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := s.client.Sign(ctx, s.vaultBaseURL, s.name, s.version, keyvault.KeySignParameters{
|
||||||
|
Algorithm: alg,
|
||||||
|
Value: &b64,
|
||||||
|
})
|
||||||
|
if err != nil && retryAttemps > 0 {
|
||||||
|
var requestError *azure.RequestError
|
||||||
|
if errors.As(err, &requestError) {
|
||||||
|
if se := requestError.ServiceError; se != nil && se.InnerError != nil {
|
||||||
|
code, ok := se.InnerError["code"].(string)
|
||||||
|
if ok && code == "KeyNotYetValid" {
|
||||||
|
time.Sleep(time.Second / time.Duration(retryAttemps))
|
||||||
|
retryAttemps--
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (keyvault.JSONWebKeySignatureAlgorithm, error) {
|
||||||
switch key.(type) {
|
switch key.(type) {
|
||||||
case *rsa.PublicKey:
|
case *rsa.PublicKey:
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
"github.com/Azure/azure-sdk-for-go/services/keyvault/v7.1/keyvault"
|
||||||
|
"github.com/Azure/go-autorest/autorest"
|
||||||
|
"github.com/Azure/go-autorest/autorest/azure"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/smallstep/certificates/kms/apiv1"
|
"github.com/smallstep/certificates/kms/apiv1"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
|
@ -350,3 +352,142 @@ func TestSigner_Sign(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSigner_Sign_signWithRetry(t *testing.T) {
|
||||||
|
sign := func(kty, crv string, bits int, opts crypto.SignerOpts) (crypto.PublicKey, []byte, string, []byte) {
|
||||||
|
key, err := keyutil.GenerateSigner(kty, crv, bits)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
h := opts.HashFunc().New()
|
||||||
|
h.Write([]byte("random-data"))
|
||||||
|
sum := h.Sum(nil)
|
||||||
|
|
||||||
|
var sig, resultSig []byte
|
||||||
|
if priv, ok := key.(*ecdsa.PrivateKey); ok {
|
||||||
|
r, s, err := ecdsa.Sign(rand.Reader, priv, sum)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
curveBits := priv.Params().BitSize
|
||||||
|
keyBytes := curveBits / 8
|
||||||
|
if curveBits%8 > 0 {
|
||||||
|
keyBytes++
|
||||||
|
}
|
||||||
|
rBytes := r.Bytes()
|
||||||
|
rBytesPadded := make([]byte, keyBytes)
|
||||||
|
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||||
|
|
||||||
|
sBytes := s.Bytes()
|
||||||
|
sBytesPadded := make([]byte, keyBytes)
|
||||||
|
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||||
|
// nolint:gocritic
|
||||||
|
resultSig = append(rBytesPadded, sBytesPadded...)
|
||||||
|
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
b.AddASN1BigInt(r)
|
||||||
|
b.AddASN1BigInt(s)
|
||||||
|
})
|
||||||
|
sig, err = b.Bytes()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sig, err = key.Sign(rand.Reader, sum, opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resultSig = sig
|
||||||
|
}
|
||||||
|
|
||||||
|
return key.Public(), h.Sum(nil), base64.RawURLEncoding.EncodeToString(resultSig), sig
|
||||||
|
}
|
||||||
|
|
||||||
|
p256, p256Digest, p256ResultSig, p256Sig := sign("EC", "P-256", 0, crypto.SHA256)
|
||||||
|
okResult := keyvault.KeyOperationResult{
|
||||||
|
Result: &p256ResultSig,
|
||||||
|
}
|
||||||
|
failResult := keyvault.KeyOperationResult{}
|
||||||
|
retryError := autorest.DetailedError{
|
||||||
|
Original: &azure.RequestError{
|
||||||
|
ServiceError: &azure.ServiceError{
|
||||||
|
InnerError: map[string]interface{}{
|
||||||
|
"code": "KeyNotYetValid",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := mockClient(t)
|
||||||
|
expects := []struct {
|
||||||
|
name string
|
||||||
|
keyVersion string
|
||||||
|
alg keyvault.JSONWebKeySignatureAlgorithm
|
||||||
|
digest []byte
|
||||||
|
result keyvault.KeyOperationResult
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"ok 1", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"ok 2", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"ok 3", "", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"ok 4", "", keyvault.ES256, p256Digest, okResult, nil},
|
||||||
|
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
{"fail", "fail-version", keyvault.ES256, p256Digest, failResult, retryError},
|
||||||
|
}
|
||||||
|
for _, e := range expects {
|
||||||
|
value := base64.RawURLEncoding.EncodeToString(e.digest)
|
||||||
|
client.EXPECT().Sign(gomock.Any(), "https://my-vault.vault.azure.net/", "my-key", e.keyVersion, keyvault.KeySignParameters{
|
||||||
|
Algorithm: e.alg,
|
||||||
|
Value: &value,
|
||||||
|
}).Return(e.result, e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
client KeyVaultClient
|
||||||
|
vaultBaseURL string
|
||||||
|
name string
|
||||||
|
version string
|
||||||
|
publicKey crypto.PublicKey
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
rand io.Reader
|
||||||
|
digest []byte
|
||||||
|
opts crypto.SignerOpts
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{client, "https://my-vault.vault.azure.net/", "my-key", "", p256}, args{
|
||||||
|
rand.Reader, p256Digest, crypto.SHA256,
|
||||||
|
}, p256Sig, false},
|
||||||
|
{"fail", fields{client, "https://my-vault.vault.azure.net/", "my-key", "fail-version", p256}, args{
|
||||||
|
rand.Reader, p256Digest, crypto.SHA256,
|
||||||
|
}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &Signer{
|
||||||
|
client: tt.fields.client,
|
||||||
|
vaultBaseURL: tt.fields.vaultBaseURL,
|
||||||
|
name: tt.fields.name,
|
||||||
|
version: tt.fields.version,
|
||||||
|
publicKey: tt.fields.publicKey,
|
||||||
|
}
|
||||||
|
got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Signer.Sign() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue