From aaf71ce66a7f52c87ca4c4e1cd530a0e66168eb2 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 20 May 2020 17:04:01 -0700 Subject: [PATCH] Add unit tests for awskms. --- kms/awskms/awskms_test.go | 358 ++++++++++++++++++++++++++++++++++++++ kms/awskms/mock_test.go | 72 ++++++++ kms/awskms/signer.go | 19 +- kms/awskms/signer_test.go | 191 ++++++++++++++++++++ 4 files changed, 631 insertions(+), 9 deletions(-) create mode 100644 kms/awskms/awskms_test.go create mode 100644 kms/awskms/mock_test.go create mode 100644 kms/awskms/signer_test.go diff --git a/kms/awskms/awskms_test.go b/kms/awskms/awskms_test.go new file mode 100644 index 00000000..f19e1c49 --- /dev/null +++ b/kms/awskms/awskms_test.go @@ -0,0 +1,358 @@ +package awskms + +import ( + "context" + "crypto" + "fmt" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kms" + "github.com/smallstep/certificates/kms/apiv1" + "github.com/smallstep/cli/crypto/pemutil" +) + +func TestNew(t *testing.T) { + ctx := context.Background() + + sess, err := session.NewSessionWithOptions(session.Options{}) + if err != nil { + t.Fatal(err) + } + expected := &KMS{ + session: sess, + service: kms.New(sess), + } + + // This will force an error in the session creation. + // It does not fail with missing credentials. + forceError := func(t *testing.T) { + key := "AWS_CA_BUNDLE" + value := os.Getenv(key) + os.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt")) + t.Cleanup(func() { + if value == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + }) + } + + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + args args + want *KMS + wantErr bool + }{ + {"ok", args{ctx, apiv1.Options{}}, expected, false}, + {"ok with options", args{ctx, apiv1.Options{ + Region: "us-east-1", + Profile: "smallstep", + CredentialsFile: "~/aws/credentials", + }}, expected, false}, + {"fail", args{ctx, apiv1.Options{}}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Force an error in the session loading + if tt.wantErr { + forceError(t) + } + + got, err := New(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err != nil { + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %#v, want %#v", got, tt.want) + } + } else { + if got.session == nil || got.service == nil { + t.Errorf("New() = %#v, want %#v", got, tt.want) + } + } + }) + } +} + +func TestKMS_GetPublicKey(t *testing.T) { + okClient := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type fields struct { + session *session.Session + service KeyManagementClient + } + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.PublicKey + wantErr bool + }{ + {"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ + Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }}, key, false}, + {"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true}, + {"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ + Name: "awskms:key-id=", + }}, nil, true}, + {"fail getPublicKey", fields{nil, &MockClient{ + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return nil, fmt.Errorf("an error") + }, + }}, args{&apiv1.GetPublicKeyRequest{ + Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }}, nil, true}, + {"fail not der", fields{nil, &MockClient{ + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return &kms.GetPublicKeyOutput{ + KeyId: input.KeyId, + PublicKey: []byte(publicKey), + }, nil + }, + }}, args{&apiv1.GetPublicKeyRequest{ + Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KMS{ + session: tt.fields.session, + service: tt.fields.service, + } + got, err := k.GetPublicKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KMS.GetPublicKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKMS_CreateKey(t *testing.T) { + okClient := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type fields struct { + session *session.Session + service KeyManagementClient + } + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + fields fields + args args + want *apiv1.CreateKeyResponse + wantErr bool + }{ + {"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, &apiv1.CreateKeyResponse{ + Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + PublicKey: key, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }, + }, false}, + {"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }}, &apiv1.CreateKeyResponse{ + Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + PublicKey: key, + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }, + }, false}, + {"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true}, + {"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.PureEd25519, + }}, nil, true}, + {"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 1234, + }}, nil, true}, + {"fail createKey", fields{nil, &MockClient{ + createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { + return nil, fmt.Errorf("an error") + }, + createAliasWithContext: okClient.createAliasWithContext, + getPublicKeyWithContext: okClient.getPublicKeyWithContext, + }}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail createAlias", fields{nil, &MockClient{ + createKeyWithContext: okClient.createKeyWithContext, + createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { + return nil, fmt.Errorf("an error") + }, + getPublicKeyWithContext: okClient.getPublicKeyWithContext, + }}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail getPublicKey", fields{nil, &MockClient{ + createKeyWithContext: okClient.createKeyWithContext, + createAliasWithContext: okClient.createAliasWithContext, + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return nil, fmt.Errorf("an error") + }, + }}, args{&apiv1.CreateKeyRequest{ + Name: "root", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KMS{ + session: tt.fields.session, + service: tt.fields.service, + } + got, err := k.CreateKey(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KMS.CreateKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKMS_CreateSigner(t *testing.T) { + client := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type fields struct { + session *session.Session + service KeyManagementClient + } + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.Signer + wantErr bool + }{ + {"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{ + SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", + }}, &Signer{ + service: client, + keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936", + publicKey: key, + }, false}, + {"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, + {"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KMS{ + session: tt.fields.session, + service: tt.fields.service, + } + got, err := k.CreateSigner(tt.args.req) + if (err != nil) != tt.wantErr { + t.Errorf("KMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("KMS.CreateSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKMS_Close(t *testing.T) { + type fields struct { + session *session.Session + service KeyManagementClient + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{nil, getOKClient()}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &KMS{ + session: tt.fields.session, + service: tt.fields.service, + } + if err := k.Close(); (err != nil) != tt.wantErr { + t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_parseKeyID(t *testing.T) { + type args struct { + name string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"ok uri", args{"awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, + {"ok key id", args{"be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, + {"ok arn", args{"arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, "arn:aws:kms:us-east-1:123456789:key/be468355-ca7a-40d9-a28b-8ae1c4c7f936", false}, + {"fail parse", args{"awskms:key-id=%ZZ"}, "", true}, + {"fail empty key", args{"awskms:key-id="}, "", true}, + {"fail missing", args{"awskms:foo=bar"}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseKeyID(tt.args.name) + if (err != nil) != tt.wantErr { + t.Errorf("parseKeyID() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("parseKeyID() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/awskms/mock_test.go b/kms/awskms/mock_test.go new file mode 100644 index 00000000..ba35d87a --- /dev/null +++ b/kms/awskms/mock_test.go @@ -0,0 +1,72 @@ +package awskms + +import ( + "encoding/pem" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/kms" +) + +type MockClient struct { + getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) + createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) + createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) + signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) +} + +func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return m.getPublicKeyWithContext(ctx, input, opts...) +} + +func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { + return m.createKeyWithContext(ctx, input, opts...) +} + +func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { + return m.createAliasWithContext(ctx, input, opts...) +} + +func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { + return m.signWithContext(ctx, input, opts...) +} + +const ( + publicKey = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE8XWlIWkOThxNjGbZLYUgRHmsvCrW +KF+HLktPfPTIK3lGd1k4849WQs59XIN+LXZQ6b2eRBEBKAHEyQus8UU7gw== +-----END PUBLIC KEY-----` + keyId = "be468355-ca7a-40d9-a28b-8ae1c4c7f936" +) + +var signature = []byte{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, +} + +func getOKClient() *MockClient { + return &MockClient{ + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + block, _ := pem.Decode([]byte(publicKey)) + return &kms.GetPublicKeyOutput{ + KeyId: input.KeyId, + PublicKey: block.Bytes, + }, nil + }, + createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { + md := new(kms.KeyMetadata) + md.SetKeyId(keyId) + return &kms.CreateKeyOutput{ + KeyMetadata: md, + }, nil + }, + createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { + return &kms.CreateAliasOutput{}, nil + }, + signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { + return &kms.SignOutput{ + Signature: signature, + }, nil + }, + } +} diff --git a/kms/awskms/signer.go b/kms/awskms/signer.go index aa1eb26c..3d9767d0 100644 --- a/kms/awskms/signer.go +++ b/kms/awskms/signer.go @@ -11,6 +11,7 @@ import ( "github.com/smallstep/cli/crypto/pemutil" ) +// Signer implements a crypto.Signer using the AWS KMS. type Signer struct { service KeyManagementClient keyID string @@ -88,30 +89,30 @@ func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string, switch h := opts.HashFunc(); h { case crypto.SHA256: if isPSS { - return "RSASSA_PSS_SHA_256", nil + return kms.SigningAlgorithmSpecRsassaPssSha256, nil } - return "RSASSA_PKCS1_V1_5_SHA_256", nil + return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil case crypto.SHA384: if isPSS { - return "RSASSA_PSS_SHA_384", nil + return kms.SigningAlgorithmSpecRsassaPssSha384, nil } - return "RSASSA_PKCS1_V1_5_SHA_384", nil + return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil case crypto.SHA512: if isPSS { - return "RSASSA_PSS_SHA_512", nil + return kms.SigningAlgorithmSpecRsassaPssSha512, nil } - return "RSASSA_PKCS1_V1_5_SHA_512", nil + return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil default: return "", errors.Errorf("unsupported hash function %v", h) } case *ecdsa.PublicKey: switch h := opts.HashFunc(); h { case crypto.SHA256: - return "ECDSA_SHA_256", nil + return kms.SigningAlgorithmSpecEcdsaSha256, nil case crypto.SHA384: - return "ECDSA_SHA_384", nil + return kms.SigningAlgorithmSpecEcdsaSha384, nil case crypto.SHA512: - return "ECDSA_SHA_512", nil + return kms.SigningAlgorithmSpecEcdsaSha512, nil default: return "", errors.Errorf("unsupported hash function %v", h) } diff --git a/kms/awskms/signer_test.go b/kms/awskms/signer_test.go new file mode 100644 index 00000000..51915174 --- /dev/null +++ b/kms/awskms/signer_test.go @@ -0,0 +1,191 @@ +package awskms + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "fmt" + "io" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/kms" + "github.com/smallstep/cli/crypto/pemutil" +) + +func TestNewSigner(t *testing.T) { + okClient := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type args struct { + svc KeyManagementClient + signingKey string + } + tests := []struct { + name string + args args + want *Signer + wantErr bool + }{ + {"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{ + service: okClient, + keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936", + publicKey: key, + }, false}, + {"fail parse", args{okClient, "awskms:key-id="}, nil, true}, + {"fail preload", args{&MockClient{ + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return nil, fmt.Errorf("an error") + }, + }, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true}, + {"fail preload not der", args{&MockClient{ + getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + return &kms.GetPublicKeyOutput{ + KeyId: input.KeyId, + PublicKey: []byte(publicKey), + }, nil + }, + }, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewSigner(tt.args.svc, tt.args.signingKey) + if (err != nil) != tt.wantErr { + t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSigner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Public(t *testing.T) { + okClient := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type fields struct { + service KeyManagementClient + keyID string + publicKey crypto.PublicKey + } + tests := []struct { + name string + fields fields + want crypto.PublicKey + }{ + {"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, key}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + service: tt.fields.service, + keyID: tt.fields.keyID, + publicKey: tt.fields.publicKey, + } + if got := s.Public(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Public() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSigner_Sign(t *testing.T) { + okClient := getOKClient() + key, err := pemutil.ParseKey([]byte(publicKey)) + if err != nil { + t.Fatal(err) + } + + type fields struct { + service KeyManagementClient + keyID string + publicKey crypto.PublicKey + } + type args struct { + rand io.Reader + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"ok", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, signature, false}, + {"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true}, + {"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true}, + {"fail sign", fields{&MockClient{ + signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { + return nil, fmt.Errorf("an error") + }, + }, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &Signer{ + service: tt.fields.service, + keyID: tt.fields.keyID, + publicKey: tt.fields.publicKey, + } + got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("Signer.Sign() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Signer.Sign() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getSigningAlgorithm(t *testing.T) { + type args struct { + key crypto.PublicKey + opts crypto.SignerOpts + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false}, + {"rsa+sha384", args{&rsa.PublicKey{}, crypto.SHA384}, "RSASSA_PKCS1_V1_5_SHA_384", false}, + {"rsa+sha512", args{&rsa.PublicKey{}, crypto.SHA512}, "RSASSA_PKCS1_V1_5_SHA_512", false}, + {"pssrsa+sha256", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA256.HashFunc()}}, "RSASSA_PSS_SHA_256", false}, + {"pssrsa+sha384", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA384.HashFunc()}}, "RSASSA_PSS_SHA_384", false}, + {"pssrsa+sha512", args{&rsa.PublicKey{}, &rsa.PSSOptions{Hash: crypto.SHA512.HashFunc()}}, "RSASSA_PSS_SHA_512", false}, + {"P256", args{&ecdsa.PublicKey{}, crypto.SHA256}, "ECDSA_SHA_256", false}, + {"P384", args{&ecdsa.PublicKey{}, crypto.SHA384}, "ECDSA_SHA_384", false}, + {"P521", args{&ecdsa.PublicKey{}, crypto.SHA512}, "ECDSA_SHA_512", false}, + {"fail type", args{[]byte("key"), crypto.SHA256}, "", true}, + {"fail rsa alg", args{&rsa.PublicKey{}, crypto.MD5}, "", true}, + {"fail ecdsa alg", args{&ecdsa.PublicKey{}, crypto.MD5}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getSigningAlgorithm(tt.args.key, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("getSigningAlgorithm() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getSigningAlgorithm() = %v, want %v", got, tt.want) + } + }) + } +}