Merge branch 'master' into ssh-renew-provisioner

This commit is contained in:
Mariano Cano 2022-05-23 14:31:15 -07:00
commit 1be74eca62
62 changed files with 2601 additions and 1833 deletions

View file

@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
} }
// NewAccount is the handler resource for creating new ACME accounts. // NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
eak, err := h.validateExternalAccountBinding(ctx, &nar) eak, err := validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -125,7 +128,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Contact: nar.Contact, Contact: nar.Contact,
Status: acme.StatusValid, Status: acme.StatusValid,
} }
if err := h.db.CreateAccount(ctx, acc); err != nil { if err := db.CreateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
@ -135,7 +138,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
@ -146,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK httpStatus = http.StatusOK
} }
h.linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -186,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc.Contact = uar.Contact acc.Contact = uar.Contact
} }
if err := h.db.UpdateAccount(ctx, acc); err != nil { if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
} }
h.linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc) render.JSON(w, acc)
} }
@ -209,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
} }
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -221,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
h.linker.LinkOrdersByAccountID(ctx, orders) linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders) render.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)

View file

@ -31,6 +31,22 @@ var (
} }
) )
type fakeProvisioner struct{}
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
return nil
}
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
return nil, nil
}
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
func (*fakeProvisioner) GetID() string { return "" }
func (*fakeProvisioner) GetName() string { return "" }
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
func newProv() acme.Provisioner { func newProv() acme.Provisioner {
// Initialize provisioners // Initialize provisioners
p := &provisioner.ACME{ p := &provisioner.ACME{
@ -320,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID} acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
@ -339,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req) GetOrdersByAccountID(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -387,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -395,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, nil) ctx := context.WithValue(context.Background(), payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -403,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to "+ err: acme.NewError(acme.ErrorMalformedType, "failed to "+
@ -417,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -429,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -442,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -456,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -478,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
@ -495,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -525,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
@ -575,8 +590,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
eak := &acme.ExternalAccountKey{ eak := &acme.ExternalAccountKey{
ID: "eakID", ID: "eakID",
@ -623,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -659,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
acc: acc, acc: acc,
statusCode: 200, statusCode: 200,
@ -688,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = false prov.RequireEAB = false
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -743,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -783,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewAccount(w, req) NewAccount(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -838,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -846,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil) ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -854,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -863,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -872,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
@ -886,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -918,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -938,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{} uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -953,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -970,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
}, },
"ok/post-as-get": func(t *testing.T) test { "ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -983,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrUpdateAccount(w, req) GetOrUpdateAccount(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -17,7 +17,7 @@ type ExternalAccountBinding struct {
} }
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
@ -48,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
return nil, acmeErr return nil, acmeErr
} }
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) db := acme.MustDatabaseFromContext(ctx)
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
if err != nil { if err != nil {
if _, ok := err.(*acme.Error); ok { if _, ok := err.(*acme.Error); ok {
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
@ -111,7 +112,6 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
// o The "nonce" field MUST NOT be present // o The "nonce" field MUST NOT be present
// o The "url" field MUST be set to the same value as the outer JWS // o The "url" field MUST be set to the same value as the outer JWS
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
if jws == nil { if jws == nil {
return "", acme.NewErrorISE("no JWS provided") return "", acme.NewErrorISE("no JWS provided")
} }

View file

@ -14,7 +14,6 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
) )
func Test_keysAreEqual(t *testing.T) { func Test_keysAreEqual(t *testing.T) {
@ -100,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -145,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
return test{ return test{
@ -191,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
ctx: ctx, ctx: ctx,
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
@ -220,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -266,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -312,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -360,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -410,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -460,8 +445,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -510,8 +494,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
return test{ return test{
@ -568,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -616,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
boundAt := time.Now().Add(1 * time.Second) boundAt := time.Now().Add(1 * time.Second)
@ -676,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -734,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -789,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -845,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, nil) ctx := context.WithValue(context.Background(), jwkContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -873,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
db: tc.db, got, err := validateExternalAccountBinding(ctx, tc.nar)
}
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
wantErr := tc.err != nil wantErr := tc.err != nil
gotErr := err != nil gotErr := err != nil
if wantErr != gotErr { if wantErr != gotErr {

View file

@ -2,12 +2,10 @@ package api
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net"
"net/http" "net/http"
"time" "time"
@ -16,6 +14,7 @@ import (
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -39,111 +38,152 @@ type payloadInfo struct {
isEmptyJSON bool isEmptyJSON bool
} }
// Handler is the ACME API request handler.
type Handler struct {
db acme.DB
backdate provisioner.Duration
ca acme.CertificateAuthority
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
prerequisitesChecker func(ctx context.Context) (bool, error)
}
// HandlerOptions required to create a new ACME API request handler. // HandlerOptions required to create a new ACME API request handler.
type HandlerOptions struct { type HandlerOptions struct {
Backdate provisioner.Duration // DB storage backend that implements the acme.DB interface.
// DB storage backend that impements the acme.DB interface. //
// Deprecated: use acme.NewContex(context.Context, acme.DB)
DB acme.DB DB acme.DB
// CA is the certificate authority interface.
//
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
CA acme.CertificateAuthority
// Backdate is the duration that the CA will subtract from the current time
// to set the NotBefore in the certificate.
Backdate provisioner.Duration
// DNS the host used to generate accurate ACME links. By default the authority // DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if // will use the Host from the request, so this value will only be used if
// request.Host is empty. // request.Host is empty.
DNS string DNS string
// Prefix is a URL path prefix under which the ACME api is served. This // Prefix is a URL path prefix under which the ACME api is served. This
// prefix is required to generate accurate ACME links. // prefix is required to generate accurate ACME links.
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
// "acme" is the prefix from which the ACME api is accessed. // "acme" is the prefix from which the ACME api is accessed.
Prefix string Prefix string
CA acme.CertificateAuthority
// PrerequisitesChecker checks if all prerequisites for serving ACME are // PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration. // met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error) PrerequisitesChecker func(ctx context.Context) (bool, error)
} }
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return authority.MustFromContext(ctx)
}
// handler is the ACME API request handler.
type handler struct {
opts *HandlerOptions
}
// Route traffic and implement the Router interface. For backward compatibility
// this route adds will add a new middleware that will set the ACME components
// on the context.
//
// Note: this method is deprecated in step-ca, other applications can still use
// this to support ACME, but the recommendation is to use use
// api.Route(api.Router) and acme.NewContext() instead.
func (h *handler) Route(r api.Router) {
client := acme.NewClient()
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
route(r, func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
next(w, r.WithContext(ctx))
}
})
}
// NewHandler returns a new ACME API handler. // NewHandler returns a new ACME API handler.
func NewHandler(ops HandlerOptions) api.RouterHandler { //
transport := &http.Transport{ // Note: this method is deprecated in step-ca, other applications can still use
TLSClientConfig: &tls.Config{ // this to support ACME, but the recommendation is to use use
InsecureSkipVerify: true, // api.Route(api.Router) and acme.NewContext() instead.
}, func NewHandler(opts HandlerOptions) api.RouterHandler {
} return &handler{
client := http.Client{ opts: &opts,
Timeout: 30 * time.Second,
Transport: transport,
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
prerequisitesChecker := func(ctx context.Context) (bool, error) {
// by default all prerequisites are met
return true, nil
}
if ops.PrerequisitesChecker != nil {
prerequisitesChecker = ops.PrerequisitesChecker
}
return &Handler{
ca: ops.CA,
db: ops.DB,
backdate: ops.Backdate,
linker: NewLinker(ops.DNS, ops.Prefix),
validateChallengeOptions: &acme.ValidateChallengeOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
},
prerequisitesChecker: prerequisitesChecker,
} }
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface. This method requires that
func (h *Handler) Route(r api.Router) { // all the acme components, authority, db, client, linker, and prerequisite
getPath := h.linker.GetUnescapedPathSuffix // checker to be present in the context.
// Standard ACME API func Route(r api.Router) {
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) route(r, nil)
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) }
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
commonMiddleware := func(next nextHTTP) nextHTTP {
handler := func(w http.ResponseWriter, r *http.Request) {
// Linker middleware gets the provisioner and current url from the
// request and sets them in the context.
linker := acme.MustLinkerFromContext(r.Context())
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
}
if middleware != nil {
handler = middleware(handler)
}
return handler
}
validatingMiddleware := func(next nextHTTP) nextHTTP { validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
} }
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKid := func(next nextHTTP) nextHTTP { extractPayloadByKid := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
} }
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) getPath := acme.GetUnescapedPathSuffix
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) // Standard ACME API
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
extractPayloadByJWK(NewAccount))
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(GetOrUpdateAccount))
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(NotImplemented))
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
extractPayloadByKid(NewOrder))
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(isPostAsGet(GetOrder)))
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(FinalizeOrder))
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
extractPayloadByKid(isPostAsGet(GetAuthorization)))
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
extractPayloadByKid(GetChallenge))
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
extractPayloadByKid(isPostAsGet(GetCertificate)))
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
extractPayloadByKidOrJWK(RevokeCert))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response
// by middleware by default. // by middleware by default.
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { func GetNonce(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" { if r.Method == "HEAD" {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
// GetDirectory is the ACME resource for returning a directory configuration // GetDirectory is the ACME resource for returning a directory configuration
// for client configuration. // for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
return return
} }
linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: Meta{ Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB, ExternalAccountRequired: acmeProv.RequireEAB,
}, },
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func NotImplemented(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { func GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, h.db); err != nil { if err = az.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
h.linker.LinkAuthorization(ctx, az) linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az) render.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { func GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
// we'll just ignore the body. // we'll just ignore the body.
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, db, jwk); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
h.linker.LinkChallenge(ctx, ch, azID) linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch) render.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { func GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID) certID := chi.URLParam(r, "certID")
cert, err := db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return return

View file

@ -3,6 +3,7 @@ package api
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
@ -19,11 +20,33 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return a
}
}
func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetNonce(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
w := httptest.NewRecorder() w := httptest.NewRecorder()
req.Method = tt.name req.Method = tt.name
h.GetNonce(w, req) GetNonce(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
} }
func TestHandler_GetDirectory(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("ca.smallstep.com", "acme") linker := acme.NewLinker("ca.smallstep.com", "acme")
_ = linker
type test struct { type test struct {
ctx context.Context ctx context.Context
statusCode int statusCode int
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), err: acme.NewErrorISE("provisioner is not in context"),
} }
}, },
"fail/different-provisioner": func(t *testing.T) test { "fail/different-provisioner": func(t *testing.T) test {
prov := &provisioner.SCEP{ ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: linker} ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetDirectory(w, req) GetDirectory(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAuthorization(w, req) GetAuthorization(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetCertificate(w, req) GetCertificate(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
type test struct { type test struct {
db acme.DB db acme.DB
vco *acme.ValidateChallengeOptions vc acme.Client
ctx context.Context ctx context.Context
statusCode int statusCode int
ch *acme.Challenge ch *acme.Challenge
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), accContextKey, nil), ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/db.GetChallenge-error": func(t *testing.T) test { "fail/db.GetChallenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/no-jwk": func(t *testing.T) test { "fail/no-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-jwk": func(t *testing.T) test { "fail/nil-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/validate-challenge-error": func(t *testing.T) test { "fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
return acme.NewErrorISE("force") return acme.NewErrorISE("force")
}, },
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := _jwk.Public() _pub := _jwk.Public()
ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, jwkContextKey, &_pub)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
URL: u, URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"), Error: acme.NewError(acme.ErrorConnectionType, "force"),
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetChallenge(w, req) GetChallenge(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
} }
} }
// baseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in
// constructing ACME link URLs.
func baseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header. // addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP { func addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context()) db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// addDirLink is a middleware that adds a 'Link' response reader with the // addDirLink is a middleware that adds a 'Link' response reader with the
// directory index url. // directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP { func addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) ctx := r.Context()
linker := acme.MustLinkerFromContext(ctx)
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
next(w, r) next(w, r)
} }
} }
// verifyContentType is a middleware that verifies that content type is // verifyContentType is a middleware that verifies that content type is
// application/jose+json. // application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var expected []string
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} u := &url.URL{
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
}
var expected []string
if strings.Contains(r.URL.String(), u.EscapedPath()) { if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
} }
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. // parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP { func parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -149,10 +126,12 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// * “nonce” (defined in Section 6.5) // * “nonce” (defined in Section 6.5)
// * “url” (defined in Section 6.4) // * “url” (defined in Section 6.4)
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste> // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// extractJWK is a middleware that extracts the JWK from the JWS and saves it // extractJWK is a middleware that extracts the JWK from the JWS and saves it
// in the context. Make sure to parse and validate the JWS before running this // in the context. Make sure to parse and validate the JWS before running this
// middleware. // middleware.
func (h *Handler) extractJWK(next nextHTTP) nextHTTP { func extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key // Get Account OR continue to generate a new one OR continue Revoke with certificate private key
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
switch { switch {
case errors.Is(err, acme.ErrNotFound): case errors.Is(err, acme.ErrNotFound):
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
} }
} }
// lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := h.ca.LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx))
}
}
// checkPrerequisites checks if all prerequisites for serving ACME // checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration. // are met by the CA configuration.
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx) // If the function is not set assume that all prerequisites are met.
if err != nil { checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) if ok {
return ok, err := checkFunc(ctx)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
return
}
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
} }
if !ok { next(w, r)
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
next(w, r.WithContext(ctx))
} }
} }
// lookupJWK loads the JWK associated with the acme account referenced by the // lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload. // kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
} }
accID := strings.TrimPrefix(kid, kidPrefix) accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.db.GetAccount(ctx, accID) acc, err := db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// extractOrLookupJWK forwards handling to either extractJWK or // extractOrLookupJWK forwards handling to either extractJWK or
// lookupJWK based on the presence of a JWK or a KID, respectively. // lookupJWK based on the presence of a JWK or a KID, respectively.
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { func extractOrLookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
// and it can be used to check if a JWK exists. This flow is used when the ACME client // and it can be used to check if a JWK exists. This flow is used when the ACME client
// signed the payload with a certificate private key. // signed the payload with a certificate private key.
if canExtractJWKFrom(jws) { if canExtractJWKFrom(jws) {
h.extractJWK(next)(w, r) extractJWK(next)(w, r)
return return
} }
// default to looking up the JWK based on KeyID. This flow is used when the ACME client // default to looking up the JWK based on KeyID. This flow is used when the ACME client
// signed the payload with an account private key. // signed the payload with an account private key.
h.lookupJWK(next)(w, r) lookupJWK(next)(w, r)
} }
} }
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
} }
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { func isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
@ -462,16 +424,12 @@ type ContextKey string
const ( const (
// accContextKey account key // accContextKey account key
accContextKey = ContextKey("acc") accContextKey = ContextKey("acc")
// baseURLContextKey baseURL key
baseURLContextKey = ContextKey("baseURL")
// jwsContextKey jws key // jwsContextKey jws key
jwsContextKey = ContextKey("jws") jwsContextKey = ContextKey("jws")
// jwkContextKey jwk key // jwkContextKey jwk key
jwkContextKey = ContextKey("jwk") jwkContextKey = ContextKey("jwk")
// payloadContextKey payload key // payloadContextKey payload key
payloadContextKey = ContextKey("payload") payloadContextKey = ContextKey("payload")
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
) )
// accountFromContext searches the context for an ACME account. Returns the // accountFromContext searches the context for an ACME account. Returns the
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
return val, nil return val, nil
} }
// baseURLFromContext returns the baseURL if one is stored in the context.
func baseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// jwkFromContext searches the context for a JWK. Returns the JWK or an error. // jwkFromContext searches the context for a JWK. Returns the JWK or an error.
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
// provisionerFromContext searches the context for a provisioner. Returns the // provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error. // provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey) p, ok := acme.ProvisionerFromContext(ctx)
if val == nil { if !ok || p == nil {
return nil, acme.NewErrorISE("provisioner expected in request context") return nil, acme.NewErrorISE("provisioner expected in request context")
} }
pval, ok := val.(acme.Provisioner) return p, nil
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
} }
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error. // pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
prov, err := provisionerFromContext(ctx) p, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
acmeProv, ok := prov.(*provisioner.ACME) ap, ok := p.(*provisioner.ACME)
if !ok || acmeProv == nil { if !ok {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
} }
return acmeProv, nil
return ap, nil
} }
// payloadFromContext searches the context for a payload. Returns the payload // payloadFromContext searches the context for a payload. Returns the payload

View file

