diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 14e00f12..67f7df30 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -15,9 +15,11 @@ import ( "time" "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -51,28 +53,76 @@ func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) { linker := NewLinker("ca.smallstep.com", "acme") - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - - expDir := Directory{ - NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), - NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), - RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), - KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), - } - type test struct { + ctx context.Context statusCode int + dir Directory err *acme.Error } var tests = map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { + "fail/no-provisioner": func(t *testing.T) test { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "fail/different-provisioner": func(t *testing.T) test { + prov := &provisioner.SCEP{ + Type: "SCEP", + Name: "test@scep-provisioner.com", + } + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "ok": func(t *testing.T) test { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + } + return test{ + ctx: ctx, + dir: expDir, + statusCode: 200, + } + }, + "ok/eab-required": func(t *testing.T) test { + prov := newACMEProv(t) + prov.RequireEAB = true + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + Meta: Meta{ + ExternalAccountRequired: true, + }, + } + return test{ + ctx: ctx, + dir: expDir, statusCode: 200, } }, @@ -82,7 +132,7 @@ func TestHandler_GetDirectory(t *testing.T) { t.Run(name, func(t *testing.T) { h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(ctx) + req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetDirectory(w, req) res := w.Result() @@ -105,7 +155,9 @@ func TestHandler_GetDirectory(t *testing.T) { } else { var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) - assert.Equals(t, dir, expDir) + if !cmp.Equal(tc.dir, dir) { + t.Errorf("GetDirectory() diff =\n%s", cmp.Diff(tc.dir, dir)) + } assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/api/api.go b/api/api.go index e057caaa..468870b6 100644 --- a/api/api.go +++ b/api/api.go @@ -25,6 +25,9 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/templates" + "go.step.sm/linkedca" + "golang.org/x/crypto/ssh" ) // Authority is the interface implemented by a CA authority. @@ -48,6 +51,21 @@ type Authority interface { Version() authority.Version } +type LinkedAuthority interface { // TODO(hs): name is not great; it is related to LinkedCA, though + Authority + IsAdminAPIEnabled() bool + LoadAdminByID(id string) (*linkedca.Admin, bool) + GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) + StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + RemoveAdmin(ctx context.Context, id string) error + AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) + StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + LoadProvisionerByID(id string) (provisioner.Interface, error) + UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error + RemoveProvisioner(ctx context.Context, id string) error +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -457,3 +475,296 @@ func fmtPublicKey(cert *x509.Certificate) string { } return fmt.Sprintf("%s %s", cert.PublicKeyAlgorithm, params) } + +type MockAuthority struct { + ret1, ret2 interface{} + err error + authorizeSign func(ott string) ([]provisioner.SignOption, error) + getTLSOptions func() *authority.TLSOptions + root func(shasum string) (*x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) + MockLoadProvisionerByName func(name string) (provisioner.Interface, error) + getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + revoke func(context.Context, *authority.RevokeOptions) error + getEncryptedKey func(kid string) (string, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) + checkSSHHost func(ctx context.Context, principal, token string) (bool, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) + version func() authority.Version + + MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two + MockErr error + MockIsAdminAPIEnabled func() bool + MockLoadAdminByID func(id string) (*linkedca.Admin, bool) + MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) + MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + MockRemoveAdmin func(ctx context.Context, id string) error + MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) + MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockLoadProvisionerByID func(id string) (provisioner.Interface, error) + MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error + MockRemoveProvisioner func(ctx context.Context, id string) error +} + +// TODO: remove once Authorize is deprecated. +func (m *MockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + return m.AuthorizeSign(ott) +} + +func (m *MockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { + if m.authorizeSign != nil { + return m.authorizeSign(ott) + } + return m.ret1.([]provisioner.SignOption), m.err +} + +func (m *MockAuthority) GetTLSOptions() *authority.TLSOptions { + if m.getTLSOptions != nil { + return m.getTLSOptions() + } + return m.ret1.(*authority.TLSOptions) +} + +func (m *MockAuthority) Root(shasum string) (*x509.Certificate, error) { + if m.root != nil { + return m.root(shasum) + } + return m.ret1.(*x509.Certificate), m.err +} + +func (m *MockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.sign != nil { + return m.sign(cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.renew != nil { + return m.renew(cert) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.rekey != nil { + return m.rekey(oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *MockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.getProvisioners != nil { + return m.getProvisioners(nextCursor, limit) + } + return m.ret1.(provisioner.List), m.ret2.(string), m.err +} + +func (m *MockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { + if m.loadProvisionerByCertificate != nil { + return m.loadProvisionerByCertificate(cert) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *MockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByName != nil { + return m.MockLoadProvisionerByName(name) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *MockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { + if m.revoke != nil { + return m.revoke(ctx, opts) + } + return m.err +} + +func (m *MockAuthority) GetEncryptedKey(kid string) (string, error) { + if m.getEncryptedKey != nil { + return m.getEncryptedKey(kid) + } + return m.ret1.(string), m.err +} + +func (m *MockAuthority) GetRoots() ([]*x509.Certificate, error) { + if m.getRoots != nil { + return m.getRoots() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *MockAuthority) GetFederation() ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *MockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.renewSSH != nil { + return m.renewSSH(ctx, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.rekeySSH != nil { + return m.rekeySSH(ctx, cert, key, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *MockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { + if m.getSSHHosts != nil { + return m.getSSHHosts(ctx, cert) + } + return m.ret1.([]authority.Host), m.err +} + +func (m *MockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHRoots != nil { + return m.getSSHRoots(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *MockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHFederation != nil { + return m.getSSHFederation(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *MockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { + if m.getSSHConfig != nil { + return m.getSSHConfig(ctx, typ, data) + } + return m.ret1.([]templates.Output), m.err +} + +func (m *MockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { + if m.checkSSHHost != nil { + return m.checkSSHHost(ctx, principal, token) + } + return m.ret1.(bool), m.err +} + +func (m *MockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { + if m.getSSHBastion != nil { + return m.getSSHBastion(ctx, user, hostname) + } + return m.ret1.(*authority.Bastion), m.err +} + +func (m *MockAuthority) Version() authority.Version { + if m.version != nil { + return m.version() + } + return m.ret1.(authority.Version) +} + +func (m *MockAuthority) IsAdminAPIEnabled() bool { + if m.MockIsAdminAPIEnabled != nil { + return m.MockIsAdminAPIEnabled() + } + return m.MockRet1.(bool) +} + +func (m *MockAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { + if m.MockLoadAdminByID != nil { + return m.MockLoadAdminByID(id) + } + return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) +} + +func (m *MockAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { + if m.MockGetAdmins != nil { + return m.MockGetAdmins(cursor, limit) + } + return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr +} + +func (m *MockAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + if m.MockStoreAdmin != nil { + return m.MockStoreAdmin(ctx, adm, prov) + } + return m.MockErr +} + +func (m *MockAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + if m.MockUpdateAdmin != nil { + return m.MockUpdateAdmin(ctx, id, nu) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *MockAuthority) RemoveAdmin(ctx context.Context, id string) error { + if m.MockRemoveAdmin != nil { + return m.MockRemoveAdmin(ctx, id) + } + return m.MockErr +} + +func (m *MockAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { + if m.MockAuthorizeAdminToken != nil { + return m.MockAuthorizeAdminToken(r, token) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *MockAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockStoreProvisioner != nil { + return m.MockStoreProvisioner(ctx, prov) + } + return m.MockErr +} + +func (m *MockAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByID != nil { + return m.MockLoadProvisionerByID(id) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *MockAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { + if m.MockUpdateProvisioner != nil { + return m.MockUpdateProvisioner(ctx, nu) + } + return m.MockErr +} + +func (m *MockAuthority) RemoveProvisioner(ctx context.Context, id string) error { + if m.MockRemoveProvisioner != nil { + return m.MockRemoveProvisioner(ctx, id) + } + return m.MockErr +} diff --git a/api/api_test.go b/api/api_test.go index 5cbce8b3..6a845249 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,7 +3,6 @@ package api import ( "bytes" "context" - "crypto" "crypto/dsa" //nolint "crypto/ecdsa" "crypto/ed25519" @@ -32,7 +31,6 @@ import ( "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/templates" "go.step.sm/crypto/jose" "golang.org/x/crypto/ssh" ) @@ -551,208 +549,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) ( return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } -type mockAuthority struct { - ret1, ret2 interface{} - err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) - getTLSOptions func() *authority.TLSOptions - root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - renew func(cert *x509.Certificate) ([]*x509.Certificate, error) - rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) - loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) - loadProvisionerByName func(name string) (provisioner.Interface, error) - getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) - revoke func(context.Context, *authority.RevokeOptions) error - getEncryptedKey func(kid string) (string, error) - getRoots func() ([]*x509.Certificate, error) - getFederation func() ([]*x509.Certificate, error) - signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) - getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) - getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) - getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) - checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) - version func() authority.Version -} - -// TODO: remove once Authorize is deprecated. -func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) - } - return m.ret1.([]provisioner.SignOption), m.err -} - -func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { - if m.getTLSOptions != nil { - return m.getTLSOptions() - } - return m.ret1.(*authority.TLSOptions) -} - -func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { - if m.root != nil { - return m.root(shasum) - } - return m.ret1.(*x509.Certificate), m.err -} - -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { - if m.renew != nil { - return m.renew(cert) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { - if m.rekey != nil { - return m.rekey(oldcert, pk) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { - if m.getProvisioners != nil { - return m.getProvisioners(nextCursor, limit) - } - return m.ret1.(provisioner.List), m.ret2.(string), m.err -} - -func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { - if m.loadProvisionerByCertificate != nil { - return m.loadProvisionerByCertificate(cert) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { - if m.loadProvisionerByName != nil { - return m.loadProvisionerByName(name) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { - if m.revoke != nil { - return m.revoke(ctx, opts) - } - return m.err -} - -func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { - if m.getEncryptedKey != nil { - return m.getEncryptedKey(kid) - } - return m.ret1.(string), m.err -} - -func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { - if m.getRoots != nil { - return m.getRoots() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { - if m.getFederation != nil { - return m.getFederation() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(ctx, key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(ctx, key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.renewSSH != nil { - return m.renewSSH(ctx, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.rekeySSH != nil { - return m.rekeySSH(ctx, cert, key, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { - if m.getSSHHosts != nil { - return m.getSSHHosts(ctx, cert) - } - return m.ret1.([]authority.Host), m.err -} - -func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHRoots != nil { - return m.getSSHRoots(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHFederation != nil { - return m.getSSHFederation(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { - if m.getSSHConfig != nil { - return m.getSSHConfig(ctx, typ, data) - } - return m.ret1.([]templates.Output), m.err -} - -func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { - if m.checkSSHHost != nil { - return m.checkSSHHost(ctx, principal, token) - } - return m.ret1.(bool), m.err -} - -func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { - if m.getSSHBastion != nil { - return m.getSSHBastion(ctx, user, hostname) - } - return m.ret1.(*authority.Bastion), m.err -} - -func (m *mockAuthority) Version() authority.Version { - if m.version != nil { - return m.version() - } - return m.ret1.(authority.Version) -} - func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority @@ -765,7 +561,7 @@ func Test_caHandler_Route(t *testing.T) { fields fields args args }{ - {"ok", fields{&mockAuthority{}}, args{chi.NewRouter()}}, + {"ok", fields{&MockAuthority{}}, args{chi.NewRouter()}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -780,7 +576,7 @@ func Test_caHandler_Route(t *testing.T) { func Test_caHandler_Health(t *testing.T) { req := httptest.NewRequest("GET", "http://example.com/health", nil) w := httptest.NewRecorder() - h := New(&mockAuthority{}).(*caHandler) + h := New(&MockAuthority{}).(*caHandler) h.Health(w, req) res := w.Result() @@ -820,7 +616,7 @@ func Test_caHandler_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) w := httptest.NewRecorder() h.Root(w, req) res := w.Result() @@ -884,7 +680,7 @@ func Test_caHandler_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.signErr, authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return tt.certAttrOpts, tt.autherr @@ -938,7 +734,7 @@ func Test_caHandler_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -999,7 +795,7 @@ func Test_caHandler_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ ret1: tt.cert, ret2: tt.root, err: tt.err, getTLSOptions: func() *authority.TLSOptions { return nil @@ -1077,9 +873,9 @@ func Test_caHandler_Provisioners(t *testing.T) { args args statusCode int }{ - {"ok", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&mockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, - {"limit fail", fields{&mockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, + {"ok", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&MockAuthority{ret1: p, ret2: "", err: fmt.Errorf("the error")}}, args{httptest.NewRecorder(), req}, 500}, + {"limit fail", fields{&MockAuthority{ret1: p, ret2: ""}}, args{httptest.NewRecorder(), reqLimitFail}, 400}, } expected, err := json.Marshal(pr) @@ -1154,8 +950,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) { args args statusCode int }{ - {"ok", fields{&mockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, - {"fail", fields{&mockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, + {"ok", fields{&MockAuthority{ret1: privKey}}, args{httptest.NewRecorder(), req}, 200}, + {"fail", fields{&MockAuthority{ret1: "", err: fmt.Errorf("not found")}}, args{httptest.NewRecorder(), req}, 404}, } expected := []byte(`{"key":"` + privKey + `"}`) @@ -1214,7 +1010,7 @@ func Test_caHandler_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/roots", nil) req.TLS = tt.tls w := httptest.NewRecorder() @@ -1260,7 +1056,7 @@ func Test_caHandler_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) + h := New(&MockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/federation", nil) req.TLS = tt.tls w := httptest.NewRecorder() diff --git a/api/revoke_test.go b/api/revoke_test.go index 4ed4e3fe..b0eaef3d 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -106,7 +106,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusOK, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -150,7 +150,7 @@ func Test_caHandler_Revoke(t *testing.T) { input: string(input), statusCode: http.StatusOK, tls: cs, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -185,7 +185,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusInternalServerError, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, @@ -207,7 +207,7 @@ func Test_caHandler_Revoke(t *testing.T) { return test{ input: string(input), statusCode: http.StatusForbidden, - auth: &mockAuthority{ + auth: &MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return nil, nil }, diff --git a/api/ssh_test.go b/api/ssh_test.go index a3d7da0d..df9e2f45 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -314,7 +314,7 @@ func Test_caHandler_SSHSign(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, @@ -377,7 +377,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -431,7 +431,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, @@ -491,7 +491,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, @@ -538,7 +538,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { return tt.exists, tt.err }, @@ -589,7 +589,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { return tt.hosts, tt.err }, @@ -644,7 +644,7 @@ func Test_caHandler_SSHBastion(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(&mockAuthority{ + h := New(&MockAuthority{ getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go index 8cba39c4..88e76a09 100644 --- a/authority/admin/api/acme.go +++ b/authority/admin/api/acme.go @@ -49,7 +49,7 @@ func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { // provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME // provisioner is set to true and thus has EAB enabled. -func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, error) { +func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *admin.Error) { var ( p provisioner.Interface err error diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go new file mode 100644 index 00000000..84e8e9f5 --- /dev/null +++ b/authority/admin/api/acme_test.go @@ -0,0 +1,416 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" +) + +func TestHandler_requireEABEnabled(t *testing.T) { + type test struct { + ctx context.Context + db admin.DB + auth api.LinkedAuthority + next nextHTTP + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + err := admin.NewErrorISE("error loading provisioner provName: force") + err.Message = "error loading provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + err: err, + statusCode: 500, + } + }, + "ok/eab-disabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") + err.Message = "ACME EAB not enabled for provisioner provName" + return test{ + ctx: ctx, + auth: auth, + db: db, + err: err, + statusCode: 400, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("prov", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + ctx: ctx, + auth: auth, + db: db, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.requireEABEnabled(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + // nothing to test when the requireEABEnabled middleware succeeds, currently + }) + } +} + +func TestHandler_provisionerHasEABEnabled(t *testing.T) { + type test struct { + db admin.DB + auth api.LinkedAuthority + provisionerName string + want bool + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/prov.GetDetails": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: nil, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/details.GetACME": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: nil, + }, + }, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "ok/eab-disabled": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-disabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-disabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-disabled", + want: false, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + auth := &api.MockAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-enabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-enabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-enabled", + want: true, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + got, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) + if (err != nil) != (tc.err != nil) { + t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) + return + } + if tc.err != nil { + // TODO(hs): the output of the diff seems to be equal to each other; not sure why it's marked as different =/ + // opts := []cmp.Option{cmpopts.EquateErrors()} + // if !cmp.Equal(tc.err, err, opts...) { + // t.Errorf("Handler.provisionerHasEABEnabled() diff =\n%v", cmp.Diff(tc.err, err, opts...)) + // } + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Status, err.Status) + assert.Equals(t, tc.err.StatusCode(), err.StatusCode()) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.Detail, err.Detail) + return + } + if got != tc.want { + t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { + type fields struct { + Reference string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "ok/empty-reference", + fields: fields{ + Reference: "", + }, + wantErr: false, + }, + { + name: "ok", + fields: fields{ + Reference: "my-eab-reference", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CreateExternalAccountKeyRequest{ + Reference: tt.fields.Reference, + } + if err := r.Validate(); (err != nil) != tt.wantErr { + t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index ba13407d..b3ed04bf 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -3,19 +3,18 @@ package api import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) -// Handler is the ACME API request handler. +// Handler is the Admin API request handler. type Handler struct { db admin.DB - auth *authority.Authority + auth api.LinkedAuthority // was: *authority.Authority acmeDB acme.DB } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth *authority.Authority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { +func NewHandler(auth api.LinkedAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { return &Handler{ db: adminDB, auth: auth, diff --git a/go.mod b/go.mod index b2014bf4..edf70903 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect