Fix panic in acme/api tests.

This commit is contained in:
Mariano Cano 2022-05-02 17:35:35 -07:00
parent d1f75f1720
commit 6f9d847bc6
10 changed files with 333 additions and 393 deletions

View file

@ -296,10 +296,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID} acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
@ -315,9 +314,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetOrdersByAccountID(w, req) GetOrdersByAccountID(w, req)
res := w.Result() res := w.Result()
@ -363,6 +362,7 @@ func TestHandler_NewAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -371,6 +371,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, nil) ctx := context.WithValue(context.Background(), payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -379,6 +380,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to "+ err: acme.NewError(acme.ErrorMalformedType, "failed to "+
@ -393,6 +395,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -405,8 +408,9 @@ func TestHandler_NewAccount(t *testing.T) {
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -418,9 +422,10 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -432,10 +437,11 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -454,9 +460,9 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
@ -471,7 +477,7 @@ func TestHandler_NewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -510,9 +516,9 @@ func TestHandler_NewAccount(t *testing.T) {
} }
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, scepProvisioner)
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
@ -551,8 +557,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
eak := &acme.ExternalAccountKey{ eak := &acme.ExternalAccountKey{
ID: "eakID", ID: "eakID",
@ -599,8 +604,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -635,11 +639,11 @@ func TestHandler_NewAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
acc: acc, acc: acc,
statusCode: 200, statusCode: 200,
@ -664,8 +668,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = false prov.RequireEAB = false
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -719,8 +722,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -759,9 +761,9 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
NewAccount(w, req) NewAccount(w, req)
res := w.Result() res := w.Result()
@ -814,6 +816,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -822,6 +825,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil) ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -830,6 +834,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -839,6 +844,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -848,6 +854,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
@ -862,6 +869,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -894,10 +902,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -914,11 +921,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{} uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -929,10 +936,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -946,11 +952,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
}, },
"ok/post-as-get": func(t *testing.T) test { "ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -959,9 +965,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetOrUpdateAccount(w, req) GetOrUpdateAccount(w, req)
res := w.Result() res := w.Result()

View file

@ -98,8 +98,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -143,8 +142,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
return test{ return test{
@ -198,8 +196,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} }
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, scepProvisioner)
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
ctx: ctx, ctx: ctx,
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
@ -218,8 +215,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -264,8 +260,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -310,8 +305,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -358,8 +352,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -408,8 +401,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -458,8 +450,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -506,8 +497,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
boundAt := time.Now().Add(1 * time.Second) boundAt := time.Now().Add(1 * time.Second)
@ -565,8 +555,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -623,8 +612,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -678,8 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -734,8 +721,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, nil) ctx := context.WithValue(context.Background(), jwkContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -762,10 +748,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{ ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
// db: tc.db, got, err := validateExternalAccountBinding(ctx, tc.nar)
// }
got, err := validateExternalAccountBinding(tc.ctx, tc.nar)
wantErr := tc.err != nil wantErr := tc.err != nil
gotErr := err != nil gotErr := err != nil
if wantErr != gotErr { if wantErr != gotErr {

View file

@ -223,7 +223,6 @@ func (d *Directory) ToLog() (interface{}, error) {
func GetDirectory(w http.ResponseWriter, r *http.Request) { func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
fmt.Println(acmeProv, err)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return

View file

@ -3,6 +3,7 @@ package api
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
@ -24,6 +25,29 @@ import (
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return a
}
}
func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetNonce(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -52,7 +76,7 @@ func TestHandler_GetNonce(t *testing.T) {
} }
func TestHandler_GetDirectory(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("ca.smallstep.com", "acme") linker := acme.NewLinker("ca.smallstep.com", "acme")
_ = linker _ = linker
type test struct { type test struct {
ctx context.Context ctx context.Context
@ -62,13 +86,10 @@ func TestHandler_GetDirectory(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), err: acme.NewErrorISE("provisioner is not in context"),
} }
}, },
"fail/different-provisioner": func(t *testing.T) test { "fail/different-provisioner": func(t *testing.T) test {
@ -76,9 +97,7 @@ func TestHandler_GetDirectory(t *testing.T) {
Type: "SCEP", Type: "SCEP",
Name: "test@scep-<test>provisioner.com", Name: "test@scep-<test>provisioner.com",
} }
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -89,8 +108,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -109,8 +127,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -131,9 +148,9 @@ func TestHandler_GetDirectory(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: linker} ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetDirectory(w, req) GetDirectory(w, req)
res := w.Result() res := w.Result()
@ -220,7 +237,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -286,10 +303,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
@ -305,9 +321,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetAuthorization(w, req) GetAuthorization(w, req)
res := w.Result() res := w.Result()
@ -448,9 +464,9 @@ func TestHandler_GetCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db} ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetCertificate(w, req) GetCertificate(w, req)
res := w.Result() res := w.Result()
@ -492,7 +508,7 @@ func TestHandler_GetChallenge(t *testing.T) {
type test struct { type test struct {
db acme.DB db acme.DB
vco *acme.ValidateChallengeOptions vc acme.Client
ctx context.Context ctx context.Context
statusCode int statusCode int
ch *acme.Challenge ch *acme.Challenge
@ -501,6 +517,7 @@ func TestHandler_GetChallenge(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -508,6 +525,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), accContextKey, nil), ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -517,6 +535,7 @@ func TestHandler_GetChallenge(t *testing.T) {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -524,10 +543,11 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -535,7 +555,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/db.GetChallenge-error": func(t *testing.T) test { "fail/db.GetChallenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -554,7 +574,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -573,7 +593,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/no-jwk": func(t *testing.T) test { "fail/no-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -592,7 +612,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-jwk": func(t *testing.T) test { "fail/nil-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
@ -612,7 +632,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/validate-challenge-error": func(t *testing.T) test { "fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -640,8 +660,8 @@ func TestHandler_GetChallenge(t *testing.T) {
return acme.NewErrorISE("force") return acme.NewErrorISE("force")
}, },
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -652,14 +672,13 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := _jwk.Public() _pub := _jwk.Public()
ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, jwkContextKey, &_pub)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -691,8 +710,8 @@ func TestHandler_GetChallenge(t *testing.T) {
URL: u, URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"), Error: acme.NewError(acme.ErrorConnectionType, "force"),
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -704,9 +723,9 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetChallenge(w, req) GetChallenge(w, req)
res := w.Result() res := w.Result()

View file

@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -63,7 +62,12 @@ func addDirLink(next nextHTTP) nextHTTP {
// application/jose+json. // application/jose+json.
func verifyContentType(next nextHTTP) nextHTTP { func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
p := acme.MustProvisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil {
render.Error(w, err)
return
}
u := &url.URL{ u := &url.URL{
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""), Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
} }
@ -260,32 +264,6 @@ func extractJWK(next nextHTTP) nextHTTP {
} }
} }
// lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist.
func lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := mustAuthority(r.Context()).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx))
}
}
// checkPrerequisites checks if all prerequisites for serving ACME // checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration. // are met by the CA configuration.
func checkPrerequisites(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP {
@ -446,16 +424,12 @@ type ContextKey string
const ( const (
// accContextKey account key // accContextKey account key
accContextKey = ContextKey("acc") accContextKey = ContextKey("acc")
// baseURLContextKey baseURL key
baseURLContextKey = ContextKey("baseURL")
// jwsContextKey jws key // jwsContextKey jws key
jwsContextKey = ContextKey("jws") jwsContextKey = ContextKey("jws")
// jwkContextKey jwk key // jwkContextKey jwk key
jwkContextKey = ContextKey("jwk") jwkContextKey = ContextKey("jwk")
// payloadContextKey payload key // payloadContextKey payload key
payloadContextKey = ContextKey("payload") payloadContextKey = ContextKey("payload")
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
) )
// accountFromContext searches the context for an ACME account. Returns the // accountFromContext searches the context for an ACME account. Returns the
@ -468,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
return val, nil return val, nil
} }
// baseURLFromContext returns the baseURL if one is stored in the context.
func baseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// jwkFromContext searches the context for a JWK. Returns the JWK or an error. // jwkFromContext searches the context for a JWK. Returns the JWK or an error.
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
@ -495,14 +460,29 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
return val, nil return val, nil
} }
// provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
p, ok := acme.ProvisionerFromContext(ctx)
if !ok || p == nil {
return nil, acme.NewErrorISE("provisioner expected in request context")
}
return p, nil
}
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error. // pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
p, ok := acme.MustProvisionerFromContext(ctx).(*provisioner.ACME) p, err := provisionerFromContext(ctx)
if err != nil {
return nil, err
}
ap, ok := p.(*provisioner.ACME)
if !ok { if !ok {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
} }
return p, nil
return ap, nil
} }
// payloadFromContext searches the context for a payload. Returns the payload // payloadFromContext searches the context for a payload. Returns the payload