@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
} }
func Test_baseURLFromRequest(t *testing.T) { func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
tests := []struct { for _, a := range args {
name string switch v := a.(type) {
targetURL string case acme.DB:
expectedResult *url.URL ctx = acme.NewDatabaseContext(ctx, v)
requestPreparer func(*http.Request) case acme.Linker:
}{ ctx = acme.NewLinkerContext(ctx, v)
{ case acme.PrerequisitesChecker:
"HTTPS host pass-through failed.", ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
"https://my.dummy.host",
&url.URL{Scheme: "https", Host: "my.dummy.host"},
nil,
},
{
"Port pass-through failed",
"https://host.with.port:8080",
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
nil,
},
{
"Explicit host from Request.Host was not used.",
"https://some.target.host:8080",
&url.URL{Scheme: "https", Host: "proxied.host"},
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Missing Request.Host value did not result in empty string result.",
"https://some.host",
nil,
func(r *http.Request) {
r.Host = ""
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest("GET", tc.targetURL, nil)
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := baseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
}
})
}
}
func TestHandler_baseURLFromRequest(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
bu := baseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
} }
} }
return ctx
h.baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
} }
func TestHandler_addNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addNonce(testNext)(w, req) addNonce(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
link string link string
linker Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
return test{ return test{
linker: NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200, statusCode: 200,
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addDirLink(testNext)(w, req) addDirLink(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler
ctx context.Context ctx context.Context
contentType string contentType string
err *acme.Error err *acme.Error
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/provisioner-not-set": func(t *testing.T) test { "fail/provisioner-not-set": func(t *testing.T) test {
return test{ return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.Background(), ctx: context.Background(),
contentType: "foo", contentType: "foo",
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/general-bad-content-type": func(t *testing.T) test { "fail/general-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/certificate-bad-content-type": func(t *testing.T) test { "fail/certificate-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkix-cert": func(t *testing.T) test { "ok/certificate/pkix-cert": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkix-cert", contentType: "application/pkix-cert",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/jose+json": func(t *testing.T) test { "ok/certificate/jose+json": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkcs7-mime": func(t *testing.T) test { "ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkcs7-mime", contentType: "application/pkcs7-mime",
statusCode: 200, statusCode: 200,
} }
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType) req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder() w := httptest.NewRecorder()
tc.h.verifyContentType(testNext)(w, req) verifyContentType(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req) isPostAsGet(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, tc.body) req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req) parseJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req) verifyAndExtractJWSPayload(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
type test struct { type test struct {
linker Linker linker acme.Linker
db acme.DB db acme.DB
ctx context.Context ctx context.Context
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz")) _jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed) ctx = context.WithValue(ctx, jwsContextKey, _parsed)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
} }
}, },
"fail/account-not-found": func(t *testing.T) test { "fail/account-not-found": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
} }
}, },
"fail/GetAccount-error": func(t *testing.T) test { "fail/GetAccount-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk} acc := &acme.Account{Status: "valid", Key: jwk}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req) lookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
} }
}, },
"fail/GetAccountByKey-error": func(t *testing.T) test { "fail/GetAccountByKey-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"} acc := &acme.Account{Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
} }
}, },
"ok/no-account": func(t *testing.T) test { "ok/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req) extractJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, nil), ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/no-signature": func(t *testing.T) test { "fail/no-signature": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req) validateJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account" u := "https://ca.smallstep.com/acme/account"
type test struct { type test struct {
db acme.DB db acme.DB
linker Linker linker acme.Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID) assert.Equals(t, kid, pub.KeyID)
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
linker: NewLinker("test.ca.smallstep.com", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req) extractOrLookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
u := fmt.Sprintf("%s/acme/%s/account/1234", u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName) baseURL, provName)
type test struct { type test struct {
linker Linker linker acme.Linker
ctx context.Context ctx context.Context
prerequisitesChecker func(context.Context) (bool, error) prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test { "fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"fail/prerequisites-nok": func(t *testing.T) test { "fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req) checkPrerequisites(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -72,8 +72,12 @@ var defaultOrderExpiry = time.Hour * 24
var defaultOrderBackdate = time.Minute var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order. // NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
ca := mustAuthority(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -113,7 +117,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
var eak *acme.ExternalAccountKey var eak *acme.ExternalAccountKey
if acmeProv.RequireEAB { if acmeProv.RequireEAB {
if eak, err = h.db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil { if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
return return
} }
@ -138,7 +142,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
// evaluate the authority level policy // evaluate the authority level policy
if err = h.ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil { if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return return
} }
@ -164,7 +168,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ExpiresAt: o.ExpiresAt, ExpiresAt: o.ExpiresAt,
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.newAuthorization(ctx, az); err != nil { if err := newAuthorization(ctx, az); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
@ -183,14 +187,14 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
} }
if err := h.db.CreateOrder(ctx, o); err != nil { if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, o, http.StatusCreated)
} }
@ -208,7 +212,7 @@ func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error
return policy.NewX509PolicyEngine(eak.Policy) return policy.NewX509PolicyEngine(eak.Policy)
} }
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { func newAuthorization(ctx context.Context, az *acme.Authorization) error {
if strings.HasPrefix(az.Identifier.Value, "*.") { if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true az.Wildcard = true
az.Identifier = acme.Identifier{ az.Identifier = acme.Identifier{
@ -224,6 +228,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
if err != nil { if err != nil {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID") return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
} }
db := acme.MustDatabaseFromContext(ctx)
az.Challenges = make([]*acme.Challenge, len(chTypes)) az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes { for i, typ := range chTypes {
ch := &acme.Challenge{ ch := &acme.Challenge{
@ -233,20 +239,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
Token: az.Token, Token: az.Token,
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.db.CreateChallenge(ctx, ch); err != nil { if err := db.CreateChallenge(ctx, ch); err != nil {
return acme.WrapErrorISE(err, "error creating challenge") return acme.WrapErrorISE(err, "error creating challenge")
} }
az.Challenges[i] = ch az.Challenges[i] = ch
} }
if err = h.db.CreateAuthorization(ctx, az); err != nil { if err = db.CreateAuthorization(ctx, az); err != nil {
return acme.WrapErrorISE(err, "error creating authorization") return acme.WrapErrorISE(err, "error creating authorization")
} }
return nil return nil
} }
// GetOrder ACME api for retrieving an order. // GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -257,7 +266,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
@ -272,20 +282,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, h.db); err != nil { if err = o.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attempts to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -312,7 +325,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
@ -327,14 +340,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
ca := mustAuthority(ctx)
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }

View file

@ -280,15 +280,17 @@ func TestHandler_GetOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -298,6 +300,7 @@ func TestHandler_GetOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -305,9 +308,10 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), nil)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -315,7 +319,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -329,7 +333,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -345,7 +349,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -361,7 +365,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/order-update-error": func(t *testing.T) test { "fail/order-update-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -385,10 +389,9 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
@ -425,11 +428,11 @@ func TestHandler_GetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrder(w, req) GetOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -640,8 +643,8 @@ func TestHandler_newAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
if err := h.newAuthorization(context.Background(), tc.az); err != nil { if err := newAuthorization(ctx, tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
@ -682,15 +685,17 @@ func TestHandler_NewOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -700,6 +705,7 @@ func TestHandler_NewOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -707,9 +713,10 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -718,8 +725,9 @@ func TestHandler_NewOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -727,21 +735,23 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("payload does not exist"),
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
@ -752,10 +762,11 @@ func TestHandler_NewOrder(t *testing.T) {
fr := &NewOrderRequest{} fr := &NewOrderRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
@ -770,7 +781,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, &acme.MockProvisioner{}) ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -798,7 +809,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -826,7 +837,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -862,7 +873,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, acmeProv) ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -905,7 +916,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -948,7 +959,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -986,7 +997,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -1020,7 +1031,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var ( var (
@ -1096,10 +1107,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3, ch4 **acme.Challenge ch1, ch2, ch3, ch4 **acme.Challenge
az1ID, az2ID *string az1ID, az2ID *string
@ -1217,10 +1227,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1315,10 +1324,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1412,10 +1420,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1510,10 +1517,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1611,10 +1617,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, provWithPolicy) ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1701,11 +1706,12 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} mockMustAuthority(t, tc.ca)
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewOrder(w, req) NewOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1738,6 +1744,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
func TestHandler_FinalizeOrder(t *testing.T) { func TestHandler_FinalizeOrder(t *testing.T) {
mockMustAuthority(t, &mockCA{})
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -1796,15 +1803,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -1814,6 +1823,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1821,9 +1831,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1832,8 +1843,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -1841,21 +1853,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("payload does not exist"),
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
@ -1866,10 +1880,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{} fr := &FinalizeRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
@ -1878,7 +1893,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1893,7 +1908,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1910,7 +1925,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1927,7 +1942,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/order-finalize-error": func(t *testing.T) test { "fail/order-finalize-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1952,10 +1967,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -1991,11 +2005,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.FinalizeOrder(w, req) FinalizeOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -26,9 +26,11 @@ type revokePayload struct {
} }
// RevokeCert attempts to revoke a certificate. // RevokeCert attempts to revoke a certificate.
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
render.Error(w, acmeErr) render.Error(w, acmeErr)
return return
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
} }
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) ca := mustAuthority(ctx)
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options) err = ca.Revoke(ctx, options)
if err != nil { if err != nil {
render.Error(w, wrapRevokeErr(err)) render.Error(w, wrapRevokeErr(err))
return return
} }
logRevoke(w, options) logRevoke(w, options)
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
w.Write(nil) w.Write(nil)
} }
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations // the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized // that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error. // to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() { if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
} }

View file

@ -521,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
ctx := context.Background() ctx := context.Background()
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -529,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, nil) ctx := context.WithValue(context.Background(), jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -537,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -544,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, nil) ctx = acme.NewProvisionerContext(ctx, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -553,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -562,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -573,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/unmarshal-payload": func(t *testing.T) test { "fail/unmarshal-payload": func(t *testing.T) test {
malformedPayload := []byte(`{"payload":malformed?}`) malformedPayload := []byte(`{"payload":malformed?}`)
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("error unmarshaling payload"), err: acme.NewErrorISE("error unmarshaling payload"),
@ -587,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -606,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
emptyPayloadBytes, err := json.Marshal(emptyPayload) emptyPayloadBytes, err := json.Marshal(emptyPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -620,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/db.GetCertificateBySerial": func(t *testing.T) test { "fail/db.GetCertificateBySerial": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -638,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/different-certificate-contents": func(t *testing.T) test { "fail/different-certificate-contents": func(t *testing.T) test {
aDifferentCert, _, err := generateCertKeyPair() aDifferentCert, _, err := generateCertKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -657,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -676,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
@ -697,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -727,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-authorized": func(t *testing.T) test { "fail/account-not-authorized": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -781,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -808,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-revoked-check-fails": func(t *testing.T) test { "fail/certificate-revoked-check-fails": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -842,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-already-revoked": func(t *testing.T) test { "fail/certificate-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -880,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -918,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
} }
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -950,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke": func(t *testing.T) test { "fail/ca.Revoke": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -982,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke-already-revoked": func(t *testing.T) test { "fail/ca.Revoke-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -1013,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"ok/using-account-key": func(t *testing.T) test { "ok/using-account-key": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1041,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes)) jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1067,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
for name, setup := range tests { for name, setup := range tests {
tc := setup(t) tc := setup(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
mockMustAuthority(t, tc.ca)
req := httptest.NewRequest("POST", revokeURL, nil) req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.RevokeCert(w, req) RevokeCert(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1208,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
for name, setup := range tests { for name, setup := range tests {
tc := setup(t) tc := setup(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} // h := &Handler{db: tc.db}
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
expectError := tc.err != nil expectError := tc.err != nil
gotError := acmeErr != nil gotError := acmeErr != nil

View file

@ -14,7 +14,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface. // type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are // satisfactorily validated, the 'status' and 'validated' attributes are
// updated. // updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
// If already valid or invalid then return without performing validation. // If already valid or invalid then return without performing validation.
if ch.Status != StatusPending { if ch.Status != StatusPending {
return nil return nil
} }
switch ch.Type { switch ch.Type {
case HTTP01: case HTTP01:
return http01Validate(ctx, ch, db, jwk, vo) return http01Validate(ctx, ch, db, jwk)
case DNS01: case DNS01:
return dns01Validate(ctx, ch, db, jwk, vo) return dns01Validate(ctx, ch, db, jwk)
case TLSALPN01: case TLSALPN01:
return tlsalpn01Validate(ctx, ch, db, jwk, vo) return tlsalpn01Validate(ctx, ch, db, jwk)
default: default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type) return NewErrorISE("unexpected challenge type '%s'", ch.Type)
} }
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String()) vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", u)) "error doing http GET for url %s", u))
@ -141,7 +141,7 @@ func tlsAlert(err error) uint8 {
return 0 return 0
} }
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
// https://tools.ietf.org/html/rfc8737#section-4 // https://tools.ietf.org/html/rfc8737#section-4
@ -154,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443") hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.TLSDial("tcp", hostPort, config) vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config)
if err != nil { if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured // With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is // client and server protocols. When this happens the connection is
@ -253,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
} }
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
// Normalize domain for wildcard DNS names // Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like // This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com // _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com // Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) vc := MustClientFromContext(ctx)
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain)) "error looking up TXT records for domain %s", domain))
@ -376,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
} }
return nil return nil
} }
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
// ValidateChallengeOptions are ACME challenge validator functions.
type ValidateChallengeOptions struct {
HTTPGet httpGetter
LookupTxt lookupTxt
TLSDial tlsDialer
}

View file

@ -29,6 +29,18 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func Test_storeError(t *testing.T) { func Test_storeError(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
func TestChallenge_Validate(t *testing.T) { func TestChallenge_Validate(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
vo *ValidateChallengeOptions vc Client
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
srv *httptest.Server srv *httptest.Server
@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
} }
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -524,7 +537,7 @@ func (errReader) Close() error {
func TestHTTP01Validate(t *testing.T) { func TestHTTP01Validate(t *testing.T) {
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: errReader(0), Body: errReader(0),
}, nil }, nil
@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
jwk.Key = "foo" jwk.Key = "foo"
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
fulldomain := "*.zap.internal" fulldomain := "*.zap.internal"
domain := strings.TrimPrefix(fulldomain, "*.") domain := strings.TrimPrefix(fulldomain, "*.")
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo"}, nil return []string{"foo"}, nil
}, },
}, },
@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
} }
} }
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
srv := httptest.NewUnstartedServer(http.NewServeMux()) srv := httptest.NewUnstartedServer(http.NewServeMux())
@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
} }
} }
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
srv: srv, srv: srv,
jwk: jwk, jwk: jwk,
@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:

79
acme/client.go Normal file
View file

@ -0,0 +1,79 @@
package acme
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// Client is the interface used to verify ACME challenges.
type Client interface {
// Get issues an HTTP GET to the specified URL.
Get(url string) (*http.Response, error)
// LookupTXT returns the DNS TXT records for the given domain name.
LookupTxt(name string) ([]string, error)
// TLSDial connects to the given network address using net.Dialer and then
// initiates a TLS handshake, returning the resulting TLS connection.
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
}
type clientKey struct{}
// NewClientContext adds the given client to the context.
func NewClientContext(ctx context.Context, c Client) context.Context {
return context.WithValue(ctx, clientKey{}, c)
}
// ClientFromContext returns the current client from the given context.
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
c, ok = ctx.Value(clientKey{}).(Client)
return
}
// MustClientFromContext returns the current client from the given context. It will
// return a new instance of the client if it does not exist.
func MustClientFromContext(ctx context.Context) Client {
c, ok := ClientFromContext(ctx)
if !ok {
return NewClient()
}
return c
}
type client struct {
http *http.Client
dialer *net.Dialer
}
// NewClient returns an implementation of Client for verifying ACME challenges.
func NewClient() Client {
return &client{
http: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
dialer: &net.Dialer{
Timeout: 30 * time.Second,
},
}
}
func (c *client) Get(url string) (*http.Response, error) {
return c.http.Get(url)
}
func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name)
}
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(c.dialer, network, addr, config)
}

View file

@ -9,15 +9,6 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.
type Clock struct{} type Clock struct{}
@ -28,6 +19,52 @@ func (c *Clock) Now() time.Time {
var clock Clock var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error)
}
// NewContext adds the given acme components to the context.
func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// PrerequisitesChecker is a function that checks if all prerequisites for
// serving ACME are met by the CA configuration.
type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil
}
type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
@ -40,6 +77,29 @@ type Provisioner interface {
GetOptions() *provisioner.Options GetOptions() *provisioner.Options
} }
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustLinkerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
if v, ok := ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
} else {
return v
}
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1 interface{} Mret1 interface{}

View file

