Make Signer public and add contructor NewCloudKMS.

This commit is contained in:
Mariano Cano 2020-01-21 19:09:21 -08:00
parent 5d5ee68d88
commit fa8116497c
4 changed files with 64 additions and 34 deletions

View file

@ -56,7 +56,9 @@ var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]interface{}{
apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384, apiv1.ECDSAWithSHA384: kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384,
} }
type keyManagementClient interface { // KeyManagementClient defines the methods on KeyManagementClient that this
// package will use. This interface will be used for unit testing.
type KeyManagementClient interface {
Close() error Close() error
GetPublicKey(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error) GetPublicKey(context.Context, *kmspb.GetPublicKeyRequest, ...gax.CallOption) (*kmspb.PublicKey, error)
AsymmetricSign(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) AsymmetricSign(context.Context, *kmspb.AsymmetricSignRequest, ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error)
@ -68,9 +70,10 @@ type keyManagementClient interface {
// CloudKMS implements a KMS using Google's Cloud apiv1. // CloudKMS implements a KMS using Google's Cloud apiv1.
type CloudKMS struct { type CloudKMS struct {
Client keyManagementClient client KeyManagementClient
} }
// New creates a new CloudKMS configured with a new client.
func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) { func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) {
var cloudOpts []option.ClientOption var cloudOpts []option.ClientOption
if opts.CredentialsFile != "" { if opts.CredentialsFile != "" {
@ -83,13 +86,20 @@ func New(ctx context.Context, opts apiv1.Options) (*CloudKMS, error) {
} }
return &CloudKMS{ return &CloudKMS{
Client: client, client: client,
}, nil }, nil
} }
// NewCloudKMS creates a CloudKMS with a given client.
func NewCloudKMS(client KeyManagementClient) *CloudKMS {
return &CloudKMS{
client: client,
}
}
// Close closes the connection of the Cloud KMS client. // Close closes the connection of the Cloud KMS client.
func (k *CloudKMS) Close() error { func (k *CloudKMS) Close() error {
if err := k.Client.Close(); err != nil { if err := k.client.Close(); err != nil {
return errors.Wrap(err, "cloudKMS Close failed") return errors.Wrap(err, "cloudKMS Close failed")
} }
return nil return nil
@ -102,7 +112,7 @@ func (k *CloudKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer,
return nil, errors.New("signing key cannot be empty") return nil, errors.New("signing key cannot be empty")
} }
return newSigner(k.Client, req.SigningKey), nil return NewSigner(k.client, req.SigningKey), nil
} }
// CreateKey creates in Google's Cloud KMS a new asymmetric key for signing. // CreateKey creates in Google's Cloud KMS a new asymmetric key for signing.
@ -145,7 +155,7 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo
defer cancel() defer cancel()
// Create private key in CloudKMS. // Create private key in CloudKMS.
response, err := k.Client.CreateCryptoKey(ctx, &kmspb.CreateCryptoKeyRequest{ response, err := k.client.CreateCryptoKey(ctx, &kmspb.CreateCryptoKeyRequest{
Parent: keyRing, Parent: keyRing,
CryptoKeyId: keyID, CryptoKeyId: keyID,
CryptoKey: &kmspb.CryptoKey{ CryptoKey: &kmspb.CryptoKey{
@ -170,7 +180,7 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo
State: kmspb.CryptoKeyVersion_ENABLED, State: kmspb.CryptoKeyVersion_ENABLED,
}, },
} }
response, err := k.Client.CreateCryptoKeyVersion(ctx, req) response, err := k.client.CreateCryptoKeyVersion(ctx, req)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed") return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed")
} }
@ -200,7 +210,7 @@ func (k *CloudKMS) createKeyRingIfNeeded(name string) error {
ctx, cancel := defaultContext() ctx, cancel := defaultContext()
defer cancel() defer cancel()
_, err := k.Client.GetKeyRing(ctx, &kmspb.GetKeyRingRequest{ _, err := k.client.GetKeyRing(ctx, &kmspb.GetKeyRingRequest{
Name: name, Name: name,
}) })
if err == nil { if err == nil {
@ -208,7 +218,7 @@ func (k *CloudKMS) createKeyRingIfNeeded(name string) error {
} }
parent, child := Parent(name) parent, child := Parent(name)
_, err = k.Client.CreateKeyRing(ctx, &kmspb.CreateKeyRingRequest{ _, err = k.client.CreateKeyRing(ctx, &kmspb.CreateKeyRingRequest{
Parent: parent, Parent: parent,
KeyRingId: child, KeyRingId: child,
}) })
@ -230,7 +240,7 @@ func (k *CloudKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKe
ctx, cancel := defaultContext() ctx, cancel := defaultContext()
defer cancel() defer cancel()
response, err := k.Client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{ response, err := k.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: req.Name, Name: req.Name,
}) })
if err != nil { if err != nil {

View file

@ -76,9 +76,29 @@ func TestNew(t *testing.T) {
} }
} }
func TestNewCloudKMS(t *testing.T) {
type args struct {
client KeyManagementClient
}
tests := []struct {
name string
args args
want *CloudKMS
}{
{"ok", args{&MockClient{}}, &CloudKMS{&MockClient{}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewCloudKMS(tt.args.client); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCloudKMS() = %v, want %v", got, tt.want)
}
})
}
}
func TestCloudKMS_Close(t *testing.T) { func TestCloudKMS_Close(t *testing.T) {
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
} }
tests := []struct { tests := []struct {
name string name string
@ -91,7 +111,7 @@ func TestCloudKMS_Close(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{ k := &CloudKMS{
Client: tt.fields.client, client: tt.fields.client,
} }
if err := k.Close(); (err != nil) != tt.wantErr { if err := k.Close(); (err != nil) != tt.wantErr {
t.Errorf("CloudKMS.Close() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("CloudKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
@ -103,7 +123,7 @@ func TestCloudKMS_Close(t *testing.T) {
func TestCloudKMS_CreateSigner(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
} }
type args struct { type args struct {
req *apiv1.CreateSignerRequest req *apiv1.CreateSignerRequest
@ -115,13 +135,13 @@ func TestCloudKMS_CreateSigner(t *testing.T) {
want crypto.Signer want crypto.Signer
wantErr bool wantErr bool
}{ }{
{"ok", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &signer{client: &MockClient{}, signingKey: keyName}, false}, {"ok", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName}, false},
{"fail", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true}, {"fail", fields{&MockClient{}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{ k := &CloudKMS{
Client: tt.fields.client, client: tt.fields.client,
} }
got, err := k.CreateSigner(tt.args.req) got, err := k.CreateSigner(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -150,7 +170,7 @@ func TestCloudKMS_CreateKey(t *testing.T) {
} }
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
} }
type args struct { type args struct {
req *apiv1.CreateKeyRequest req *apiv1.CreateKeyRequest
@ -269,7 +289,7 @@ func TestCloudKMS_CreateKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{ k := &CloudKMS{
Client: tt.fields.client, client: tt.fields.client,
} }
got, err := k.CreateKey(tt.args.req) got, err := k.CreateKey(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@ -297,7 +317,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
} }
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
} }
type args struct { type args struct {
req *apiv1.GetPublicKeyRequest req *apiv1.GetPublicKeyRequest
@ -335,7 +355,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
k := &CloudKMS{ k := &CloudKMS{
Client: tt.fields.client, client: tt.fields.client,
} }
got, err := k.GetPublicKey(tt.args.req) got, err := k.GetPublicKey(tt.args.req)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {

View file

@ -9,21 +9,21 @@ import (
kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1" kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
) )
// signer implements a crypto.Signer using Google's Cloud KMS. // Signer implements a crypto.Signer using Google's Cloud KMS.
type signer struct { type Signer struct {
client keyManagementClient client KeyManagementClient
signingKey string signingKey string
} }
func newSigner(c keyManagementClient, signingKey string) *signer { func NewSigner(c KeyManagementClient, signingKey string) *Signer {
return &signer{ return &Signer{
client: c, client: c,
signingKey: signingKey, signingKey: signingKey,
} }
} }
// Public returns the public key of this signer or an error. // Public returns the public key of this signer or an error.
func (s *signer) Public() crypto.PublicKey { func (s *Signer) Public() crypto.PublicKey {
ctx, cancel := defaultContext() ctx, cancel := defaultContext()
defer cancel() defer cancel()
@ -43,7 +43,7 @@ func (s *signer) Public() crypto.PublicKey {
} }
// Sign signs digest with the private key stored in Google's Cloud KMS. // Sign signs digest with the private key stored in Google's Cloud KMS.
func (s *signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (s *Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
req := &kmspb.AsymmetricSignRequest{ req := &kmspb.AsymmetricSignRequest{
Name: s.signingKey, Name: s.signingKey,
Digest: &kmspb.Digest{}, Digest: &kmspb.Digest{},

View file

@ -17,19 +17,19 @@ import (
func Test_newSigner(t *testing.T) { func Test_newSigner(t *testing.T) {
type args struct { type args struct {
c keyManagementClient c KeyManagementClient
signingKey string signingKey string
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *signer want *Signer
}{ }{
{"ok", args{&MockClient{}, "signingKey"}, &signer{client: &MockClient{}, signingKey: "signingKey"}}, {"ok", args{&MockClient{}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey"}},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := newSigner(tt.args.c, tt.args.signingKey); !reflect.DeepEqual(got, tt.want) { if got := NewSigner(tt.args.c, tt.args.signingKey); !reflect.DeepEqual(got, tt.want) {
t.Errorf("newSigner() = %v, want %v", got, tt.want) t.Errorf("newSigner() = %v, want %v", got, tt.want)
} }
}) })
@ -50,7 +50,7 @@ func Test_signer_Public(t *testing.T) {
} }
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
signingKey string signingKey string
} }
tests := []struct { tests := []struct {
@ -78,7 +78,7 @@ func Test_signer_Public(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &signer{ s := &Signer{
client: tt.fields.client, client: tt.fields.client,
signingKey: tt.fields.signingKey, signingKey: tt.fields.signingKey,
} }
@ -108,7 +108,7 @@ func Test_signer_Sign(t *testing.T) {
} }
type fields struct { type fields struct {
client keyManagementClient client KeyManagementClient
signingKey string signingKey string
} }
type args struct { type args struct {
@ -131,7 +131,7 @@ func Test_signer_Sign(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &signer{ s := &Signer{
client: tt.fields.client, client: tt.fields.client,
signingKey: tt.fields.signingKey, signingKey: tt.fields.signingKey,
} }