diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index a2332d5f..d434ce48 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -3,6 +3,7 @@ package cloudkms import ( "context" "crypto" + "log" "strings" "time" @@ -18,6 +19,8 @@ import ( kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1" ) +const pendingGenerationRetries = 10 + // protectionLevelMapping maps step protection levels with cloud kms ones. var protectionLevelMapping = map[apiv1.ProtectionLevel]kmspb.ProtectionLevel{ apiv1.UnspecifiedProtectionLevel: kmspb.ProtectionLevel_PROTECTION_LEVEL_UNSPECIFIED, @@ -189,6 +192,12 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo crytoKeyName = response.Name + "/cryptoKeyVersions/1" } + // Sleep deterministically to avoid retries because of PENDING_GENERATING. + // One second is often enough. + if protectionLevel == kmspb.ProtectionLevel_HSM { + time.Sleep(1 * time.Second) + } + // Retrieve public key to add it to the response. pk, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{ Name: crytoKeyName, @@ -237,12 +246,7 @@ func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKe return nil, errors.New("createKeyRequest 'name' cannot be empty") } - ctx, cancel := defaultContext() - defer cancel() - - response, err := k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ - Name: req.Name, - }) + response, err := k.getPublicKeyWithRetries(req.Name, pendingGenerationRetries) if err != nil { return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed") } @@ -255,6 +259,30 @@ func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKe return pk, nil } +// getPublicKeyWithRetries retries the request if the error is +// FailedPrecondition, caused because the key is in the PENDING_GENERATION +// status. +func (k *CloudKMS) getPublicKeyWithRetries(name string, retries int) (response *kmspb.PublicKey, err error) { + workFn := func() (*kmspb.PublicKey, error) { + ctx, cancel := defaultContext() + defer cancel() + return k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ + Name: name, + }) + } + for i := 0; i < retries; i++ { + if response, err = workFn(); err == nil { + return + } + if status.Code(err) == codes.FailedPrecondition { + log.Println("Waiting for key generation ...") + time.Sleep(time.Duration(i+1) * time.Second) + continue + } + } + return +} + func defaultContext() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 15*time.Second) } diff --git a/kms/cloudkms/cloudkms_test.go b/kms/cloudkms/cloudkms_test.go index b4f92fa6..c5eba318 100644 --- a/kms/cloudkms/cloudkms_test.go +++ b/kms/cloudkms/cloudkms_test.go @@ -175,6 +175,7 @@ func TestCloudKMS_CreateKey(t *testing.T) { t.Fatal(err) } + var retries int type fields struct { client KeyManagementClient } @@ -236,6 +237,24 @@ func TestCloudKMS_CreateKey(t *testing.T) { }}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, &apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/2", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/2"}}, false}, + {"ok with retries", fields{ + &MockClient{ + getKeyRing: func(_ context.Context, _ *kmspb.GetKeyRingRequest, _ ...gax.CallOption) (*kmspb.KeyRing, error) { + return &kmspb.KeyRing{}, nil + }, + createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) { + return &kmspb.CryptoKey{Name: keyName}, nil + }, + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + if retries != 2 { + retries++ + return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION") + } + return &kmspb.PublicKey{Pem: string(pemBytes)}, nil + }, + }}, + args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.HSM, SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, + &apiv1.CreateKeyResponse{Name: keyName + "/cryptoKeyVersions/1", PublicKey: pk, CreateSignerRequest: apiv1.CreateSignerRequest{SigningKey: keyName + "/cryptoKeyVersions/1"}}, false}, {"fail name", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{}}, nil, true}, {"fail protection level", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.ProtectionLevel(100)}}, nil, true}, {"fail signature algorithm", fields{&MockClient{}}, args{&apiv1.CreateKeyRequest{Name: keyName, ProtectionLevel: apiv1.Software, SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, nil, true}, @@ -322,6 +341,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) { t.Fatal(err) } + var retries int type fields struct { client KeyManagementClient } @@ -342,6 +362,17 @@ func TestCloudKMS_GetPublicKey(t *testing.T) { }, }}, args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false}, + {"ok with retries", fields{ + &MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + if retries != 2 { + retries++ + return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION") + } + return &kmspb.PublicKey{Pem: string(pemBytes)}, nil + }, + }}, + args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false}, {"fail name", fields{&MockClient{}}, args{&apiv1.GetPublicKeyRequest{}}, nil, true}, {"fail get public key", fields{ &MockClient{