@ -49,6 +49,29 @@ type DB interface {
UpdateOrder(ctx context.Context, o *Order) error UpdateOrder(ctx context.Context, o *Order) error
} }
type dbKey struct{}
// NewDatabaseContext adds the given acme database to the context.
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// DatabaseFromContext returns the current acme database from the given context.
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustDatabaseFromContext returns the current database from the given context.
// It will panic if it's not in the context.
func MustDatabaseFromContext(ctx context.Context) DB {
if db, ok := DatabaseFromContext(ctx); !ok {
panic("acme database is not in the context")
} else {
return db
}
}
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
// a mock in tests. // a mock in tests.
type MockDB struct { type MockDB struct {

View file

@ -1,100 +1,19 @@
package api package acme
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/smallstep/certificates/acme" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
) )
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account)
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
LinkAuthorization(ctx context.Context, o *acme.Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// GetLink is a helper for GetLinkExplicit
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
u = *baseURL
}
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + u.Path
return u.String()
}
// LinkType captures the link type. // LinkType captures the link type.
type LinkType int type LinkType int
@ -160,8 +79,155 @@ func (l LinkType) String() string {
} }
} }
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
Middleware(http.Handler) http.Handler
LinkOrder(ctx context.Context, o *Order)
LinkAccount(ctx context.Context, o *Account)
LinkChallenge(ctx context.Context, o *Challenge, azID string)
LinkAuthorization(ctx context.Context, o *Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
type linkerKey struct{}
// NewLinkerContext adds the given linker to the context.
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
return context.WithValue(ctx, linkerKey{}, v)
}
// LinkerFromContext returns the current linker from the given context.
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
v, ok = ctx.Value(linkerKey{}).(Linker)
return
}
// MustLinkerFromContext returns the current linker from the given context. It
// will panic if it's not in the context.
func MustLinkerFromContext(ctx context.Context) Linker {
if v, ok := LinkerFromContext(ctx); !ok {
panic("acme linker is not the context")
} else {
return v
}
}
type baseURLKey struct{}
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
var u *url.URL
if r.Host != "" {
u = &url.URL{Scheme: "https", Host: r.Host}
}
return context.WithValue(ctx, baseURLKey{}, u)
}
func baseURLFromContext(ctx context.Context) *url.URL {
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
return u
}
return nil
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
// Middleware gets the provisioner and current url from the request and sets
// them in the context so we can use the linker to create ACME links.
func (l *linker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add base url to the context.
ctx := newBaseURLContext(r.Context(), r)
// Add provisioner to the context.
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var name string
if p, ok := ProvisionerFromContext(ctx); ok {
name = p.GetName()
}
var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL
}
if u.Scheme == "" {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
return u.String()
}
// LinkOrder sets the ACME links required by an ACME order. // LinkOrder sets the ACME links required by an ACME order.
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { func (l *linker) LinkOrder(ctx context.Context, o *Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs { for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
} }
// LinkAccount sets the ACME links required by an ACME account. // LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
} }
// LinkChallenge sets the ACME links required by an ACME challenge. // LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
} }
// LinkAuthorization sets the ACME links required by an ACME authorization. // LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
for _, ch := range az.Challenges { for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch, az.ID) l.LinkChallenge(ctx, ch, az.ID)
} }

View file

@ -1,21 +1,38 @@
package api package acme
import ( import (
"context" "context"
"fmt" "fmt"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner"
) )
func TestLinker_GetUnescapedPathSuffix(t *testing.T) { func mockProvisioner(t *testing.T) Provisioner {
dns := "ca.smallstep.com" t.Helper()
prefix := "acme" var defaultDisableRenewal = false
linker := NewLinker(dns, prefix)
getPath := linker.GetUnescapedPathSuffix // Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func TestGetUnescapedPathSuffix(t *testing.T) {
getPath := GetUnescapedPathSuffix
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
} }
func TestLinker_DNS(t *testing.T) { func TestLinker_DNS(t *testing.T) {
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
type test struct { type test struct {
name string name string
dns string dns string
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
linker := NewLinker(dns, prefix) linker := NewLinker(dns, prefix)
id := "1234" id := "1234"
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
// No provisioner and no BaseURL from request // No provisioner and no BaseURL from request
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
// Provisioner: yes, BaseURL: no // Provisioner: yes, BaseURL: no
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
// Provisioner: no, BaseURL: yes // Provisioner: no, BaseURL: yes
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
oid := "orderID" oid := "orderID"
certID := "certID" certID := "certID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
o *acme.Order o *Order
validate func(o *acme.Order) validate func(o *Order)
} }
var tests = map[string]test{ var tests = map[string]test{
"no-authz-and-no-cert": { "no-authz-and-no-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "") assert.Equals(t, o.CertificateURL, "")
}, },
}, },
"one-authz-and-cert": { "one-authz-and-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo"}, AuthorizationIDs: []string{"foo"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
}, },
}, },
"many-authz": { "many-authz": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"}, AuthorizationIDs: []string{"foo", "bar", "zap"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
accID := "accountID" accID := "accountID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
a *acme.Account a *Account
validate func(o *acme.Account) validate func(o *Account)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
a: &acme.Account{ a: &Account{
ID: accID, ID: accID,
}, },
validate: func(a *acme.Account) { validate: func(a *Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
}, },
}, },
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID := "chID" chID := "chID"
azID := "azID" azID := "azID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
ch *acme.Challenge ch *Challenge
validate func(o *acme.Challenge) validate func(o *Challenge)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
ch: &acme.Challenge{ ch: &Challenge{
ID: chID, ID: chID,
}, },
validate: func(ch *acme.Challenge) { validate: func(ch *Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
}, },
}, },
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID0 := "chID-0" chID0 := "chID-0"
chID1 := "chID-1" chID1 := "chID-1"
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
az *acme.Authorization az *Authorization
validate func(o *acme.Authorization) validate func(o *Authorization)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
az: &acme.Authorization{ az: &Authorization{
ID: azID, ID: azID,
Challenges: []*acme.Challenge{ Challenges: []*Challenge{
{ID: chID0}, {ID: chID0},
{ID: chID1}, {ID: chID1},
{ID: chID2}, {ID: chID2},
}, },
}, },
validate: func(az *acme.Authorization) { validate: func(az *Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
func TestLinker_LinkOrdersByAccountID(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)

View file

@ -35,7 +35,6 @@ type Authority interface {
SSHAuthority SSHAuthority
// context specifies the Authorize[Sign|Revoke|etc.] method. // context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
@ -52,6 +51,11 @@ type Authority interface {
Version() authority.Version Version() authority.Version
} }
// mustAuthority will be replaced on unit tests.
var mustAuthority = func(ctx context.Context) Authority {
return authority.MustFromContext(ctx)
}
// TimeDuration is an alias of provisioner.TimeDuration // TimeDuration is an alias of provisioner.TimeDuration
type TimeDuration = provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration
@ -243,48 +247,53 @@ type caHandler struct {
Authority Authority Authority Authority
} }
// New creates a new RouterHandler with the CA endpoints. // Route configures the http request router.
func New(auth Authority) RouterHandler { func (h *caHandler) Route(r Router) {
return &caHandler{ Route(r)
Authority: auth,
}
} }
func (h *caHandler) Route(r Router) { // New creates a new RouterHandler with the CA endpoints.
r.MethodFunc("GET", "/version", h.Version) //
r.MethodFunc("GET", "/health", h.Health) // Deprecated: Use api.Route(r Router)
r.MethodFunc("GET", "/root/{sha}", h.Root) func New(auth Authority) RouterHandler {
r.MethodFunc("POST", "/sign", h.Sign) return &caHandler{}
r.MethodFunc("POST", "/renew", h.Renew) }
r.MethodFunc("POST", "/rekey", h.Rekey)
r.MethodFunc("POST", "/revoke", h.Revoke) func Route(r Router) {
r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/version", Version)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/health", Health)
r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/root/{sha}", Root)
r.MethodFunc("GET", "/roots.pem", h.RootsPEM) r.MethodFunc("POST", "/sign", Sign)
r.MethodFunc("GET", "/federation", h.Federation) r.MethodFunc("POST", "/renew", Renew)
r.MethodFunc("POST", "/rekey", Rekey)
r.MethodFunc("POST", "/revoke", Revoke)
r.MethodFunc("GET", "/provisioners", Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
r.MethodFunc("GET", "/roots", Roots)
r.MethodFunc("GET", "/roots.pem", RootsPEM)
r.MethodFunc("GET", "/federation", Federation)
// SSH CA // SSH CA
r.MethodFunc("POST", "/ssh/sign", h.SSHSign) r.MethodFunc("POST", "/ssh/sign", SSHSign)
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) r.MethodFunc("POST", "/ssh/renew", SSHRenew)
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) r.MethodFunc("GET", "/ssh/roots", SSHRoots)
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) r.MethodFunc("GET", "/ssh/federation", SSHFederation)
r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config", SSHConfig)
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
// For compatibility with old code: // For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew) r.MethodFunc("POST", "/re-sign", Renew)
r.MethodFunc("POST", "/sign-ssh", h.SSHSign) r.MethodFunc("POST", "/sign-ssh", SSHSign)
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
} }
// Version is an HTTP handler that returns the version of the server. // Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { func Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version() v := mustAuthority(r.Context()).Version()
render.JSON(w, VersionResponse{ render.JSON(w, VersionResponse{
Version: v.Version, Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication, RequireClientAuthentication: v.RequireClientAuthentication,
@ -292,17 +301,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
} }
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
render.JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, HealthResponse{Status: "ok"})
} }
// Root is an HTTP handler that using the SHA256 from the URL, returns the root // Root is an HTTP handler that using the SHA256 from the URL, returns the root
// certificate for the given SHA256. // certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { func Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha") sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := mustAuthority(r.Context()).Root(sum)
if err != nil { if err != nil {
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
@ -320,18 +329,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
} }
// Provisioners returns the list of provisioners configured in the authority. // Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &ProvisionersResponse{ render.JSON(w, &ProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
@ -339,19 +349,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
} }
// ProvisionerKey returns the encrypted key of a provisioner by it's key id. // ProvisionerKey returns the encrypted key of a provisioner by it's key id.
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
if err != nil { if err != nil {
render.Error(w, errs.NotFoundErr(err)) render.Error(w, errs.NotFoundErr(err))
return return
} }
render.JSON(w, &ProvisionerKeyResponse{key}) render.JSON(w, &ProvisionerKeyResponse{key})
} }
// Roots returns all the root certificates for the CA. // Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting roots")) render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return return
@ -368,8 +379,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
} }
// RootsPEM returns all the root certificates for the CA in PEM format. // RootsPEM returns all the root certificates for the CA in PEM format.
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { func RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -391,8 +402,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
} }
// Federation returns all the public certificates in the federation. // Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := mustAuthority(r.Context()).GetFederation()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return return

View file

@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr return csr
} }
func mockMustAuthority(t *testing.T, a Authority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) Authority {
return a
}
}
type mockAuthority struct { type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorizeSign func(ott string) ([]provisioner.SignOption, error) authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
@ -203,12 +214,8 @@ type mockAuthority struct {
// TODO: remove once Authorize is deprecated. // TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return m.AuthorizeSign(ott) if m.authorize != nil {
} return m.authorize(ctx, ott)
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
} }
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
@ -789,11 +796,10 @@ func Test_caHandler_Route(t *testing.T) {
} }
} }
func Test_caHandler_Health(t *testing.T) { func Test_Health(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/health", nil) req := httptest.NewRequest("GET", "http://example.com/health", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := New(&mockAuthority{}).(*caHandler) Health(w, req)
h.Health(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != 200 { if res.StatusCode != 200 {
@ -811,7 +817,7 @@ func Test_caHandler_Health(t *testing.T) {
} }
} }
func Test_caHandler_Root(t *testing.T) { func Test_Root(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
root *x509.Certificate root *x509.Certificate
@ -832,9 +838,9 @@ func Test_caHandler_Root(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Root(w, req) Root(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -855,7 +861,7 @@ func Test_caHandler_Root(t *testing.T) {
} }
} }
func Test_caHandler_Sign(t *testing.T) { func Test_Sign(t *testing.T) {
csr := parseCertificateRequest(csrPEM) csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(SignRequest{ valid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr}, CsrPEM: CertificateRequest{csr},
@ -896,18 +902,18 @@ func Test_caHandler_Sign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr, ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr return tt.certAttrOpts, tt.autherr
}, },
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Sign(logging.NewResponseLogger(w), req) Sign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -928,7 +934,7 @@ func Test_caHandler_Sign(t *testing.T) {
} }
} }
func Test_caHandler_Renew(t *testing.T) { func Test_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1018,7 +1024,7 @@ func Test_caHandler_Renew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
@ -1039,12 +1045,12 @@ func Test_caHandler_Renew(t *testing.T) {
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) Renew(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -1073,7 +1079,7 @@ func Test_caHandler_Renew(t *testing.T) {
} }
} }
func Test_caHandler_Rekey(t *testing.T) { func Test_Rekey(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1104,16 +1110,16 @@ func Test_caHandler_Rekey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Rekey(logging.NewResponseLogger(w), req) Rekey(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1134,7 +1140,7 @@ func Test_caHandler_Rekey(t *testing.T) {
} }
} }
func Test_caHandler_Provisioners(t *testing.T) { func Test_Provisioners(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1200,10 +1206,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, Provisioners(tt.args.w, tt.args.r)
}
h.Provisioners(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1238,7 +1242,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
} }
} }
func Test_caHandler_ProvisionerKey(t *testing.T) { func Test_ProvisionerKey(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1270,10 +1274,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, ProvisionerKey(tt.args.w, tt.args.r)
}
h.ProvisionerKey(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1298,7 +1300,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
} }
} }
func Test_caHandler_Roots(t *testing.T) { func Test_Roots(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1319,11 +1321,11 @@ func Test_caHandler_Roots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/roots", nil) req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Roots(w, req) Roots(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1360,10 +1362,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
req := httptest.NewRequest("GET", "https://example.com/roots", nil) req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.RootsPEM(w, req) RootsPEM(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1384,7 +1386,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
} }
} }
func Test_caHandler_Federation(t *testing.T) { func Test_Federation(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1405,11 +1407,11 @@ func Test_caHandler_Federation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/federation", nil) req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Federation(w, req) Federation(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {

View file

@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
} }
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
render.Error(w, errs.BadRequest("missing client certificate")) render.Error(w, errs.BadRequest("missing client certificate"))
return return
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
return return
} }
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) a := mustAuthority(r.Context())
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return return
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }

View file

@ -16,14 +16,15 @@ const (
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func Renew(w http.ResponseWriter, r *http.Request) {
cert, err := h.getPeerCertificate(r) cert, err := getPeerCertificate(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
certChain, err := h.Authority.Renew(cert) a := mustAuthority(r.Context())
certChain, err := a.Renew(cert)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil return r.TLS.PeerCertificates[0], nil
} }
if s := r.Header.Get(authorizationHeader); s != "" { if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) ctx := r.Context()
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
} }
} }
return nil, errs.BadRequest("missing client certificate") return nil, errs.BadRequest("missing client certificate")

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
// //
// TODO: Add CRL and OCSP support. // TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive, PassiveOnly: body.Passive,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
opts.MTLS = true opts.MTLS = true
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }

View file

@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusOK, statusCode: http.StatusOK,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
statusCode: http.StatusOK, statusCode: http.StatusOK,
tls: cs, tls: cs,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusInternalServerError, statusCode: http.StatusInternalServerError,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusForbidden, statusCode: http.StatusForbidden,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
for name, _tc := range tests { for name, _tc := range tests {
tc := _tc(t) tc := _tc(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*caHandler) mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
if tc.tls != nil { if tc.tls != nil {
req.TLS = tc.tls req.TLS = tc.tls
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Revoke(logging.NewResponseLogger(w), req) Revoke(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -49,7 +49,7 @@ type SignResponse struct {
// Sign is an HTTP handler that reads a certificate request and an // Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the // one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request. // information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
TemplateData: body.TemplateData, TemplateData: body.TemplateData,
} }
signOpts, err := h.Authority.AuthorizeSign(body.OTT) ctx := r.Context()
a := mustAuthority(ctx)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }

View file

@ -250,7 +250,7 @@ type SSHBastionResponse struct {
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -289,13 +289,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
@ -303,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var addUserCertificate *SSHCertificate var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
@ -316,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if cr := body.IdentityCSR.CertificateRequest; cr != nil { if cr := body.IdentityCSR.CertificateRequest; cr != nil {
ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -328,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0), NotAfter: time.Unix(int64(cert.ValidBefore), 0),
}) })
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
@ -345,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host // SSHRoots is an HTTP handler that returns the SSH public keys for user and host
// certificates. // certificates.
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context()) ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -370,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
// for user and host certificates. // for user and host certificates.
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context()) ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -395,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients // SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers. // and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -406,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) ctx := r.Context()
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -427,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -438,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) ctx := r.Context()
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -449,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
} }
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
var cert *x509.Certificate var cert *x509.Certificate
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
cert = r.TLS.PeerCertificates[0] cert = r.TLS.PeerCertificates[0]
} }
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) ctx := r.Context()
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -466,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
} }
// SSHBastion provides returns the bastion configured if any. // SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -477,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return return
} }
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) ctx := r.Context()
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return

