forked from TrueCloudLab/certificates
Add tests for ACME EAB Admin
Refactored some of the existing bits for testing the Authority API by creation of a new LinkedAuthority interface and changing visibility of the MockAuthority to be usable by other packages. At this time, not all of the functions of MockAuthority it usable yet. Will refactor when needed or requested.
This commit is contained in:
parent
9885d42711
commit
2215a05c28
9 changed files with 826 additions and 251 deletions
|
@ -15,9 +15,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
@ -51,28 +53,76 @@ func TestHandler_GetNonce(t *testing.T) {
|
|||
|
||||
func TestHandler_GetDirectory(t *testing.T) {
|
||||
linker := NewLinker("ca.smallstep.com", "acme")
|
||||
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
|
||||
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
|
||||
}
|
||||
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
statusCode int
|
||||
dir Directory
|
||||
err *acme.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"ok": func(t *testing.T) test {
|
||||
"fail/no-provisioner": func(t *testing.T) test {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
||||
}
|
||||
},
|
||||
"fail/different-provisioner": func(t *testing.T) test {
|
||||
prov := &provisioner.SCEP{
|
||||
Type: "SCEP",
|
||||
Name: "test@scep-<test>provisioner.com",
|
||||
}
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
return test{
|
||||
ctx: ctx,
|
||||
statusCode: 500,
|
||||
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
|
||||
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
dir: expDir,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
"ok/eab-required": func(t *testing.T) test {
|
||||
prov := newACMEProv(t)
|
||||
prov.RequireEAB = true
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
expDir := Directory{
|
||||
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
|
||||
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
|
||||
NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
|
||||
RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
|
||||
KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
|
||||
Meta: Meta{
|
||||
ExternalAccountRequired: true,
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
dir: expDir,
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
|
@ -82,7 +132,7 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{linker: linker}
|
||||
req := httptest.NewRequest("GET", "/foo/bar", nil)
|
||||
req = req.WithContext(ctx)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.GetDirectory(w, req)
|
||||
res := w.Result()
|
||||
|
@ -105,7 +155,9 @@ func TestHandler_GetDirectory(t *testing.T) {
|
|||
} else {
|
||||
var dir Directory
|
||||
json.Unmarshal(bytes.TrimSpace(body), &dir)
|
||||
assert.Equals(t, dir, expDir)
|
||||
if !cmp.Equal(tc.dir, dir) {
|
||||
t.Errorf("GetDirectory() diff =\n%s", cmp.Diff(tc.dir, dir))
|
||||
}
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||
}
|
||||
})
|
||||
|
|
311
api/api.go
311
api/api.go
|
@ -25,6 +25,9 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/linkedca"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Authority is the interface implemented by a CA authority.
|
||||
|
@ -48,6 +51,21 @@ type Authority interface {
|
|||
Version() authority.Version
|
||||
}
|
||||
|
||||
type LinkedAuthority interface { // TODO(hs): name is not great; it is related to LinkedCA, though
|
||||
Authority
|
||||
IsAdminAPIEnabled() bool
|
||||
LoadAdminByID(id string) (*linkedca.Admin, bool)
|
||||
GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error)
|
||||
StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error
|
||||
UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error)
|
||||
RemoveAdmin(ctx context.Context, id string) error
|
||||
AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error)
|
||||
StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
LoadProvisionerByID(id string) (provisioner.Interface, error)
|
||||
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
RemoveProvisioner(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// TimeDuration is an alias of provisioner.TimeDuration
|
||||
type TimeDuration = provisioner.TimeDuration
|
||||
|
||||
|
@ -457,3 +475,296 @@ func fmtPublicKey(cert *x509.Certificate) string {
|
|||
}
|
||||
return fmt.Sprintf("%s %s", cert.PublicKeyAlgorithm, params)
|
||||
}
|
||||
|
||||
type MockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
MockLoadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
revoke func(context.Context, *authority.RevokeOptions) error
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error)
|
||||
getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
|
||||
getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||
version func() authority.Version
|
||||
|
||||
MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two
|
||||
MockErr error
|
||||
MockIsAdminAPIEnabled func() bool
|
||||
MockLoadAdminByID func(id string) (*linkedca.Admin, bool)
|
||||
MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error)
|
||||
MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error
|
||||
MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error)
|
||||
MockRemoveAdmin func(ctx context.Context, id string) error
|
||||
MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error)
|
||||
MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error
|
||||
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
MockRemoveProvisioner func(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *MockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
func (m *MockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetTLSOptions() *authority.TLSOptions {
|
||||
if m.getTLSOptions != nil {
|
||||
return m.getTLSOptions()
|
||||
}
|
||||
return m.ret1.(*authority.TLSOptions)
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Root(shasum string) (*x509.Certificate, error) {
|
||||
if m.root != nil {
|
||||
return m.root(shasum)
|
||||
}
|
||||
return m.ret1.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
if m.sign != nil {
|
||||
return m.sign(cr, opts, signOpts...)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
if m.renew != nil {
|
||||
return m.renew(cert)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.rekey != nil {
|
||||
return m.rekey(oldcert, pk)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
|
||||
if m.getProvisioners != nil {
|
||||
return m.getProvisioners(nextCursor, limit)
|
||||
}
|
||||
return m.ret1.(provisioner.List), m.ret2.(string), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByCertificate != nil {
|
||||
return m.loadProvisionerByCertificate(cert)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.MockLoadProvisionerByName != nil {
|
||||
return m.MockLoadProvisionerByName(name)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
if m.revoke != nil {
|
||||
return m.revoke(ctx, opts)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||
if m.getEncryptedKey != nil {
|
||||
return m.getEncryptedKey(kid)
|
||||
}
|
||||
return m.ret1.(string), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetRoots() ([]*x509.Certificate, error) {
|
||||
if m.getRoots != nil {
|
||||
return m.getRoots()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
||||
if m.getFederation != nil {
|
||||
return m.getFederation()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.signSSH != nil {
|
||||
return m.signSSH(ctx, key, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.signSSHAddUser != nil {
|
||||
return m.signSSHAddUser(ctx, key, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.renewSSH != nil {
|
||||
return m.renewSSH(ctx, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.rekeySSH != nil {
|
||||
return m.rekeySSH(ctx, cert, key, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) {
|
||||
if m.getSSHHosts != nil {
|
||||
return m.getSSHHosts(ctx, cert)
|
||||
}
|
||||
return m.ret1.([]authority.Host), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHRoots != nil {
|
||||
return m.getSSHRoots(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHFederation != nil {
|
||||
return m.getSSHFederation(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
if m.getSSHConfig != nil {
|
||||
return m.getSSHConfig(ctx, typ, data)
|
||||
}
|
||||
return m.ret1.([]templates.Output), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
|
||||
if m.checkSSHHost != nil {
|
||||
return m.checkSSHHost(ctx, principal, token)
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
if m.getSSHBastion != nil {
|
||||
return m.getSSHBastion(ctx, user, hostname)
|
||||
}
|
||||
return m.ret1.(*authority.Bastion), m.err
|
||||
}
|
||||
|
||||
func (m *MockAuthority) Version() authority.Version {
|
||||
if m.version != nil {
|
||||
return m.version()
|
||||
}
|
||||
return m.ret1.(authority.Version)
|
||||
}
|
||||
|
||||
func (m *MockAuthority) IsAdminAPIEnabled() bool {
|
||||
if m.MockIsAdminAPIEnabled != nil {
|
||||
return m.MockIsAdminAPIEnabled()
|
||||
}
|
||||
return m.MockRet1.(bool)
|
||||
}
|
||||
|
||||
func (m *MockAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||
if m.MockLoadAdminByID != nil {
|
||||
return m.MockLoadAdminByID(id)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool)
|
||||
}
|
||||
|
||||
func (m *MockAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) {
|
||||
if m.MockGetAdmins != nil {
|
||||
return m.MockGetAdmins(cursor, limit)
|
||||
}
|
||||
return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
if m.MockStoreAdmin != nil {
|
||||
return m.MockStoreAdmin(ctx, adm, prov)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) {
|
||||
if m.MockUpdateAdmin != nil {
|
||||
return m.MockUpdateAdmin(ctx, id, nu)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) RemoveAdmin(ctx context.Context, id string) error {
|
||||
if m.MockRemoveAdmin != nil {
|
||||
return m.MockRemoveAdmin(ctx, id)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
if m.MockAuthorizeAdminToken != nil {
|
||||
return m.MockAuthorizeAdminToken(r, token)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Admin), m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
|
||||
if m.MockStoreProvisioner != nil {
|
||||
return m.MockStoreProvisioner(ctx, prov)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
|
||||
if m.MockLoadProvisionerByID != nil {
|
||||
return m.MockLoadProvisionerByID(id)
|
||||
}
|
||||
return m.MockRet1.(provisioner.Interface), m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error {
|
||||
if m.MockUpdateProvisioner != nil {
|
||||
return m.MockUpdateProvisioner(ctx, nu)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *MockAuthority) RemoveProvisioner(ctx context.Context, id string) error {
|
||||
if m.MockRemoveProvisioner != nil {
|
||||
return m.MockRemoveProvisioner(ctx, id)
|
||||
}
|
||||
return m.MockErr
|
||||
}
|
||||
|
|
230
api/api_test.go
230
api/api_test.go
|
@ -3,7 +3,6 @@ package api
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/dsa" //nolint
|
||||
"crypto/ecdsa"
|
||||
"crypto/ed25519"
|
||||
|
@ -32,7 +31,6 @@ import (
|
|||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
@ -551,208 +549,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) (
|
|||
return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
type mockAuthority struct {
|
||||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByName func(name string) (provisioner.Interface, error)
|
||||
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
|
||||
revoke func(context.Context, *authority.RevokeOptions) error
|
||||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error)
|
||||
getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
|
||||
getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||
version func() authority.Version
|
||||
}
|
||||
|
||||
// TODO: remove once Authorize is deprecated.
|
||||
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
|
||||
return m.AuthorizeSign(ott)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
|
||||
if m.authorizeSign != nil {
|
||||
return m.authorizeSign(ott)
|
||||
}
|
||||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
|
||||
if m.getTLSOptions != nil {
|
||||
return m.getTLSOptions()
|
||||
}
|
||||
return m.ret1.(*authority.TLSOptions)
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) {
|
||||
if m.root != nil {
|
||||
return m.root(shasum)
|
||||
}
|
||||
return m.ret1.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
if m.sign != nil {
|
||||
return m.sign(cr, opts, signOpts...)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
if m.renew != nil {
|
||||
return m.renew(cert)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
|
||||
if m.rekey != nil {
|
||||
return m.rekey(oldcert, pk)
|
||||
}
|
||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
|
||||
if m.getProvisioners != nil {
|
||||
return m.getProvisioners(nextCursor, limit)
|
||||
}
|
||||
return m.ret1.(provisioner.List), m.ret2.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByCertificate != nil {
|
||||
return m.loadProvisionerByCertificate(cert)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) {
|
||||
if m.loadProvisionerByName != nil {
|
||||
return m.loadProvisionerByName(name)
|
||||
}
|
||||
return m.ret1.(provisioner.Interface), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error {
|
||||
if m.revoke != nil {
|
||||
return m.revoke(ctx, opts)
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
|
||||
if m.getEncryptedKey != nil {
|
||||
return m.getEncryptedKey(kid)
|
||||
}
|
||||
return m.ret1.(string), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) {
|
||||
if m.getRoots != nil {
|
||||
return m.getRoots()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
||||
if m.getFederation != nil {
|
||||
return m.getFederation()
|
||||
}
|
||||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.signSSH != nil {
|
||||
return m.signSSH(ctx, key, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.signSSHAddUser != nil {
|
||||
return m.signSSHAddUser(ctx, key, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.renewSSH != nil {
|
||||
return m.renewSSH(ctx, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.rekeySSH != nil {
|
||||
return m.rekeySSH(ctx, cert, key, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) {
|
||||
if m.getSSHHosts != nil {
|
||||
return m.getSSHHosts(ctx, cert)
|
||||
}
|
||||
return m.ret1.([]authority.Host), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHRoots != nil {
|
||||
return m.getSSHRoots(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHFederation != nil {
|
||||
return m.getSSHFederation(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
if m.getSSHConfig != nil {
|
||||
return m.getSSHConfig(ctx, typ, data)
|
||||
}
|
||||
return m.ret1.([]templates.Output), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
|
||||
if m.checkSSHHost != nil {
|
||||
return m.checkSSHHost(ctx, principal, token)
|
||||
}
|
||||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
if m.getSSHBastion != nil {
|
||||
return m.getSSHBastion(ctx, user, hostname)
|
||||
}
|
||||
return m.ret1.(*authority.Bastion), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Version() authority.Version {
|
||||
if m.version != nil {
|
||||
return m.version()
|
||||
}
|
||||
return m.ret1.(authority.Version)
|
||||
}
|
||||
|
||||
func Test_caHandler_Route(t *testing.T) {
|
||||
type fields struct {
|
||||
Authority Authority
|
||||
|
@ -765,7 +561,7 @@ func Test_caHandler_Route(t *testing.T) {
|
|||
fields fields
|
||||
args args
|
||||
}{
|
||||
{"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}},
|
||||
{"ok", fields{&MockAuthority{}}, args{chi.NewRouter()}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -780,7 +576,7 @@ func Test_caHandler_Route(t *testing.T) {
|
|||
func Test_caHandler_Health(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||
w := httptest.NewRecorder()
|
||||
h := New(&mockAuthority{}).(*caHandler)
|
||||
h := New(&MockAuthority{}).(*caHandler)
|
||||
h.Health(w, req)
|
||||
|
||||
res := w.Result()
|
||||
|
@ -820,7 +616,7 @@ func Test_caHandler_Root(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||||
h := New(&MockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
||||
w := httptest.NewRecorder()
|
||||
h.Root(w, req)
|
||||
res := w.Result()
|
||||
|
@ -884,7 +680,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return tt.certAttrOpts, tt.autherr
|
||||
|
@ -938,7 +734,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
|
@ -999,7 +795,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
|
@ -1077,9 +873,9 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
args args
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200},
|
||||
{"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500},
|
||||
{"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400},
|
||||
{"ok", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200},
|
||||
{"fail", fields{&MockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500},
|
||||
{"limit fail", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400},
|
||||
}
|
||||
|
||||
expected, err := json.Marshal(pr)
|
||||
|
@ -1154,8 +950,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
|||
args args
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200},
|
||||
{"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404},
|
||||
{"ok", fields{&MockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200},
|
||||
{"fail", fields{&MockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404},
|
||||
}
|
||||
|
||||
expected := []byte(`{"key":"` + privKey + `"}`)
|
||||
|
@ -1214,7 +1010,7 @@ func Test_caHandler_Roots(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -1260,7 +1056,7 @@ func Test_caHandler_Federation(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||
req.TLS = tt.tls
|
||||
w := httptest.NewRecorder()
|
||||
|
|
|
@ -106,7 +106,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
return test{
|
||||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
auth: &mockAuthority{
|
||||
auth: &MockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
|
@ -150,7 +150,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
input: string(input),
|
||||
statusCode: http.StatusOK,
|
||||
tls: cs,
|
||||
auth: &mockAuthority{
|
||||
auth: &MockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
|
@ -185,7 +185,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
return test{
|
||||
input: string(input),
|
||||
statusCode: http.StatusInternalServerError,
|
||||
auth: &mockAuthority{
|
||||
auth: &MockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
|
@ -207,7 +207,7 @@ func Test_caHandler_Revoke(t *testing.T) {
|
|||
return test{
|
||||
input: string(input),
|
||||
statusCode: http.StatusForbidden,
|
||||
auth: &mockAuthority{
|
||||
auth: &MockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return nil, nil
|
||||
},
|
||||
|
|
|
@ -314,7 +314,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
|
@ -377,7 +377,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
|
@ -431,7 +431,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
|
@ -491,7 +491,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
|
@ -538,7 +538,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||
return tt.exists, tt.err
|
||||
},
|
||||
|
@ -589,7 +589,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
|
@ -644,7 +644,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
h := New(&MockAuthority{
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
|
|
|
@ -49,7 +49,7 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP {
|
|||
|
||||
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
|
||||
// provisioner is set to true and thus has EAB enabled.
|
||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, error) {
|
||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *admin.Error) {
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
|
|
416
authority/admin/api/acme_test.go
Normal file
416
authority/admin/api/acme_test.go
Normal file
|
@ -0,0 +1,416 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
ctx context.Context
|
||||
db admin.DB
|
||||
auth api.LinkedAuthority
|
||||
next nextHTTP
|
||||
err *admin.Error
|
||||
statusCode int
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("prov", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
err := admin.NewErrorISE("error loading provisioner provName: force")
|
||||
err.Message = "error loading provisioner provName: force"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
err: err,
|
||||
statusCode: 500,
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("prov", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
|
||||
err.Message = "ACME EAB not enabled for provisioner provName"
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
db: db,
|
||||
err: err,
|
||||
statusCode: 400,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
chiCtx := chi.NewRouteContext()
|
||||
chiCtx.URLParams.Add("prov", "provName")
|
||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
db: db,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(nil) // mock response with status 200
|
||||
},
|
||||
statusCode: 200,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
db: tc.db,
|
||||
auth: tc.auth,
|
||||
acmeDB: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.requireEABEnabled(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
err := admin.Error{}
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||
return
|
||||
}
|
||||
|
||||
// nothing to test when the requireEABEnabled middleware succeeds, currently
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
||||
type test struct {
|
||||
db admin.DB
|
||||
auth api.LinkedAuthority
|
||||
provisionerName string
|
||||
want bool
|
||||
err *admin.Error
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
db: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: nil,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
db: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"fail/details.GetACME": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "provName", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "provName",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: nil,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
auth: auth,
|
||||
db: db,
|
||||
provisionerName: "provName",
|
||||
want: false,
|
||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
||||
}
|
||||
},
|
||||
"ok/eab-disabled": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-disabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-disabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-disabled",
|
||||
want: false,
|
||||
}
|
||||
},
|
||||
"ok/eab-enabled": func(t *testing.T) test {
|
||||
auth := &api.MockAuthority{
|
||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||
assert.Equals(t, "eab-enabled", name)
|
||||
return &provisioner.MockProvisioner{
|
||||
MgetID: func() string {
|
||||
return "provID"
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
db := &admin.MockDB{
|
||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||
assert.Equals(t, "provID", id)
|
||||
return &linkedca.Provisioner{
|
||||
Id: "provID",
|
||||
Name: "eab-enabled",
|
||||
Details: &linkedca.ProvisionerDetails{
|
||||
Data: &linkedca.ProvisionerDetails_ACME{
|
||||
ACME: &linkedca.ACMEProvisioner{
|
||||
RequireEab: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
return test{
|
||||
db: db,
|
||||
auth: auth,
|
||||
provisionerName: "eab-enabled",
|
||||
want: true,
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
db: tc.db,
|
||||
auth: tc.auth,
|
||||
acmeDB: nil,
|
||||
}
|
||||
got, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
||||
if (err != nil) != (tc.err != nil) {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
||||
return
|
||||
}
|
||||
if tc.err != nil {
|
||||
// TODO(hs): the output of the diff seems to be equal to each other; not sure why it's marked as different =/
|
||||
// opts := []cmp.Option{cmpopts.EquateErrors()}
|
||||
// if !cmp.Equal(tc.err, err, opts...) {
|
||||
// t.Errorf("Handler.provisionerHasEABEnabled() diff =\n%v", cmp.Diff(tc.err, err, opts...))
|
||||
// }
|
||||
assert.Equals(t, tc.err.Type, err.Type)
|
||||
assert.Equals(t, tc.err.Status, err.Status)
|
||||
assert.Equals(t, tc.err.StatusCode(), err.StatusCode())
|
||||
assert.Equals(t, tc.err.Message, err.Message)
|
||||
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||
return
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Reference string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "ok/empty-reference",
|
||||
fields: fields{
|
||||
Reference: "",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
Reference: "my-eab-reference",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &CreateExternalAccountKeyRequest{
|
||||
Reference: tt.fields.Reference,
|
||||
}
|
||||
if err := r.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,19 +3,18 @@ package api
|
|||
import (
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
)
|
||||
|
||||
// Handler is the ACME API request handler.
|
||||
// Handler is the Admin API request handler.
|
||||
type Handler struct {
|
||||
db admin.DB
|
||||
auth *authority.Authority
|
||||
auth api.LinkedAuthority // was: *authority.Authority
|
||||
acmeDB acme.DB
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth *authority.Authority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler {
|
||||
func NewHandler(auth api.LinkedAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler {
|
||||
return &Handler{
|
||||
db: adminDB,
|
||||
auth: auth,
|
||||
|
|
1
go.mod
1
go.mod
|
@ -18,6 +18,7 @@ require (
|
|||
github.com/go-kit/kit v0.10.0 // indirect
|
||||
github.com/go-piv/piv-go v1.7.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/googleapis/gax-go/v2 v2.0.5
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
|
|
Loading…
Reference in a new issue