View file

@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
} }
func Test_baseURLFromRequest(t *testing.T) { func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
tests := []struct { for _, a := range args {
name string switch v := a.(type) {
targetURL string case acme.DB:
expectedResult *url.URL ctx = acme.NewDatabaseContext(ctx, v)
requestPreparer func(*http.Request) case acme.Linker:
}{ ctx = acme.NewLinkerContext(ctx, v)
{ case acme.PrerequisitesChecker:
"HTTPS host pass-through failed.", ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
"https://my.dummy.host",
&url.URL{Scheme: "https", Host: "my.dummy.host"},
nil,
},
{
"Port pass-through failed",
"https://host.with.port:8080",
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
nil,
},
{
"Explicit host from Request.Host was not used.",
"https://some.target.host:8080",
&url.URL{Scheme: "https", Host: "proxied.host"},
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Missing Request.Host value did not result in empty string result.",
"https://some.host",
nil,
func(r *http.Request) {
r.Host = ""
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest("GET", tc.targetURL, nil)
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := getBaseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
}
})
} }
} }
return ctx
func TestHandler_baseURLFromRequest(t *testing.T) {
// h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
bu := baseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
}
}
baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
baseURLFromRequest(next)(w, req)
} }
func TestHandler_addNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
@ -139,8 +74,8 @@ func TestHandler_addNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
addNonce(testNext)(w, req) addNonce(testNext)(w, req)
res := w.Result() res := w.Result()
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
link string link string
linker Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
return test{ return test{
linker: NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200, statusCode: 200,
@ -195,7 +128,6 @@ func TestHandler_addDirLink(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler
ctx context.Context ctx context.Context
contentType string contentType string
err *acme.Error err *acme.Error
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/provisioner-not-set": func(t *testing.T) test { "fail/provisioner-not-set": func(t *testing.T) test {
return test{ return test{
h: Handler{
// linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.Background(), ctx: context.Background(),
contentType: "foo", contentType: "foo",
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/general-bad-content-type": func(t *testing.T) test { "fail/general-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{
// linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/certificate-bad-content-type": func(t *testing.T) test { "fail/certificate-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkix-cert": func(t *testing.T) test { "ok/certificate/pkix-cert": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkix-cert", contentType: "application/pkix-cert",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/jose+json": func(t *testing.T) test { "ok/certificate/jose+json": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkcs7-mime": func(t *testing.T) test { "ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkcs7-mime", contentType: "application/pkcs7-mime",
statusCode: 200, statusCode: 200,
} }
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
type test struct { type test struct {
linker Linker linker acme.Linker
db acme.DB db acme.DB
ctx context.Context ctx context.Context
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz")) _jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed) ctx = context.WithValue(ctx, jwsContextKey, _parsed)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
} }
}, },
"fail/account-not-found": func(t *testing.T) test { "fail/account-not-found": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
} }
}, },
"fail/GetAccount-error": func(t *testing.T) test { "fail/GetAccount-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk} acc := &acme.Account{Status: "valid", Key: jwk}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -881,9 +791,9 @@ func TestHandler_lookupJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
lookupJWK(tc.next)(w, req) lookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
} }
}, },
"fail/GetAccountByKey-error": func(t *testing.T) test { "fail/GetAccountByKey-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"} acc := &acme.Account{Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
} }
}, },
"ok/no-account": func(t *testing.T) test { "ok/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1077,9 +991,9 @@ func TestHandler_extractJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
extractJWK(tc.next)(w, req) extractJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, nil), ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/no-signature": func(t *testing.T) test { "fail/no-signature": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
@ -1444,9 +1365,9 @@ func TestHandler_validateJWS(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
validateJWS(tc.next)(w, req) validateJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account" u := "https://ca.smallstep.com/acme/account"
type test struct { type test struct {
db acme.DB db acme.DB
linker Linker linker acme.Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID) assert.Equals(t, kid, pub.KeyID)
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
linker: NewLinker("test.ca.smallstep.com", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
@ -1628,9 +1548,9 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
extractOrLookupJWK(tc.next)(w, req) extractOrLookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
u := fmt.Sprintf("%s/acme/%s/account/1234", u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName) baseURL, provName)
type test struct { type test struct {
linker Linker linker acme.Linker
ctx context.Context ctx context.Context
prerequisitesChecker func(context.Context) (bool, error) prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test { "fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"fail/prerequisites-nok": func(t *testing.T) test { "fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {

View file

@ -72,13 +72,17 @@ func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -189,13 +193,17 @@ func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
@ -228,13 +236,17 @@ func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)

View file

@ -276,15 +276,17 @@ func TestHandler_GetOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -294,6 +296,7 @@ func TestHandler_GetOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -301,9 +304,10 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), nil)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -311,7 +315,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -325,7 +329,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -341,7 +345,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -357,7 +361,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/order-update-error": func(t *testing.T) test { "fail/order-update-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -381,10 +385,9 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
@ -421,9 +424,9 @@ func TestHandler_GetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
GetOrder(w, req) GetOrder(w, req)
res := w.Result() res := w.Result()
@ -636,8 +639,8 @@ func TestHandler_newAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
// h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
if err := newAuthorization(context.Background(), tc.az); err != nil { if err := newAuthorization(ctx, tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
@ -677,15 +680,17 @@ func TestHandler_NewOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -695,6 +700,7 @@ func TestHandler_NewOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -702,9 +708,10 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -713,8 +720,9 @@ func TestHandler_NewOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -722,10 +730,11 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("paylod does not exist"),
@ -733,10 +742,11 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
@ -747,10 +757,11 @@ func TestHandler_NewOrder(t *testing.T) {
fr := &NewOrderRequest{} fr := &NewOrderRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
@ -765,7 +776,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
@ -793,7 +804,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var ( var (
@ -863,10 +874,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3, ch4 **acme.Challenge ch1, ch2, ch3, ch4 **acme.Challenge
az1ID, az2ID *string az1ID, az2ID *string
@ -978,10 +988,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1070,10 +1079,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1161,10 +1169,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1253,10 +1260,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1334,9 +1340,9 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
NewOrder(w, req) NewOrder(w, req)
res := w.Result() res := w.Result()
@ -1371,6 +1377,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
func TestHandler_FinalizeOrder(t *testing.T) { func TestHandler_FinalizeOrder(t *testing.T) {
mockMustAuthority(t, &mockCA{})
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -1429,15 +1436,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -1447,6 +1456,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1454,9 +1464,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1465,8 +1476,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -1474,10 +1486,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("paylod does not exist"),
@ -1485,10 +1498,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
@ -1499,10 +1513,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{} fr := &FinalizeRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
@ -1511,7 +1526,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1526,7 +1541,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1543,7 +1558,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1560,7 +1575,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/order-finalize-error": func(t *testing.T) test { "fail/order-finalize-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1585,10 +1600,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -1624,9 +1638,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
FinalizeOrder(w, req) FinalizeOrder(w, req)
res := w.Result() res := w.Result()

View file

@ -30,7 +30,6 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx) db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx) linker := acme.MustLinkerFromContext(ctx)
prov := acme.MustProvisionerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
@ -38,6 +37,12 @@ func RevokeCert(w http.ResponseWriter, r *http.Request) {
return return
} }
prov, err := provisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)

View file

@ -511,6 +511,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
ctx := context.Background() ctx := context.Background()
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -519,6 +520,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, nil) ctx := context.WithValue(context.Background(), jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -527,6 +529,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -534,8 +537,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, nil) ctx = acme.NewProvisionerContext(ctx, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -543,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -552,9 +557,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -563,9 +569,10 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/unmarshal-payload": func(t *testing.T) test { "fail/unmarshal-payload": func(t *testing.T) test {
malformedPayload := []byte(`{"payload":malformed?}`) malformedPayload := []byte(`{"payload":malformed?}`)
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("error unmarshaling payload"), err: acme.NewErrorISE("error unmarshaling payload"),
@ -577,10 +584,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -596,10 +604,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
emptyPayloadBytes, err := json.Marshal(emptyPayload) emptyPayloadBytes, err := json.Marshal(emptyPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -610,7 +619,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/db.GetCertificateBySerial": func(t *testing.T) test { "fail/db.GetCertificateBySerial": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -628,7 +637,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/different-certificate-contents": func(t *testing.T) test { "fail/different-certificate-contents": func(t *testing.T) test {
aDifferentCert, _, err := generateCertKeyPair() aDifferentCert, _, err := generateCertKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -647,7 +656,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -666,7 +675,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
@ -687,11 +696,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -717,11 +725,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-authorized": func(t *testing.T) test { "fail/account-not-authorized": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -771,10 +778,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -798,11 +804,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-revoked-check-fails": func(t *testing.T) test { "fail/certificate-revoked-check-fails": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -832,7 +837,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-already-revoked": func(t *testing.T) test { "fail/certificate-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -870,7 +875,7 @@ func TestHandler_RevokeCert(t *testing.T) {
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -908,7 +913,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
} }
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -940,7 +945,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke": func(t *testing.T) test { "fail/ca.Revoke": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -972,7 +977,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke-already-revoked": func(t *testing.T) test { "fail/ca.Revoke-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -1003,11 +1008,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"ok/using-account-key": func(t *testing.T) test { "ok/using-account-key": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1031,10 +1035,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes)) jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1057,9 +1060,10 @@ func TestHandler_RevokeCert(t *testing.T) {
for name, setup := range tests { for name, setup := range tests {
tc := setup(t) tc := setup(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
mockMustAuthority(t, tc.ca)
req := httptest.NewRequest("POST", revokeURL, nil) req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
RevokeCert(w, req) RevokeCert(w, req)
res := w.Result() res := w.Result()