From 163eb7029cbf47c602172fb08c0975c67cfe63b0 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 19 Feb 2021 15:36:55 -0800 Subject: [PATCH] Refactor cloudkms signer to return an error on the constructor. --- kms/cloudkms/cloudkms.go | 14 +----- kms/cloudkms/cloudkms_test.go | 24 ++++++++++- kms/cloudkms/signer.go | 28 +++++++----- kms/cloudkms/signer_test.go | 80 ++++++++++++++++++++--------------- 4 files changed, 86 insertions(+), 60 deletions(-) diff --git a/kms/cloudkms/cloudkms.go b/kms/cloudkms/cloudkms.go index 83bd167c..cfbf8235 100644 --- a/kms/cloudkms/cloudkms.go +++ b/kms/cloudkms/cloudkms.go @@ -140,19 +140,7 @@ func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, if req.SigningKey == "" { return nil, errors.New("signing key cannot be empty") } - - // Validate that the key exists - ctx, cancel := defaultContext() - defer cancel() - - _, err := k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ - Name: req.SigningKey, - }) - if err != nil { - return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed") - } - - return NewSigner(k.client, req.SigningKey), nil + return NewSigner(k.client, req.SigningKey) } // CreateKey creates in Google's Cloud KMS a new asymmetric key for signing. diff --git a/kms/cloudkms/cloudkms_test.go b/kms/cloudkms/cloudkms_test.go index e04e0198..fefa6e2a 100644 --- a/kms/cloudkms/cloudkms_test.go +++ b/kms/cloudkms/cloudkms_test.go @@ -165,6 +165,15 @@ func TestCloudKMS_Close(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) { keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" + pemBytes, err := ioutil.ReadFile("testdata/pub.pem") + if err != nil { + t.Fatal(err) + } + pk, err := pemutil.ParseKey(pemBytes) + if err != nil { + t.Fatal(err) + } + type fields struct { client KeyManagementClient } @@ -178,8 +187,16 @@ func TestCloudKMS_CreateSigner(t *testing.T) { want crypto.Signer wantErr bool }{ - {"ok", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName}, false}, - {"fail", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true}, + {"ok", fields{&MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + return &kmspb.PublicKey{Pem: string(pemBytes)}, nil + }, + }}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false}, + {"fail", fields{&MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + return nil, fmt.Errorf("test error") + }, + }}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -191,6 +208,9 @@ func TestCloudKMS_CreateSigner(t *testing.T) { t.Errorf("CloudKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) return } + if signer, ok := got.(*Signer); ok { + signer.client = &MockClient{} + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CloudKMS.CreateSigner() = %v, want %v", got, tt.want) } diff --git a/kms/cloudkms/signer.go b/kms/cloudkms/signer.go index 303c2496..686aca25 100644 --- a/kms/cloudkms/signer.go +++ b/kms/cloudkms/signer.go @@ -13,33 +13,41 @@ import ( type Signer struct { client KeyManagementClient signingKey string + publicKey crypto.PublicKey } -func NewSigner(c KeyManagementClient, signingKey string) *Signer { - return &Signer{ +// NewSigner creates a new crypto.Signer the given CloudKMS signing key. +func NewSigner(c KeyManagementClient, signingKey string) (*Signer, error) { + // Make sure that the key exists. + signer := &Signer{ client: c, signingKey: signingKey, } + if err := signer.preloadKey(signingKey); err != nil { + return nil, err + } + + return signer, nil } -// Public returns the public key of this signer or an error. -func (s *Signer) Public() crypto.PublicKey { +func (s *Signer) preloadKey(signingKey string) error { ctx, cancel := defaultContext() defer cancel() response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ - Name: s.signingKey, + Name: signingKey, }) if err != nil { return errors.Wrap(err, "cloudKMS GetPublicKey failed") } - pk, err := pemutil.ParseKey([]byte(response.Pem)) - if err != nil { - return err - } + s.publicKey, err = pemutil.ParseKey([]byte(response.Pem)) + return err +} - return pk +// Public returns the public key of this signer or an error. +func (s *Signer) Public() crypto.PublicKey { + return s.publicKey } // Sign signs digest with the private key stored in Google's Cloud KMS. diff --git a/kms/cloudkms/signer_test.go b/kms/cloudkms/signer_test.go index dec176f4..fa730fe3 100644 --- a/kms/cloudkms/signer_test.go +++ b/kms/cloudkms/signer_test.go @@ -16,30 +16,59 @@ import ( ) func Test_newSigner(t *testing.T) { + pemBytes, err := ioutil.ReadFile("testdata/pub.pem") + if err != nil { + t.Fatal(err) + } + pk, err := pemutil.ParseKey(pemBytes) + if err != nil { + t.Fatal(err) + } + type args struct { c KeyManagementClient signingKey string } tests := []struct { - name string - args args - want *Signer + name string + args args + want *Signer + wantErr bool }{ - {"ok", args{&MockClient{}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey"}}, + {"ok", args{&MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + return &kmspb.PublicKey{Pem: string(pemBytes)}, nil + }, + }, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey", publicKey: pk}, false}, + {"fail get public key", args{&MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + return nil, fmt.Errorf("an error") + }, + }, "signingKey"}, nil, true}, + {"fail parse pem", args{&MockClient{ + getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { + return &kmspb.PublicKey{Pem: string("bad pem")}, nil + }, + }, "signingKey"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := NewSigner(tt.args.c, tt.args.signingKey); !reflect.DeepEqual(got, tt.want) { - t.Errorf("newSigner() = %v, want %v", got, tt.want) + got, err := NewSigner(tt.args.c, tt.args.signingKey) + if (err != nil) != tt.wantErr { + t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + got.client = &MockClient{} + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSigner() = %v, want %v", got, tt.want) } }) } } func Test_signer_Public(t *testing.T) { - keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" - testError := fmt.Errorf("an error") - pemBytes, err := ioutil.ReadFile("testdata/pub.pem") if err != nil { t.Fatal(err) @@ -52,42 +81,23 @@ func Test_signer_Public(t *testing.T) { type fields struct { client KeyManagementClient signingKey string + publicKey crypto.PublicKey } tests := []struct { - name string - fields fields - want crypto.PublicKey - wantErr bool + name string + fields fields + want crypto.PublicKey }{ - {"ok", fields{&MockClient{ - getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { - return &kmspb.PublicKey{Pem: string(pemBytes)}, nil - }, - }, keyName}, pk, false}, - {"fail get public key", fields{&MockClient{ - getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { - return nil, testError - }, - }, keyName}, nil, true}, - {"fail parse pem", fields{ - &MockClient{ - getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) { - return &kmspb.PublicKey{Pem: string("bad pem")}, nil - }, - }, keyName}, nil, true}, + {"ok", fields{&MockClient{}, "signingKey", pk}, pk}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Signer{ client: tt.fields.client, signingKey: tt.fields.signingKey, + publicKey: tt.fields.publicKey, } - got := s.Public() - if _, ok := got.(error); ok != tt.wantErr { - t.Errorf("signer.Public() error = %v, wantErr %v", got, tt.wantErr) - return - } - if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + if got := s.Public(); !reflect.DeepEqual(got, tt.want) { t.Errorf("signer.Public() = %v, want %v", got, tt.want) } })