From 291fd5d45a57ef90eddc4923b15aee9293635beb Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 10 Mar 2021 23:59:02 -0800 Subject: [PATCH] [acme db interface] more unit tests --- acme/api/handler.go | 1 + acme/api/handler_test.go | 217 ++++++++++++++++++++++----------------- acme/api/middleware.go | 2 +- 3 files changed, 127 insertions(+), 93 deletions(-) diff --git a/acme/api/handler.go b/acme/api/handler.go index 3fe72d54..5960d49c 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -171,6 +171,7 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { } if err = az.UpdateStatus(ctx, h.db); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status")) + return } h.linker.LinkAuthorization(ctx, az) diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 8621ca18..34c720f1 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -16,7 +16,6 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -110,10 +109,11 @@ func TestHandler_GetDirectory(t *testing.T) { } } -func TestHandler_GetAuthz(t *testing.T) { +func TestHandler_GetAuthorization(t *testing.T) { expiry := time.Now().UTC().Add(6 * time.Hour) az := acme.Authorization{ - ID: "authzID", + ID: "authzID", + AccountID: "accID", Identifier: acme.Identifier{ Type: "dns", Value: "example.com", @@ -147,7 +147,7 @@ func TestHandler_GetAuthz(t *testing.T) { // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) - url := fmt.Sprintf("%s/acme/%s/challenge/%s", + url := fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, az.ID) type test struct { @@ -175,7 +175,7 @@ func TestHandler_GetAuthz(t *testing.T) { err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), } }, - "fail/getAuthz-error": func(t *testing.T) test { + "fail/db.GetAuthorization-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), accContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) @@ -188,6 +188,48 @@ func TestHandler_GetAuthz(t *testing.T) { err: acme.NewErrorISE("force"), } }, + "fail/account-id-mismatch": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "foo", + }, nil + }, + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"), + } + }, + "fail/db.UpdateAuthorization-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { + assert.Equals(t, id, az.ID) + return &acme.Authorization{ + AccountID: "accID", + Status: acme.StatusPending, + Expires: time.Now().Add(-1 * time.Hour), + }, nil + }, + MockUpdateAuthorization: func(ctx context.Context, az *acme.Authorization) error { + assert.Equals(t, az.Status, acme.StatusInvalid) + return acme.NewErrorISE("force") + }, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("force"), + } + }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) @@ -200,15 +242,6 @@ func TestHandler_GetAuthz(t *testing.T) { assert.Equals(t, id, az.ID) return &az, nil }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) - assert.True(t, abs) - assert.Equals(t, in, []string{az.ID}) - return url - }, - */ }, ctx: ctx, statusCode: 200, @@ -218,7 +251,7 @@ func TestHandler_GetAuthz(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() @@ -402,7 +435,7 @@ func ch() acme.Challenge { } } -func TestHandlerGetChallenge(t *testing.T) { +func TestHandler_GetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") prov := newProv() @@ -437,8 +470,8 @@ func TestHandlerGetChallenge(t *testing.T) { ctx := context.WithValue(context.Background(), accContextKey, acc) return test{ ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + statusCode: 400, + err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), } }, "fail/nil-payload": func(t *testing.T) test { @@ -448,88 +481,88 @@ func TestHandlerGetChallenge(t *testing.T) { ctx = context.WithValue(ctx, payloadContextKey, nil) return test{ ctx: ctx, - statusCode: 500, - err: acme.NewErrorISE("payload expected in request context"), + statusCode: 400, + err: acme.NewError(acme.ErrorMalformedType, "payload expected in request context"), } }, - "fail/validate-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - } - }, - "fail/get-challenge-error": func(t *testing.T) test { - acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - return test{ - db: &acme.MockDB{ - MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - }, - ctx: ctx, - statusCode: 401, - err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), - } - }, - "ok/validate-challenge": func(t *testing.T) test { - key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - ch := ch() - ch.Status = "valid" - ch.Validated = time.Now().UTC().Format(time.RFC3339) - return test{ - db: &acme.MockDB{ - MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { - assert.Equals(t, chID, ch.ID) - return &ch, nil + /* + "fail/validate-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), }, - /* - getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { - var ret string - switch count { - case 0: - assert.Equals(t, typ, acme.AuthzLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) - case 1: - assert.Equals(t, typ, acme.ChallengeLink) - assert.True(t, abs) - assert.Equals(t, in, []string{ch.ID}) - ret = url - } - count++ - return ret + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + } + }, + "fail/get-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + }, + ctx: ctx, + statusCode: 401, + err: acme.NewError(acme.ErrorUnauthorizedType, "unauthorized"), + } + }, + "ok/validate-challenge": func(t *testing.T) test { + key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + acc := &acme.Account{ID: "accID", Key: key} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ch := ch() + ch.Status = "valid" + ch.Validated = time.Now().UTC().Format(time.RFC3339) + return test{ + db: &acme.MockDB{ + MockGetChallenge: func(ctx context.Context, chID, azID string) (*acme.Challenge, error) { + assert.Equals(t, chID, ch.ID) + return &ch, nil }, - */ - }, - ctx: ctx, - statusCode: 200, - ch: ch, - } - }, + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { + var ret string + switch count { + case 0: + assert.Equals(t, typ, acme.AuthzLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.AuthzID}) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) + case 1: + assert.Equals(t, typ, acme.ChallengeLink) + assert.True(t, abs) + assert.Equals(t, in, []string{ch.ID}) + ret = url + } + count++ + return ret + }, + }, + ctx: ctx, + statusCode: 200, + ch: ch, + } + }, + */ } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - h := &Handler{db: tc.db} + h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} req := httptest.NewRequest("GET", url, nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() diff --git a/acme/api/middleware.go b/acme/api/middleware.go index f2a35c3a..a021c936 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -462,7 +462,7 @@ func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { func payloadFromContext(ctx context.Context) (*payloadInfo, error) { val, ok := ctx.Value(payloadContextKey).(*payloadInfo) if !ok || val == nil { - return nil, acme.NewErrorISE("payload expected in request context") + return nil, acme.NewError(acme.ErrorMalformedType, "payload expected in request context") } return val, nil }