Add tests for LoadProvisionerByCertificate.
This commit is contained in:
parent
e53bd64861
commit
1d1e095447
3 changed files with 185 additions and 24 deletions
|
@ -45,41 +45,43 @@ func (a *Authority) GetProvisioners(cursor string, limit int) (provisioner.List,
|
|||
// LoadProvisionerByCertificate returns an interface to the provisioner that
|
||||
// provisioned the certificate.
|
||||
func (a *Authority) LoadProvisionerByCertificate(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||
// Default implementation looks at the provisioner extension.
|
||||
loadProvisioner := func() (provisioner.Interface, error) {
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
if p, err := a.unsafeLoadProvisionerFromDatabase(crt); err == nil {
|
||||
return p, nil
|
||||
}
|
||||
return a.unsafeLoadProvisionerFromExtension(crt)
|
||||
}
|
||||
|
||||
func (a *Authority) unsafeLoadProvisionerFromExtension(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.LoadByCertificate(crt)
|
||||
if !ok {
|
||||
if !ok || p.GetType() == 0 {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (a *Authority) unsafeLoadProvisionerFromDatabase(crt *x509.Certificate) (provisioner.Interface, error) {
|
||||
// certificateDataGetter is an interface that can be use to retrieve the
|
||||
// provisioner from a db or a linked ca.
|
||||
type certificateDataGetter interface {
|
||||
GetCertificateData(string) (*db.CertificateData, error)
|
||||
}
|
||||
var cdg certificateDataGetter
|
||||
if getter, ok := a.adminDB.(certificateDataGetter); ok {
|
||||
cdg = getter
|
||||
} else if getter, ok := a.db.(certificateDataGetter); ok {
|
||||
cdg = getter
|
||||
}
|
||||
if cdg != nil {
|
||||
if data, err := cdg.GetCertificateData(crt.SerialNumber.String()); err == nil && data.Provisioner != nil {
|
||||
loadProvisioner = func() (provisioner.Interface, error) {
|
||||
p, ok := a.provisioners.Load(data.Provisioner.ID)
|
||||
if !ok {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
||||
|
||||
var err error
|
||||
var data *db.CertificateData
|
||||
|
||||
if cdg, ok := a.adminDB.(certificateDataGetter); ok {
|
||||
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
|
||||
} else if cdg, ok := a.db.(certificateDataGetter); ok {
|
||||
data, err = cdg.GetCertificateData(crt.SerialNumber.String())
|
||||
}
|
||||
if err == nil && data != nil && data.Provisioner != nil {
|
||||
if p, ok := a.provisioners.Load(data.Provisioner.ID); ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.adminMutex.RLock()
|
||||
defer a.adminMutex.RUnlock()
|
||||
return loadProvisioner()
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "unable to load provisioner from certificate")
|
||||
}
|
||||
|
||||
// LoadProvisionerByToken returns an interface to the provisioner that
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/api/render"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
)
|
||||
|
||||
func TestGetEncryptedKey(t *testing.T) {
|
||||
|
@ -67,6 +75,15 @@ func TestGetEncryptedKey(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type mockAdminDB struct {
|
||||
admin.MockDB
|
||||
MGetCertificateData func(string) (*db.CertificateData, error)
|
||||
}
|
||||
|
||||
func (c *mockAdminDB) GetCertificateData(sn string) (*db.CertificateData, error) {
|
||||
return c.MGetCertificateData(sn)
|
||||
}
|
||||
|
||||
func TestGetProvisioners(t *testing.T) {
|
||||
type gp struct {
|
||||
a *Authority
|
||||
|
@ -104,3 +121,133 @@ func TestGetProvisioners(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_LoadProvisionerByCertificate(t *testing.T) {
|
||||
_, priv, err := keyutil.GenerateDefaultKeyPair()
|
||||
assert.FatalError(t, err)
|
||||
csr := getCSR(t, priv)
|
||||
|
||||
sign := func(a *Authority, extraOpts ...provisioner.SignOption) *x509.Certificate {
|
||||
key, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateToken("smallstep test", "step-cli", testAudiences.Sign[0], []string{"test.smallstep.com"}, time.Now(), key)
|
||||
assert.FatalError(t, err)
|
||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
|
||||
opts, err := a.Authorize(ctx, token)
|
||||
assert.FatalError(t, err)
|
||||
opts = append(opts, extraOpts...)
|
||||
certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
|
||||
assert.FatalError(t, err)
|
||||
return certs[0]
|
||||
}
|
||||
getProvisioner := func(a *Authority, name string) provisioner.Interface {
|
||||
p, ok := a.provisioners.LoadByName(name)
|
||||
if !ok {
|
||||
t.Fatalf("provisioner %s does not exists", name)
|
||||
}
|
||||
return p
|
||||
}
|
||||
removeExtension := provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
for i, ext := range cert.ExtraExtensions {
|
||||
if ext.Id.Equal(provisioner.StepOIDProvisioner) {
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions[:i], cert.ExtraExtensions[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
a0 := testAuthority(t)
|
||||
|
||||
a1 := testAuthority(t)
|
||||
a1.db = &db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
p, err := a1.LoadProvisionerByName("dev")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{
|
||||
ID: p.GetID(),
|
||||
Name: p.GetName(),
|
||||
Type: p.GetType().String(),
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a2 := testAuthority(t)
|
||||
a2.adminDB = &mockAdminDB{
|
||||
MGetCertificateData: (func(s string) (*db.CertificateData, error) {
|
||||
p, err := a2.LoadProvisionerByName("dev")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{
|
||||
ID: p.GetID(),
|
||||
Name: p.GetName(),
|
||||
Type: p.GetType().String(),
|
||||
},
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
a3 := testAuthority(t)
|
||||
a3.db = &db.MockAuthDB{
|
||||
MUseToken: func(id, tok string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{
|
||||
ID: "foo", Name: "foo", Type: "foo",
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
a4 := testAuthority(t)
|
||||
a4.adminDB = &mockAdminDB{
|
||||
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
|
||||
return &db.CertificateData{
|
||||
Provisioner: &db.ProvisionerData{
|
||||
ID: "foo", Name: "foo", Type: "foo",
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
crt *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
authority *Authority
|
||||
args args
|
||||
want provisioner.Interface
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok from certificate", a0, args{sign(a0)}, getProvisioner(a0, "step-cli"), false},
|
||||
{"ok from db", a1, args{sign(a1)}, getProvisioner(a1, "dev"), false},
|
||||
{"ok from admindb", a2, args{sign(a2)}, getProvisioner(a2, "dev"), false},
|
||||
{"fail from certificate", a0, args{sign(a0, removeExtension)}, nil, true},
|
||||
{"fail from db", a3, args{sign(a3, removeExtension)}, nil, true},
|
||||
{"fail from admindb", a4, args{sign(a4, removeExtension)}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.authority.LoadProvisionerByCertificate(tt.args.crt)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.LoadProvisionerByCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.LoadProvisionerByCertificate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
12
db/db.go
12
db/db.go
|
@ -359,6 +359,7 @@ type MockAuthDB struct {
|
|||
MRevoke func(rci *RevokedCertificateInfo) error
|
||||
MRevokeSSH func(rci *RevokedCertificateInfo) error
|
||||
MGetCertificate func(serialNumber string) (*x509.Certificate, error)
|
||||
MGetCertificateData func(serialNumber string) (*CertificateData, error)
|
||||
MStoreCertificate func(crt *x509.Certificate) error
|
||||
MUseToken func(id, tok string) (bool, error)
|
||||
MIsSSHHost func(principal string) (bool, error)
|
||||
|
@ -418,6 +419,17 @@ func (m *MockAuthDB) GetCertificate(serialNumber string) (*x509.Certificate, err
|
|||
return m.Ret1.(*x509.Certificate), m.Err
|
||||
}
|
||||
|
||||
// GetCertificateData mock.
|
||||
func (m *MockAuthDB) GetCertificateData(serialNumber string) (*CertificateData, error) {
|
||||
if m.MGetCertificateData != nil {
|
||||
return m.MGetCertificateData(serialNumber)
|
||||
}
|
||||
if cd, ok := m.Ret1.(*CertificateData); ok {
|
||||
return cd, m.Err
|
||||
}
|
||||
return nil, m.Err
|
||||
}
|
||||
|
||||
// StoreCertificate mock.
|
||||
func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error {
|
||||
if m.MStoreCertificate != nil {
|
||||
|
|
Loading…
Reference in a new issue