View file

@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -60,7 +60,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -71,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
return return
} }
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
@ -81,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return

View file

@ -37,7 +37,7 @@ type SSHRenewResponse struct {
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -52,7 +52,9 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
_, err := h.Authority.Authorize(ctx, body.OTT)
a := mustAuthority(ctx)
_, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -63,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
return return
} }
newCert, err := h.Authority.RenewSSH(ctx, oldCert) newCert, err := a.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
@ -73,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
@ -86,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
} }
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
return nil, nil return nil, nil
} }
@ -106,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
cert.NotAfter = notAfter cert.NotAfter = notAfter
} }
certChain, err := h.Authority.Renew(cert) certChain, err := mustAuthority(r.Context()).Renew(cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// Revoke supports handful of different methods that revoke a Certificate. // Revoke supports handful of different methods that revoke a Certificate.
// //
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }

View file

@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
} }
} }
func Test_caHandler_SSHSign(t *testing.T) { func Test_SSHSign(t *testing.T) {
user, err := getSignedUserCertificate() user, err := getSignedUserCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
host, err := getSignedHostCertificate() host, err := getSignedHostCertificate()
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return []provisioner.SignOption{}, tt.authErr return []provisioner.SignOption{}, tt.authErr
}, },
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr return tt.tlsSignCerts, tt.tlsSignErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHSign(logging.NewResponseLogger(w), req) SSHSign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
} }
func Test_caHandler_SSHRoots(t *testing.T) { func Test_SSHRoots(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHRoots(logging.NewResponseLogger(w), req) SSHRoots(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
} }
func Test_caHandler_SSHFederation(t *testing.T) { func Test_SSHFederation(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHFederation(logging.NewResponseLogger(w), req) SSHFederation(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
} }
func Test_caHandler_SSHConfig(t *testing.T) { func Test_SSHConfig(t *testing.T) {
userOutput := []templates.Output{ userOutput := []templates.Output{
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
return tt.output, tt.err return tt.output, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHConfig(logging.NewResponseLogger(w), req) SSHConfig(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
} }
func Test_caHandler_SSHCheckHost(t *testing.T) { func Test_SSHCheckHost(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
req string req string
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
return tt.exists, tt.err return tt.exists, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHCheckHost(logging.NewResponseLogger(w), req) SSHCheckHost(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
} }
func Test_caHandler_SSHGetHosts(t *testing.T) { func Test_SSHGetHosts(t *testing.T) {
hosts := []authority.Host{ hosts := []authority.Host{
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
return tt.hosts, tt.err return tt.hosts, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHGetHosts(logging.NewResponseLogger(w), req) SSHGetHosts(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
} }
func Test_caHandler_SSHBastion(t *testing.T) { func Test_SSHBastion(t *testing.T) {
bastion := &authority.Bastion{ bastion := &authority.Bastion{
Hostname: "bastion.local", Hostname: "bastion.local",
} }
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
return tt.bastion, tt.bastionErr return tt.bastion, tt.bastionErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHBastion(logging.NewResponseLogger(w), req) SSHBastion(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {

View file

@ -33,7 +33,7 @@ type GetExternalAccountKeysResponse struct {
// requireEABEnabled is a middleware that ensures ACME EAB is enabled // requireEABEnabled is a middleware that ensures ACME EAB is enabled
// before serving requests that act on ACME EAB credentials. // before serving requests that act on ACME EAB credentials.
func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc { func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
@ -53,32 +53,33 @@ func (h *Handler) requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
} }
} }
type acmeAdminResponderInterface interface { // ACMEAdminResponder is responsible for writing ACME admin responses
type ACMEAdminResponder interface {
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
} }
// ACMEAdminResponder is responsible for writing ACME admin responses // acmeAdminResponder implements ACMEAdminResponder.
type ACMEAdminResponder struct{} type acmeAdminResponder struct{}
// NewACMEAdminResponder returns a new ACMEAdminResponder // NewACMEAdminResponder returns a new ACMEAdminResponder
func NewACMEAdminResponder() *ACMEAdminResponder { func NewACMEAdminResponder() ACMEAdminResponder {
return &ACMEAdminResponder{} return &acmeAdminResponder{}
} }
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }

View file

@ -33,6 +33,17 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
return protojson.Unmarshal(data, m) return protojson.Unmarshal(data, m)
} }
func mockMustAuthority(t *testing.T, a adminAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) adminAuthority {
return a
}
}
func TestHandler_requireEABEnabled(t *testing.T) { func TestHandler_requireEABEnabled(t *testing.T) {
type test struct { type test struct {
ctx context.Context ctx context.Context
@ -117,12 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.requireEABEnabled(tc.next)(w, req) requireEABEnabled(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -85,10 +85,10 @@ type DeleteResponse struct {
} }
// GetAdmin returns the requested admin, or an error. // GetAdmin returns the requested admin, or an error.
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { func GetAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
adm, ok := h.auth.LoadAdminByID(id) adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
if !ok { if !ok {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id)) "admin %s not found", id))
@ -98,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
} }
// GetAdmins returns a segment of admins associated with the authority. // GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -106,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
return return
} }
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return return
@ -118,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
} }
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -130,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
p, err := h.auth.LoadProvisionerByName(body.Provisioner) auth := mustAuthority(r.Context())
p, err := auth.LoadProvisionerByName(body.Provisioner)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return return
@ -141,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
Type: body.Type, Type: body.Type,
} }
// Store to authority collection. // Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing admin")) render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return return
} }
@ -150,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
} }
// DeleteAdmin deletes admin. // DeleteAdmin deletes admin.
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return return
} }
@ -162,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -175,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
} }
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
auth := mustAuthority(r.Context())
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return return

View file

@ -352,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAdmin(w, req) GetAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -491,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAdmins(w, req) GetAdmins(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -675,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.CreateAdmin(w, req) CreateAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -767,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.DeleteAdmin(w, req) DeleteAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -912,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.UpdateAdmin(w, req) UpdateAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -1,50 +1,58 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
) )
// Handler is the Admin API request handler. // Handler is the Admin API request handler.
type Handler struct { type Handler struct {
adminDB admin.DB acmeResponder ACMEAdminResponder
auth adminAuthority policyResponder PolicyAdminResponder
acmeDB acme.DB }
acmeResponder acmeAdminResponderInterface
policyResponder policyAdminResponderInterface // Route traffic and implement the Router interface.
//
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func (h *Handler) Route(r api.Router) {
Route(r, h.acmeResponder, h.policyResponder)
} }
// NewHandler returns a new Authority Config Handler. // NewHandler returns a new Authority Config Handler.
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler { //
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
return &Handler{ return &Handler{
auth: auth,
adminDB: adminDB,
acmeDB: acmeDB,
acmeResponder: acmeResponder, acmeResponder: acmeResponder,
policyResponder: policyResponder, policyResponder: policyResponder,
} }
} }
// Route traffic and implement the Router interface. var mustAuthority = func(ctx context.Context) adminAuthority {
func (h *Handler) Route(r api.Router) { return authority.MustFromContext(ctx)
}
// Route traffic and implement the Router interface.
func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
authnz := func(next http.HandlerFunc) http.HandlerFunc { authnz := func(next http.HandlerFunc) http.HandlerFunc {
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
} }
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return h.checkAction(next, true) return checkAction(next, true)
} }
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc { disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return h.checkAction(next, false) return checkAction(next, false)
} }
acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc { acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(h.loadProvisionerByName(h.requireEABEnabled(next))) return authnz(loadProvisionerByName(requireEABEnabled(next)))
} }
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
@ -52,53 +60,58 @@ func (h *Handler) Route(r api.Router) {
} }
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(h.loadProvisionerByName(next))) return authnz(disabledInStandalone(loadProvisionerByName(next)))
} }
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc { acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(h.loadProvisionerByName(h.requireEABEnabled(h.loadExternalAccountKey(next))))) return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
} }
// Provisioners // Provisioners
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
// Admins // Admins
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) r.MethodFunc("GET", "/admins", authnz(GetAdmins))
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
// ACME External Account Binding Keys // ACME responder
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) if acmeResponder != nil {
r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.GetExternalAccountKeys)) // ACME External Account Binding Keys
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(h.acmeResponder.CreateExternalAccountKey)) r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(h.acmeResponder.DeleteExternalAccountKey)) r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
}
// Policy - Authority // Policy responder
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(h.policyResponder.GetAuthorityPolicy)) if policyResponder != nil {
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(h.policyResponder.CreateAuthorityPolicy)) // Policy - Authority
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(h.policyResponder.UpdateAuthorityPolicy)) r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(h.policyResponder.DeleteAuthorityPolicy)) r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
// Policy - Provisioner // Policy - Provisioner
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.GetProvisionerPolicy)) r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.CreateProvisionerPolicy)) r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.UpdateProvisionerPolicy)) r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(h.policyResponder.DeleteProvisionerPolicy)) r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
// Policy - ACME Account
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.GetACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(h.policyResponder.DeleteACMEAccountPolicy))
// Policy - ACME Account
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
}
} }

View file

