From 80c8567d9977e2dadb8c035bce0af29ea7aee1de Mon Sep 17 00:00:00 2001 From: max furman Date: Thu, 25 Mar 2021 14:54:12 -0700 Subject: [PATCH] change errnotfound type for getAccount - more generalized NotFound type rather than the nosql one we were using - if the error is not recognized then the logic in create account will break. --- acme/api/middleware.go | 15 ++++++++--- acme/api/middleware_test.go | 2 +- acme/db.go | 8 ++++++ acme/db/nosql/account.go | 4 +-- acme/db/nosql/account_test.go | 50 ++--------------------------------- 5 files changed, 24 insertions(+), 55 deletions(-) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index f2a35c3a..e06e4736 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -3,6 +3,7 @@ package api import ( "context" "crypto/rsa" + "errors" "io/ioutil" "net/http" "net/url" @@ -243,15 +244,21 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header")) return } - ctx = context.WithValue(ctx, jwkContextKey, jwk) - kid, err := acme.KeyToID(jwk) + + // Overwrite KeyID with the JWK thumbprint. + jwk.KeyID, err = acme.KeyToID(jwk) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK")) return } - acc, err := h.db.GetAccountByKeyID(ctx, kid) + + // Store the JWK in the context. + ctx = context.WithValue(ctx, jwkContextKey, jwk) + + // Get Account or continue to generate a new one. + acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) switch { - case nosql.IsErrNotFound(err): + case errors.Is(err, acme.ErrNotFound): // For NewAccount requests ... break case err != nil: diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 4f2c4bcb..1c0f3689 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1047,7 +1047,7 @@ func TestHandler_extractJWK(t *testing.T) { db: &acme.MockDB{ MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { assert.Equals(t, kid, pub.KeyID) - return nil, database.ErrNotFound + return nil, acme.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { diff --git a/acme/db.go b/acme/db.go index dcc7846f..d678fef4 100644 --- a/acme/db.go +++ b/acme/db.go @@ -2,8 +2,16 @@ package acme import ( "context" + + "github.com/pkg/errors" ) +// ErrNotFound is an error that should be used by the acme.DB interface to +// indicate that an entity does not exist. For example, in the new-account +// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new +// account. +var ErrNotFound = errors.New("not found") + // DB is the DB interface expected by the step-ca ACME API. type DB interface { CreateAccount(ctx context.Context, acc *Account) error diff --git a/acme/db/nosql/account.go b/acme/db/nosql/account.go index 3115e8ab..d7ac9655 100644 --- a/acme/db/nosql/account.go +++ b/acme/db/nosql/account.go @@ -30,7 +30,7 @@ func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, erro id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid) + return "", acme.ErrNotFound } return "", errors.Wrapf(err, "error loading key-account index for key %s", kid) } @@ -42,7 +42,7 @@ func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) { data, err := db.db.Get(accountTable, []byte(id)) if err != nil { if nosqlDB.IsErrNotFound(err) { - return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id) + return nil, acme.ErrNotFound } return nil, errors.Wrapf(err, "error loading account %s", id) } diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go index 9f889e64..5ba99a73 100644 --- a/acme/db/nosql/account_test.go +++ b/acme/db/nosql/account_test.go @@ -34,7 +34,7 @@ func TestDB_getDBAccount(t *testing.T) { return nil, nosqldb.ErrNotFound }, }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), + err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { @@ -142,7 +142,7 @@ func TestDB_getAccountIDByKeyID(t *testing.T) { return nil, nosqldb.ErrNotFound }, }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), + err: acme.ErrNotFound, } }, "fail/db.Get-error": func(t *testing.T) test { @@ -221,19 +221,6 @@ func TestDB_GetAccount(t *testing.T) { err: errors.New("error loading account accID: force"), } }, - "fail/forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, accountTable) - assert.Equals(t, string(key), accID) - - return nil, nosqldb.ErrNotFound - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), - } - }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) @@ -314,19 +301,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) { err: errors.New("error loading key-account index for key kid: force"), } }, - "fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, string(bucket), string(accountByKeyIDTable)) - assert.Equals(t, string(key), kid) - - return nil, nosqldb.ErrNotFound - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"), - } - }, "fail/db.GetAccount-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ @@ -347,26 +321,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) { err: errors.New("error loading account accID: force"), } }, - "fail/db.GetAccount-forward-acme-error": func(t *testing.T) test { - return test{ - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - switch string(bucket) { - case string(accountByKeyIDTable): - assert.Equals(t, string(key), kid) - return []byte(accID), nil - case string(accountTable): - assert.Equals(t, string(key), accID) - return nil, nosqldb.ErrNotFound - default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) - return nil, errors.New("force") - } - }, - }, - acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"), - } - }, "ok": func(t *testing.T) test { now := clock.Now() jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)