From 1d1e09544785d07b9ff4abe95814b00a2a91786d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 8 Apr 2022 13:06:29 -0700 Subject: [PATCH] Add tests for LoadProvisionerByCertificate. --- authority/provisioners.go | 50 +++++------ authority/provisioners_test.go | 147 +++++++++++++++++++++++++++++++++ db/db.go | 12 +++ 3 files changed, 185 insertions(+), 24 deletions(-) diff --git a/authority/provisioners.go b/authority/provisioners.go index 0dab24b9..7ff080ed 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -45,41 +45,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List, // LoadProvisionerByCertificate returns an interface to the provisioner that // provisioned the certificate. func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { - // Default implementation looks at the provisioner extension. - loadProvisioner := func() (provisioner.Interface, error) { - p, ok := a.provisioners.LoadByCertificate(crt) - if !ok { - return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") - } + a.adminMutex.RLock() + defer a.adminMutex.RUnlock() + if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil { return p, nil } + return a.unsafeLoadProvisionerFromExtension(crt) +} +func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) { + p, ok := a.provisioners.LoadByCertificate(crt) + if !ok || p.GetType() == 0 { + return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") + } + return p, nil +} + +func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) { // certificateDataGetter is an interface that can be use to retrieve the // provisioner from a db or a linked ca. type certificateDataGetter interface { GetCertificateData(string) (*db.CertificateData, error) } - var cdg certificateDataGetter - if getter, ok := a.adminDB.(certificateDataGetter); ok { - cdg = getter - } else if getter, ok := a.db.(certificateDataGetter); ok { - cdg = getter + + var err error + var data *db.CertificateData + + if cdg, ok := a.adminDB.(certificateDataGetter); ok { + data, err = cdg.GetCertificateData(crt.SerialNumber.String()) + } else if cdg, ok := a.db.(certificateDataGetter); ok { + data, err = cdg.GetCertificateData(crt.SerialNumber.String()) } - if cdg != nil { - if data, err := cdg.GetCertificateData(crt.SerialNumber.String()); err == nil && data.Provisioner != nil { - loadProvisioner = func() (provisioner.Interface, error) { - p, ok := a.provisioners.Load(data.Provisioner.ID) - if !ok { - return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") - } - return p, nil - } + if err == nil && data != nil && data.Provisioner != nil { + if p, ok := a.provisioners.Load(data.Provisioner.ID); ok { + return p, nil } } - - a.adminMutex.RLock() - defer a.adminMutex.RUnlock() - return loadProvisioner() + return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate") } // LoadProvisionerByToken returns an interface to the provisioner that diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index 81dc38bf..56cd16b1 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -1,13 +1,21 @@ package authority import ( + "context" + "crypto/x509" "errors" "net/http" + "reflect" "testing" + "time" "github.com/smallstep/assert" "github.com/smallstep/certificates/api/render" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" ) func TestGetEncryptedKey(t *testing.T) { @@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) { } } +type mockAdminDB struct { + admin.MockDB + MGetCertificateData func(string) (*db.CertificateData, error) +} + +func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) { + return c.MGetCertificateData(sn) +} + func TestGetProvisioners(t *testing.T) { type gp struct { a *Authority @@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) { }) } } + +func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { + _, priv, err := keyutil.GenerateDefaultKeyPair() + assert.FatalError(t, err) + csr := getCSR(t, priv) + + sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate { + key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key) + assert.FatalError(t, err) + ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) + opts, err := a.Authorize(ctx, token) + assert.FatalError(t, err) + opts = append(opts, extraOpts...) + certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + assert.FatalError(t, err) + return certs[0] + } + getProvisioner := func(a *Authority, name string) provisioner.Interface { + p, ok := a.provisioners.LoadByName(name) + if !ok { + t.Fatalf("provisioner %s does not exists", name) + } + return p + } + removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { + for i, ext := range cert.ExtraExtensions { + if ext.Id.Equal(provisioner.StepOIDProvisioner) { + cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...) + break + } + } + return nil + }) + + a0 := testAuthority(t) + + a1 := testAuthority(t) + a1.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + p, err := a1.LoadProvisionerByName("dev") + if err != nil { + t.Fatal(err) + } + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + }, + }, nil + }, + } + + a2 := testAuthority(t) + a2.adminDB = &mockAdminDB{ + MGetCertificateData: (func(s string) (*db.CertificateData, error) { + p, err := a2.LoadProvisionerByName("dev") + if err != nil { + t.Fatal(err) + } + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: p.GetID(), + Name: p.GetName(), + Type: p.GetType().String(), + }, + }, nil + }), + } + + a3 := testAuthority(t) + a3.db = &db.MockAuthDB{ + MUseToken: func(id, tok string) (bool, error) { + return true, nil + }, + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: "foo", Name: "foo", Type: "foo", + }, + }, nil + }, + } + + a4 := testAuthority(t) + a4.adminDB = &mockAdminDB{ + MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) { + return &db.CertificateData{ + Provisioner: &db.ProvisionerData{ + ID: "foo", Name: "foo", Type: "foo", + }, + }, nil + }, + } + + type args struct { + crt *x509.Certificate + } + tests := []struct { + name string + authority *Authority + args args + want provisioner.Interface + wantErr bool + }{ + {"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false}, + {"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false}, + {"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false}, + {"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true}, + {"fail from db", a3, args{sign(a3, removeExtension)}, nil, true}, + {"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt) + if (err != nil) != tt.wantErr { + t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/db/db.go b/db/db.go index a3ebb19f..602e3623 100644 --- a/db/db.go +++ b/db/db.go @@ -359,6 +359,7 @@ type MockAuthDB struct { MRevoke func(rci *RevokedCertificateInfo) error MRevokeSSH func(rci *RevokedCertificateInfo) error MGetCertificate func(serialNumber string) (*x509.Certificate, error) + MGetCertificateData func(serialNumber string) (*CertificateData, error) MStoreCertificate func(crt *x509.Certificate) error MUseToken func(id, tok string) (bool, error) MIsSSHHost func(principal string) (bool, error) @@ -418,6 +419,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err return m.Ret1.(*x509.Certificate), m.Err } +// GetCertificateData mock. +func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) { + if m.MGetCertificateData != nil { + return m.MGetCertificateData(serialNumber) + } + if cd, ok := m.Ret1.(*CertificateData); ok { + return cd, m.Err + } + return nil, m.Err +} + // StoreCertificate mock. func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { if m.MStoreCertificate != nil {