@ -17,11 +17,10 @@ import (
// requireAPIEnabled is a middleware that ensures the Administration API // requireAPIEnabled is a middleware that ensures the Administration API
// is enabled before servicing requests. // is enabled before servicing requests.
func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc { func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() { if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
"administration API not enabled"))
return return
} }
next(w, r) next(w, r)
@ -29,7 +28,7 @@ func (h *Handler) requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
} }
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc { func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
@ -39,36 +38,39 @@ func (h *Handler) extractAuthorizeTokenAdmin(next http.HandlerFunc) http.Handler
return return
} }
adm, err := h.auth.AuthorizeAdminToken(r, tok) ctx := r.Context()
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := linkedca.NewContextWithAdmin(r.Context(), adm) ctx = linkedca.NewContextWithAdmin(ctx, adm)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
// loadProvisionerByName is a middleware that searches for a provisioner // loadProvisionerByName is a middleware that searches for a provisioner
// by name and stores it in the context. // by name and stores it in the context.
func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc { func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "provisionerName")
var ( var (
p provisioner.Interface p provisioner.Interface
err error err error
) )
ctx := r.Context()
auth := mustAuthority(ctx)
adminDB := admin.MustFromContext(ctx)
name := chi.URLParam(r, "provisionerName")
// TODO(hs): distinguish 404 vs. 500 // TODO(hs): distinguish 404 vs. 500
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) prov, err := adminDB.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
return return
@ -80,9 +82,8 @@ func (h *Handler) loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc
} }
// checkAction checks if an action is supported in standalone or not // checkAction checks if an action is supported in standalone or not
func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc { func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// actions allowed in standalone mode are always supported // actions allowed in standalone mode are always supported
if supportedInStandalone { if supportedInStandalone {
next(w, r) next(w, r)
@ -91,7 +92,7 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
// when an action is not supported in standalone mode and when // when an action is not supported in standalone mode and when
// using a nosql.DB backend, actions are not supported // using a nosql.DB backend, actions are not supported
if _, ok := h.adminDB.(*nosql.DB); ok { if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode")) "operation not supported in standalone mode"))
return return
@ -104,10 +105,11 @@ func (h *Handler) checkAction(next http.HandlerFunc, supportedInStandalone bool)
// loadExternalAccountKey is a middleware that searches for an ACME // loadExternalAccountKey is a middleware that searches for an ACME
// External Account Key by reference or keyID and stores it in the context. // External Account Key by reference or keyID and stores it in the context.
func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc { func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
acmeDB := acme.MustDatabaseFromContext(ctx)
reference := chi.URLParam(r, "reference") reference := chi.URLParam(r, "reference")
keyID := chi.URLParam(r, "keyID") keyID := chi.URLParam(r, "keyID")
@ -118,9 +120,9 @@ func (h *Handler) loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc
) )
if keyID != "" { if keyID != "" {
eak, err = h.acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID) eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
} else { } else {
eak, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
} }
if err != nil { if err != nil {

View file

@ -71,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.requireAPIEnabled(tc.next)(w, req) requireAPIEnabled(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -196,13 +194,10 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractAuthorizeTokenAdmin(tc.next)(w, req) extractAuthorizeTokenAdmin(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -251,6 +246,7 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
return test{ return test{
ctx: ctx, ctx: ctx,
auth: auth, auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500, statusCode: 500,
err: err, err: err,
} }
@ -326,16 +322,13 @@ func TestHandler_loadProvisionerByName(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, ctx := admin.NewContext(tc.ctx, tc.adminDB)
adminDB: tc.adminDB,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.loadProvisionerByName(tc.next)(w, req) loadProvisionerByName(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -405,14 +398,10 @@ func TestHandler_checkAction(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ ctx := admin.NewContext(context.Background(), tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
adminDB: tc.adminDB,
}
req := httptest.NewRequest("GET", "/foo", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.checkAction(tc.next, tc.supportedInStandalone)(w, req) checkAction(tc.next, tc.supportedInStandalone)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -653,14 +642,11 @@ func TestHandler_loadExternalAccountKey(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
acmeDB: tc.acmeDB,
}
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.loadExternalAccountKey(tc.next)(w, req) loadExternalAccountKey(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
@ -14,7 +15,9 @@ import (
"github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/policy"
) )
type policyAdminResponderInterface interface { // PolicyAdminResponder is the interface responsible for writing ACME admin
// responses.
type PolicyAdminResponder interface {
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
@ -29,39 +32,24 @@ type policyAdminResponderInterface interface {
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
} }
// PolicyAdminResponder is responsible for writing ACME admin responses // policyAdminResponder implements PolicyAdminResponder.
type PolicyAdminResponder struct { type policyAdminResponder struct{}
auth adminAuthority
adminDB admin.DB
acmeDB acme.DB
isLinkedCA bool
}
// NewACMEAdminResponder returns a new ACMEAdminResponder // NewACMEAdminResponder returns a new PolicyAdminResponder.
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) *PolicyAdminResponder { func NewPolicyAdminResponder() PolicyAdminResponder {
return &policyAdminResponder{}
var isLinkedCA bool
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok {
isLinkedCA = a.IsLinkedCA()
}
return &PolicyAdminResponder{
auth: auth,
adminDB: adminDB,
acmeDB: acmeDB,
isLinkedCA: isLinkedCA,
}
} }
// GetAuthorityPolicy handles the GET /admin/authority/policy request // GetAuthorityPolicy handles the GET /admin/authority/policy request
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
authorityPolicy, err := par.auth.GetAuthorityPolicy(r.Context()) auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return return
@ -76,15 +64,15 @@ func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *ht
} }
// CreateAuthorityPolicy handles the POST /admin/authority/policy request // CreateAuthorityPolicy handles the POST /admin/authority/policy request
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context() auth := mustAuthority(ctx)
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
@ -113,7 +101,7 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
adm := linkedca.MustAdminFromContext(ctx) adm := linkedca.MustAdminFromContext(ctx)
var createdPolicy *linkedca.Policy var createdPolicy *linkedca.Policy
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return return
@ -127,15 +115,15 @@ func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r
} }
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request // UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context() auth := mustAuthority(ctx)
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy")) render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
@ -163,7 +151,7 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
adm := linkedca.MustAdminFromContext(ctx) adm := linkedca.MustAdminFromContext(ctx)
var updatedPolicy *linkedca.Policy var updatedPolicy *linkedca.Policy
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil { if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return return
@ -177,15 +165,15 @@ func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r
} }
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request // DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context() auth := mustAuthority(ctx)
authorityPolicy, err := par.auth.GetAuthorityPolicy(ctx) authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) { if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy")) render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
@ -197,7 +185,7 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
return return
} }
if err := par.auth.RemoveAuthorityPolicy(ctx); err != nil { if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy")) render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
return return
} }
@ -206,15 +194,14 @@ func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r
} }
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request // GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov := linkedca.MustProvisionerFromContext(r.Context()) prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil { if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
@ -225,16 +212,14 @@ func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *
} }
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request // CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy != nil { if provisionerPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name) adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
@ -256,8 +241,8 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
} }
prov.Policy = newPolicy prov.Policy = newPolicy
auth := mustAuthority(ctx)
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
return return
@ -271,16 +256,14 @@ func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter,
} }
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request // UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy() provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil { if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
@ -301,7 +284,8 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
} }
prov.Policy = newPolicy prov.Policy = newPolicy
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) { if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
return return
@ -315,16 +299,14 @@ func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter,
} }
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request // DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
if prov.Policy == nil { if prov.Policy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return return
@ -333,7 +315,8 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
// remove the policy // remove the policy
prov.Policy = nil prov.Policy = nil
if err := par.auth.UpdateProvisioner(ctx, prov); err != nil { auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy")) render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
return return
} }
@ -341,16 +324,14 @@ func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter,
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK) render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
} }
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
@ -360,17 +341,15 @@ func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK) render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
} }
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy != nil { if eakPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id) adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
@ -394,7 +373,8 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
eak.Policy = newPolicy eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy")) render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
return return
} }
@ -402,17 +382,15 @@ func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter,
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated) render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
} }
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
@ -434,7 +412,8 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
eak.Policy = newPolicy eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy")) render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
return return
} }
@ -442,17 +421,15 @@ func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter,
render.ProtoJSONStatus(w, newPolicy, http.StatusOK) render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
} }
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) { func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := par.blockLinkedCA(); err != nil { if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx) prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx) eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy() eakPolicy := eak.GetPolicy()
if eakPolicy == nil { if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist")) render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
@ -463,7 +440,8 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
eak.Policy = nil eak.Policy = nil
acmeEAK := linkedEAKToCertificates(eak) acmeEAK := linkedEAKToCertificates(eak)
if err := par.acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil { acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy")) render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
return return
} }
@ -472,9 +450,10 @@ func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter,
} }
// blockLinkedCA blocks all API operations on linked deployments // blockLinkedCA blocks all API operations on linked deployments
func (par *PolicyAdminResponder) blockLinkedCA() error { func blockLinkedCA(ctx context.Context) error {
// temporary blocking linked deployments // temporary blocking linked deployments
if par.isLinkedCA { adminDB := admin.MustFromContext(ctx)
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments") return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
} }
return nil return nil

View file

@ -109,7 +109,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
err.Message = "error retrieving authority policy: force" err.Message = "error retrieving authority policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorServerInternalType, "force") return nil, admin.NewError(admin.ErrorServerInternalType, "force")
@ -124,7 +125,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist") err := admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist")
err.Message = "authority policy does not exist" err.Message = "authority policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorNotFoundType, "not found") return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
@ -179,7 +181,8 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
}, },
} }
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return policy, nil return policy, nil
@ -234,11 +237,12 @@ func TestPolicyAdminResponder_GetAuthorityPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.GetAuthorityPolicy(w, req) par.GetAuthorityPolicy(w, req)
@ -301,7 +305,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
err.Message = "error retrieving authority policy: force" err.Message = "error retrieving authority policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorServerInternalType, "force") return nil, admin.NewError(admin.ErrorServerInternalType, "force")
@ -316,7 +321,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
err := admin.NewError(admin.ErrorConflictType, "authority already has a policy") err := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
err.Message = "authority already has a policy" err.Message = "authority already has a policy"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return &linkedca.Policy{}, nil return &linkedca.Policy{}, nil
@ -332,7 +338,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorNotFoundType, "not found") return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
@ -358,7 +365,8 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
} }
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorNotFoundType, "not found") return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
@ -509,11 +517,13 @@ func TestPolicyAdminResponder_CreateAuthorityPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) ctx := admin.NewContext(tc.ctx, tc.adminDB)
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.CreateAuthorityPolicy(w, req) par.CreateAuthorityPolicy(w, req)
@ -586,7 +596,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
err.Message = "error retrieving authority policy: force" err.Message = "error retrieving authority policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorServerInternalType, "force") return nil, admin.NewError(admin.ErrorServerInternalType, "force")
@ -602,7 +613,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
err.Message = "authority policy does not exist" err.Message = "authority policy does not exist"
err.Status = http.StatusNotFound err.Status = http.StatusNotFound
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, nil return nil, nil
@ -625,7 +637,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
adminErr.Message = "proto: syntax error (line 1:2): invalid value ?" adminErr.Message = "proto: syntax error (line 1:2): invalid value ?"
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return policy, nil return policy, nil
@ -658,7 +671,8 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
} }
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return policy, nil return policy, nil
@ -809,11 +823,13 @@ func TestPolicyAdminResponder_UpdateAuthorityPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) ctx := admin.NewContext(tc.ctx, tc.adminDB)
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.UpdateAuthorityPolicy(w, req) par.UpdateAuthorityPolicy(w, req)
@ -886,7 +902,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy") err := admin.WrapErrorISE(errors.New("force"), "error retrieving authority policy")
err.Message = "error retrieving authority policy: force" err.Message = "error retrieving authority policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorServerInternalType, "force") return nil, admin.NewError(admin.ErrorServerInternalType, "force")
@ -902,7 +919,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
err.Message = "authority policy does not exist" err.Message = "authority policy does not exist"
err.Status = http.StatusNotFound err.Status = http.StatusNotFound
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, nil return nil, nil
@ -924,7 +942,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
err := admin.NewErrorISE("error deleting authority policy: force") err := admin.NewErrorISE("error deleting authority policy: force")
err.Message = "error deleting authority policy: force" err.Message = "error deleting authority policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return policy, nil return policy, nil
@ -947,7 +966,8 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
} }
ctx := context.Background() ctx := context.Background()
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return policy, nil return policy, nil
@ -963,11 +983,13 @@ func TestPolicyAdminResponder_DeleteAuthorityPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) ctx := admin.NewContext(tc.ctx, tc.adminDB)
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.DeleteAuthorityPolicy(w, req) par.DeleteAuthorityPolicy(w, req)
@ -1033,6 +1055,7 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
err.Message = "provisioner policy does not exist" err.Message = "provisioner policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -1085,7 +1108,8 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
} }
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
response: &testPolicyResponse{ response: &testPolicyResponse{
X509: &testX509Policy{ X509: &testX509Policy{
Allow: &testX509Names{ Allow: &testX509Names{
@ -1135,11 +1159,13 @@ func TestPolicyAdminResponder_GetProvisionerPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) ctx := admin.NewContext(tc.ctx, tc.adminDB)
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.GetProvisionerPolicy(w, req) par.GetProvisionerPolicy(w, req)
@ -1214,6 +1240,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
err.Message = "provisioner provName already has a policy" err.Message = "provisioner provName already has a policy"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 409, statusCode: 409,
} }
@ -1228,6 +1255,7 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -1251,7 +1279,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
} }
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorNotFoundType, "not found") return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
@ -1283,7 +1312,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return &authority.PolicyError{ return &authority.PolicyError{
@ -1318,7 +1348,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return &authority.PolicyError{ return &authority.PolicyError{
@ -1351,7 +1382,8 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return nil return nil
@ -1372,11 +1404,12 @@ func TestPolicyAdminResponder_CreateProvisionerPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.CreateProvisionerPolicy(w, req) par.CreateProvisionerPolicy(w, req)
@ -1452,6 +1485,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
err.Message = "provisioner policy does not exist" err.Message = "provisioner policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -1474,6 +1508,7 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -1505,7 +1540,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
} }
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) {
return nil, admin.NewError(admin.ErrorNotFoundType, "not found") return nil, admin.NewError(admin.ErrorNotFoundType, "not found")
@ -1538,7 +1574,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return &authority.PolicyError{ return &authority.PolicyError{
@ -1574,7 +1611,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return &authority.PolicyError{ return &authority.PolicyError{
@ -1608,7 +1646,8 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return nil return nil
@ -1629,11 +1668,12 @@ func TestPolicyAdminResponder_UpdateProvisionerPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, nil) ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.UpdateProvisionerPolicy(w, req) par.UpdateProvisionerPolicy(w, req)
@ -1710,6 +1750,7 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
err.Message = "provisioner policy does not exist" err.Message = "provisioner policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -1723,7 +1764,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
err := admin.NewErrorISE("error deleting provisioner policy: force") err := admin.NewErrorISE("error deleting provisioner policy: force")
err.Message = "error deleting provisioner policy: force" err.Message = "error deleting provisioner policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return errors.New("force") return errors.New("force")
@ -1740,7 +1782,8 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
} }
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
auth: &mockAdminAuthority{ auth: &mockAdminAuthority{
MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error {
return nil return nil
@ -1753,11 +1796,13 @@ func TestPolicyAdminResponder_DeleteProvisionerPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
par := NewPolicyAdminResponder(tc.auth, tc.adminDB, tc.acmeDB) ctx := admin.NewContext(tc.ctx, tc.adminDB)
ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.DeleteProvisionerPolicy(w, req) par.DeleteProvisionerPolicy(w, req)
@ -1828,6 +1873,7 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
err.Message = "ACME EAK policy does not exist" err.Message = "ACME EAK policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -1885,7 +1931,8 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
response: &testPolicyResponse{ response: &testPolicyResponse{
X509: &testX509Policy{ X509: &testX509Policy{
Allow: &testX509Names{ Allow: &testX509Names{
@ -1935,11 +1982,12 @@ func TestPolicyAdminResponder_GetACMEAccountPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.GetACMEAccountPolicy(w, req) par.GetACMEAccountPolicy(w, req)
@ -2018,6 +2066,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
err.Message = "ACME EAK eakID already has a policy" err.Message = "ACME EAK eakID already has a policy"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 409, statusCode: 409,
} }
@ -2036,6 +2085,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -2064,6 +2114,7 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -2091,7 +2142,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2124,7 +2176,8 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2147,11 +2200,12 @@ func TestPolicyAdminResponder_CreateACMEAccountPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.CreateACMEAccountPolicy(w, req) par.CreateACMEAccountPolicy(w, req)
@ -2231,6 +2285,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
err.Message = "ACME EAK policy does not exist" err.Message = "ACME EAK policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -2257,6 +2312,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
body := []byte("{?}") body := []byte("{?}")
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -2293,6 +2349,7 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
}`) }`)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
body: body, body: body,
err: adminErr, err: adminErr,
statusCode: 400, statusCode: 400,
@ -2321,7 +2378,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2355,7 +2413,8 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
body, err := protojson.Marshal(policy) body, err := protojson.Marshal(policy)
assert.NoError(t, err) assert.NoError(t, err)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2378,11 +2437,12 @@ func TestPolicyAdminResponder_UpdateACMEAccountPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.UpdateACMEAccountPolicy(w, req) par.UpdateACMEAccountPolicy(w, req)
@ -2462,6 +2522,7 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
err.Message = "ACME EAK policy does not exist" err.Message = "ACME EAK policy does not exist"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
err: err, err: err,
statusCode: 404, statusCode: 404,
} }
@ -2487,7 +2548,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
err := admin.NewErrorISE("error deleting ACME EAK policy: force") err := admin.NewErrorISE("error deleting ACME EAK policy: force")
err.Message = "error deleting ACME EAK policy: force" err.Message = "error deleting ACME EAK policy: force"
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2518,7 +2580,8 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov) ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak) ctx = linkedca.NewContextWithExternalAccountKey(ctx, eak)
return test{ return test{
ctx: ctx, ctx: ctx,
adminDB: &admin.MockDB{},
acmeDB: &acme.MockDB{ acmeDB: &acme.MockDB{
MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { MockUpdateExternalAccountKey: func(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
assert.Equal(t, "provID", provisionerID) assert.Equal(t, "provID", provisionerID)
@ -2533,11 +2596,12 @@ func TestPolicyAdminResponder_DeleteACMEAccountPolicy(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(tc.ctx, tc.adminDB)
par := NewPolicyAdminResponder(nil, tc.adminDB, tc.acmeDB) ctx = acme.NewDatabaseContext(ctx, tc.acmeDB)
par := NewPolicyAdminResponder()
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
par.DeleteACMEAccountPolicy(w, req) par.DeleteACMEAccountPolicy(w, req)

View file

@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
} }
// GetProvisioner returns the requested provisioner, or an error. // GetProvisioner returns the requested provisioner, or an error.
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { func GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
var ( var (
p provisioner.Interface p provisioner.Interface
err error err error
) )
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) prov, err := db.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// GetProvisioners returns the given segment of provisioners associated with the authority. // GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
return return
} }
p, next, err := h.auth.GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
} }
// CreateProvisioner creates a new prov. // CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
render.Error(w, err) render.Error(w, err)
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return return
} }
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// DeleteProvisioner deletes a provisioner. // DeleteProvisioner deletes a provisioner.
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
var ( var (
p provisioner.Interface p provisioner.Interface
err error err error
) )
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(r.Context())
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return return
} }
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name) auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
p, err := auth.LoadProvisionerByName(name)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return return
} }
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) old, err := db.GetProvisioner(r.Context(), p.GetID())
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
return return
} }
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }

View file

