Improve middleware test coverage
This commit is contained in:
parent
6da243c34d
commit
bfa4d809fd
5 changed files with 206 additions and 341 deletions
|
@ -1,17 +1,13 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/go-chi/chi"
|
|
||||||
|
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/api/render"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
||||||
|
@ -38,48 +34,27 @@ type GetExternalAccountKeysResponse struct {
|
||||||
func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
provName := chi.URLParam(r, "provisionerName")
|
prov := linkedca.ProvisionerFromContext(ctx)
|
||||||
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
|
|
||||||
if err != nil {
|
|
||||||
render.Error(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !eabEnabled {
|
|
||||||
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
|
|
||||||
next(w, r.WithContext(ctx))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
|
|
||||||
// provisioner is set to true and thus has EAB enabled.
|
|
||||||
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
|
|
||||||
var (
|
|
||||||
p provisioner.Interface
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
|
|
||||||
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
|
|
||||||
}
|
|
||||||
|
|
||||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
|
|
||||||
}
|
|
||||||
|
|
||||||
details := prov.GetDetails()
|
details := prov.GetDetails()
|
||||||
if details == nil {
|
if details == nil {
|
||||||
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
|
render.Error(w, admin.NewErrorISE("error getting details for provisioner '%s'", prov.GetName()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
acmeProvisioner := details.GetACME()
|
acmeProvisioner := details.GetACME()
|
||||||
if acmeProvisioner == nil {
|
if acmeProvisioner == nil {
|
||||||
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
|
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
return acmeProvisioner.GetRequireEab(), prov, nil
|
if !acmeProvisioner.RequireEab {
|
||||||
|
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next(w, r.WithContext(ctx))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type acmeAdminResponderInterface interface {
|
type acmeAdminResponderInterface interface {
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -20,7 +19,6 @@ import (
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||||
|
@ -35,50 +33,42 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
|
||||||
func TestHandler_requireEABEnabled(t *testing.T) {
|
func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
adminDB admin.DB
|
|
||||||
auth adminAuthority
|
|
||||||
next http.HandlerFunc
|
next http.HandlerFunc
|
||||||
err *admin.Error
|
err *admin.Error
|
||||||
statusCode int
|
statusCode int
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
var tests = map[string]func(t *testing.T) test{
|
||||||
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test {
|
"fail/prov.GetDetails": func(t *testing.T) test {
|
||||||
chiCtx := chi.NewRouteContext()
|
prov := &linkedca.Provisioner{
|
||||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
Id: "provID",
|
||||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
Name: "provName",
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
err := admin.NewErrorISE("error loading provisioner provName: force")
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
err.Message = "error loading provisioner provName: force"
|
err := admin.NewErrorISE("error getting details for provisioner 'provName'")
|
||||||
|
err.Message = "error getting details for provisioner 'provName'"
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
err: err,
|
||||||
|
statusCode: 500,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/details.GetACME": func(t *testing.T) test {
|
||||||
|
prov := &linkedca.Provisioner{
|
||||||
|
Id: "provID",
|
||||||
|
Name: "provName",
|
||||||
|
Details: &linkedca.ProvisionerDetails{},
|
||||||
|
}
|
||||||
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
|
err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
|
||||||
|
err.Message = "error getting ACME details for provisioner 'provName'"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
auth: auth,
|
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/eab-disabled": func(t *testing.T) test {
|
"ok/eab-disabled": func(t *testing.T) test {
|
||||||
chiCtx := chi.NewRouteContext()
|
prov := &linkedca.Provisioner{
|
||||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
|
||||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
Id: "provID",
|
||||||
Name: "provName",
|
Name: "provName",
|
||||||
Details: &linkedca.ProvisionerDetails{
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
@ -88,37 +78,18 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
|
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
|
||||||
err.Message = "ACME EAB not enabled for provisioner provName"
|
err.Message = "ACME EAB not enabled for provisioner 'provName'"
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
auth: auth,
|
|
||||||
adminDB: db,
|
|
||||||
err: err,
|
err: err,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok/eab-enabled": func(t *testing.T) test {
|
"ok/eab-enabled": func(t *testing.T) test {
|
||||||
chiCtx := chi.NewRouteContext()
|
prov := &linkedca.Provisioner{
|
||||||
chiCtx.URLParams.Add("provisionerName", "provName")
|
|
||||||
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
Id: "provID",
|
||||||
Name: "provName",
|
Name: "provName",
|
||||||
Details: &linkedca.ProvisionerDetails{
|
Details: &linkedca.ProvisionerDetails{
|
||||||
|
@ -128,13 +99,10 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
|
||||||
return test{
|
return test{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
auth: auth,
|
|
||||||
adminDB: db,
|
|
||||||
next: func(w http.ResponseWriter, r *http.Request) {
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write(nil) // mock response with status 200
|
w.Write(nil) // mock response with status 200
|
||||||
},
|
},
|
||||||
|
@ -146,13 +114,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
for name, prep := range tests {
|
for name, prep := range tests {
|
||||||
tc := prep(t)
|
tc := prep(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{
|
h := &Handler{}
|
||||||
auth: tc.auth,
|
|
||||||
adminDB: tc.adminDB,
|
|
||||||
acmeDB: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
req := httptest.NewRequest("GET", "/foo", nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.requireEABEnabled(tc.next)(w, req)
|
h.requireEABEnabled(tc.next)(w, req)
|
||||||
|
@ -179,216 +143,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
|
|
||||||
type test struct {
|
|
||||||
adminDB admin.DB
|
|
||||||
auth adminAuthority
|
|
||||||
provisionerName string
|
|
||||||
want bool
|
|
||||||
err *admin.Error
|
|
||||||
}
|
|
||||||
var tests = map[string]func(t *testing.T) test{
|
|
||||||
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
auth: auth,
|
|
||||||
provisionerName: "provName",
|
|
||||||
want: false,
|
|
||||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db.GetProvisioner": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return nil, errors.New("force")
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
auth: auth,
|
|
||||||
adminDB: db,
|
|
||||||
provisionerName: "provName",
|
|
||||||
want: false,
|
|
||||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/prov.GetDetails": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
|
||||||
Name: "provName",
|
|
||||||
Details: nil,
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
auth: auth,
|
|
||||||
adminDB: db,
|
|
||||||
provisionerName: "provName",
|
|
||||||
want: false,
|
|
||||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/details.GetACME": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "provName", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
|
||||||
Name: "provName",
|
|
||||||
Details: &linkedca.ProvisionerDetails{
|
|
||||||
Data: &linkedca.ProvisionerDetails_ACME{
|
|
||||||
ACME: nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
auth: auth,
|
|
||||||
adminDB: db,
|
|
||||||
provisionerName: "provName",
|
|
||||||
want: false,
|
|
||||||
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/eab-disabled": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "eab-disabled", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
|
||||||
Name: "eab-disabled",
|
|
||||||
Details: &linkedca.ProvisionerDetails{
|
|
||||||
Data: &linkedca.ProvisionerDetails_ACME{
|
|
||||||
ACME: &linkedca.ACMEProvisioner{
|
|
||||||
RequireEab: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
adminDB: db,
|
|
||||||
auth: auth,
|
|
||||||
provisionerName: "eab-disabled",
|
|
||||||
want: false,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok/eab-enabled": func(t *testing.T) test {
|
|
||||||
auth := &mockAdminAuthority{
|
|
||||||
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
|
||||||
assert.Equals(t, "eab-enabled", name)
|
|
||||||
return &provisioner.MockProvisioner{
|
|
||||||
MgetID: func() string {
|
|
||||||
return "provID"
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
db := &admin.MockDB{
|
|
||||||
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
|
||||||
assert.Equals(t, "provID", id)
|
|
||||||
return &linkedca.Provisioner{
|
|
||||||
Id: "provID",
|
|
||||||
Name: "eab-enabled",
|
|
||||||
Details: &linkedca.ProvisionerDetails{
|
|
||||||
Data: &linkedca.ProvisionerDetails_ACME{
|
|
||||||
ACME: &linkedca.ACMEProvisioner{
|
|
||||||
RequireEab: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return test{
|
|
||||||
adminDB: db,
|
|
||||||
auth: auth,
|
|
||||||
provisionerName: "eab-enabled",
|
|
||||||
want: true,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, prep := range tests {
|
|
||||||
tc := prep(t)
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
h := &Handler{
|
|
||||||
auth: tc.auth,
|
|
||||||
adminDB: tc.adminDB,
|
|
||||||
acmeDB: nil,
|
|
||||||
}
|
|
||||||
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
|
|
||||||
if (err != nil) != (tc.err != nil) {
|
|
||||||
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tc.err != nil {
|
|
||||||
assert.Type(t, &linkedca.Provisioner{}, prov)
|
|
||||||
assert.Type(t, &admin.Error{}, err)
|
|
||||||
adminError, _ := err.(*admin.Error)
|
|
||||||
assert.Equals(t, tc.err.Type, adminError.Type)
|
|
||||||
assert.Equals(t, tc.err.Status, adminError.Status)
|
|
||||||
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
|
|
||||||
assert.Equals(t, tc.err.Message, adminError.Message)
|
|
||||||
assert.Equals(t, tc.err.Detail, adminError.Detail)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if got != tc.want {
|
|
||||||
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Reference string
|
Reference string
|
||||||
|
|
|
@ -62,10 +62,10 @@ func (h *Handler) Route(r api.Router) {
|
||||||
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin))
|
||||||
|
|
||||||
// ACME External Account Binding Keys
|
// ACME External Account Binding Keys
|
||||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(h.loadProvisionerByName(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))))
|
||||||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(h.loadProvisionerByName(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))))
|
||||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
|
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(h.loadProvisionerByName(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))))
|
||||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
|
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(h.loadProvisionerByName(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))))
|
||||||
|
|
||||||
// Policy - Authority
|
// Policy - Authority
|
||||||
r.MethodFunc("GET", "/policy", authnz(enabledInStandalone(h.policyResponder.GetAuthorityPolicy)))
|
r.MethodFunc("GET", "/policy", authnz(enabledInStandalone(h.policyResponder.GetAuthorityPolicy)))
|
||||||
|
@ -74,16 +74,14 @@ func (h *Handler) Route(r api.Router) {
|
||||||
r.MethodFunc("DELETE", "/policy", authnz(enabledInStandalone(h.policyResponder.DeleteAuthorityPolicy)))
|
r.MethodFunc("DELETE", "/policy", authnz(enabledInStandalone(h.policyResponder.DeleteAuthorityPolicy)))
|
||||||
|
|
||||||
// Policy - Provisioner
|
// Policy - Provisioner
|
||||||
//r.MethodFunc("GET", "/provisioners/{name}/policy", noauth(h.policyResponder.GetProvisionerPolicy))
|
|
||||||
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.GetProvisionerPolicy))))
|
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.GetProvisionerPolicy))))
|
||||||
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.CreateProvisionerPolicy))))
|
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.CreateProvisionerPolicy))))
|
||||||
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.UpdateProvisionerPolicy))))
|
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.UpdateProvisionerPolicy))))
|
||||||
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.DeleteProvisionerPolicy))))
|
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", authnz(disabledInStandalone(h.loadProvisionerByName(h.policyResponder.DeleteProvisionerPolicy))))
|
||||||
|
|
||||||
// Policy - ACME Account
|
// Policy - ACME Account
|
||||||
// TODO: ensure we don't clash with eab; might want to change eab paths slightly (as long as we don't have it released completely; needs changes in adminClient too)
|
r.MethodFunc("GET", "/acme/policy/{provisionerName}/{accountID}", authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.policyResponder.GetACMEAccountPolicy)))))
|
||||||
r.MethodFunc("GET", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.GetACMEAccountPolicy)))
|
r.MethodFunc("POST", "/acme/policy/{provisionerName}/{accountID}", authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.policyResponder.CreateACMEAccountPolicy)))))
|
||||||
r.MethodFunc("POST", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.CreateACMEAccountPolicy)))
|
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/{accountID}", authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.policyResponder.UpdateACMEAccountPolicy)))))
|
||||||
r.MethodFunc("PUT", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateACMEAccountPolicy)))
|
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/{accountID}", authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.policyResponder.DeleteACMEAccountPolicy)))))
|
||||||
r.MethodFunc("DELETE", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteACMEAccountPolicy)))
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,6 +59,8 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc
|
||||||
p provisioner.Interface
|
p provisioner.Interface
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(hs): distinguish 404 vs. 500
|
||||||
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
if p, err = h.auth.LoadProvisionerByName(name); err != nil {
|
||||||
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||||
return
|
return
|
||||||
|
@ -66,7 +68,7 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc
|
||||||
|
|
||||||
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
render.Error(w, err)
|
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,14 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
@ -18,6 +20,7 @@ import (
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandler_requireAPIEnabled(t *testing.T) {
|
func TestHandler_requireAPIEnabled(t *testing.T) {
|
||||||
|
@ -220,3 +223,136 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandler_loadProvisionerByName(t *testing.T) {
|
||||||
|
type test struct {
|
||||||
|
adminDB admin.DB
|
||||||
|
auth adminAuthority
|
||||||
|
ctx context.Context
|
||||||
|
next http.HandlerFunc
|
||||||
|
err *admin.Error
|
||||||
|
statusCode int
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||||
|
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||||
|
auth := &mockAdminAuthority{
|
||||||
|
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||||
|
assert.Equals(t, "provName", name)
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
|
||||||
|
err.Message = "error loading provisioner provName: force"
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: auth,
|
||||||
|
statusCode: 500,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/db.GetProvisioner": func(t *testing.T) test {
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||||
|
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||||
|
auth := &mockAdminAuthority{
|
||||||
|
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||||
|
assert.Equals(t, "provName", name)
|
||||||
|
return &provisioner.MockProvisioner{
|
||||||
|
MgetID: func() string {
|
||||||
|
return "provID"
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
db := &admin.MockDB{
|
||||||
|
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||||
|
assert.Equals(t, "provID", id)
|
||||||
|
return nil, errors.New("force")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
|
||||||
|
err.Message = "error retrieving provisioner provName: force"
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: auth,
|
||||||
|
adminDB: db,
|
||||||
|
statusCode: 500,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok": func(t *testing.T) test {
|
||||||
|
chiCtx := chi.NewRouteContext()
|
||||||
|
chiCtx.URLParams.Add("provisionerName", "provName")
|
||||||
|
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
|
||||||
|
auth := &mockAdminAuthority{
|
||||||
|
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
|
||||||
|
assert.Equals(t, "provName", name)
|
||||||
|
return &provisioner.MockProvisioner{
|
||||||
|
MgetID: func() string {
|
||||||
|
return "provID"
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
db := &admin.MockDB{
|
||||||
|
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
|
||||||
|
assert.Equals(t, "provID", id)
|
||||||
|
return &linkedca.Provisioner{
|
||||||
|
Id: "provID",
|
||||||
|
Name: "provName",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return test{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: auth,
|
||||||
|
adminDB: db,
|
||||||
|
statusCode: 200,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
prov := linkedca.ProvisionerFromContext(r.Context())
|
||||||
|
assert.NotNil(t, prov)
|
||||||
|
assert.Equals(t, "provID", prov.GetId())
|
||||||
|
assert.Equals(t, "provName", prov.GetName())
|
||||||
|
w.Write(nil) // mock response with status 200
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{
|
||||||
|
auth: tc.auth,
|
||||||
|
adminDB: tc.adminDB,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.loadProvisionerByName(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
err := admin.Error{}
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
|
||||||
|
|
||||||
|
assert.Equals(t, tc.err.Type, err.Type)
|
||||||
|
assert.Equals(t, tc.err.Message, err.Message)
|
||||||
|
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
|
||||||
|
assert.Equals(t, tc.err.Detail, err.Detail)
|
||||||
|
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue