diff --git a/acme/api/account.go b/acme/api/account.go index c06c034a..30d406e4 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -180,8 +180,8 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { } } -// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. -func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { +// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. +func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/account_test.go b/acme/api/account_test.go index bdd61c59..d94819c7 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -12,7 +12,6 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" @@ -53,7 +52,7 @@ func TestNewAccountRequestValidate(t *testing.T) { OnlyReturnExisting: true, Contact: []string{"foo", "bar"}, }, - err: acme.MalformedErr(errors.Errorf("incompatible input; onlyReturnExisting must be alone")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; onlyReturnExisting must be alone"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -61,7 +60,7 @@ func TestNewAccountRequestValidate(t *testing.T) { nar: &NewAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "ok": func(t *testing.T) test { @@ -109,8 +108,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { Contact: []string{"foo", "bar"}, Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("incompatible input; " + - "contact and status updates are mutually exclusive")), + err: acme.NewError(acme.ErrorMalformedType, "incompatible input; "+ + "contact and status updates are mutually exclusive"), } }, "fail/bad-contact": func(t *testing.T) test { @@ -118,7 +117,7 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Contact: []string{"foo", ""}, }, - err: acme.MalformedErr(errors.Errorf("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/bad-status": func(t *testing.T) test { @@ -126,8 +125,8 @@ func TestUpdateAccountRequestValidate(t *testing.T) { uar: &UpdateAccountRequest{ Status: "foo", }, - err: acme.MalformedErr(errors.Errorf("cannot update account " + - "status to foo, only deactivated")), + err: acme.NewError(acme.ErrorMalformedType, "cannot update account "+ + "status to foo, only deactivated"), } }, "ok/contact": func(t *testing.T) test { @@ -168,13 +167,12 @@ func TestUpdateAccountRequestValidate(t *testing.T) { } } -func TestHandlerGetOrdersByAccount(t *testing.T) { +func TestHandler_GetOrdersByAccountID(t *testing.T) { oids := []string{ "https://ca.smallstep.com/acme/order/foo", "https://ca.smallstep.com/acme/order/bar", } accID := "account-id" - prov := newProv() // Request with chi context chiCtx := chi.NewRouteContext() @@ -182,67 +180,59 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s/orders", accID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "foo"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account ID does not match url param")), + err: acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"), } }, "fail/getOrdersByAccount-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) + db: &acme.MockDB{ + MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { assert.Equals(t, id, acc.ID) return oids, nil }, @@ -255,11 +245,11 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetOrdersByAccount(w, req) + h.GetOrdersByAccountID(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -268,15 +258,14 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(oids) @@ -288,7 +277,7 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { } } -func TestHandlerNewAccount(t *testing.T) { +func TestHandler_NewAccount(t *testing.T) { accID := "accountID" acc := acme.Account{ ID: accID, @@ -300,35 +289,34 @@ func TestHandlerNewAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to "+ + "unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -337,12 +325,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, "fail/no-existing-account": func(t *testing.T) test { @@ -351,12 +338,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-jwk": func(t *testing.T) test { @@ -365,12 +351,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -379,13 +364,12 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, nil) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/NewAccount-error": func(t *testing.T) test { @@ -396,23 +380,19 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return nil, acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/new-account": func(t *testing.T) test { @@ -423,28 +403,27 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ - auth: &mockAcmeAuthority{ - newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, ops.Contact, nar.Contact) - assert.Equals(t, ops.Key, jwk) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.True(t, abs) + assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 201, @@ -456,21 +435,11 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, ctx: ctx, statusCode: 200, } @@ -479,7 +448,7 @@ func TestHandlerNewAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -492,15 +461,14 @@ func TestHandlerNewAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) @@ -527,55 +495,51 @@ func TestHandlerGetUpdateAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx := context.WithValue(context.Background(), accContextKey, &acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to unmarshal new-account request payload: unexpected end of JSON input")), + err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), } }, "fail/malformed-payload-error": func(t *testing.T) test { @@ -584,62 +548,33 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("contact cannot be empty string")), + err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), } }, - "fail/Deactivate-error": func(t *testing.T) test { + "fail/update-error": func(t *testing.T) test { uar := &UpdateAccountRequest{ Status: "deactivated", } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return nil, acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return acme.NewErrorISE("force") }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), - } - }, - "fail/UpdateAccount-error": func(t *testing.T) test { - uar := &UpdateAccountRequest{ - Contact: []string{"foo", "bar"}, - } - b, err := json.Marshal(uar) - assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return nil, acme.ServerInternalErr(errors.New("force")) - }, - }, - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok/deactivate": func(t *testing.T) test { @@ -648,27 +583,27 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Status, acme.StatusDeactivated) + assert.Equals(t, upd.ID, acc.ID) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 200, @@ -678,21 +613,11 @@ func TestHandlerGetUpdateAccount(t *testing.T) { uar := &UpdateAccountRequest{} b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) - }, - }, ctx: ctx, statusCode: 200, } @@ -703,49 +628,50 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, id, accID) - assert.Equals(t, contacts, uar.Contact) - return &acc, nil - }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL.String(), provName, accID) + db: &acme.MockDB{ + MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { + assert.Equals(t, upd.Contact, uar.Contact) + assert.Equals(t, upd.ID, acc.ID) + return nil }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + */ }, ctx: ctx, statusCode: 200, } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, &acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, &acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, typ, acme.AccountLink) - assert.True(t, abs) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.Equals(t, ins, []string{accID}) - return fmt.Sprintf("%s/acme/%s/account/%s", - baseURL, provName, accID) + /* + auth: &mockAcmeAuthority{ + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL, provName, accID) + }, }, - }, + */ ctx: ctx, statusCode: 200, } @@ -754,7 +680,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -767,15 +693,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(acc) diff --git a/acme/api/handler.go b/acme/api/handler.go index 997456a7..31466c6c 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -90,9 +90,9 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } @@ -149,8 +149,8 @@ func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) } -// GetAuthz ACME api for retrieving an Authz. -func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { +// GetAuthorization ACME api for retrieving an Authz. +func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { ctx := r.Context() acc, err := accountFromContext(ctx) if err != nil { diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 7e19ea75..8a5ac694 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -3,7 +3,6 @@ package api import ( "bytes" "context" - "crypto/x509" "encoding/json" "encoding/pem" "fmt" @@ -14,209 +13,13 @@ import ( "time" "github.com/go-chi/chi" - "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) -type mockAcmeAuthority struct { - getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string - getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string - - deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccount func(ctx context.Context, accID string) (*acme.Account, error) - getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error) - newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error) - updateAccount func(context.Context, string, []string) (*acme.Account, error) - - getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error) - validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error) - getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error) - getDirectory func(ctx context.Context) (*acme.Directory, error) - getCertificate func(string, string) ([]byte, error) - - finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error) - getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error) - getOrdersByAccount func(ctx context.Context, accID string) ([]string, error) - newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error) - - loadProvisionerByID func(string) (provisioner.Interface, error) - newNonce func() (string, error) - useNonce func(string) error - ret1 interface{} - err error -} - -func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.deactivateAccount != nil { - return m.deactivateAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { - if m.finalizeOrder != nil { - return m.finalizeOrder(ctx, accID, id, csr) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) { - if m.getAccount != nil { - return m.getAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { - if m.getAccountByKey != nil { - return m.getAccountByKey(ctx, jwk) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) { - if m.getAuthz != nil { - return m.getAuthz(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Authz), m.err -} - -func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) { - if m.getCertificate != nil { - return m.getCertificate(accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]byte), m.err -} - -func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) { - if m.getChallenge != nil { - return m.getChallenge(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Challenge), m.err -} - -func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) { - if m.getDirectory != nil { - return m.getDirectory(ctx) - } - return m.ret1.(*acme.Directory), m.err -} - -func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - if m.getLink != nil { - return m.getLink(ctx, typ, abs, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string { - if m.getLinkExplicit != nil { - return m.getLinkExplicit(typ, provID, abs, baseURL, ins...) - } - return m.ret1.(string) -} - -func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) { - if m.getOrder != nil { - return m.getOrder(ctx, accID, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { - if m.getOrdersByAccount != nil { - return m.getOrdersByAccount(ctx, id) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.([]string), m.err -} - -func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Interface, error) { - if m.loadProvisionerByID != nil { - return m.loadProvisionerByID(provID) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { - if m.newAccount != nil { - return m.newAccount(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) NewNonce() (string, error) { - if m.newNonce != nil { - return m.newNonce() - } else if m.err != nil { - return "", m.err - } - return m.ret1.(string), m.err -} - -func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { - if m.newOrder != nil { - return m.newOrder(ctx, ops) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Order), m.err -} - -func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) { - if m.updateAccount != nil { - return m.updateAccount(ctx, id, contact) - } else if m.err != nil { - return nil, m.err - } - return m.ret1.(*acme.Account), m.err -} - -func (m *mockAcmeAuthority) UseNonce(nonce string) error { - if m.useNonce != nil { - return m.useNonce(nonce) - } - return m.err -} - -func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - switch { - case m.validateChallenge != nil: - return m.validateChallenge(ctx, accID, id, jwk) - case m.err != nil: - return nil, m.err - default: - return m.ret1.(*acme.Challenge), m.err - } -} - -func TestHandlerGetNonce(t *testing.T) { +func TestHandler_GetNonce(t *testing.T) { tests := []struct { name string statusCode int @@ -230,7 +33,7 @@ func TestHandlerGetNonce(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} w := httptest.NewRecorder() req.Method = tt.name h.GetNonce(w, req) @@ -243,21 +46,16 @@ func TestHandlerGetNonce(t *testing.T) { } } -func TestHandlerGetDirectory(t *testing.T) { - auth, err := acme.New(nil, acme.AuthorityOptions{ - DB: new(db.MockNoSQLDB), - DNS: "ca.smallstep.com", - Prefix: "acme", - }) - assert.FatalError(t, err) +func TestHandler_GetDirectory(t *testing.T) { + linker := NewLinker("acme", "ca.smallstep.com") prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - expDir := acme.Directory{ + expDir := Directory{ NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), @@ -267,7 +65,7 @@ func TestHandlerGetDirectory(t *testing.T) { type test struct { statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { @@ -279,7 +77,7 @@ func TestHandlerGetDirectory(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(auth).(*Handler) + h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(ctx) w := httptest.NewRecorder() @@ -292,18 +90,17 @@ func TestHandlerGetDirectory(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { - var dir acme.Directory + var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) assert.Equals(t, dir, expDir) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) @@ -312,16 +109,16 @@ func TestHandlerGetDirectory(t *testing.T) { } } -func TestHandlerGetAuthz(t *testing.T) { +func TestHandler_GetAuthz(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) - az := acme.Authz{ + az := acme.Authorization{ ID: "authzID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", }, Status: "pending", - Expires: expiry.Format(time.RFC3339), + Expires: expiry, Wildcard: false, Challenges: []*acme.Challenge{ { @@ -353,67 +150,64 @@ func TestHandlerGetAuthz(t *testing.T) { baseURL.String(), provName, az.ID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/getAuthz-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { assert.Equals(t, id, az.ID) return &az, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.True(t, abs) - assert.Equals(t, in, []string{az.ID}) - return url - }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + assert.Equals(t, typ, acme.AuthzLink) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.True(t, abs) + assert.Equals(t, in, []string{az.ID}) + return url + }, + */ }, ctx: ctx, statusCode: 200, @@ -423,11 +217,11 @@ func TestHandlerGetAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() - h.GetAuthz(w, req) + h.GetAuthorization(w, req) res := w.Result() assert.Equals(t, res.StatusCode, tc.statusCode) @@ -436,15 +230,14 @@ func TestHandlerGetAuthz(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { //var gotAz acme.Authz @@ -459,7 +252,7 @@ func TestHandlerGetAuthz(t *testing.T) { } } -func TestHandlerGetCertificate(t *testing.T) { +func TestHandler_GetCertificate(t *testing.T) { leaf, err := pemutil.ReadCertificate("../../authority/testdata/certs/foo.crt") assert.FatalError(t, err) inter, err := pemutil.ReadCertificate("../../authority/testdata/certs/intermediate_ca.crt") @@ -490,89 +283,83 @@ func TestHandlerGetCertificate(t *testing.T) { baseURL.String(), provName, certID) type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + db: &acme.MockDB{}, + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.AccContextKey, nil) + ctx := context.WithValue(context.Background(), accContextKey, nil) return test{ - auth: &mockAcmeAuthority{}, + db: &acme.MockDB{}, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/getCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.ServerInternalErr(errors.New("force")), + db: &acme.MockDB{ + MockError: acme.NewErrorISE("force"), }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/decode-leaf-for-loggger": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return []byte("foo"), nil + return &acme.Certificate{}, nil }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")), + err: acme.NewErrorISE("failed to decode any certificates from generated certBytes"), } }, "fail/parse-x509-leaf-for-logger": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: []byte("foo"), - }), nil + return &acme.Certificate{}, nil }, }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to parse generated leaf certificate")), + err: acme.NewErrorISE("failed to parse generated leaf certificate"), } }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - getCertificate: func(accID, id string) ([]byte, error) { - assert.Equals(t, accID, acc.ID) + db: &acme.MockDB{ + MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { assert.Equals(t, id, certID) - return certBytes, nil + return &acme.Certificate{}, nil }, }, ctx: ctx, @@ -583,7 +370,7 @@ func TestHandlerGetCertificate(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -596,15 +383,14 @@ func TestHandlerGetCertificate(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.HasPrefix(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.HasPrefix(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), bytes.TrimSpace(certBytes)) @@ -634,121 +420,115 @@ func TestHandlerGetChallenge(t *testing.T) { url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID") type test struct { - auth acme.Interface + db acme.DB ctx context.Context statusCode int ch acme.Challenge - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.Background(), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ - ctx: ctx, + ctx: context.WithValue(context.Background(), accContextKey, nil), statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(nil), + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), } }, "fail/get-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ - auth: &mockAcmeAuthority{ - err: acme.UnauthorizedErr(nil), + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(nil), + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), } }, "ok/validate-challenge": func(t *testing.T) test { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.AccContextKey, acc) - ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ch := ch() ch.Status = "valid" ch.Validated = time.Now().UTC().Format(time.RFC3339) count := 0 return test{ - auth: &mockAcmeAuthority{ - validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { - p, err := acme.ProvisionerFromContext(ctx) - assert.FatalError(t, err) - assert.Equals(t, p, prov) - assert.Equals(t, accID, acc.ID) - assert.Equals(t, id, ch.ID) - assert.Equals(t, jwk.KeyID, key.KeyID) + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, ch.ID) return &ch, nil }, - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret - }, + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + var ret string + switch count { + case 0: + assert.Equals(t, typ, acme.AuthzLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.AuthzID}) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) + case 1: + assert.Equals(t, typ, acme.ChallengeLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.ID}) + ret = url + } + count++ + return ret + }, + */ }, ctx: ctx, statusCode: 200, @@ -759,7 +539,7 @@ func TestHandlerGetChallenge(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -772,15 +552,14 @@ func TestHandlerGetChallenge(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { expB, err := json.Marshal(tc.ch) diff --git a/acme/api/linker_test.go b/acme/api/linker_test.go new file mode 100644 index 00000000..ab1ad3ba --- /dev/null +++ b/acme/api/linker_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "context" + "fmt" + "net/url" + "testing" + + "github.com/smallstep/assert" +) + +func TestLinkerGetLink(t *testing.T) { + dns := "ca.smallstep.com" + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + // No provisioner + ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true), + fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) + assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce") + + // No baseURL + ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true), + fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) + assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) + + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id), + fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) + assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) +} + +func TestLinkerGetLinkExplicit(t *testing.T) { + dns := "ca.smallstep.com" + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prefix := "acme" + linker := NewLinker(dns, prefix) + id := "1234" + + prov := newProv() + provID := url.PathEscape(prov.GetName()) + + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) + + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) + + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) + + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) + + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) + + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) + + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) + + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) + + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) + + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) + assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) + + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) + assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) +} diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index d2a9cdc0..750b019d 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -82,13 +82,13 @@ func Test_baseURLFromRequest(t *testing.T) { } func TestHandlerBaseURLFromRequest(t *testing.T) { - h := New(&mockAcmeAuthority{}).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req.Host = "test.ca.smallstep.com:8080" w := httptest.NewRecorder() next := func(w http.ResponseWriter, r *http.Request) { - bu := acme.BaseURLFromContext(r.Context()) + bu := baseURLFromContext(r.Context()) if assert.NotNil(t, bu) { assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") assert.Equals(t, bu.Scheme, "https") @@ -101,35 +101,35 @@ func TestHandlerBaseURLFromRequest(t *testing.T) { req.Host = "" next = func(w http.ResponseWriter, r *http.Request) { - assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil) + assert.Equals(t, baseURLFromContext(r.Context()), nil) } h.baseURLFromRequest(next)(w, req) } -func TestHandlerAddNonce(t *testing.T) { +func TestHandler_AddNonce(t *testing.T) { url := "https://ca.smallstep.com/acme/new-nonce" type test struct { - auth acme.Interface - problem *acme.Error + db acme.DB + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/AddNonce-error": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { - return "", acme.ServerInternalErr(errors.New("force")) + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { + return acme.Nonce(""), acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "ok": func(t *testing.T) test { return test{ - auth: &mockAcmeAuthority{ - newNonce: func() (string, error) { + db: &acme.MockDB{ + MockCreateNonce: func(ctx context.Context) (acme.Nonce, error) { return "bar", nil }, }, @@ -140,7 +140,7 @@ func TestHandlerAddNonce(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{db: tc.db} req := httptest.NewRequest("GET", url, nil) w := httptest.NewRecorder() h.addNonce(testNext)(w, req) @@ -152,15 +152,14 @@ func TestHandlerAddNonce(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Replay-Nonce"], []string{"bar"}) @@ -176,22 +175,24 @@ func TestHandlerAddDirLink(t *testing.T) { provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { - auth acme.Interface + db acme.DB link string statusCode int ctx context.Context - problem *acme.Error + err *acme.Error } var tests = map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ - getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) - }, + db: &acme.MockDB{ + /* + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, baseURLFromContext(ctx), baseURL) + return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) + }, + */ }, ctx: ctx, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), @@ -202,7 +203,7 @@ func TestHandlerAddDirLink(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(tc.auth).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -215,15 +216,14 @@ func TestHandlerAddDirLink(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s>;rel=\"index\"", tc.link)}) @@ -242,7 +242,7 @@ func TestHandlerVerifyContentType(t *testing.T) { h Handler ctx context.Context contentType string - problem *acme.Error + err *acme.Error statusCode int url string } @@ -250,7 +250,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -260,16 +260,16 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), } }, "fail/certificate-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -278,16 +278,16 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "foo", statusCode: 400, - problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), + err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), } }, "ok": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -296,7 +296,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -304,7 +304,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkix-cert": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -313,7 +313,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkix-cert", statusCode: 200, } @@ -321,7 +321,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/jose+json": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -330,7 +330,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -338,7 +338,7 @@ func TestHandlerVerifyContentType(t *testing.T) { "ok/certificate/pkcs7-mime": func(t *testing.T) test { return test{ h: Handler{ - Auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, abs, false) @@ -347,7 +347,7 @@ func TestHandlerVerifyContentType(t *testing.T) { }, }, }, - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -373,15 +373,14 @@ func TestHandlerVerifyContentType(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -394,7 +393,7 @@ func TestHandlerIsPostAsGet(t *testing.T) { url := "https://ca.smallstep.com/acme/new-account" type test struct { ctx context.Context - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -402,26 +401,26 @@ func TestHandlerIsPostAsGet(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil), + ctx: context.WithValue(context.Background(), payloadContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("payload expected in request context")), + err: acme.NewErrorISE("payload expected in request context"), } }, "fail/not-post-as-get": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("expected POST-as-GET")), + err: acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"), } }, "ok": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}), + ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), statusCode: 200, } }, @@ -429,7 +428,7 @@ func TestHandlerIsPostAsGet(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -442,15 +441,14 @@ func TestHandlerIsPostAsGet(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -473,7 +471,7 @@ func TestHandlerParseJWS(t *testing.T) { type test struct { next nextHTTP body io.Reader - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -481,14 +479,14 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: errReader(0), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("failed to read request body: force")), + err: acme.NewErrorISE("failed to read request body: force"), } }, "fail/parse-jws-error": func(t *testing.T) test { return test{ body: strings.NewReader("foo"), statusCode: 400, - problem: acme.MalformedErr(errors.New("failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts")), + err: acme.NewError(acme.ErrorMalformedType, "failed to parse JWS from request body: square/go-jose: compact JWS format must have three parts"), } }, "ok": func(t *testing.T) test { @@ -507,7 +505,7 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: strings.NewReader(expRaw), next: func(w http.ResponseWriter, r *http.Request) { - jws, err := acme.JwsFromContext(r.Context()) + jws, err := jwsFromContext(r.Context()) assert.FatalError(t, err) gotRaw, err := jws.CompactSerialize() assert.FatalError(t, err) @@ -521,7 +519,7 @@ func TestHandlerParseJWS(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, tc.body) w := httptest.NewRecorder() h.parseJWS(tc.next)(w, req) @@ -533,15 +531,14 @@ func TestHandlerParseJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -572,7 +569,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { type test struct { ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -580,58 +577,58 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-jwk": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) return test{ - ctx: context.WithValue(ctx, acme.JwkContextKey, nil), + ctx: context.WithValue(ctx, jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), + err: acme.NewErrorISE("jwk expected in request context"), } }, "fail/verify-jws-failure": func(t *testing.T) test { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("error verifying jws: square/go-jose: error in cryptographic primitive")), + err: acme.NewError(acme.ErrorMalformedType, "error verifying jws: square/go-jose: error in cryptographic primitive"), } }, "fail/algorithm-mismatch": func(t *testing.T) test { _pub := *pub clone := &_pub clone.Algorithm = jose.HS256 - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, clone) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, clone) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("verifier and signature algorithm do not match")), + err: acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"), } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -651,8 +648,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _pub := *pub clone := &_pub clone.Algorithm = "" - ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -675,8 +672,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -699,8 +696,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.JwkContextKey, pub) + ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) + ctx = context.WithValue(ctx, jwsContextKey, pub) return test{ ctx: ctx, statusCode: 200, @@ -720,7 +717,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := New(nil).(*Handler) + h := &Handler{} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -733,15 +730,14 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { - var ae acme.AError + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -775,27 +771,27 @@ func TestHandlerLookupJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-kid": func(t *testing.T) test { @@ -806,11 +802,11 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) @@ -820,7 +816,7 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got ", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), } }, "fail/bad-kid-prefix": func(t *testing.T) test { @@ -837,11 +833,11 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _parsed) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) @@ -851,15 +847,15 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("kid does not have required prefix; expected %s, but got foo", prefix)), + err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -876,21 +872,21 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 400, - problem: acme.AccountDoesNotExistErr(nil), + err: acme.AccountDoesNotExistErr(nil), } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) - return nil, acme.ServerInternalErr(errors.New("force")) + return nil, acme.NewErrorISE("force") }, getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) @@ -901,16 +897,16 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -927,16 +923,16 @@ func TestHandlerLookupJWK(t *testing.T) { }, ctx: ctx, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.UnauthorizedErr(errors.New("account is not active")), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -981,15 +977,14 @@ func TestHandlerLookupJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -1024,27 +1019,27 @@ func TestHandlerExtractJWK(t *testing.T) { url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName) type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, nil) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jwk": func(t *testing.T) test { @@ -1057,8 +1052,8 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, @@ -1075,39 +1070,39 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, - problem: acme.MalformedErr(errors.New("invalid jwk in protected header")), + err: acme.MalformedErr(errors.New("invalid jwk in protected header")), } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) - return nil, acme.ServerInternalErr(errors.New("force")) + return nil, acme.NewErrorISE("force") }, }, statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1117,16 +1112,16 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, statusCode: 401, - problem: acme.UnauthorizedErr(errors.New("account is not active")), + err: acme.UnauthorizedErr(errors.New("account is not active")), } }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1148,11 +1143,11 @@ func TestHandlerExtractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) return test{ ctx: ctx, - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { p, err := acme.ProvisionerFromContext(ctx) assert.FatalError(t, err) @@ -1190,15 +1185,14 @@ func TestHandlerExtractJWK(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) @@ -1210,10 +1204,10 @@ func TestHandlerExtractJWK(t *testing.T) { func TestHandlerValidateJWS(t *testing.T) { url := "https://ca.smallstep.com/acme/account/1234" type test struct { - auth acme.Interface + db acme.DB ctx context.Context next func(http.ResponseWriter, *http.Request) - problem *acme.Error + err *acme.Error statusCode int } var tests = map[string]func(t *testing.T) test{ @@ -1221,21 +1215,21 @@ func TestHandlerValidateJWS(t *testing.T) { return test{ ctx: context.Background(), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), + ctx: context.WithValue(context.Background(), jwsContextKey, nil), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("jws expected in request context")), + err: acme.NewErrorISE("jws expected in request context"), } }, "fail/no-signature": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}), + ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body does not contain a signature")), + err: acme.MalformedErr(errors.New("request body does not contain a signature")), } }, "fail/more-than-one-signature": func(t *testing.T) test { @@ -1246,9 +1240,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("request body contains more than one signature")), + err: acme.MalformedErr(errors.New("request body contains more than one signature")), } }, "fail/unprotected-header-not-empty": func(t *testing.T) test { @@ -1258,9 +1252,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unprotected header must not be used")), + err: acme.MalformedErr(errors.New("unprotected header must not be used")), } }, "fail/unsuitable-algorithm-none": func(t *testing.T) test { @@ -1270,9 +1264,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), + err: acme.MalformedErr(errors.New("unsuitable algorithm: none")), } }, "fail/unsuitable-algorithm-mac": func(t *testing.T) test { @@ -1282,9 +1276,9 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), + err: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), } }, "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { @@ -1305,14 +1299,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), + err: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), } }, "fail/rsa-key-too-small": func(t *testing.T) test { @@ -1333,14 +1327,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), + err: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), } }, "fail/UseNonce-error": func(t *testing.T) test { @@ -1350,14 +1344,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { - return acme.ServerInternalErr(errors.New("force")) + return acme.NewErrorISE("force") }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 500, - problem: acme.ServerInternalErr(errors.New("force")), + err: acme.NewErrorISE("force"), } }, "fail/no-url-header": func(t *testing.T) test { @@ -1367,14 +1361,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.New("jws missing url protected header")), + err: acme.MalformedErr(errors.New("jws missing url protected header")), } }, "fail/url-mismatch": func(t *testing.T) test { @@ -1391,14 +1385,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), + err: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), } }, "fail/both-jwk-kid": func(t *testing.T) test { @@ -1420,14 +1414,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), + err: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), } }, "fail/no-jwk-kid": func(t *testing.T) test { @@ -1444,14 +1438,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), statusCode: 400, - problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), + err: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), } }, "ok/kid": func(t *testing.T) test { @@ -1469,12 +1463,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1499,12 +1493,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1529,12 +1523,12 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - auth: &mockAcmeAuthority{ + db: &acme.MockDB{ useNonce: func(n string) error { return nil }, }, - ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), + ctx: context.WithValue(context.Background(), jwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1558,15 +1552,14 @@ func TestHandlerValidateJWS(t *testing.T) { res.Body.Close() assert.FatalError(t, err) - if res.StatusCode >= 400 && assert.NotNil(t, tc.problem) { + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { var ae acme.AError assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) - prob := tc.problem.ToACME() - assert.Equals(t, ae.Type, prob.Type) - assert.Equals(t, ae.Detail, prob.Detail) - assert.Equals(t, ae.Identifier, prob.Identifier) - assert.Equals(t, ae.Subproblems, prob.Subproblems) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) } else { assert.Equals(t, bytes.TrimSpace(body), testBody) diff --git a/acme/db.go b/acme/db.go index a19621c0..dcc7846f 100644 --- a/acme/db.go +++ b/acme/db.go @@ -1,6 +1,8 @@ package acme -import "context" +import ( + "context" +) // DB is the DB interface expected by the step-ca ACME API. type DB interface { @@ -28,3 +30,214 @@ type DB interface { GetOrdersByAccountID(ctx context.Context, accountID string) ([]string, error) UpdateOrder(ctx context.Context, o *Order) error } + +// MockDB is an implementation of the DB interface that should only be used as +// a mock in tests. +type MockDB struct { + MockCreateAccount func(ctx context.Context, acc *Account) error + MockGetAccount func(ctx context.Context, id string) (*Account, error) + MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error) + MockUpdateAccount func(ctx context.Context, acc *Account) error + + MockCreateNonce func(ctx context.Context) (Nonce, error) + MockDeleteNonce func(ctx context.Context, nonce Nonce) error + + MockCreateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) + MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + + MockCreateCertificate func(ctx context.Context, cert *Certificate) error + MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) + + MockCreateChallenge func(ctx context.Context, ch *Challenge) error + MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) + MockUpdateChallenge func(ctx context.Context, ch *Challenge) error + + MockCreateOrder func(ctx context.Context, o *Order) error + MockGetOrder func(ctx context.Context, id string) (*Order, error) + MockGetOrdersByAccountID func(ctx context.Context, accountID string) ([]string, error) + MockUpdateOrder func(ctx context.Context, o *Order) error + + MockRet1 interface{} + MockError error +} + +// CreateAccount mock. +func (m *MockDB) CreateAccount(ctx context.Context, acc *Account) error { + if m.MockCreateAccount != nil { + return m.MockCreateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAccount mock. +func (m *MockDB) GetAccount(ctx context.Context, id string) (*Account, error) { + if m.MockGetAccount != nil { + return m.MockGetAccount(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// GetAccountByKeyID mock +func (m *MockDB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { + if m.MockGetAccountByKeyID != nil { + return m.MockGetAccountByKeyID(ctx, kid) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Account), m.MockError +} + +// UpdateAccount mock +func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error { + if m.MockUpdateAccount != nil { + return m.MockUpdateAccount(ctx, acc) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateNonce mock +func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) { + if m.MockCreateNonce != nil { + return m.MockCreateNonce(ctx) + } else if m.MockError != nil { + return Nonce(""), m.MockError + } + return m.MockRet1.(Nonce), m.MockError +} + +// DeleteNonce mock +func (m *MockDB) DeleteNonce(ctx context.Context, nonce Nonce) error { + if m.MockDeleteNonce != nil { + return m.MockDeleteNonce(ctx, nonce) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateAuthorization mock +func (m *MockDB) CreateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockCreateAuthorization != nil { + return m.MockCreateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetAuthorization mock +func (m *MockDB) GetAuthorization(ctx context.Context, id string) (*Authorization, error) { + if m.MockGetAuthorization != nil { + return m.MockGetAuthorization(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Authorization), m.MockError +} + +// UpdateAuthorization mock +func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) error { + if m.MockUpdateAuthorization != nil { + return m.MockUpdateAuthorization(ctx, az) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateCertificate mock +func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { + if m.MockCreateCertificate != nil { + return m.MockCreateCertificate(ctx, cert) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetCertificate mock +func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { + if m.MockGetCertificate != nil { + return m.MockGetCertificate(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Certificate), m.MockError +} + +// CreateChallenge mock +func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockCreateChallenge != nil { + return m.MockCreateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetChallenge mock +func (m *MockDB) GetChallenge(ctx context.Context, chID, azID string) (*Challenge, error) { + if m.MockGetChallenge != nil { + return m.MockGetChallenge(ctx, chID, azID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Challenge), m.MockError +} + +// UpdateChallenge mock +func (m *MockDB) UpdateChallenge(ctx context.Context, ch *Challenge) error { + if m.MockUpdateChallenge != nil { + return m.MockUpdateChallenge(ctx, ch) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// CreateOrder mock +func (m *MockDB) CreateOrder(ctx context.Context, o *Order) error { + if m.MockCreateOrder != nil { + return m.MockCreateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrder mock +func (m *MockDB) GetOrder(ctx context.Context, id string) (*Order, error) { + if m.MockGetOrder != nil { + return m.MockGetOrder(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Order), m.MockError +} + +// UpdateOrder mock +func (m *MockDB) UpdateOrder(ctx context.Context, o *Order) error { + if m.MockUpdateOrder != nil { + return m.MockUpdateOrder(ctx, o) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// GetOrdersByAccountID mock +func (m *MockDB) GetOrdersByAccountID(ctx context.Context, accID string) ([]string, error) { + if m.MockGetOrdersByAccountID != nil { + return m.MockGetOrdersByAccountID(ctx, accID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.([]string), m.MockError +} diff --git a/acme/directory_test.go b/acme/directory_test.go deleted file mode 100644 index dd4c534c..00000000 --- a/acme/directory_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package acme - -import ( - "context" - "fmt" - "net/url" - "testing" - - "github.com/smallstep/assert" -) - -func TestDirectoryGetLink(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - - assert.Equals(t, dir.getLink(ctx, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - // No provisioner - ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true), - fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) - assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce") - - // No baseURL - ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true), - fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) - assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - - assert.Equals(t, dir.getLink(ctx, OrderLink, true, id), - fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) - assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName)) -} - -func TestDirectoryGetLinkExplicit(t *testing.T) { - dns := "ca.smallstep.com" - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prefix := "acme" - dir := newDirectory(dns, prefix) - id := "1234" - - prov := newProv() - provID := url.PathEscape(prov.GetName()) - - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) - - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) - - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) - - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) - - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) - - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) - - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) - - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID)) - - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) - assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) -}