@ -50,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx, ctx: ctx,
req: req, req: req,
auth: auth, auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
Type: admin.ErrorServerInternalType.String(), Type: admin.ErrorServerInternalType.String(),
@ -74,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx, ctx: ctx,
req: req, req: req,
auth: auth, auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
Type: admin.ErrorServerInternalType.String(), Type: admin.ErrorServerInternalType.String(),
@ -156,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, ctx := admin.NewContext(tc.ctx, tc.adminDB)
adminDB: tc.adminDB, req := tc.req.WithContext(ctx)
}
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetProvisioner(w, req) GetProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -280,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetProvisioners(w, req) GetProvisioners(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -405,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.CreateProvisioner(w, req) CreateProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -571,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.DeleteProvisioner(w, req) DeleteProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -625,6 +619,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{ return test{
ctx: context.Background(), ctx: context.Background(),
body: body, body: body,
adminDB: &admin.MockDB{},
statusCode: 400, statusCode: 400,
err: &admin.Error{ err: &admin.Error{
Type: "badRequest", Type: "badRequest",
@ -654,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{ return test{
ctx: ctx, ctx: ctx,
body: body, body: body,
adminDB: &admin.MockDB{},
auth: auth, auth: auth,
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
@ -1061,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, ctx := admin.NewContext(tc.ctx, tc.adminDB)
adminDB: tc.adminDB,
}
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.UpdateProvisioner(w, req) UpdateProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -76,6 +76,29 @@ type DB interface {
DeleteAuthorityPolicy(ctx context.Context) error DeleteAuthorityPolicy(ctx context.Context) error
} }
type dbKey struct{}
// NewContext adds the given admin database to the context.
func NewContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// FromContext returns the current admin database from the given context.
func FromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustFromContext returns the current admin database from the given context. It
// will panic if it's not in the context.
func MustFromContext(ctx context.Context) DB {
if db, ok := FromContext(ctx); !ok {
panic("admin database is not in the context")
} else {
return db
}
}
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
// a mock in tests. // a mock in tests.
type MockDB struct { type MockDB struct {

View file

@ -167,6 +167,29 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
return a, nil return a, nil
} }
type authorityKey struct{}
// NewContext adds the given authority to the context.
func NewContext(ctx context.Context, a *Authority) context.Context {
return context.WithValue(ctx, authorityKey{}, a)
}
// FromContext returns the current authority from the given context.
func FromContext(ctx context.Context) (a *Authority, ok bool) {
a, ok = ctx.Value(authorityKey{}).(*Authority)
return
}
// MustFromContext returns the current authority from the given context. It will
// panic if the authority is not in the context.
func MustFromContext(ctx context.Context) *Authority {
if a, ok := FromContext(ctx); !ok {
panic("authority is not in the context")
} else {
return a
}
}
// ReloadAdminResources reloads admins and provisioners from the DB. // ReloadAdminResources reloads admins and provisioners from the DB.
func (a *Authority) ReloadAdminResources(ctx context.Context) error { func (a *Authority) ReloadAdminResources(ctx context.Context) error {
var ( var (
@ -235,6 +258,7 @@ func (a *Authority) init() error {
} }
var err error var err error
ctx := NewContext(context.Background(), a)
// Set password if they are not set. // Set password if they are not set.
var configPassword []byte var configPassword []byte
@ -270,7 +294,7 @@ func (a *Authority) init() error {
if a.config.KMS != nil { if a.config.KMS != nil {
options = *a.config.KMS options = *a.config.KMS
} }
a.keyManager, err = kms.New(context.Background(), options) a.keyManager, err = kms.New(ctx, options)
if err != nil { if err != nil {
return err return err
} }
@ -300,7 +324,7 @@ func (a *Authority) init() error {
// Configure linked RA // Configure linked RA
if linkedcaClient != nil && options.CertificateAuthority == "" { if linkedcaClient != nil && options.CertificateAuthority == "" {
conf, err := linkedcaClient.GetConfiguration(context.Background()) conf, err := linkedcaClient.GetConfiguration(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -334,7 +358,7 @@ func (a *Authority) init() error {
} }
} }
a.x509CAService, err = cas.New(context.Background(), options) a.x509CAService, err = cas.New(ctx, options)
if err != nil { if err != nil {
return err return err
} }
@ -521,7 +545,7 @@ func (a *Authority) init() error {
} }
} }
a.scepService, err = scep.NewService(context.Background(), options) a.scepService, err = scep.NewService(ctx, options)
if err != nil { if err != nil {
return err return err
} }
@ -543,19 +567,19 @@ func (a *Authority) init() error {
} }
} }
provs, err := a.adminDB.GetProvisioners(context.Background()) provs, err := a.adminDB.GetProvisioners(ctx)
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
} }
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Create First Provisioner // Create First Provisioner
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner") return admin.WrapErrorISE(err, "error creating first provisioner")
} }
// Create first admin // Create first admin
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
ProvisionerId: prov.Id, ProvisionerId: prov.Id,
Subject: "step", Subject: "step",
Type: linkedca.Admin_SUPER_ADMIN, Type: linkedca.Admin_SUPER_ADMIN,
@ -571,7 +595,7 @@ func (a *Authority) init() error {
} }
// Load x509 and SSH Policy Engines // Load x509 and SSH Policy Engines
if err := a.reloadPolicyEngines(context.Background()); err != nil { if err := a.reloadPolicyEngines(ctx); err != nil {
return err return err
} }
@ -596,6 +620,15 @@ func (a *Authority) init() error {
return nil return nil
} }
// GetID returns the define authority id or a zero uuid.
func (a *Authority) GetID() string {
const zeroUUID = "00000000-0000-0000-0000-000000000000"
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
return id
}
return zeroUUID
}
// GetDatabase returns the authority database. If the configuration does not // GetDatabase returns the authority database. If the configuration does not
// define a database, GetDatabase will return a db.SimpleDB instance. // define a database, GetDatabase will return a db.SimpleDB instance.
func (a *Authority) GetDatabase() db.AuthDB { func (a *Authority) GetDatabase() db.AuthDB {

View file

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
}) })
} }
} }
func TestAuthority_GetID(t *testing.T) {
type fields struct {
authorityID string
}
tests := []struct {
name string
fields fields
want string
}{
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authority{
config: &config.Config{
AuthorityConfig: &config.AuthConfig{
AuthorityID: tt.fields.authorityID,
},
},
}
if got := a.GetID(); got != tt.want {
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -260,8 +260,7 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
// AuthorizeSign authorizes a signature request by validating and authenticating // AuthorizeSign authorizes a signature request by validating and authenticating
// a token that must be sent w/ the request. // a token that must be sent w/ the request.
// //
// NOTE: This method is deprecated and should not be used. We make it available // Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
// in the short term os as not to break existing clients.
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod)
return a.Authorize(ctx, token) return a.Authorize(ctx, token)

View file

@ -54,7 +54,11 @@ func startCABootstrapServer() *httptest.Server {
if err != nil { if err != nil {
panic(err) panic(err)
} }
baseContext := buildContext(ca.auth, nil, nil, nil)
srv.Config.Handler = ca.srv.Handler srv.Config.Handler = ca.srv.Handler
srv.Config.BaseContext = func(net.Listener) context.Context {
return baseContext
}
srv.TLS = ca.srv.TLSConfig srv.TLS = ca.srv.TLSConfig
srv.StartTLS() srv.StartTLS()
// Force the use of GetCertificate on IPs // Force the use of GetCertificate on IPs

View file

@ -1,10 +1,12 @@
package ca package ca
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
@ -18,6 +20,7 @@ import (
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
adminAPI "github.com/smallstep/certificates/authority/admin/api" adminAPI "github.com/smallstep/certificates/authority/admin/api"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -170,10 +173,9 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
insecureHandler := http.Handler(insecureMux) insecureHandler := http.Handler(insecureMux)
// Add regular CA api endpoints in / and /1.0 // Add regular CA api endpoints in / and /1.0
routerHandler := api.New(auth) api.Route(mux)
routerHandler.Route(mux)
mux.Route("/1.0", func(r chi.Router) { mux.Route("/1.0", func(r chi.Router) {
routerHandler.Route(r) api.Route(r)
}) })
//Add ACME api endpoints in /acme and /1.0/acme //Add ACME api endpoints in /acme and /1.0/acme
@ -187,49 +189,41 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
dns = fmt.Sprintf("%s:%s", dns, port) dns = fmt.Sprintf("%s:%s", dns, port)
} }
// ACME Router // ACME Router is only available if we have a database.
prefix := "acme"
var acmeDB acme.DB var acmeDB acme.DB
if cfg.DB == nil { var acmeLinker acme.Linker
acmeDB = nil if cfg.DB != nil {
} else {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error configuring ACME DB interface") return nil, errors.Wrap(err, "error configuring ACME DB interface")
} }
acmeLinker = acme.NewLinker(dns, "acme")
mux.Route("/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/acme", func(r chi.Router) {
acmeAPI.Route(r)
})
} }
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *cfg.AuthorityConfig.Backdate,
DB: acmeDB,
DNS: dns,
Prefix: prefix,
CA: auth,
})
mux.Route("/"+prefix, func(r chi.Router) {
acmeHandler.Route(r)
})
// Use 2.0 because, at the moment, our ACME api is only compatible with v2.0
// of the ACME spec.
mux.Route("/2.0/"+prefix, func(r chi.Router) {
acmeHandler.Route(r)
})
// Admin API Router // Admin API Router
if cfg.AuthorityConfig.EnableAdmin { if cfg.AuthorityConfig.EnableAdmin {
adminDB := auth.GetAdminDatabase() adminDB := auth.GetAdminDatabase()
if adminDB != nil { if adminDB != nil {
acmeAdminResponder := adminAPI.NewACMEAdminResponder() acmeAdminResponder := adminAPI.NewACMEAdminResponder()
policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB, acmeDB) policyAdminResponder := adminAPI.NewPolicyAdminResponder()
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder)
mux.Route("/admin", func(r chi.Router) { mux.Route("/admin", func(r chi.Router) {
adminHandler.Route(r) adminAPI.Route(r, acmeAdminResponder, policyAdminResponder)
}) })
} }
} }
var scepAuthority *scep.Authority
if ca.shouldServeSCEPEndpoints() { if ca.shouldServeSCEPEndpoints() {
scepPrefix := "scep" scepPrefix := "scep"
scepAuthority, err := scep.New(auth, scep.AuthorityOptions{ scepAuthority, err = scep.New(auth, scep.AuthorityOptions{
Service: auth.GetSCEPService(), Service: auth.GetSCEPService(),
DNS: dns, DNS: dns,
Prefix: scepPrefix, Prefix: scepPrefix,
@ -237,13 +231,12 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error creating SCEP authority") return nil, errors.Wrap(err, "error creating SCEP authority")
} }
scepRouterHandler := scepAPI.New(scepAuthority)
// According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10), // According to the RFC (https://tools.ietf.org/html/rfc8894#section-7.10),
// SCEP operations are performed using HTTP, so that's why the API is mounted // SCEP operations are performed using HTTP, so that's why the API is mounted
// to the insecure mux. // to the insecure mux.
insecureMux.Route("/"+scepPrefix, func(r chi.Router) { insecureMux.Route("/"+scepPrefix, func(r chi.Router) {
scepRouterHandler.Route(r) scepAPI.Route(r)
}) })
// The RFC also mentions usage of HTTPS, but seems to advise // The RFC also mentions usage of HTTPS, but seems to advise
@ -253,7 +246,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
// as well as HTTPS can be used to request certificates // as well as HTTPS can be used to request certificates
// using SCEP. // using SCEP.
mux.Route("/"+scepPrefix, func(r chi.Router) { mux.Route("/"+scepPrefix, func(r chi.Router) {
scepRouterHandler.Route(r) scepAPI.Route(r)
}) })
} }
@ -280,7 +273,13 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
insecureHandler = logger.Middleware(insecureHandler) insecureHandler = logger.Middleware(insecureHandler)
} }
// Create context with all the necessary values.
baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker)
ca.srv = server.New(cfg.Address, handler, tlsConfig) ca.srv = server.New(cfg.Address, handler, tlsConfig)
ca.srv.BaseContext = func(net.Listener) context.Context {
return baseContext
}
// only start the insecure server if the insecure address is configured // only start the insecure server if the insecure address is configured
// and, currently, also only when it should serve SCEP endpoints. // and, currently, also only when it should serve SCEP endpoints.
@ -290,11 +289,32 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
// will probably introduce more complexity in terms of graceful // will probably introduce more complexity in terms of graceful
// reload. // reload.
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil) ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
ca.insecureSrv.BaseContext = func(net.Listener) context.Context {
return baseContext
}
} }
return ca, nil return ca, nil
} }
// buildContext builds the server base context.
func buildContext(a *authority.Authority, scepAuthority *scep.Authority, acmeDB acme.DB, acmeLinker acme.Linker) context.Context {
ctx := authority.NewContext(context.Background(), a)
if authDB := a.GetDatabase(); authDB != nil {
ctx = db.NewContext(ctx, authDB)
}
if adminDB := a.GetAdminDatabase(); adminDB != nil {
ctx = admin.NewContext(ctx, adminDB)
}
if scepAuthority != nil {
ctx = scep.NewContext(ctx, scepAuthority)
}
if acmeDB != nil {
ctx = acme.NewContext(ctx, acmeDB, acme.NewClient(), acmeLinker, nil)
}
return ctx
}
// Run starts the CA calling to the server ListenAndServe method. // Run starts the CA calling to the server ListenAndServe method.
func (ca *CA) Run() error { func (ca *CA) Run() error {
var wg sync.WaitGroup var wg sync.WaitGroup

View file

@ -2,6 +2,7 @@ package ca
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
@ -281,7 +282,8 @@ ZEp7knvU2psWRw==
assert.FatalError(t, err) assert.FatalError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
@ -360,7 +362,8 @@ func TestCAProvisioners(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
@ -426,7 +429,8 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
@ -487,7 +491,8 @@ func TestCARoot(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
@ -534,7 +539,8 @@ func TestCAHealth(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}
@ -628,7 +634,8 @@ func TestCARenew(t *testing.T) {
rq.TLS = tc.tlsConnState rq.TLS = tc.tlsConnState
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
tc.ca.srv.Handler.ServeHTTP(rr, rq) ctx := authority.NewContext(context.Background(), tc.ca.auth)
tc.ca.srv.Handler.ServeHTTP(rr, rq.WithContext(ctx))
if assert.Equals(t, rr.Code, tc.status) { if assert.Equals(t, rr.Code, tc.status) {
body := &ClosingBuffer{rr.Body} body := &ClosingBuffer{rr.Body}

View file

@ -10,6 +10,7 @@ import (
"encoding/hex" "encoding/hex"
"io" "io"
"log" "log"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -77,7 +78,12 @@ func startCATestServer() *httptest.Server {
panic(err) panic(err)
} }
// Use a httptest.Server instead // Use a httptest.Server instead
return startTestServer(ca.srv.TLSConfig, ca.srv.Handler) srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler)
baseContext := buildContext(ca.auth, nil, nil, nil)
srv.Config.BaseContext = func(net.Listener) context.Context {
return baseContext
}
return srv
} }
func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) {

View file

@ -0,0 +1,67 @@
package approle
import (
"encoding/json"
"errors"
"fmt"
"github.com/hashicorp/vault/api/auth/approle"
)
// AuthOptions defines the configuration options added using the
// VaultOptions.AuthOptions field when AuthType is approle
type AuthOptions struct {
RoleID string `json:"roleID,omitempty"`
SecretID string `json:"secretID,omitempty"`
SecretIDFile string `json:"secretIDFile,omitempty"`
SecretIDEnv string `json:"secretIDEnv,omitempty"`
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
}
func NewApproleAuthMethod(mountPath string, options json.RawMessage) (*approle.AppRoleAuth, error) {
var opts *AuthOptions
err := json.Unmarshal(options, &opts)
if err != nil {
return nil, fmt.Errorf("error decoding AppRole auth options: %w", err)
}
var approleAuth *approle.AppRoleAuth
var loginOptions []approle.LoginOption
if mountPath != "" {
loginOptions = append(loginOptions, approle.WithMountPath(mountPath))
}
if opts.IsWrappingToken {
loginOptions = append(loginOptions, approle.WithWrappingToken())
}
if opts.RoleID == "" {
return nil, errors.New("you must set roleID")
}
var sid approle.SecretID
switch {
case opts.SecretID != "" && opts.SecretIDFile == "" && opts.SecretIDEnv == "":
sid = approle.SecretID{
FromString: opts.SecretID,
}
case opts.SecretIDFile != "" && opts.SecretID == "" && opts.SecretIDEnv == "":
sid = approle.SecretID{
FromFile: opts.SecretIDFile,
}
case opts.SecretIDEnv != "" && opts.SecretIDFile == "" && opts.SecretID == "":
sid = approle.SecretID{
FromEnv: opts.SecretIDEnv,
}
default:
return nil, errors.New("you must set one of secretID, secretIDFile or secretIDEnv")
}
approleAuth, err = approle.NewAppRoleAuth(opts.RoleID, &sid, loginOptions...)
if err != nil {
return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err)
}
return approleAuth, nil
}

View file

@ -0,0 +1,195 @@
package approle
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
vault "github.com/hashicorp/vault/api"
)
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.RequestURI == "/v1/auth/approle/login":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{
"auth": {
"client_token": "hvs.0000"
}
}`)
case r.RequestURI == "/v1/auth/custom-approle/login":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{
"auth": {
"client_token": "hvs.9999"
}
}`)
default:
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
}
}))
t.Cleanup(func() {
srv.Close()
})
u, err := url.Parse(srv.URL)
if err != nil {
srv.Close()
t.Fatal(err)
}
config := vault.DefaultConfig()
config.Address = srv.URL
client, err := vault.NewClient(config)
if err != nil {
srv.Close()
t.Fatal(err)
}
return u, client
}
func TestApprole_LoginMountPaths(t *testing.T) {
caURL, _ := testCAHelper(t)
config := vault.DefaultConfig()
config.Address = caURL.String()
client, _ := vault.NewClient(config)
tests := []struct {
name string
mountPath string
token string
}{
{
name: "ok default mount path",
mountPath: "",
token: "hvs.0000",
},
{
name: "ok explicit mount path",
mountPath: "approle",
token: "hvs.0000",
},
{
name: "ok custom mount path",
mountPath: "custom-approle",
token: "hvs.9999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
method, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(`{"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}`))
if err != nil {
t.Errorf("NewApproleAuthMethod() error = %v", err)
return
}
secret, err := client.Auth().Login(context.Background(), method)
if err != nil {
t.Errorf("Login() error = %v", err)
return
}
token, _ := secret.TokenID()
if token != tt.token {
t.Errorf("Token error got %v, expected %v", token, tt.token)
return
}
})
}
}
func TestApprole_NewApproleAuthMethod(t *testing.T) {
tests := []struct {
name string
mountPath string
raw string
wantErr bool
}{
{
"ok secret-id string",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000"}`,
false,
},
{
"ok secret-id string and wrapped",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`,
false,
},
{
"ok secret-id string and wrapped with custom mountPath",
"approle2",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "isWrappedToken": true}`,
false,
},
{
"ok secret-id file",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`,
false,
},
{
"ok secret-id env",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
false,
},
{
"fail mandatory role-id",
"",
`{}`,
true,
},
{
"fail mandatory secret-id any",
"",
`{"RoleID": "0000-0000-0000-0000"}`,
true,
},
{
"fail multiple secret-id types id and env",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
true,
},
{
"fail multiple secret-id types id and file",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id"}`,
true,
},
{
"fail multiple secret-id types env and file",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
true,
},
{
"fail multiple secret-id types all",
"",
`{"RoleID": "0000-0000-0000-0000", "SecretID": "0000-0000-0000-0000", "SecretIDFile": "./secret-id", "SecretIDEnv": "VAULT_APPROLE_SECRETID"}`,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewApproleAuthMethod(tt.mountPath, json.RawMessage(tt.raw))
if (err != nil) != tt.wantErr {
t.Errorf("Approle.NewApproleAuthMethod() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -0,0 +1,49 @@
package kubernetes
import (
"encoding/json"
"errors"
"fmt"
"github.com/hashicorp/vault/api/auth/kubernetes"
)
// AuthOptions defines the configuration options added using the
// VaultOptions.AuthOptions field when AuthType is kubernetes
type AuthOptions struct {
Role string `json:"role,omitempty"`
TokenPath string `json:"tokenPath,omitempty"`
}
func NewKubernetesAuthMethod(mountPath string, options json.RawMessage) (*kubernetes.KubernetesAuth, error) {
var opts *AuthOptions
err := json.Unmarshal(options, &opts)
if err != nil {
return nil, fmt.Errorf("error decoding Kubernetes auth options: %w", err)
}
var kubernetesAuth *kubernetes.KubernetesAuth
var loginOptions []kubernetes.LoginOption
if mountPath != "" {
loginOptions = append(loginOptions, kubernetes.WithMountPath(mountPath))
}
if opts.TokenPath != "" {
loginOptions = append(loginOptions, kubernetes.WithServiceAccountTokenPath(opts.TokenPath))
}
if opts.Role == "" {
return nil, errors.New("you must set role")
}
kubernetesAuth, err = kubernetes.NewKubernetesAuth(
opts.Role,
loginOptions...,
)
if err != nil {
return nil, fmt.Errorf("unable to initialize Kubernetes auth method: %w", err)
}
return kubernetesAuth, nil
}

View file

@ -0,0 +1,149 @@
package kubernetes
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"path"
"path/filepath"
"runtime"
"testing"
vault "github.com/hashicorp/vault/api"
)
func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.RequestURI == "/v1/auth/kubernetes/login":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{
"auth": {
"client_token": "hvs.0000"
}
}`)
case r.RequestURI == "/v1/auth/custom-kubernetes/login":
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{
"auth": {
"client_token": "hvs.9999"
}
}`)
default:
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, `{"error":"not found"}`)
}
}))
t.Cleanup(func() {
srv.Close()
})
u, err := url.Parse(srv.URL)
if err != nil {
srv.Close()
t.Fatal(err)
}
config := vault.DefaultConfig()
config.Address = srv.URL
client, err := vault.NewClient(config)
if err != nil {
srv.Close()
t.Fatal(err)
}
return u, client
}
func TestApprole_LoginMountPaths(t *testing.T) {
caURL, _ := testCAHelper(t)
_, filename, _, _ := runtime.Caller(0)
tokenPath := filepath.Join(path.Dir(filename), "token")
config := vault.DefaultConfig()
config.Address = caURL.String()
client, _ := vault.NewClient(config)
tests := []struct {
name string
mountPath string
token string
}{
{
name: "ok default mount path",
mountPath: "",
token: "hvs.0000",
},
{
name: "ok explicit mount path",
mountPath: "kubernetes",
token: "hvs.0000",
},
{
name: "ok custom mount path",
mountPath: "custom-kubernetes",
token: "hvs.9999",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
method, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(`{"role": "SomeRoleName", "tokenPath": "`+tokenPath+`"}`))
if err != nil {
t.Errorf("NewApproleAuthMethod() error = %v", err)
return
}
secret, err := client.Auth().Login(context.Background(), method)
if err != nil {
t.Errorf("Login() error = %v", err)
return
}
token, _ := secret.TokenID()
if token != tt.token {
t.Errorf("Token error got %v, expected %v", token, tt.token)
return
}
})
}
}
func TestApprole_NewApproleAuthMethod(t *testing.T) {
_, filename, _, _ := runtime.Caller(0)
tokenPath := filepath.Join(path.Dir(filename), "token")
tests := []struct {
name string
mountPath string
raw string
wantErr bool
}{
{
"ok secret-id string",
"",
`{"role": "SomeRoleName", "tokenPath": "` + tokenPath + `"}`,
false,
},
{
"fail mandatory role",
"",
`{}`,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewKubernetesAuthMethod(tt.mountPath, json.RawMessage(tt.raw))
if (err != nil) != tt.wantErr {
t.Errorf("Kubernetes.NewKubernetesAuthMethod() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -0,0 +1 @@
token

View file

@ -15,9 +15,10 @@ import (
"time" "time"
"github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/cas/vaultcas/auth/approle"
"github.com/smallstep/certificates/cas/vaultcas/auth/kubernetes"
vault "github.com/hashicorp/vault/api" vault "github.com/hashicorp/vault/api"
auth "github.com/hashicorp/vault/api/auth/approle"
) )
func init() { func init() {
@ -29,15 +30,14 @@ func init() {
// VaultOptions defines the configuration options added using the // VaultOptions defines the configuration options added using the
// apiv1.Options.Config field. // apiv1.Options.Config field.
type VaultOptions struct { type VaultOptions struct {
PKI string `json:"pki,omitempty"` PKIMountPath string `json:"pkiMountPath,omitempty"`
PKIRoleDefault string `json:"pkiRoleDefault,omitempty"` PKIRoleDefault string `json:"pkiRoleDefault,omitempty"`
PKIRoleRSA string `json:"pkiRoleRSA,omitempty"` PKIRoleRSA string `json:"pkiRoleRSA,omitempty"`
PKIRoleEC string `json:"pkiRoleEC,omitempty"` PKIRoleEC string `json:"pkiRoleEC,omitempty"`
PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"` PKIRoleEd25519 string `json:"pkiRoleEd25519,omitempty"`
RoleID string `json:"roleID,omitempty"` AuthType string `json:"authType,omitempty"`
SecretID auth.SecretID `json:"secretID,omitempty"` AuthMountPath string `json:"authMountPath,omitempty"`
AppRole string `json:"appRole,omitempty"` AuthOptions json.RawMessage `json:"authOptions,omitempty"`
IsWrappingToken bool `json:"isWrappingToken,omitempty"`
} }
// VaultCAS implements a Certificate Authority Service using Hashicorp Vault. // VaultCAS implements a Certificate Authority Service using Hashicorp Vault.
@ -77,28 +77,22 @@ func New(ctx context.Context, opts apiv1.Options) (*VaultCAS, error) {
return nil, fmt.Errorf("unable to initialize vault client: %w", err) return nil, fmt.Errorf("unable to initialize vault client: %w", err)
} }
var appRoleAuth *auth.AppRoleAuth var method vault.AuthMethod
if vc.IsWrappingToken { switch vc.AuthType {
appRoleAuth, err = auth.NewAppRoleAuth( case "kubernetes":
vc.RoleID, method, err = kubernetes.NewKubernetesAuthMethod(vc.AuthMountPath, vc.AuthOptions)
&vc.SecretID, case "approle":
auth.WithWrappingToken(), method, err = approle.NewApproleAuthMethod(vc.AuthMountPath, vc.AuthOptions)
auth.WithMountPath(vc.AppRole), default:
) return nil, fmt.Errorf("unknown auth type: %s, only 'kubernetes' and 'approle' currently supported", vc.AuthType)
} else {
appRoleAuth, err = auth.NewAppRoleAuth(
vc.RoleID,
&vc.SecretID,
auth.WithMountPath(vc.AppRole),
)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to initialize AppRole auth method: %w", err) return nil, fmt.Errorf("unable to configure %s auth method: %w", vc.AuthType, err)
} }
authInfo, err := client.Auth().Login(ctx, appRoleAuth) authInfo, err := client.Auth().Login(ctx, method)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to login to AppRole auth method: %w", err) return nil, fmt.Errorf("unable to login to %s auth method: %w", vc.AuthType, err)
} }
if authInfo == nil { if authInfo == nil {
return nil, errors.New("no auth info was returned after login") return nil, errors.New("no auth info was returned after login")
@ -134,7 +128,7 @@ func (v *VaultCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv
// GetCertificateAuthority returns the root certificate of the certificate // GetCertificateAuthority returns the root certificate of the certificate
// authority using the configured fingerprint. // authority using the configured fingerprint.
func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) { func (v *VaultCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequest) (*apiv1.GetCertificateAuthorityResponse, error) {
secret, err := v.client.Logical().Read(v.config.PKI + "/cert/ca_chain") secret, err := v.client.Logical().Read(v.config.PKIMountPath + "/cert/ca_chain")
if err != nil { if err != nil {
return nil, fmt.Errorf("error reading ca chain: %w", err) return nil, fmt.Errorf("error reading ca chain: %w", err)
} }
@ -190,7 +184,7 @@ func (v *VaultCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv
vaultReq := map[string]interface{}{ vaultReq := map[string]interface{}{
"serial_number": formatSerialNumber(sn), "serial_number": formatSerialNumber(sn),
} }
_, err := v.client.Logical().Write(v.config.PKI+"/revoke/", vaultReq) _, err := v.client.Logical().Write(v.config.PKIMountPath+"/revoke/", vaultReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("error revoking certificate: %w", err) return nil, fmt.Errorf("error revoking certificate: %w", err)
} }
@ -224,7 +218,7 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.
"ttl": lifetime.Seconds(), "ttl": lifetime.Seconds(),
} }
secret, err := v.client.Logical().Write(v.config.PKI+"/sign/"+vaultPKIRole, vaultReq) secret, err := v.client.Logical().Write(v.config.PKIMountPath+"/sign/"+vaultPKIRole, vaultReq)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error signing certificate: %w", err) return nil, nil, fmt.Errorf("error signing certificate: %w", err)
} }
@ -247,21 +241,17 @@ func (v *VaultCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.
} }
func loadOptions(config json.RawMessage) (*VaultOptions, error) { func loadOptions(config json.RawMessage) (*VaultOptions, error) {
var vc *VaultOptions // setup default values
vc := VaultOptions{
PKIMountPath: "pki",
PKIRoleDefault: "default",
}
err := json.Unmarshal(config, &vc) err := json.Unmarshal(config, &vc)
if err != nil { if err != nil {
return nil, fmt.Errorf("error decoding vaultCAS config: %w", err) return nil, fmt.Errorf("error decoding vaultCAS config: %w", err)
} }
if vc.PKI == "" {
vc.PKI = "pki" // use default pki vault name
}
if vc.PKIRoleDefault == "" {
vc.PKIRoleDefault = "default" // use default pki role name
}
if vc.PKIRoleRSA == "" { if vc.PKIRoleRSA == "" {
vc.PKIRoleRSA = vc.PKIRoleDefault vc.PKIRoleRSA = vc.PKIRoleDefault
} }
@ -272,23 +262,7 @@ func loadOptions(config json.RawMessage) (*VaultOptions, error) {
vc.PKIRoleEd25519 = vc.PKIRoleDefault vc.PKIRoleEd25519 = vc.PKIRoleDefault
} }
if vc.RoleID == "" { return &vc, nil
return nil, errors.New("vaultCAS config options must define `roleID`")
}
if vc.SecretID.FromEnv == "" && vc.SecretID.FromFile == "" && vc.SecretID.FromString == "" {
return nil, errors.New("vaultCAS config options must define `secretID` object with one of `FromEnv`, `FromFile` or `FromString`")
}
if vc.PKI == "" {
vc.PKI = "pki" // use default pki vault name
}
if vc.AppRole == "" {
vc.AppRole = "auth/approle"
}
return vc, nil
} }
func parseCertificates(pemCert string) []*x509.Certificate { func parseCertificates(pemCert string) []*x509.Certificate {

View file

@ -14,7 +14,6 @@ import (
"time" "time"
vault "github.com/hashicorp/vault/api" vault "github.com/hashicorp/vault/api"
auth "github.com/hashicorp/vault/api/auth/approle"
"github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/cas/apiv1"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -99,7 +98,7 @@ func testCAHelper(t *testing.T) (*url.URL, *vault.Client) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.RequestURI == "/v1/auth/auth/approle/login": case r.RequestURI == "/v1/auth/approle/login":
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{ fmt.Fprintf(w, `{
"auth": { "auth": {
@ -183,11 +182,8 @@ func TestNew_register(t *testing.T) {
CertificateAuthority: caURL.String(), CertificateAuthority: caURL.String(),
CertificateAuthorityFingerprint: testRootFingerprint, CertificateAuthorityFingerprint: testRootFingerprint,
Config: json.RawMessage(`{ Config: json.RawMessage(`{
"PKI": "pki", "AuthType": "approle",
"PKIRoleDefault": "pki-role", "AuthOptions": {"RoleID":"roleID","SecretID":"secretID","IsWrappingToken":false}
"RoleID": "roleID",
"SecretID": {"FromString": "secretID"},
"IsWrappingToken": false
}`), }`),
}) })
@ -201,15 +197,11 @@ func TestVaultCAS_CreateCertificate(t *testing.T) {
_, client := testCAHelper(t) _, client := testCAHelper(t)
options := VaultOptions{ options := VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "approle",
IsWrappingToken: false,
} }
type fields struct { type fields struct {
@ -291,7 +283,7 @@ func TestVaultCAS_GetCertificateAuthority(t *testing.T) {
} }
options := VaultOptions{ options := VaultOptions{
PKI: "pki", PKIMountPath: "pki",
} }
rootCert := parseCertificates(testRootCertificate)[0] rootCert := parseCertificates(testRootCertificate)[0]
@ -335,15 +327,11 @@ func TestVaultCAS_RevokeCertificate(t *testing.T) {
_, client := testCAHelper(t) _, client := testCAHelper(t)
options := VaultOptions{ options := VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "approle",
IsWrappingToken: false,
} }
type fields struct { type fields struct {
@ -407,15 +395,11 @@ func TestVaultCAS_RenewCertificate(t *testing.T) {
_, client := testCAHelper(t) _, client := testCAHelper(t)
options := VaultOptions{ options := VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "approle",
IsWrappingToken: false,
} }
type fields struct { type fields struct {
@ -464,202 +448,66 @@ func TestVaultCAS_loadOptions(t *testing.T) {
want *VaultOptions want *VaultOptions
wantErr bool wantErr bool
}{ }{
{
"ok mandatory with SecretID FromString",
`{"RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
&VaultOptions{
PKI: "pki",
PKIRoleDefault: "default",
PKIRoleRSA: "default",
PKIRoleEC: "default",
PKIRoleEd25519: "default",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
},
false,
},
{
"ok mandatory with SecretID FromFile",
`{"RoleID": "roleID", "SecretID": {"FromFile": "secretID"}}`,
&VaultOptions{
PKI: "pki",
PKIRoleDefault: "default",
PKIRoleRSA: "default",
PKIRoleEC: "default",
PKIRoleEd25519: "default",
RoleID: "roleID",
SecretID: auth.SecretID{FromFile: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
},
false,
},
{
"ok mandatory with SecretID FromEnv",
`{"RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`,
&VaultOptions{
PKI: "pki",
PKIRoleDefault: "default",
PKIRoleRSA: "default",
PKIRoleEC: "default",
PKIRoleEd25519: "default",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
},
false,
},
{ {
"ok mandatory PKIRole PKIRoleEd25519", "ok mandatory PKIRole PKIRoleEd25519",
`{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, `{"PKIRoleDefault": "role", "PKIRoleEd25519": "ed25519"}`,
&VaultOptions{ &VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "role", PKIRoleRSA: "role",
PKIRoleEC: "role", PKIRoleEC: "role",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
}, },
false, false,
}, },
{ {
"ok mandatory PKIRole PKIRoleEC", "ok mandatory PKIRole PKIRoleEC",
`{"PKIRoleDefault": "role", "PKIRoleEC": "ec" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, `{"PKIRoleDefault": "role", "PKIRoleEC": "ec"}`,
&VaultOptions{ &VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "role", PKIRoleRSA: "role",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "role", PKIRoleEd25519: "role",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
}, },
false, false,
}, },
{ {
"ok mandatory PKIRole PKIRoleRSA", "ok mandatory PKIRole PKIRoleRSA",
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa" , "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa"}`,
&VaultOptions{ &VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "role", PKIRoleEC: "role",
PKIRoleEd25519: "role", PKIRoleEd25519: "role",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
}, },
false, false,
}, },
{ {
"ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519", "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519",
`{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, `{"PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`,
&VaultOptions{ &VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "default", PKIRoleDefault: "default",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
}, },
false, false,
}, },
{ {
"ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault", "ok mandatory PKIRoleRSA PKIRoleEC PKIRoleEd25519 with useless PKIRoleDefault",
`{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519", "RoleID": "roleID", "SecretID": {"FromEnv": "secretID"}}`, `{"PKIRoleDefault": "role", "PKIRoleRSA": "rsa", "PKIRoleEC": "ec", "PKIRoleEd25519": "ed25519"}`,
&VaultOptions{ &VaultOptions{
PKI: "pki", PKIMountPath: "pki",
PKIRoleDefault: "role", PKIRoleDefault: "role",
PKIRoleRSA: "rsa", PKIRoleRSA: "rsa",
PKIRoleEC: "ec", PKIRoleEC: "ec",
PKIRoleEd25519: "ed25519", PKIRoleEd25519: "ed25519",
RoleID: "roleID",
SecretID: auth.SecretID{FromEnv: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: false,
}, },
false, false,
}, },
{
"ok mandatory with AppRole",
`{"AppRole": "test", "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
&VaultOptions{
PKI: "pki",
PKIRoleDefault: "default",
PKIRoleRSA: "default",
PKIRoleEC: "default",
PKIRoleEd25519: "default",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "test",
IsWrappingToken: false,
},
false,
},
{
"ok mandatory with IsWrappingToken",
`{"IsWrappingToken": true, "RoleID": "roleID", "SecretID": {"FromString": "secretID"}}`,
&VaultOptions{
PKI: "pki",
PKIRoleDefault: "default",
PKIRoleRSA: "default",
PKIRoleEC: "default",
PKIRoleEd25519: "default",
RoleID: "roleID",
SecretID: auth.SecretID{FromString: "secretID"},
AppRole: "auth/approle",
IsWrappingToken: true,
},
false,
},
{
"fail with SecretID FromFail",
`{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`,
nil,
true,
},
{
"fail with SecretID empty FromEnv",
`{"RoleID": "roleID", "SecretID": {"FromEnv": ""}}`,
nil,
true,
},
{
"fail with SecretID empty FromFile",
`{"RoleID": "roleID", "SecretID": {"FromFile": ""}}`,
nil,
true,
},
{
"fail with SecretID empty FromString",
`{"RoleID": "roleID", "SecretID": {"FromString": ""}}`,
nil,
true,
},
{
"fail mandatory with SecretID FromFail",
`{"RoleID": "roleID", "SecretID": {"FromFail": "secretID"}}`,
nil,
true,
},
{
"fail missing RoleID",
`{"SecretID": {"FromString": "secretID"}}`,
nil,
true,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"strconv" "strconv"
@ -56,6 +57,29 @@ type AuthDB interface {
Shutdown() error Shutdown() error
} }
type dbKey struct{}
// NewContext adds the given authority database to the context.
func NewContext(ctx context.Context, db AuthDB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// FromContext returns the current authority database from the given context.
func FromContext(ctx context.Context) (db AuthDB, ok bool) {
db, ok = ctx.Value(dbKey{}).(AuthDB)
return
}
// MustFromContext returns the current database from the given context. It
// will panic if it's not in the context.
func MustFromContext(ctx context.Context) AuthDB {
if db, ok := FromContext(ctx); !ok {
panic("authority database is not in the context")
} else {
return db
}
}
// CertificateStorer is an extension of AuthDB that allows to store // CertificateStorer is an extension of AuthDB that allows to store
// certificates. // certificates.
type CertificateStorer interface { type CertificateStorer interface {

1
go.mod
View file

@ -29,6 +29,7 @@ require (
github.com/googleapis/gax-go/v2 v2.1.1 github.com/googleapis/gax-go/v2 v2.1.1
github.com/hashicorp/vault/api v1.3.1 github.com/hashicorp/vault/api v1.3.1
github.com/hashicorp/vault/api/auth/approle v0.1.1 github.com/hashicorp/vault/api/auth/approle v0.1.1
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0
github.com/jhump/protoreflect v1.9.0 // indirect github.com/jhump/protoreflect v1.9.0 // indirect
github.com/mattn/go-colorable v0.1.8 // indirect github.com/mattn/go-colorable v0.1.8 // indirect
github.com/mattn/go-isatty v0.0.13 // indirect github.com/mattn/go-isatty v0.0.13 // indirect

2
go.sum
View file

@ -449,6 +449,8 @@ github.com/hashicorp/vault/api v1.3.1 h1:pkDkcgTh47PRjY1NEFeofqR4W/HkNUi9qIakESO
github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw= github.com/hashicorp/vault/api v1.3.1/go.mod h1:QeJoWxMFt+MsuWcYhmwRLwKEXrjwAFFywzhptMsTIUw=
github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ= github.com/hashicorp/vault/api/auth/approle v0.1.1 h1:R5yA+xcNvw1ix6bDuWOaLOq2L4L77zDCVsethNw97xQ=
github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag= github.com/hashicorp/vault/api/auth/approle v0.1.1/go.mod h1:mHOLgh//xDx4dpqXoq6tS8Ob0FoCFWLU2ibJ26Lfmag=
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0 h1:6BtyahbF4aQp8gg3ww0A/oIoqzbhpNP1spXU3nHE0n0=
github.com/hashicorp/vault/api/auth/kubernetes v0.1.0/go.mod h1:Pdgk78uIs0mgDOLvc3a+h/vYIT9rznw2sz+ucuH9024=
github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY= github.com/hashicorp/vault/sdk v0.3.0 h1:kR3dpxNkhh/wr6ycaJYqp6AFT/i2xaftbfnwZduTKEY=
github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0= github.com/hashicorp/vault/sdk v0.3.0/go.mod h1:aZ3fNuL5VNydQk8GcLJ2TV8YCRVvyaakYkhZRoVuhj0=
github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M= github.com/hashicorp/yamux v0.0.0-20180604194846-3520598351bb h1:b5rjCoWHc7eqmAS4/qyk21ZsHyb6Mxv/jykxvNTkU4M=

View file

@ -38,8 +38,8 @@ type request struct {
Message []byte Message []byte
} }
// response is a SCEP server response. // Response is a SCEP server Response.
type response struct { type Response struct {
Operation string Operation string
CACertNum int CACertNum int
Data []byte Data []byte
@ -52,25 +52,48 @@ type handler struct {
auth *scep.Authority auth *scep.Authority
} }
// Route traffic and implement the Router interface.
//
// Deprecated: use scep.Route(r api.Router)
func (h *handler) Route(r api.Router) {
route(r, func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := scep.NewContext(r.Context(), h.auth)
next(w, r.WithContext(ctx))
}
})
}
// New returns a new SCEP API router. // New returns a new SCEP API router.
//
// Deprecated: use scep.Route(r api.Router)
func New(auth *scep.Authority) api.RouterHandler { func New(auth *scep.Authority) api.RouterHandler {
return &handler{ return &handler{auth: auth}
auth: auth,
}
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *handler) Route(r api.Router) { func Route(r api.Router) {
getLink := h.auth.GetLinkExplicit route(r, nil)
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) }
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post)) func route(r api.Router, middleware func(next http.HandlerFunc) http.HandlerFunc) {
r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) getHandler := lookupProvisioner(Get)
postHandler := lookupProvisioner(Post)
// For backward compatibility.
if middleware != nil {
getHandler = middleware(getHandler)
postHandler = middleware(postHandler)
}
r.MethodFunc(http.MethodGet, "/{provisionerName}/*", getHandler)
r.MethodFunc(http.MethodGet, "/{provisionerName}", getHandler)
r.MethodFunc(http.MethodPost, "/{provisionerName}/*", postHandler)
r.MethodFunc(http.MethodPost, "/{provisionerName}", postHandler)
} }
// Get handles all SCEP GET requests // Get handles all SCEP GET requests
func (h *handler) Get(w http.ResponseWriter, r *http.Request) { func Get(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
fail(w, fmt.Errorf("invalid scep get request: %w", err)) fail(w, fmt.Errorf("invalid scep get request: %w", err))
@ -78,15 +101,15 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
var res response var res Response
switch req.Operation { switch req.Operation {
case opnGetCACert: case opnGetCACert:
res, err = h.GetCACert(ctx) res, err = GetCACert(ctx)
case opnGetCACaps: case opnGetCACaps:
res, err = h.GetCACaps(ctx) res, err = GetCACaps(ctx)
case opnPKIOperation: case opnPKIOperation:
res, err = h.PKIOperation(ctx, req) res, err = PKIOperation(ctx, req)
default: default:
err = fmt.Errorf("unknown operation: %s", req.Operation) err = fmt.Errorf("unknown operation: %s", req.Operation)
} }
@ -100,20 +123,17 @@ func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
} }
// Post handles all SCEP POST requests // Post handles all SCEP POST requests
func (h *handler) Post(w http.ResponseWriter, r *http.Request) { func Post(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
fail(w, fmt.Errorf("invalid scep post request: %w", err)) fail(w, fmt.Errorf("invalid scep post request: %w", err))
return return
} }
ctx := r.Context() var res Response
var res response
switch req.Operation { switch req.Operation {
case opnPKIOperation: case opnPKIOperation:
res, err = h.PKIOperation(ctx, req) res, err = PKIOperation(r.Context(), req)
default: default:
err = fmt.Errorf("unknown operation: %s", req.Operation) err = fmt.Errorf("unknown operation: %s", req.Operation)
} }
@ -127,7 +147,6 @@ func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
} }
func decodeRequest(r *http.Request) (request, error) { func decodeRequest(r *http.Request) (request, error) {
defer r.Body.Close() defer r.Body.Close()
method := r.Method method := r.Method
@ -179,9 +198,8 @@ func decodeRequest(r *http.Request) (request, error) {
// lookupProvisioner loads the provisioner associated with the request. // lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist. // Responds 404 if the provisioner does not exist.
func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "provisionerName") name := chi.URLParam(r, "provisionerName")
provisionerName, err := url.PathUnescape(name) provisionerName, err := url.PathUnescape(name)
if err != nil { if err != nil {
@ -189,7 +207,9 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return return
} }
p, err := h.auth.LoadProvisionerByName(provisionerName) ctx := r.Context()
auth := scep.MustFromContext(ctx)
p, err := auth.LoadProvisionerByName(provisionerName)
if err != nil { if err != nil {
fail(w, err) fail(w, err)
return return
@ -201,25 +221,24 @@ func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return return
} }
ctx := r.Context()
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
// GetCACert returns the CA certificates in a SCEP response // GetCACert returns the CA certificates in a SCEP response
func (h *handler) GetCACert(ctx context.Context) (response, error) { func GetCACert(ctx context.Context) (Response, error) {
auth := scep.MustFromContext(ctx)
certs, err := h.auth.GetCACertificates(ctx) certs, err := auth.GetCACertificates(ctx)
if err != nil { if err != nil {
return response{}, err return Response{}, err
} }
if len(certs) == 0 { if len(certs) == 0 {
return response{}, errors.New("missing CA cert") return Response{}, errors.New("missing CA cert")
} }
res := response{ res := Response{
Operation: opnGetCACert, Operation: opnGetCACert,
CACertNum: len(certs), CACertNum: len(certs),
} }
@ -232,7 +251,7 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
// not signed or encrypted data has to be returned. // not signed or encrypted data has to be returned.
data, err := microscep.DegenerateCertificates(certs) data, err := microscep.DegenerateCertificates(certs)
if err != nil { if err != nil {
return response{}, err return Response{}, err
} }
res.Data = data res.Data = data
} }
@ -241,11 +260,11 @@ func (h *handler) GetCACert(ctx context.Context) (response, error) {
} }
// GetCACaps returns the CA capabilities in a SCEP response // GetCACaps returns the CA capabilities in a SCEP response
func (h *handler) GetCACaps(ctx context.Context) (response, error) { func GetCACaps(ctx context.Context) (Response, error) {
auth := scep.MustFromContext(ctx)
caps := auth.GetCACaps(ctx)
caps := h.auth.GetCACaps(ctx) res := Response{
res := response{
Operation: opnGetCACaps, Operation: opnGetCACaps,
Data: formatCapabilities(caps), Data: formatCapabilities(caps),
} }
@ -254,13 +273,12 @@ func (h *handler) GetCACaps(ctx context.Context) (response, error) {
} }
// PKIOperation performs PKI operations and returns a SCEP response // PKIOperation performs PKI operations and returns a SCEP response
func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { func PKIOperation(ctx context.Context, req request) (Response, error) {
// parse the message using microscep implementation // parse the message using microscep implementation
microMsg, err := microscep.ParsePKIMessage(req.Message) microMsg, err := microscep.ParsePKIMessage(req.Message)
if err != nil { if err != nil {
// return the error, because we can't use the msg for creating a CertRep // return the error, because we can't use the msg for creating a CertRep
return response{}, err return Response{}, err
} }
// this is essentially doing the same as microscep.ParsePKIMessage, but // this is essentially doing the same as microscep.ParsePKIMessage, but
@ -268,7 +286,7 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
// wrapper for the microscep implementation. // wrapper for the microscep implementation.
p7, err := pkcs7.Parse(microMsg.Raw) p7, err := pkcs7.Parse(microMsg.Raw)
if err != nil { if err != nil {
return response{}, err return Response{}, err
} }
// copy over properties to our internal PKIMessage // copy over properties to our internal PKIMessage
@ -280,8 +298,9 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
P7: p7, P7: p7,
} }
if err := h.auth.DecryptPKIEnvelope(ctx, msg); err != nil { auth := scep.MustFromContext(ctx)
return response{}, err if err := auth.DecryptPKIEnvelope(ctx, msg); err != nil {
return Response{}, err
} }
// NOTE: at this point we have sufficient information for returning nicely signed CertReps // NOTE: at this point we have sufficient information for returning nicely signed CertReps
@ -293,13 +312,13 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
// a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients. // a certificate exists; then it will use RenewalReq. Adding the challenge check here may be a small breaking change for clients.
// We'll have to see how it works out. // We'll have to see how it works out.
if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq { if msg.MessageType == microscep.PKCSReq || msg.MessageType == microscep.RenewalReq {
challengeMatches, err := h.auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword) challengeMatches, err := auth.MatchChallengePassword(ctx, msg.CSRReqMessage.ChallengePassword)
if err != nil { if err != nil {
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password")) return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("error when checking password"))
} }
if !challengeMatches { if !challengeMatches {
// TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too. // TODO: can this be returned safely to the client? In the end, if the password was correct, that gains a bit of info too.
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided")) return createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.New("wrong password provided"))
} }
} }
@ -311,12 +330,12 @@ func (h *handler) PKIOperation(ctx context.Context, req request) (response, erro
// Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification // Authentication by the (self-signed) certificate with an optional challenge is required; supporting renewals incl. verification
// of the client cert is not. // of the client cert is not.
certRep, err := h.auth.SignCSR(ctx, csr, msg) certRep, err := auth.SignCSR(ctx, csr, msg)
if err != nil { if err != nil {
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) return createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
} }
res := response{ res := Response{
Operation: opnPKIOperation, Operation: opnPKIOperation,
Data: certRep.Raw, Data: certRep.Raw,
Certificate: certRep.Certificate, Certificate: certRep.Certificate,
@ -330,7 +349,7 @@ func formatCapabilities(caps []string) []byte {
} }
// writeResponse writes a SCEP response back to the SCEP client. // writeResponse writes a SCEP response back to the SCEP client.
func writeResponse(w http.ResponseWriter, res response) { func writeResponse(w http.ResponseWriter, res Response) {
if res.Error != nil { if res.Error != nil {
log.Error(w, res.Error) log.Error(w, res.Error)
@ -350,19 +369,20 @@ func fail(w http.ResponseWriter, err error) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} }
func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { func createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (Response, error) {
certRepMsg, err := h.auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) auth := scep.MustFromContext(ctx)
certRepMsg, err := auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
if err != nil { if err != nil {
return response{}, err return Response{}, err
} }
return response{ return Response{
Operation: opnPKIOperation, Operation: opnPKIOperation,
Data: certRepMsg.Raw, Data: certRepMsg.Raw,
Error: failError, Error: failError,
}, nil }, nil
} }
func contentHeader(r response) string { func contentHeader(r Response) string {
switch r.Operation { switch r.Operation {
default: default:
return "text/plain" return "text/plain"

View file

@ -27,6 +27,29 @@ type Authority struct {
signAuth SignAuthority signAuth SignAuthority
} }
type authorityKey struct{}
// NewContext adds the given authority to the context.
func NewContext(ctx context.Context, a *Authority) context.Context {
return context.WithValue(ctx, authorityKey{}, a)
}
// FromContext returns the current authority from the given context.
func FromContext(ctx context.Context) (a *Authority, ok bool) {
a, ok = ctx.Value(authorityKey{}).(*Authority)
return
}
// MustFromContext returns the current authority from the given context. It will
// panic if the authority is not in the context.
func MustFromContext(ctx context.Context) *Authority {
if a, ok := FromContext(ctx); !ok {
panic("scep authority is not in the context")
} else {
return a
}
}
// AuthorityOptions required to create a new SCEP Authority. // AuthorityOptions required to create a new SCEP Authority.
type AuthorityOptions struct { type AuthorityOptions struct {
// Service provides the certificate chain, the signer and the decrypter to the Authority // Service provides the certificate chain, the signer and the decrypter to the Authority
@ -163,7 +186,6 @@ func (a *Authority) GetCACertificates(ctx context.Context) ([]*x509.Certificate,
// DecryptPKIEnvelope decrypts an enveloped message // DecryptPKIEnvelope decrypts an enveloped message
func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error { func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) error {
p7c, err := pkcs7.Parse(msg.P7.Content) p7c, err := pkcs7.Parse(msg.P7.Content)
if err != nil { if err != nil {
return fmt.Errorf("error parsing pkcs7 content: %w", err) return fmt.Errorf("error parsing pkcs7 content: %w", err)
@ -210,7 +232,6 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err
// SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials // SignCSR creates an x509.Certificate based on a CSR template and Cert Authority credentials
// returns a new PKIMessage with CertRep data // returns a new PKIMessage with CertRep data
func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) { func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, msg *PKIMessage) (*PKIMessage, error) {
// TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate // TODO: intermediate storage of the request? In SCEP it's possible to request a csr/certificate
// to be signed, which can be performed asynchronously / out-of-band. In that case a client can // to be signed, which can be performed asynchronously / out-of-band. In that case a client can
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model // poll for the status. It seems to be similar as what can happen in ACME, so might want to model
@ -432,7 +453,6 @@ func (a *Authority) CreateFailureResponse(ctx context.Context, csr *x509.Certifi
// MatchChallengePassword verifies a SCEP challenge password // MatchChallengePassword verifies a SCEP challenge password
func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) { func (a *Authority) MatchChallengePassword(ctx context.Context, password string) (bool, error) {
p, err := provisionerFromContext(ctx) p, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
return false, err return false, err