Add tests for LoadProvisionerByCertificate.
This commit is contained in:
parent
e53bd64861
commit
1d1e095447
3 changed files with 185 additions and 24 deletions
|
@ -45,41 +45,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
|
||||||
// LoadProvisionerByCertificate returns an interface to the provisioner that
|
// LoadProvisionerByCertificate returns an interface to the provisioner that
|
||||||
// provisioned the certificate.
|
// provisioned the certificate.
|
||||||
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
|
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||||
// Default implementation looks at the provisioner extension.
|
a.adminMutex.RLock()
|
||||||
loadProvisioner := func() (provisioner.Interface, error) {
|
defer a.adminMutex.RUnlock()
|
||||||
|
if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return a.unsafeLoadProvisionerFromExtension(crt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||||
if !ok {
|
if !ok || p.GetType() == 0 {
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||||
// certificateDataGetter is an interface that can be use to retrieve the
|
// certificateDataGetter is an interface that can be use to retrieve the
|
||||||
// provisioner from a db or a linked ca.
|
// provisioner from a db or a linked ca.
|
||||||
type certificateDataGetter interface {
|
type certificateDataGetter interface {
|
||||||
GetCertificateData(string) (*db.CertificateData, error)
|
GetCertificateData(string) (*db.CertificateData, error)
|
||||||
}
|
}
|
||||||
var cdg certificateDataGetter
|
|
||||||
if getter, ok := a.adminDB.(certificateDataGetter); ok {
|
var err error
|
||||||
cdg = getter
|
var data *db.CertificateData
|
||||||
} else if getter, ok := a.db.(certificateDataGetter); ok {
|
|
||||||
cdg = getter
|
if cdg, ok := a.adminDB.(certificateDataGetter); ok {
|
||||||
}
|
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
|
||||||
if cdg != nil {
|
} else if cdg, ok := a.db.(certificateDataGetter); ok {
|
||||||
if data, err := cdg.GetCertificateData(crt.SerialNumber.String()); err == nil && data.Provisioner != nil {
|
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
|
||||||
loadProvisioner = func() (provisioner.Interface, error) {
|
|
||||||
p, ok := a.provisioners.Load(data.Provisioner.ID)
|
|
||||||
if !ok {
|
|
||||||
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
|
||||||
}
|
}
|
||||||
|
if err == nil && data != nil && data.Provisioner != nil {
|
||||||
|
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
||||||
|
|
||||||
a.adminMutex.RLock()
|
|
||||||
defer a.adminMutex.RUnlock()
|
|
||||||
return loadProvisioner()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadProvisionerByToken returns an interface to the provisioner that
|
// LoadProvisionerByToken returns an interface to the provisioner that
|
||||||
|
|
|
@ -1,13 +1,21 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/api/render"
|
"github.com/smallstep/certificates/api/render"
|
||||||
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/certificates/db"
|
||||||
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/crypto/keyutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetEncryptedKey(t *testing.T) {
|
func TestGetEncryptedKey(t *testing.T) {
|
||||||
|
@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockAdminDB struct {
|
||||||
|
admin.MockDB
|
||||||
|
MGetCertificateData func(string) (*db.CertificateData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) {
|
||||||
|
return c.MGetCertificateData(sn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetProvisioners(t *testing.T) {
|
func TestGetProvisioners(t *testing.T) {
|
||||||
type gp struct {
|
type gp struct {
|
||||||
a *Authority
|
a *Authority
|
||||||
|
@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
|
||||||
|
_, priv, err := keyutil.GenerateDefaultKeyPair()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
csr := getCSR(t, priv)
|
||||||
|
|
||||||
|
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
|
||||||
|
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||||
|
opts, err := a.Authorize(ctx, token)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
opts = append(opts, extraOpts...)
|
||||||
|
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return certs[0]
|
||||||
|
}
|
||||||
|
getProvisioner := func(a *Authority, name string) provisioner.Interface {
|
||||||
|
p, ok := a.provisioners.LoadByName(name)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("provisioner %s does not exists", name)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||||
|
for i, ext := range cert.ExtraExtensions {
|
||||||
|
if ext.Id.Equal(provisioner.StepOIDProvisioner) {
|
||||||
|
cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
a0 := testAuthority(t)
|
||||||
|
|
||||||
|
a1 := testAuthority(t)
|
||||||
|
a1.db = &db.MockAuthDB{
|
||||||
|
MUseToken: func(id, tok string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||||
|
p, err := a1.LoadProvisionerByName("dev")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return &db.CertificateData{
|
||||||
|
Provisioner: &db.ProvisionerData{
|
||||||
|
ID: p.GetID(),
|
||||||
|
Name: p.GetName(),
|
||||||
|
Type: p.GetType().String(),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a2 := testAuthority(t)
|
||||||
|
a2.adminDB = &mockAdminDB{
|
||||||
|
MGetCertificateData: (func(s string) (*db.CertificateData, error) {
|
||||||
|
p, err := a2.LoadProvisionerByName("dev")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return &db.CertificateData{
|
||||||
|
Provisioner: &db.ProvisionerData{
|
||||||
|
ID: p.GetID(),
|
||||||
|
Name: p.GetName(),
|
||||||
|
Type: p.GetType().String(),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
a3 := testAuthority(t)
|
||||||
|
a3.db = &db.MockAuthDB{
|
||||||
|
MUseToken: func(id, tok string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||||
|
return &db.CertificateData{
|
||||||
|
Provisioner: &db.ProvisionerData{
|
||||||
|
ID: "foo", Name: "foo", Type: "foo",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
a4 := testAuthority(t)
|
||||||
|
a4.adminDB = &mockAdminDB{
|
||||||
|
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||||
|
return &db.CertificateData{
|
||||||
|
Provisioner: &db.ProvisionerData{
|
||||||
|
ID: "foo", Name: "foo", Type: "foo",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
crt *x509.Certificate
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authority *Authority
|
||||||
|
args args
|
||||||
|
want provisioner.Interface
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false},
|
||||||
|
{"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false},
|
||||||
|
{"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false},
|
||||||
|
{"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true},
|
||||||
|
{"fail from db", a3, args{sign(a3, removeExtension)}, nil, true},
|
||||||
|
{"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
12
db/db.go
12
db/db.go
|
@ -359,6 +359,7 @@ type MockAuthDB struct {
|
||||||
MRevoke func(rci *RevokedCertificateInfo) error
|
MRevoke func(rci *RevokedCertificateInfo) error
|
||||||
MRevokeSSH func(rci *RevokedCertificateInfo) error
|
MRevokeSSH func(rci *RevokedCertificateInfo) error
|
||||||
MGetCertificate func(serialNumber string) (*x509.Certificate, error)
|
MGetCertificate func(serialNumber string) (*x509.Certificate, error)
|
||||||
|
MGetCertificateData func(serialNumber string) (*CertificateData, error)
|
||||||
MStoreCertificate func(crt *x509.Certificate) error
|
MStoreCertificate func(crt *x509.Certificate) error
|
||||||
MUseToken func(id, tok string) (bool, error)
|
MUseToken func(id, tok string) (bool, error)
|
||||||
MIsSSHHost func(principal string) (bool, error)
|
MIsSSHHost func(principal string) (bool, error)
|
||||||
|
@ -418,6 +419,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err
|
||||||
return m.Ret1.(*x509.Certificate), m.Err
|
return m.Ret1.(*x509.Certificate), m.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateData mock.
|
||||||
|
func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) {
|
||||||
|
if m.MGetCertificateData != nil {
|
||||||
|
return m.MGetCertificateData(serialNumber)
|
||||||
|
}
|
||||||
|
if cd, ok := m.Ret1.(*CertificateData); ok {
|
||||||
|
return cd, m.Err
|
||||||
|
}
|
||||||
|
return nil, m.Err
|
||||||
|
}
|
||||||
|
|
||||||
// StoreCertificate mock.
|
// StoreCertificate mock.
|
||||||
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
|
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
|
||||||
if m.MStoreCertificate != nil {
|
if m.MStoreCertificate != nil {
|
||||||
|
|
Loading…
Reference in a new issue