forked from TrueCloudLab/certificates
change errnotfound type for getAccount
- more generalized NotFound type rather than the nosql one we were using - if the error is not recognized then the logic in create account will break.
This commit is contained in:
parent
1831920363
commit
80c8567d99
5 changed files with 24 additions and 55 deletions
|
@ -3,6 +3,7 @@ package api
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"errors"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
@ -243,15 +244,21 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
|
||||||
kid, err := acme.KeyToID(jwk)
|
// Overwrite KeyID with the JWK thumbprint.
|
||||||
|
jwk.KeyID, err = acme.KeyToID(jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
acc, err := h.db.GetAccountByKeyID(ctx, kid)
|
|
||||||
|
// Store the JWK in the context.
|
||||||
|
ctx = context.WithValue(ctx, jwkContextKey, jwk)
|
||||||
|
|
||||||
|
// Get Account or continue to generate a new one.
|
||||||
|
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case errors.Is(err, acme.ErrNotFound):
|
||||||
// For NewAccount requests ...
|
// For NewAccount requests ...
|
||||||
break
|
break
|
||||||
case err != nil:
|
case err != nil:
|
||||||
|
|
|
@ -1047,7 +1047,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
assert.Equals(t, kid, pub.KeyID)
|
assert.Equals(t, kid, pub.KeyID)
|
||||||
return nil, database.ErrNotFound
|
return nil, acme.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
next: func(w http.ResponseWriter, r *http.Request) {
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -2,8 +2,16 @@ package acme
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrNotFound is an error that should be used by the acme.DB interface to
|
||||||
|
// indicate that an entity does not exist. For example, in the new-account
|
||||||
|
// endpoint, if GetAccountByKeyID returns ErrNotFound we will create the new
|
||||||
|
// account.
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
// DB is the DB interface expected by the step-ca ACME API.
|
// DB is the DB interface expected by the step-ca ACME API.
|
||||||
type DB interface {
|
type DB interface {
|
||||||
CreateAccount(ctx context.Context, acc *Account) error
|
CreateAccount(ctx context.Context, acc *Account) error
|
||||||
|
|
|
@ -30,7 +30,7 @@ func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, erro
|
||||||
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if nosqlDB.IsErrNotFound(err) {
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
return "", acme.NewError(acme.ErrorMalformedType, "account with key-id %s not found", kid)
|
return "", acme.ErrNotFound
|
||||||
}
|
}
|
||||||
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
return "", errors.Wrapf(err, "error loading key-account index for key %s", kid)
|
||||||
}
|
}
|
||||||
|
@ -42,7 +42,7 @@ func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
|
||||||
data, err := db.db.Get(accountTable, []byte(id))
|
data, err := db.db.Get(accountTable, []byte(id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if nosqlDB.IsErrNotFound(err) {
|
if nosqlDB.IsErrNotFound(err) {
|
||||||
return nil, acme.NewError(acme.ErrorMalformedType, "account %s not found", id)
|
return nil, acme.ErrNotFound
|
||||||
}
|
}
|
||||||
return nil, errors.Wrapf(err, "error loading account %s", id)
|
return nil, errors.Wrapf(err, "error loading account %s", id)
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, nosqldb.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
err: acme.ErrNotFound,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/db.Get-error": func(t *testing.T) test {
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
@ -142,7 +142,7 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, nosqldb.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"),
|
err: acme.ErrNotFound,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/db.Get-error": func(t *testing.T) test {
|
"fail/db.Get-error": func(t *testing.T) test {
|
||||||
|
@ -221,19 +221,6 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
err: errors.New("error loading account accID: force"),
|
err: errors.New("error loading account accID: force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/forward-acme-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, bucket, accountTable)
|
|
||||||
assert.Equals(t, string(key), accID)
|
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
now := clock.Now()
|
now := clock.Now()
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
@ -314,19 +301,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
err: errors.New("error loading key-account index for key kid: force"),
|
err: errors.New("error loading key-account index for key kid: force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/db.getAccountIDByKeyID-forward-acme-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
assert.Equals(t, string(bucket), string(accountByKeyIDTable))
|
|
||||||
assert.Equals(t, string(key), kid)
|
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
|
||||||
},
|
|
||||||
},
|
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "account with key-id kid not found"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/db.GetAccount-error": func(t *testing.T) test {
|
"fail/db.GetAccount-error": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
|
@ -347,26 +321,6 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
err: errors.New("error loading account accID: force"),
|
err: errors.New("error loading account accID: force"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/db.GetAccount-forward-acme-error": func(t *testing.T) test {
|
|
||||||
return test{
|
|
||||||
db: &db.MockNoSQLDB{
|
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
||||||
switch string(bucket) {
|
|
||||||
case string(accountByKeyIDTable):
|
|
||||||
assert.Equals(t, string(key), kid)
|
|
||||||
return []byte(accID), nil
|
|
||||||
case string(accountTable):
|
|
||||||
assert.Equals(t, string(key), accID)
|
|
||||||
return nil, nosqldb.ErrNotFound
|
|
||||||
default:
|
|
||||||
assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket)))
|
|
||||||
return nil, errors.New("force")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "account accID not found"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
now := clock.Now()
|
now := clock.Now()
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
|
Loading…
Reference in a new issue