From 8957e5e5a24bc46b70d6b90120cebdd80dae1ef4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 16 Sep 2020 12:34:42 -0700 Subject: [PATCH] Add missing tests --- cas/apiv1/extension_test.go | 95 +++++++++++++++++++++++++- cas/apiv1/options_test.go | 131 ++++++++++++++++++++++++++++++++++++ cas/apiv1/registry_test.go | 90 +++++++++++++++++++++++++ cas/apiv1/services_test.go | 23 +++++++ cas/cas_test.go | 60 +++++++++++++++++ 5 files changed, 397 insertions(+), 2 deletions(-) create mode 100644 cas/apiv1/options_test.go create mode 100644 cas/apiv1/registry_test.go create mode 100644 cas/apiv1/services_test.go create mode 100644 cas/cas_test.go diff --git a/cas/apiv1/extension_test.go b/cas/apiv1/extension_test.go index 113e3de1..7d6fe4dc 100644 --- a/cas/apiv1/extension_test.go +++ b/cas/apiv1/extension_test.go @@ -1,8 +1,8 @@ package apiv1 import ( + "crypto/x509" "crypto/x509/pkix" - "fmt" "reflect" "testing" ) @@ -49,7 +49,98 @@ func TestCreateCertificateAuthorityExtension(t *testing.T) { } if !reflect.DeepEqual(got, tt.want) { t.Errorf("CreateCertificateAuthorityExtension() = %v, want %v", got, tt.want) - fmt.Printf("%x\n", got.Value) + } + }) + } +} + +func TestFindCertificateAuthorityExtension(t *testing.T) { + expected := pkix.Extension{ + Id: oidStepCertificateAuthority, + Value: []byte("fake data"), + } + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want pkix.Extension + want1 bool + }{ + {"first", args{&x509.Certificate{Extensions: []pkix.Extension{ + expected, + {Id: []int{1, 2, 3, 4}}, + }}}, expected, true}, + {"last", args{&x509.Certificate{Extensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + {Id: []int{2, 3, 4, 5}}, + expected, + }}}, expected, true}, + {"fail", args{&x509.Certificate{Extensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + }}}, pkix.Extension{}, false}, + {"fail ExtraExtensions", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ + expected, + {Id: []int{1, 2, 3, 4}}, + }}}, pkix.Extension{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := FindCertificateAuthorityExtension(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("FindCertificateAuthorityExtension() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("FindCertificateAuthorityExtension() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestRemoveCertificateAuthorityExtension(t *testing.T) { + caExt := pkix.Extension{ + Id: oidStepCertificateAuthority, + Value: []byte("fake data"), + } + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + args args + want *x509.Certificate + }{ + {"first", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ + caExt, + {Id: []int{1, 2, 3, 4}}, + }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + }}}, + {"last", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + caExt, + }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + }}}, + {"missing", args{&x509.Certificate{ExtraExtensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + }}}, &x509.Certificate{ExtraExtensions: []pkix.Extension{ + {Id: []int{1, 2, 3, 4}}, + }}}, + {"extensions", args{&x509.Certificate{Extensions: []pkix.Extension{ + caExt, + {Id: []int{1, 2, 3, 4}}, + }}}, &x509.Certificate{Extensions: []pkix.Extension{ + caExt, + {Id: []int{1, 2, 3, 4}}, + }}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + RemoveCertificateAuthorityExtension(tt.args.cert) + if !reflect.DeepEqual(tt.args.cert, tt.want) { + t.Errorf("RemoveCertificateAuthorityExtension() cert = %v, want %v", tt.args.cert, tt.want) } }) } diff --git a/cas/apiv1/options_test.go b/cas/apiv1/options_test.go new file mode 100644 index 00000000..ddf26f7f --- /dev/null +++ b/cas/apiv1/options_test.go @@ -0,0 +1,131 @@ +package apiv1 + +import ( + "context" + "crypto" + "crypto/x509" + "sync" + "testing" +) + +type testCAS struct { + name string +} + +func (t *testCAS) CreateCertificate(req *CreateCertificateRequest) (*CreateCertificateResponse, error) { + return nil, nil +} + +func (t *testCAS) RenewCertificate(req *RenewCertificateRequest) (*RenewCertificateResponse, error) { + return nil, nil +} + +func (t *testCAS) RevokeCertificate(req *RevokeCertificateRequest) (*RevokeCertificateResponse, error) { + return nil, nil +} + +func mockRegister(t *testing.T) { + t.Helper() + Register(SoftCAS, func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { + return &testCAS{name: SoftCAS}, nil + }) + Register(CloudCAS, func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { + return &testCAS{name: CloudCAS}, nil + }) + t.Cleanup(func() { + registry = new(sync.Map) + }) +} + +func TestOptions_Validate(t *testing.T) { + mockRegister(t) + type fields struct { + Type string + CredentialsFile string + Certificateauthority string + Issuer *x509.Certificate + Signer crypto.Signer + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"empty", fields{}, false}, + {"SoftCAS", fields{SoftCAS, "", "", nil, nil}, false}, + {"CloudCAS", fields{CloudCAS, "", "", nil, nil}, false}, + {"softcas", fields{"softcas", "", "", nil, nil}, false}, + {"CLOUDCAS", fields{"CLOUDCAS", "", "", nil, nil}, false}, + {"fail", fields{"FailCAS", "", "", nil, nil}, true}, + } + t.Run("nil", func(t *testing.T) { + var o *Options + if err := o.Validate(); err != nil { + t.Errorf("Options.Validate() error = %v, wantErr %v", err, false) + } + }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &Options{ + Type: tt.fields.Type, + CredentialsFile: tt.fields.CredentialsFile, + Certificateauthority: tt.fields.Certificateauthority, + Issuer: tt.fields.Issuer, + Signer: tt.fields.Signer, + } + if err := o.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestOptions_HasType(t *testing.T) { + mockRegister(t) + + type fields struct { + Type string + CredentialsFile string + Certificateauthority string + Issuer *x509.Certificate + Signer crypto.Signer + } + type args struct { + t Type + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"empty", fields{}, args{}, true}, + {"SoftCAS", fields{SoftCAS, "", "", nil, nil}, args{"SoftCAS"}, true}, + {"CloudCAS", fields{CloudCAS, "", "", nil, nil}, args{"CloudCAS"}, true}, + {"softcas", fields{"softcas", "", "", nil, nil}, args{SoftCAS}, true}, + {"CLOUDCAS", fields{"CLOUDCAS", "", "", nil, nil}, args{CloudCAS}, true}, + {"UnknownCAS", fields{"UnknownCAS", "", "", nil, nil}, args{"UnknownCAS"}, true}, + {"fail", fields{CloudCAS, "", "", nil, nil}, args{"SoftCAS"}, false}, + {"fail", fields{SoftCAS, "", "", nil, nil}, args{"CloudCAS"}, false}, + } + t.Run("nil", func(t *testing.T) { + var o *Options + if got := o.HasType(SoftCAS); got != true { + t.Errorf("Options.HasType() = %v, want %v", got, true) + } + }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &Options{ + Type: tt.fields.Type, + CredentialsFile: tt.fields.CredentialsFile, + Certificateauthority: tt.fields.Certificateauthority, + Issuer: tt.fields.Issuer, + Signer: tt.fields.Signer, + } + if got := o.HasType(tt.args.t); got != tt.want { + t.Errorf("Options.HasType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cas/apiv1/registry_test.go b/cas/apiv1/registry_test.go new file mode 100644 index 00000000..ce510d13 --- /dev/null +++ b/cas/apiv1/registry_test.go @@ -0,0 +1,90 @@ +package apiv1 + +import ( + "context" + "fmt" + "reflect" + "sync" + "testing" +) + +func TestRegister(t *testing.T) { + t.Cleanup(func() { + registry = new(sync.Map) + }) + type args struct { + t Type + fn CertificateAuthorityServiceNewFunc + } + tests := []struct { + name string + args args + want CertificateAuthorityService + wantErr bool + }{ + {"ok", args{"TestCAS", func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { + return &testCAS{}, nil + }}, &testCAS{}, false}, + {"error", args{"ErrorCAS", func(ctx context.Context, opts Options) (CertificateAuthorityService, error) { + return nil, fmt.Errorf("an error") + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Register(tt.args.t, tt.args.fn) + fmt.Println(registry) + fn, ok := registry.LoadAndDelete(tt.args.t.String()) + if !ok { + t.Errorf("Register() failed") + return + } + got, err := fn.(CertificateAuthorityServiceNewFunc)(context.Background(), Options{}) + if (err != nil) != tt.wantErr { + t.Errorf("CertificateAuthorityServiceNewFunc() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CertificateAuthorityServiceNewFunc() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLoadCertificateAuthorityServiceNewFunc(t *testing.T) { + mockRegister(t) + type args struct { + t Type + } + tests := []struct { + name string + args args + want CertificateAuthorityService + wantOk bool + }{ + {"default", args{""}, &testCAS{name: SoftCAS}, true}, + {"SoftCAS", args{"SoftCAS"}, &testCAS{name: SoftCAS}, true}, + {"CloudCAS", args{"CloudCAS"}, &testCAS{name: CloudCAS}, true}, + {"softcas", args{"softcas"}, &testCAS{name: SoftCAS}, true}, + {"cloudcas", args{"cloudcas"}, &testCAS{name: CloudCAS}, true}, + {"FailCAS", args{"FailCAS"}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn, ok := LoadCertificateAuthorityServiceNewFunc(tt.args.t) + if ok != tt.wantOk { + t.Errorf("LoadCertificateAuthorityServiceNewFunc() ok = %v, want %v", ok, tt.wantOk) + return + } + if ok { + got, err := fn(context.Background(), Options{}) + if err != nil { + t.Errorf("CertificateAuthorityServiceNewFunc() error = %v", err) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CertificateAuthorityServiceNewFunc() = %v, want %v", got, tt.want) + } + } + }) + } +} diff --git a/cas/apiv1/services_test.go b/cas/apiv1/services_test.go new file mode 100644 index 00000000..f9ab9042 --- /dev/null +++ b/cas/apiv1/services_test.go @@ -0,0 +1,23 @@ +package apiv1 + +import "testing" + +func TestType_String(t *testing.T) { + tests := []struct { + name string + t Type + want string + }{ + {"default", "", "softcas"}, + {"SoftCAS", SoftCAS, "softcas"}, + {"CloudCAS", CloudCAS, "cloudcas"}, + {"UnknownCAS", "UnknownCAS", "unknowncas"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.t.String(); got != tt.want { + t.Errorf("Type.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cas/cas_test.go b/cas/cas_test.go new file mode 100644 index 00000000..a01e8dab --- /dev/null +++ b/cas/cas_test.go @@ -0,0 +1,60 @@ +package cas + +import ( + "context" + "crypto/ed25519" + "crypto/x509" + "crypto/x509/pkix" + "reflect" + "testing" + + "github.com/smallstep/certificates/cas/softcas" + + "github.com/smallstep/certificates/cas/apiv1" +) + +func TestNew(t *testing.T) { + expected := &softcas.SoftCAS{ + Issuer: &x509.Certificate{Subject: pkix.Name{CommonName: "Test Issuer"}}, + Signer: ed25519.PrivateKey{}, + } + type args struct { + ctx context.Context + opts apiv1.Options + } + tests := []struct { + name string + args args + want CertificateAuthorityService + wantErr bool + }{ + {"ok default", args{context.Background(), apiv1.Options{ + Issuer: &x509.Certificate{Subject: pkix.Name{CommonName: "Test Issuer"}}, + Signer: ed25519.PrivateKey{}, + }}, expected, false}, + {"ok softcas", args{context.Background(), apiv1.Options{ + Type: "softcas", + Issuer: &x509.Certificate{Subject: pkix.Name{CommonName: "Test Issuer"}}, + Signer: ed25519.PrivateKey{}, + }}, expected, false}, + {"ok SoftCAS", args{context.Background(), apiv1.Options{ + Type: "SoftCAS", + Issuer: &x509.Certificate{Subject: pkix.Name{CommonName: "Test Issuer"}}, + Signer: ed25519.PrivateKey{}, + }}, expected, false}, + {"fail empty", args{context.Background(), apiv1.Options{}}, (*softcas.SoftCAS)(nil), true}, + {"fail type", args{context.Background(), apiv1.Options{Type: "FailCAS"}}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(tt.args.ctx, tt.args.opts) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %#v, want %v", got, tt.want) + } + }) + } +}