Add tests for LoadProvisionerByCertificate.

This commit is contained in:
Mariano Cano 2022-04-08 13:06:29 -07:00
parent e53bd64861
commit 1d1e095447
3 changed files with 185 additions and 24 deletions

View file

@ -45,41 +45,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
// LoadProvisionerByCertificate returns an interface to the provisioner that // LoadProvisionerByCertificate returns an interface to the provisioner that
// provisioned the certificate. // provisioned the certificate.
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) { func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
// Default implementation looks at the provisioner extension. a.adminMutex.RLock()
loadProvisioner := func() (provisioner.Interface, error) { defer a.adminMutex.RUnlock()
p, ok := a.provisioners.LoadByCertificate(crt) if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil {
if !ok {
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
}
return p, nil return p, nil
} }
return a.unsafeLoadProvisionerFromExtension(crt)
}
func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) {
p, ok := a.provisioners.LoadByCertificate(crt)
if !ok || p.GetType() == 0 {
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
}
return p, nil
}
func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) {
// certificateDataGetter is an interface that can be use to retrieve the // certificateDataGetter is an interface that can be use to retrieve the
// provisioner from a db or a linked ca. // provisioner from a db or a linked ca.
type certificateDataGetter interface { type certificateDataGetter interface {
GetCertificateData(string) (*db.CertificateData, error) GetCertificateData(string) (*db.CertificateData, error)
} }
var cdg certificateDataGetter
if getter, ok := a.adminDB.(certificateDataGetter); ok { var err error
cdg = getter var data *db.CertificateData
} else if getter, ok := a.db.(certificateDataGetter); ok {
cdg = getter if cdg, ok := a.adminDB.(certificateDataGetter); ok {
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
} else if cdg, ok := a.db.(certificateDataGetter); ok {
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
} }
if cdg != nil { if err == nil && data != nil && data.Provisioner != nil {
if data, err := cdg.GetCertificateData(crt.SerialNumber.String()); err == nil && data.Provisioner != nil { if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
loadProvisioner = func() (provisioner.Interface, error) { return p, nil
p, ok := a.provisioners.Load(data.Provisioner.ID)
if !ok {
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
}
return p, nil
}
} }
} }
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
a.adminMutex.RLock()
defer a.adminMutex.RUnlock()
return loadProvisioner()
} }
// LoadProvisionerByToken returns an interface to the provisioner that // LoadProvisionerByToken returns an interface to the provisioner that

View file

@ -1,13 +1,21 @@
package authority package authority
import ( import (
"context"
"crypto/x509"
"errors" "errors"
"net/http" "net/http"
"reflect"
"testing" "testing"
"time"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
) )
func TestGetEncryptedKey(t *testing.T) { func TestGetEncryptedKey(t *testing.T) {
@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) {
} }
} }
type mockAdminDB struct {
admin.MockDB
MGetCertificateData func(string) (*db.CertificateData, error)
}
func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) {
return c.MGetCertificateData(sn)
}
func TestGetProvisioners(t *testing.T) { func TestGetProvisioners(t *testing.T) {
type gp struct { type gp struct {
a *Authority a *Authority
@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) {
}) })
} }
} }
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
_, priv, err := keyutil.GenerateDefaultKeyPair()
assert.FatalError(t, err)
csr := getCSR(t, priv)
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
assert.FatalError(t, err)
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
assert.FatalError(t, err)
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
opts, err := a.Authorize(ctx, token)
assert.FatalError(t, err)
opts = append(opts, extraOpts...)
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
assert.FatalError(t, err)
return certs[0]
}
getProvisioner := func(a *Authority, name string) provisioner.Interface {
p, ok := a.provisioners.LoadByName(name)
if !ok {
t.Fatalf("provisioner %s does not exists", name)
}
return p
}
removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
for i, ext := range cert.ExtraExtensions {
if ext.Id.Equal(provisioner.StepOIDProvisioner) {
cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...)
break
}
}
return nil
})
a0 := testAuthority(t)
a1 := testAuthority(t)
a1.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, err := a1.LoadProvisionerByName("dev")
if err != nil {
t.Fatal(err)
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
},
}, nil
},
}
a2 := testAuthority(t)
a2.adminDB = &mockAdminDB{
MGetCertificateData: (func(s string) (*db.CertificateData, error) {
p, err := a2.LoadProvisionerByName("dev")
if err != nil {
t.Fatal(err)
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
},
}, nil
}),
}
a3 := testAuthority(t)
a3.db = &db.MockAuthDB{
MUseToken: func(id, tok string) (bool, error) {
return true, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: "foo", Name: "foo", Type: "foo",
},
}, nil
},
}
a4 := testAuthority(t)
a4.adminDB = &mockAdminDB{
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: "foo", Name: "foo", Type: "foo",
},
}, nil
},
}
type args struct {
crt *x509.Certificate
}
tests := []struct {
name string
authority *Authority
args args
want provisioner.Interface
wantErr bool
}{
{"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false},
{"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false},
{"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false},
{"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true},
{"fail from db", a3, args{sign(a3, removeExtension)}, nil, true},
{"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -359,6 +359,7 @@ type MockAuthDB struct {
MRevoke func(rci *RevokedCertificateInfo) error MRevoke func(rci *RevokedCertificateInfo) error
MRevokeSSH func(rci *RevokedCertificateInfo) error MRevokeSSH func(rci *RevokedCertificateInfo) error
MGetCertificate func(serialNumber string) (*x509.Certificate, error) MGetCertificate func(serialNumber string) (*x509.Certificate, error)
MGetCertificateData func(serialNumber string) (*CertificateData, error)
MStoreCertificate func(crt *x509.Certificate) error MStoreCertificate func(crt *x509.Certificate) error
MUseToken func(id, tok string) (bool, error) MUseToken func(id, tok string) (bool, error)
MIsSSHHost func(principal string) (bool, error) MIsSSHHost func(principal string) (bool, error)
@ -418,6 +419,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err
return m.Ret1.(*x509.Certificate), m.Err return m.Ret1.(*x509.Certificate), m.Err
} }
// GetCertificateData mock.
func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) {
if m.MGetCertificateData != nil {
return m.MGetCertificateData(serialNumber)
}
if cd, ok := m.Ret1.(*CertificateData); ok {
return cd, m.Err
}
return nil, m.Err
}
// StoreCertificate mock. // StoreCertificate mock.
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
if m.MStoreCertificate != nil { if m.MStoreCertificate != nil {