From 2215a05c280e6e166dd6cbb68ab4b78c9d62b17f Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Wed, 8 Dec 2021 15:19:38 +0100 Subject: [PATCH] Add tests for ACME EAB Admin Refactored some of the existing bits for testing the Authority API by creation of a new LinkedAuthority interface and changing visibility of the MockAuthority to be usable by other packages. At this time, not all of the functions of MockAuthority it usable yet. Will refactor when needed or requested. --- acme/api/handler_test.go | 88 +++++-- api/api.go | 311 +++++++++++++++++++++++ api/api_test.go | 230 +---------------- api/revoke_test.go | 8 +- api/ssh_test.go | 14 +- authority/admin/api/acme.go | 2 +- authority/admin/api/acme_test.go | 416 +++++++++++++++++++++++++++++++ authority/admin/api/handler.go | 7 +- go.mod | 1 + 9 files changed, 826 insertions(+), 251 deletions(-) create mode 100644 authority/admin/api/acme_test.go diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 14e00f12..67f7df30 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -15,9 +15,11 @@ import ( "time" "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -51,28 +53,76 @@ func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) { linker := NewLinker("ca.smallstep.com", "acme") - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - - expDir := Directory{ - NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), - NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), - RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), - KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), - } - type test struct { + ctx context.Context statusCode int + dir Directory err *acme.Error } var tests = map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { + "fail/no-provisioner": func(t *testing.T) test { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "fail/different-provisioner": func(t *testing.T) test { + prov := &provisioner.SCEP{ + Type: "SCEP", + Name: "test@scep-provisioner.com", + } + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "ok": func(t *testing.T) test { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + } + return test{ + ctx: ctx, + dir: expDir, + statusCode: 200, + } + }, + "ok/eab-required": func(t *testing.T) test { + prov := newACMEProv(t) + prov.RequireEAB = true + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + Meta: Meta{ + ExternalAccountRequired: true, + }, + } + return test{ + ctx: ctx, + dir: expDir, statusCode: 200, } }, @@ -82,7 +132,7 @@ func TestHandler_GetDirectory(t *testing.T) { t.Run(name, func(t *testing.T) { h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(ctx) + req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetDirectory(w, req) res := w.Result() @@ -105,7 +155,9 @@ func TestHandler_GetDirectory(t *testing.T) { } else { var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) - assert.Equals(t, dir, expDir) + if !cmp.Equal(tc.dir, dir) { + t.Errorf("GetDirectory() diff =\n%s", cmp.Diff(tc.dir, dir)) + } assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/api/api.go b/api/api.go index e057caaa..468870b6 100644 --- a/api/api.go +++ b/api/api.go @@ -25,6 +25,9 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/templates" + "go.step.sm/linkedca" + "golang.org/x/crypto/ssh" ) // Authority is the interface implemented by a CA authority. @@ -48,6 +51,21 @@ type Authority interface { Version() authority.Version } +type LinkedAuthority interface { // TODO(hs): name is not great; it is related to LinkedCA, though + Authority + IsAdminAPIEnabled() bool + LoadAdminByID(id string) (*linkedca.Admin, bool) + GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) + StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + RemoveAdmin(ctx context.Context, id string) error + AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) + StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + LoadProvisionerByID(id string) (provisioner.Interface, error) + UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error + RemoveProvisioner(ctx context.Context, id string) error +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -457,3 +475,296 @@ func fmtPublicKey(cert *x509.Certificate) string { } return fmt.Sprintf("%s %s", cert.PublicKeyAlgorithm, params) } + +type MockAuthority struct { + ret1, ret2 interface{} + err error + authorizeSign func(ott string) ([]provisioner.SignOption, error) + getTLSOptions func() *authority.TLSOptions + root func(shasum string) (*x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) + MockLoadProvisionerByName func(name string) (provisioner.Interface, error) + getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + revoke func(context.Context, *authority.RevokeOptions) error + getEncryptedKey func(kid string) (string, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) + checkSSHHost func(ctx context.Context, principal, token string) (bool, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) + version func() authority.Version + + MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two + MockErr error + MockIsAdminAPIEnabled func() bool + MockLoadAdminByID func(id string) (*linkedca.Admin, bool) + MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) + MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + MockRemoveAdmin func(ctx context.Context, id string) error + MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) + MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockLoadProvisionerByID func(id string) (provisioner.Interface, error) + MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error + MockRemoveProvisioner func(ctx context.Context, id string) error +} + +// TODO: remove once Authorize is deprecated. +func (m *MockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + return m.AuthorizeSign(ott) +} + +func (m *MockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { + if m.authorizeSign != nil { + return m.authorizeSign(ott) + } + return m.ret1.([]provisioner.SignOption), m.err +} + +func (m *MockAuthority) GetTLSOptions() *authority.TLSOptions { + if m.getTLSOptions != nil { + return m.getTLSOptions() + } + return m.ret1.(*authority.TLSOptions) +} + +func (m *MockAuthority) Root(shasum string) (*x509.Certificate, error) { + if m.root != nil { + return m.root(shasum) + } + return m.ret1.(*x509.Certificate), m.err +} + +func (m *MockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.sign != nil { + return m.sign(cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.renew != nil { + return m.renew(cert) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.rekey != nil { + return m.rekey(oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.getProvisioners != nil { + return m.getProvisioners(nextCursor, limit) + } + return m.ret1.(provisioner.List), m.ret2.(string), m.err +} + +func (m *MockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { + if m.loadProvisionerByCertificate != nil { + return m.loadProvisionerByCertificate(cert) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *MockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByName != nil { + return m.MockLoadProvisionerByName(name) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *MockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { + if m.revoke != nil { + return m.revoke(ctx, opts) + } + return m.err +} + +func (m *MockAuthority) GetEncryptedKey(kid string) (string, error) { + if m.getEncryptedKey != nil { + return m.getEncryptedKey(kid) + } + return m.ret1.(string), m.err +} + +func (m *MockAuthority) GetRoots() ([]*x509.Certificate, error) { + if m.getRoots != nil { + return m.getRoots() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *MockAuthority) GetFederation() ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *MockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.renewSSH != nil { + return m.renewSSH(ctx, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.rekeySSH != nil { + return m.rekeySSH(ctx, cert, key, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { + if m.getSSHHosts != nil { + return m.getSSHHosts(ctx, cert) + } + return m.ret1.([]authority.Host), m.err +} + +func (m *MockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHRoots != nil { + return m.getSSHRoots(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *MockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHFederation != nil { + return m.getSSHFederation(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *MockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { + if m.getSSHConfig != nil { + return m.getSSHConfig(ctx, typ, data) + } + return m.ret1.([]templates.Output), m.err +} + +func (m *MockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { + if m.checkSSHHost != nil { + return m.checkSSHHost(ctx, principal, token) + } + return m.ret1.(bool), m.err +} + +func (m *MockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { + if m.getSSHBastion != nil { + return m.getSSHBastion(ctx, user, hostname) + } + return m.ret1.(*authority.Bastion), m.err +} + +func (m *MockAuthority) Version() authority.Version { + if m.version != nil { + return m.version() + } + return m.ret1.(authority.Version) +} + +func (m *MockAuthority) IsAdminAPIEnabled() bool { + if m.MockIsAdminAPIEnabled != nil { + return m.MockIsAdminAPIEnabled() + } + return m.MockRet1.(bool) +} + +func (m *MockAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { + if m.MockLoadAdminByID != nil { + return m.MockLoadAdminByID(id) + } + return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) +} + +func (m *MockAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { + if m.MockGetAdmins != nil { + return m.MockGetAdmins(cursor, limit) + } + return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr +} + +func (m *MockAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + if m.MockStoreAdmin != nil { + return m.MockStoreAdmin(ctx, adm, prov) + } + return m.MockErr +} + +func (m *MockAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + if m.MockUpdateAdmin != nil { + return m.MockUpdateAdmin(ctx, id, nu) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *MockAuthority) RemoveAdmin(ctx context.Context, id string) error { + if m.MockRemoveAdmin != nil { + return m.MockRemoveAdmin(ctx, id) + } + return m.MockErr +} + +func (m *MockAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { + if m.MockAuthorizeAdminToken != nil { + return m.MockAuthorizeAdminToken(r, token) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *MockAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockStoreProvisioner != nil { + return m.MockStoreProvisioner(ctx, prov) + } + return m.MockErr +} + +func (m *MockAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByID != nil { + return m.MockLoadProvisionerByID(id) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *MockAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { + if m.MockUpdateProvisioner != nil { + return m.MockUpdateProvisioner(ctx, nu) + } + return m.MockErr +} + +func (m *MockAuthority) RemoveProvisioner(ctx context.Context, id string) error { + if m.MockRemoveProvisioner != nil { + return m.MockRemoveProvisioner(ctx, id) + } + return m.MockErr +} diff --git a/api/api_test.go b/api/api_test.go index 5cbce8b3..6a845249 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,7 +3,6 @@ package api import ( "bytes" "context" - "crypto" "crypto/dsa" //nolint "crypto/ecdsa" "crypto/ed25519" @@ -32,7 +31,6 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) @@ -551,208 +549,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) ( return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } -type mockAuthority struct { - ret1, ret2 interface{} - err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) - getTLSOptions func() *authority.TLSOptions - root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - renew func(cert *x509.Certificate) ([]*x509.Certificate, error) - rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) - loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) - loadProvisionerByName func(name string) (provisioner.Interface, error) - getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) - revoke func(context.Context, *authority.RevokeOptions) error - getEncryptedKey func(kid string) (string, error) - getRoots func() ([]*x509.Certificate, error) - getFederation func() ([]*x509.Certificate, error) - signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) - getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) - getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) - getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) - checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) - version func() authority.Version -} - -// TODO: remove once Authorize is deprecated. -func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) - } - return m.ret1.([]provisioner.SignOption), m.err -} - -func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { - if m.getTLSOptions != nil { - return m.getTLSOptions() - } - return m.ret1.(*authority.TLSOptions) -} - -func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { - if m.root != nil { - return m.root(shasum) - } - return m.ret1.(*x509.Certificate), m.err -} - -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { - if m.renew != nil { - return m.renew(cert) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { - if m.rekey != nil { - return m.rekey(oldcert, pk) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { - if m.getProvisioners != nil { - return m.getProvisioners(nextCursor, limit) - } - return m.ret1.(provisioner.List), m.ret2.(string), m.err -} - -func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { - if m.loadProvisionerByCertificate != nil { - return m.loadProvisionerByCertificate(cert) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { - if m.loadProvisionerByName != nil { - return m.loadProvisionerByName(name) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { - if m.revoke != nil { - return m.revoke(ctx, opts) - } - return m.err -} - -func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { - if m.getEncryptedKey != nil { - return m.getEncryptedKey(kid) - } - return m.ret1.(string), m.err -} - -func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { - if m.getRoots != nil { - return m.getRoots() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { - if m.getFederation != nil { - return m.getFederation() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(ctx, key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(ctx, key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.renewSSH != nil { - return m.renewSSH(ctx, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.rekeySSH != nil { - return m.rekeySSH(ctx, cert, key, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { - if m.getSSHHosts != nil { - return m.getSSHHosts(ctx, cert) - } - return m.ret1.([]authority.Host), m.err -} - -func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHRoots != nil { - return m.getSSHRoots(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHFederation != nil { - return m.getSSHFederation(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { - if m.getSSHConfig != nil { - return m.getSSHConfig(ctx, typ, data) - } - return m.ret1.([]templates.Output), m.err -} - -func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { - if m.checkSSHHost != nil { - return m.checkSSHHost(ctx, principal, token) - } - return m.ret1.(bool), m.err -} - -func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { - if m.getSSHBastion != nil { - return m.getSSHBastion(ctx, user, hostname) - } - return m.ret1.(*authority.Bastion), m.err -} - -func (m *mockAuthority) Version() authority.Version { - if m.version != nil { - return m.version() - } - return m.ret1.(authority.Version) -} - func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority @@ -765,7 +561,7 @@ func Test_caHandler_Route(t *testing.T) { fields fields args args }{ - {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, + {"ok", fields{&MockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -780,7 +576,7 @@ func Test_caHandler_Route(t *testing.T) { func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&mockAuthority{}).(*caHandler) + h := New(&MockAuthority{}).(*caHandler) h.Health(w, req) res := w.Result() @@ -820,7 +616,7 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) w := httptest.NewRecorder() h.Root(w, req) res := w.Result() @@ -884,7 +680,7 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr @@ -938,7 +734,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -999,7 +795,7 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -1077,9 +873,9 @@ func Test_caHandler_Provisioners(t *testing.T) { args args statusCode int }{ - {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, - {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, + {"ok", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&MockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, + {"limit fail", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) @@ -1154,8 +950,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { args args statusCode int }{ - {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, + {"ok", fields{&MockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&MockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) @@ -1214,7 +1010,7 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() @@ -1260,7 +1056,7 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..b0eaef3d 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -106,7 +106,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusOK, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -150,7 +150,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, tls: cs, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -185,7 +185,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusInternalServerError, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -207,7 +207,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusForbidden, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, diff --git a/api/ssh_test.go b/api/ssh_test.go index a3d7da0d..df9e2f45 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -314,7 +314,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, @@ -377,7 +377,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -431,7 +431,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -491,7 +491,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, @@ -538,7 +538,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, @@ -589,7 +589,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, @@ -644,7 +644,7 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 8cba39c4..88e76a09 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -49,7 +49,7 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { // provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME // provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, error) { +func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *admin.Error) { var ( p provisioner.Interface err error diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go new file mode 100644 index 00000000..84e8e9f5 --- /dev/null +++ b/authority/admin/api/acme_test.go @@ -0,0 +1,416 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" +) + +func TestHandler_requireEABEnabled(t *testing.T) { + type test struct { + ctx context.Context + db admin.DB + auth api.LinkedAuthority + next nextHTTP + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + err := admin.NewErrorISE("error loading provisioner provName: force") + err.Message = "error loading provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + err: err, + statusCode: 500, + } + }, + "ok/eab-disabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") + err.Message = "ACME EAB not enabled for provisioner provName" + return test{ + ctx: ctx, + auth: auth, + db: db, + err: err, + statusCode: 400, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + ctx: ctx, + auth: auth, + db: db, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.requireEABEnabled(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + // nothing to test when the requireEABEnabled middleware succeeds, currently + }) + } +} + +func TestHandler_provisionerHasEABEnabled(t *testing.T) { + type test struct { + db admin.DB + auth api.LinkedAuthority + provisionerName string + want bool + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/prov.GetDetails": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: nil, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/details.GetACME": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: nil, + }, + }, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "ok/eab-disabled": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-disabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-disabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-disabled", + want: false, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-enabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-enabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-enabled", + want: true, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + got, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) + if (err != nil) != (tc.err != nil) { + t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) + return + } + if tc.err != nil { + // TODO(hs): the output of the diff seems to be equal to each other; not sure why it's marked as different =/ + // opts := []cmp.Option{cmpopts.EquateErrors()} + // if !cmp.Equal(tc.err, err, opts...) { + // t.Errorf("Handler.provisionerHasEABEnabled() diff =\n%v", cmp.Diff(tc.err, err, opts...)) + // } + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Status, err.Status) + assert.Equals(t, tc.err.StatusCode(), err.StatusCode()) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.Detail, err.Detail) + return + } + if got != tc.want { + t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { + type fields struct { + Reference string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "ok/empty-reference", + fields: fields{ + Reference: "", + }, + wantErr: false, + }, + { + name: "ok", + fields: fields{ + Reference: "my-eab-reference", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CreateExternalAccountKeyRequest{ + Reference: tt.fields.Reference, + } + if err := r.Validate(); (err != nil) != tt.wantErr { + t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index ba13407d..b3ed04bf 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -3,19 +3,18 @@ package api import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) -// Handler is the ACME API request handler. +// Handler is the Admin API request handler. type Handler struct { db admin.DB - auth *authority.Authority + auth api.LinkedAuthority // was: *authority.Authority acmeDB acme.DB } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth *authority.Authority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { +func NewHandler(auth api.LinkedAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { return &Handler{ db: adminDB, auth: auth, diff --git a/go.mod b/go.mod index b2014bf4..edf70903 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect