diff --git a/acme/account.go b/acme/account.go index 197a3400..027d7be1 100644 --- a/acme/account.go +++ b/acme/account.go @@ -4,6 +4,7 @@ import ( "crypto" "encoding/base64" "encoding/json" + "time" "go.step.sm/crypto/jose" ) @@ -11,11 +12,12 @@ import ( // Account is a subset of the internal account type containing only those // attributes required for responses in the ACME protocol. type Account struct { - ID string `json:"-"` - Key *jose.JSONWebKey `json:"-"` - Contact []string `json:"contact,omitempty"` - Status Status `json:"status"` - OrdersURL string `json:"orders"` + ID string `json:"-"` + Key *jose.JSONWebKey `json:"-"` + Contact []string `json:"contact,omitempty"` + Status Status `json:"status"` + OrdersURL string `json:"orders"` + ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"` } // ToLog enables response logging. @@ -40,3 +42,32 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) { } return base64.RawURLEncoding.EncodeToString(kid), nil } + +// ExternalAccountKey is an ACME External Account Binding key. +type ExternalAccountKey struct { + ID string `json:"id"` + ProvisionerID string `json:"provisionerID"` + Reference string `json:"reference"` + AccountID string `json:"-"` + KeyBytes []byte `json:"-"` + CreatedAt time.Time `json:"createdAt"` + BoundAt time.Time `json:"boundAt,omitempty"` +} + +// AlreadyBound returns whether this EAK is already bound to +// an ACME Account or not. +func (eak *ExternalAccountKey) AlreadyBound() bool { + return !eak.BoundAt.IsZero() +} + +// BindTo binds the EAK to an Account. +// It returns an error if it's already bound. +func (eak *ExternalAccountKey) BindTo(account *Account) error { + if eak.AlreadyBound() { + return NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", eak.ID, eak.AccountID, eak.BoundAt) + } + eak.AccountID = account.ID + eak.BoundAt = time.Now() + eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once + return nil +} diff --git a/acme/account_test.go b/acme/account_test.go index 5625c3dc..33524d87 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -4,6 +4,7 @@ import ( "crypto" "encoding/base64" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -79,3 +80,67 @@ func TestAccount_IsValid(t *testing.T) { }) } } + +func TestExternalAccountKey_BindTo(t *testing.T) { + boundAt := time.Now() + tests := []struct { + name string + eak *ExternalAccountKey + acct *Account + err *Error + }{ + { + name: "ok", + eak: &ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "ref", + KeyBytes: []byte{1, 3, 3, 7}, + }, + acct: &Account{ + ID: "accountID", + }, + err: nil, + }, + { + name: "fail/already-bound", + eak: &ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "ref", + KeyBytes: []byte{1, 3, 3, 7}, + AccountID: "someAccountID", + BoundAt: boundAt, + }, + acct: &Account{ + ID: "accountID", + }, + err: NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "someAccountID", boundAt), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eak := tt.eak + acct := tt.acct + err := eak.BindTo(acct) + wantErr := tt.err != nil + gotErr := err != nil + if wantErr != gotErr { + t.Errorf("ExternalAccountKey.BindTo() error = %v, wantErr %v", err, tt.err) + } + if wantErr { + assert.NotNil(t, err) + assert.Type(t, &Error{}, err) + ae, _ := err.(*Error) + assert.Equals(t, ae.Type, tt.err.Type) + assert.Equals(t, ae.Detail, tt.err.Detail) + assert.Equals(t, ae.Identifier, tt.err.Identifier) + assert.Equals(t, ae.Subproblems, tt.err.Subproblems) + } else { + assert.Equals(t, eak.AccountID, acct.ID) + assert.Equals(t, eak.KeyBytes, []byte{}) + assert.NotNil(t, eak.BoundAt) + } + }) + } +} diff --git a/acme/api/account.go b/acme/api/account.go index 259cb2a2..0dc8ab40 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -12,9 +12,10 @@ import ( // NewAccountRequest represents the payload for a new account request. type NewAccountRequest struct { - Contact []string `json:"contact"` - OnlyReturnExisting bool `json:"onlyReturnExisting"` - TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"` + Contact []string `json:"contact"` + OnlyReturnExisting bool `json:"onlyReturnExisting"` + TermsOfServiceAgreed bool `json:"termsOfServiceAgreed"` + ExternalAccountBinding *ExternalAccountBinding `json:"externalAccountBinding,omitempty"` } func validateContacts(cs []string) error { @@ -83,8 +84,14 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } + prov, err := acmeProvisionerFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + httpStatus := http.StatusCreated - acc, err := accountFromContext(r.Context()) + acc, err := accountFromContext(ctx) if err != nil { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { @@ -99,12 +106,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { "account does not exist")) return } + jwk, err := jwkFromContext(ctx) if err != nil { api.WriteError(w, err) return } + eak, err := h.validateExternalAccountBinding(ctx, &nar) + if err != nil { + api.WriteError(w, err) + return + } + acc = &acme.Account{ Key: jwk, Contact: nar.Contact, @@ -114,8 +128,21 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { api.WriteError(w, acme.WrapErrorISE(err, "error creating account")) return } + + if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response + err := eak.BindTo(acc) + if err != nil { + api.WriteError(w, err) + return + } + if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) + return + } + acc.ExternalAccountBinding = nar.ExternalAccountBinding + } } else { - // Account exists // + // Account exists httpStatus = http.StatusOK } diff --git a/acme/api/account_test.go b/acme/api/account_test.go index abee97a2..4c3404ec 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/go-chi/chi" + "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner" @@ -40,6 +41,66 @@ func newProv() acme.Provisioner { return p } +func newACMEProv(t *testing.T) *provisioner.ACME { + p := newProv() + a, ok := p.(*provisioner.ACME) + if !ok { + t.Fatal("not a valid ACME provisioner") + } + return a +} + +func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) { + signer, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm("HS256"), + Key: hmacKey, + }, + &jose.SignerOptions{ + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "kid": keyID, + "url": u, + }, + EmbedJWK: false, + }, + ) + if err != nil { + return nil, err + } + + jwkJSONBytes, err := jwk.Public().MarshalJSON() + if err != nil { + return nil, err + } + + jws, err := signer.Sign(jwkJSONBytes) + if err != nil { + return nil, err + } + + raw, err := jws.CompactSerialize() + if err != nil { + return nil, err + } + + parsedJWS, err := jose.ParseJWS(raw) + if err != nil { + return nil, err + } + + return parsedJWS, nil +} + +func createRawEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) ([]byte, error) { + jws, err := createEABJWS(jwk, hmacKey, keyID, u) + if err != nil { + return nil, err + } + + rawJWS := jws.FullSerialize() + return []byte(rawJWS), nil +} + func TestNewAccountRequest_Validate(t *testing.T) { type test struct { nar *NewAccountRequest @@ -290,6 +351,7 @@ func TestHandler_NewAccount(t *testing.T) { prov := newProv() escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + provID := prov.GetID() type test struct { db acme.DB @@ -343,6 +405,7 @@ func TestHandler_NewAccount(t *testing.T) { b, err := json.Marshal(nar) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ ctx: ctx, statusCode: 400, @@ -355,7 +418,8 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, @@ -368,7 +432,8 @@ func TestHandler_NewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, jwkContextKey, nil) return test{ ctx: ctx, @@ -376,6 +441,27 @@ func TestHandler_NewAccount(t *testing.T) { err: acme.NewErrorISE("jwk expected in request context"), } }, + "fail/new-account-no-eab-provided": func(t *testing.T) test { + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: nil, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 400, + err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), + } + }, "fail/db.CreateAccount-error": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, @@ -385,6 +471,7 @@ func TestHandler_NewAccount(t *testing.T) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, jwkContextKey, jwk) return test{ db: &acme.MockDB{ @@ -399,6 +486,109 @@ func TestHandler_NewAccount(t *testing.T) { err: acme.NewErrorISE("force"), } }, + "fail/acmeProvisionerFromContext": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + scepProvisioner := &provisioner.SCEP{ + Type: "SCEP", + Name: "test@scep-provisioner.com", + } + if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + assert.FatalError(t, err) + } + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), + } + }, + "fail/db.UpdateExternalAccountKey-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + eak := &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: time.Now(), + } + return test{ + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + acc.ID = "accountID" + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil + }, + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return eak, nil + }, + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerName string, eak *acme.ExternalAccountKey) error { + return errors.New("force") + }, + }, + acc: &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), + ExternalAccountBinding: eab, + }, + ctx: ctx, + statusCode: 500, + err: acme.NewError(acme.ErrorServerInternalType, "error updating external account binding key"), + } + }, "ok/new-account": func(t *testing.T) test { nar := &NewAccountRequest{ Contact: []string{"foo", "bar"}, @@ -455,6 +645,116 @@ func TestHandler_NewAccount(t *testing.T) { statusCode: 200, } }, + "ok/new-account-no-eab-required": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = false + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + acc.ID = "accountID" + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil + }, + }, + acc: &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), + }, + ctx: ctx, + statusCode: 201, + } + }, + "ok/new-account-with-eab": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { + acc.ID = "accountID" + assert.Equals(t, acc.Contact, nar.Contact) + assert.Equals(t, acc.Key, jwk) + return nil + }, + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: time.Now(), + }, nil + }, + MockUpdateExternalAccountKey: func(ctx context.Context, provisionerName string, eak *acme.ExternalAccountKey) error { + return nil + }, + }, + acc: &acme.Account{ + ID: "accountID", + Key: jwk, + Status: acme.StatusValid, + Contact: []string{"foo", "bar"}, + OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName), + ExternalAccountBinding: eab, + }, + ctx: ctx, + statusCode: 201, + } + }, } for name, run := range tests { tc := run(t) diff --git a/acme/api/eab.go b/acme/api/eab.go new file mode 100644 index 00000000..3660d066 --- /dev/null +++ b/acme/api/eab.go @@ -0,0 +1,155 @@ +package api + +import ( + "context" + "encoding/json" + + "github.com/smallstep/certificates/acme" + "go.step.sm/crypto/jose" +) + +// ExternalAccountBinding represents the ACME externalAccountBinding JWS +type ExternalAccountBinding struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Sig string `json:"signature"` +} + +// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. +func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { + acmeProv, err := acmeProvisionerFromContext(ctx) + if err != nil { + return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") + } + + if !acmeProv.RequireEAB { + return nil, nil + } + + if nar.ExternalAccountBinding == nil { + return nil, acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided") + } + + eabJSONBytes, err := json.Marshal(nar.ExternalAccountBinding) + if err != nil { + return nil, acme.WrapErrorISE(err, "error marshaling externalAccountBinding into bytes") + } + + eabJWS, err := jose.ParseJWS(string(eabJSONBytes)) + if err != nil { + return nil, acme.WrapErrorISE(err, "error parsing externalAccountBinding jws") + } + + // TODO(hs): implement strategy pattern to allow for different ways of verification (i.e. webhook call) based on configuration? + + keyID, acmeErr := validateEABJWS(ctx, eabJWS) + if acmeErr != nil { + return nil, acmeErr + } + + externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) + if err != nil { + if _, ok := err.(*acme.Error); ok { + return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") + } + return nil, acme.WrapErrorISE(err, "error retrieving external account key") + } + + if externalAccountKey.AlreadyBound() { + return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt) + } + + payload, err := eabJWS.Verify(externalAccountKey.KeyBytes) + if err != nil { + return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature") + } + + jwk, err := jwkFromContext(ctx) + if err != nil { + return nil, err + } + + var payloadJWK *jose.JSONWebKey + if err = json.Unmarshal(payload, &payloadJWK); err != nil { + return nil, acme.WrapError(acme.ErrorMalformedType, err, "error unmarshaling payload into jwk") + } + + if !keysAreEqual(jwk, payloadJWK) { + return nil, acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match") + } + + return externalAccountKey, nil +} + +// keysAreEqual performs an equality check on two JWKs by comparing +// the (base64 encoding) of the Key IDs. +func keysAreEqual(x, y *jose.JSONWebKey) bool { + if x == nil || y == nil { + return false + } + digestX, errX := acme.KeyToID(x) + digestY, errY := acme.KeyToID(y) + if errX != nil || errY != nil { + return false + } + return digestX == digestY +} + +// validateEABJWS verifies the contents of the External Account Binding JWS. +// The protected header of the JWS MUST meet the following criteria: +// o The "alg" field MUST indicate a MAC-based algorithm +// o The "kid" field MUST contain the key identifier provided by the CA +// o The "nonce" field MUST NOT be present +// o The "url" field MUST be set to the same value as the outer JWS +func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { + + if jws == nil { + return "", acme.NewErrorISE("no JWS provided") + } + + if len(jws.Signatures) != 1 { + return "", acme.NewError(acme.ErrorMalformedType, "JWS must have one signature") + } + + header := jws.Signatures[0].Protected + algorithm := header.Algorithm + keyID := header.KeyID + nonce := header.Nonce + + if !(algorithm == jose.HS256 || algorithm == jose.HS384 || algorithm == jose.HS512) { + return "", acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm '%s'", algorithm) + } + + if keyID == "" { + return "", acme.NewError(acme.ErrorMalformedType, "'kid' field is required") + } + + if nonce != "" { + return "", acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present") + } + + jwsURL, ok := header.ExtraHeaders["url"] + if !ok { + return "", acme.NewError(acme.ErrorMalformedType, "'url' field is required") + } + + outerJWS, err := jwsFromContext(ctx) + if err != nil { + return "", acme.WrapErrorISE(err, "could not retrieve outer JWS from context") + } + + if len(outerJWS.Signatures) != 1 { + return "", acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature") + } + + outerJWSURL, ok := outerJWS.Signatures[0].Protected.ExtraHeaders["url"] + if !ok { + return "", acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS") + } + + if jwsURL != outerJWSURL { + return "", acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS") + } + + return keyID, nil +} diff --git a/acme/api/eab_test.go b/acme/api/eab_test.go new file mode 100644 index 00000000..dce9f36d --- /dev/null +++ b/acme/api/eab_test.go @@ -0,0 +1,1068 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/crypto/jose" +) + +func Test_keysAreEqual(t *testing.T) { + jwkX, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + jwkY, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + wrongJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + wrongJWK.Key = struct{}{} + type args struct { + x *jose.JSONWebKey + y *jose.JSONWebKey + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "ok/nil", + args: args{ + x: jwkX, + y: nil, + }, + want: false, + }, + { + name: "ok/equal", + args: args{ + x: jwkX, + y: jwkX, + }, + want: true, + }, + { + name: "ok/not-equal", + args: args{ + x: jwkX, + y: jwkY, + }, + want: false, + }, + { + name: "ok/wrong-key-type", + args: args{ + x: wrongJWK, + y: jwkY, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := keysAreEqual(tt.args.x, tt.args.y); got != tt.want { + t.Errorf("keysAreEqual() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandler_validateExternalAccountBinding(t *testing.T) { + acmeProv := newACMEProv(t) + escProvName := url.PathEscape(acmeProv.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + provID := acmeProv.GetID() + type test struct { + db acme.DB + ctx context.Context + nar *NewAccountRequest + eak *acme.ExternalAccountKey + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "ok/no-eab-required-but-provided": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + prov := newACMEProv(t) + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + db: &acme.MockDB{}, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: nil, + } + }, + "ok/eab": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + createdAt := time.Now() + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt, + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt, + }, + err: nil, + } + }, + "fail/acmeProvisionerFromContext": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + b, err := json.Marshal(nar) + assert.FatalError(t, err) + scepProvisioner := &provisioner.SCEP{ + Type: "SCEP", + Name: "test@scep-provisioner.com", + } + if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { + assert.FatalError(t, err) + } + ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner) + return test{ + ctx: ctx, + err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), + } + }, + "fail/parse-eab-jose": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + eab.Payload += "{}" + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + db: &acme.MockDB{}, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewErrorISE("error parsing externalAccountBinding jws"), + } + }, + "fail/validate-eab-jws-no-signatures": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + parsedJWS.Signatures = []jose.Signature{} + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{}, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature"), + } + }, + "fail/retrieve-eab-key-db-failure": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockError: errors.New("db failure"), + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewErrorISE("error retrieving external account key"), + } + }, + "fail/db.GetExternalAccountKey-not-found": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return nil, acme.ErrNotFound + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewErrorISE("error retrieving external account key"), + } + }, + "fail/db.GetExternalAccountKey-error": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return nil, errors.New("force") + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewErrorISE("error retrieving external account key"), + } + }, + "fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockError: acme.NewError(acme.ErrorUnauthorizedType, "name of provisioner does not match provisioner for which the EAB key was created"), + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key: name of provisioner does not match provisioner for which the EAB key was created"), + } + }, + "fail/eab-already-bound": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + createdAt := time.Now() + boundAt := time.Now().Add(1 * time.Second) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + CreatedAt: createdAt, + AccountID: "some-account-id", + BoundAt: boundAt, + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "some-account-id", boundAt), + } + }, + "fail/eab-verify": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 2, 3, 4}, + CreatedAt: time.Now(), + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewErrorISE("error verifying externalAccountBinding signature"), + } + }, + "fail/eab-non-matching-keys": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + differentJWK, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(differentJWK, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, jwk) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: time.Now(), + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorUnauthorizedType, "keys in jws and eab payload do not match"), + } + }, + "fail/no-jwk": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: time.Now(), + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorServerInternalType, "jwk expected in request context"), + } + }, + "fail/nil-jwk": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal(rawEABJWS, &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + prov := newACMEProv(t) + prov.RequireEAB = true + ctx := context.WithValue(context.Background(), jwkContextKey, nil) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + db: &acme.MockDB{ + MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) { + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: provID, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: time.Now(), + }, nil + }, + }, + ctx: ctx, + nar: &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + }, + eak: nil, + err: acme.NewError(acme.ErrorServerInternalType, "jwk expected in request context"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + } + got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar) + wantErr := tc.err != nil + gotErr := err != nil + if wantErr != gotErr { + t.Errorf("Handler.validateExternalAccountBinding() error = %v, want %v", err, tc.err) + } + if wantErr { + assert.NotNil(t, err) + assert.Type(t, &acme.Error{}, err) + ae, _ := err.(*acme.Error) + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Status, tc.err.Status) + assert.HasPrefix(t, ae.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + } else { + if got == nil { + assert.Nil(t, tc.eak) + } else { + assert.NotNil(t, tc.eak) + assert.Equals(t, got.ID, tc.eak.ID) + assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID) + assert.Equals(t, got.Reference, tc.eak.Reference) + assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) + assert.Equals(t, got.AccountID, tc.eak.AccountID) + assert.Equals(t, got.BoundAt, tc.eak.BoundAt) + } + } + }) + } +} + +func Test_validateEABJWS(t *testing.T) { + acmeProv := newACMEProv(t) + escProvName := url.PathEscape(acmeProv.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + type test struct { + ctx context.Context + jws *jose.JSONWebSignature + keyID string + err *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/nil-jws": func(t *testing.T) test { + return test{ + jws: nil, + err: acme.NewErrorISE("no JWS provided"), + } + }, + "fail/invalid-number-of-signatures": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eabJWS.Signatures = append(eabJWS.Signatures, jose.Signature{}) + return test{ + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "JWS must have one signature"), + } + }, + "fail/invalid-algorithm": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eabJWS.Signatures[0].Protected.Algorithm = "HS42" + return test{ + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'alg' field set to invalid algorithm 'HS42'"), + } + }, + "fail/kid-not-set": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eabJWS.Signatures[0].Protected.KeyID = "" + return test{ + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'kid' field is required"), + } + }, + "fail/nonce-not-empty": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + eabJWS.Signatures[0].Protected.Nonce = "some-bogus-nonce" + return test{ + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'nonce' must not be present"), + } + }, + "fail/url-not-set": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + delete(eabJWS.Signatures[0].Protected.ExtraHeaders, "url") + return test{ + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'url' field is required"), + } + }, + "fail/no-outer-jws": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + ctx := context.WithValue(context.TODO(), jwsContextKey, nil) + return test{ + ctx: ctx, + jws: eabJWS, + err: acme.NewErrorISE("could not retrieve outer JWS from context"), + } + }, + "fail/outer-jws-multiple-signatures": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + rawEABJWS := eabJWS.FullSerialize() + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal([]byte(rawEABJWS), &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + outerJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + outerJWS.Signatures = append(outerJWS.Signatures, jose.Signature{}) + ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) + return test{ + ctx: ctx, + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "outer JWS must have one signature"), + } + }, + "fail/outer-jws-no-url": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + rawEABJWS := eabJWS.FullSerialize() + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal([]byte(rawEABJWS), &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + outerJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) + return test{ + ctx: ctx, + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'url' field must be set in outer JWS"), + } + }, + "fail/outer-jws-with-different-url": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + rawEABJWS := eabJWS.FullSerialize() + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal([]byte(rawEABJWS), &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", "this-is-not-the-same-url-as-in-the-eab-jws") + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + outerJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) + return test{ + ctx: ctx, + jws: eabJWS, + err: acme.NewError(acme.ErrorMalformedType, "'url' field is not the same value as the outer JWS"), + } + }, + "ok": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName) + eabJWS, err := createEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url) + assert.FatalError(t, err) + rawEABJWS := eabJWS.FullSerialize() + assert.FatalError(t, err) + eab := &ExternalAccountBinding{} + err = json.Unmarshal([]byte(rawEABJWS), &eab) + assert.FatalError(t, err) + nar := &NewAccountRequest{ + Contact: []string{"foo", "bar"}, + ExternalAccountBinding: eab, + } + payloadBytes, err := json.Marshal(nar) + assert.FatalError(t, err) + so := new(jose.SignerOptions) + so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm)) + so.WithHeader("url", url) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign(payloadBytes) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + outerJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + ctx := context.WithValue(context.TODO(), jwsContextKey, outerJWS) + return test{ + ctx: ctx, + jws: eabJWS, + keyID: "eakID", + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + keyID, err := validateEABJWS(tc.ctx, tc.jws) + wantErr := tc.err != nil + gotErr := err != nil + if wantErr != gotErr { + t.Errorf("validateEABJWS() error = %v, want %v", err, tc.err) + } + if wantErr { + assert.NotNil(t, err) + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Status, err.Status) + assert.HasPrefix(t, err.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, tc.err.Identifier, err.Identifier) + assert.Equals(t, tc.err.Subproblems, err.Subproblems) + } else { + assert.Nil(t, err) + assert.Equals(t, tc.keyID, keyID) + } + }) + } +} diff --git a/acme/api/handler.go b/acme/api/handler.go index 09ca03a3..bd226e73 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -136,6 +136,13 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { } } +type Meta struct { + TermsOfService string `json:"termsOfService,omitempty"` + Website string `json:"website,omitempty"` + CaaIdentities []string `json:"caaIdentities,omitempty"` + ExternalAccountRequired bool `json:"externalAccountRequired,omitempty"` +} + // Directory represents an ACME directory for configuring clients. type Directory struct { NewNonce string `json:"newNonce"` @@ -143,6 +150,7 @@ type Directory struct { NewOrder string `json:"newOrder"` RevokeCert string `json:"revokeCert"` KeyChange string `json:"keyChange"` + Meta Meta `json:"meta"` } // ToLog enables response logging for the Directory type. @@ -158,12 +166,21 @@ func (d *Directory) ToLog() (interface{}, error) { // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + acmeProv, err := acmeProvisionerFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + api.JSON(w, &Directory{ NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), + Meta: Meta{ + ExternalAccountRequired: acmeProv.RequireEAB, + }, }) } diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 14e00f12..67f7df30 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -15,9 +15,11 @@ import ( "time" "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" ) @@ -51,28 +53,76 @@ func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) { linker := NewLinker("ca.smallstep.com", "acme") - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) - - expDir := Directory{ - NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), - NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), - NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), - RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), - KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), - } - type test struct { + ctx context.Context statusCode int + dir Directory err *acme.Error } var tests = map[string]func(t *testing.T) test{ - "ok": func(t *testing.T) test { + "fail/no-provisioner": func(t *testing.T) test { + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, nil) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "fail/different-provisioner": func(t *testing.T) test { + prov := &provisioner.SCEP{ + Type: "SCEP", + Name: "test@scep-provisioner.com", + } + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), + } + }, + "ok": func(t *testing.T) test { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + } + return test{ + ctx: ctx, + dir: expDir, + statusCode: 200, + } + }, + "ok/eab-required": func(t *testing.T) test { + prov := newACMEProv(t) + prov.RequireEAB = true + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + expDir := Directory{ + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), + Meta: Meta{ + ExternalAccountRequired: true, + }, + } + return test{ + ctx: ctx, + dir: expDir, statusCode: 200, } }, @@ -82,7 +132,7 @@ func TestHandler_GetDirectory(t *testing.T) { t.Run(name, func(t *testing.T) { h := &Handler{linker: linker} req := httptest.NewRequest("GET", "/foo/bar", nil) - req = req.WithContext(ctx) + req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetDirectory(w, req) res := w.Result() @@ -105,7 +155,9 @@ func TestHandler_GetDirectory(t *testing.T) { } else { var dir Directory json.Unmarshal(bytes.TrimSpace(body), &dir) - assert.Equals(t, dir, expDir) + if !cmp.Equal(tc.dir, dir) { + t.Errorf("GetDirectory() diff =\n%s", cmp.Diff(tc.dir, dir)) + } assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/api/middleware.go b/acme/api/middleware.go index d701f240..de8614ee 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -507,6 +507,20 @@ func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { return pval, nil } +// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns +// pointer to an ACME provisioner or an error. +func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { + prov, err := provisionerFromContext(ctx) + if err != nil { + return nil, err + } + acmeProv, ok := prov.(*provisioner.ACME) + if !ok || acmeProv == nil { + return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") + } + return acmeProv, nil +} + // payloadFromContext searches the context for a payload. Returns the payload // or an error. func payloadFromContext(ctx context.Context) (*payloadInfo, error) { diff --git a/acme/db.go b/acme/db.go index 1675c7e7..412276fd 100644 --- a/acme/db.go +++ b/acme/db.go @@ -19,6 +19,13 @@ type DB interface { GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) UpdateAccount(ctx context.Context, acc *Account) error + CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) + GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) + GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error + UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error + CreateNonce(ctx context.Context) (Nonce, error) DeleteNonce(ctx context.Context, nonce Nonce) error @@ -49,6 +56,13 @@ type MockDB struct { MockGetAccountByKeyID func(ctx context.Context, kid string) (*Account, error) MockUpdateAccount func(ctx context.Context, acc *Account) error + MockCreateExternalAccountKey func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) + MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) + MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) + MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error + MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error + MockCreateNonce func(ctx context.Context) (Nonce, error) MockDeleteNonce func(ctx context.Context, nonce Nonce) error @@ -114,6 +128,66 @@ func (m *MockDB) UpdateAccount(ctx context.Context, acc *Account) error { return m.MockError } +// CreateExternalAccountKey mock +func (m *MockDB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) { + if m.MockCreateExternalAccountKey != nil { + return m.MockCreateExternalAccountKey(ctx, provisionerID, reference) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*ExternalAccountKey), m.MockError +} + +// GetExternalAccountKey mock +func (m *MockDB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) { + if m.MockGetExternalAccountKey != nil { + return m.MockGetExternalAccountKey(ctx, provisionerID, keyID) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*ExternalAccountKey), m.MockError +} + +// GetExternalAccountKeys mock +func (m *MockDB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) { + if m.MockGetExternalAccountKeys != nil { + return m.MockGetExternalAccountKeys(ctx, provisionerID, cursor, limit) + } else if m.MockError != nil { + return nil, "", m.MockError + } + return m.MockRet1.([]*ExternalAccountKey), "", m.MockError +} + +// GetExternalAccountKeyByReference mock +func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) { + if m.MockGetExternalAccountKeyByReference != nil { + return m.MockGetExternalAccountKeyByReference(ctx, provisionerID, reference) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*ExternalAccountKey), m.MockError +} + +// DeleteExternalAccountKey mock +func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { + if m.MockDeleteExternalAccountKey != nil { + return m.MockDeleteExternalAccountKey(ctx, provisionerID, keyID) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + +// UpdateExternalAccountKey mock +func (m *MockDB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error { + if m.MockUpdateExternalAccountKey != nil { + return m.MockUpdateExternalAccountKey(ctx, provisionerID, eak) + } else if m.MockError != nil { + return m.MockError + } + return m.MockError +} + // CreateNonce mock func (m *MockDB) CreateNonce(ctx context.Context) (Nonce, error) { if m.MockCreateNonce != nil { diff --git a/acme/db/nosql/account_test.go b/acme/db/nosql/account_test.go index a02e93dc..83a23476 100644 --- a/acme/db/nosql/account_test.go +++ b/acme/db/nosql/account_test.go @@ -307,7 +307,7 @@ func TestDB_GetAccountByKeyID(t *testing.T) { assert.Equals(t, string(key), accID) return nil, errors.New("force") default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, @@ -340,7 +340,7 @@ func TestDB_GetAccountByKeyID(t *testing.T) { assert.Equals(t, string(key), accID) return b, nil default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, errors.New("force") } }, @@ -462,7 +462,7 @@ func TestDB_CreateAccount(t *testing.T) { assert.True(t, dbacc.DeactivatedAt.IsZero()) return nil, false, errors.New("force") default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, @@ -506,7 +506,7 @@ func TestDB_CreateAccount(t *testing.T) { assert.True(t, dbacc.DeactivatedAt.IsZero()) return nu, true, nil default: - assert.FatalError(t, errors.Errorf("unrecognized bucket %s", string(bucket))) + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) return nil, false, errors.New("force") } }, diff --git a/acme/db/nosql/eab.go b/acme/db/nosql/eab.go new file mode 100644 index 00000000..f9a24daf --- /dev/null +++ b/acme/db/nosql/eab.go @@ -0,0 +1,380 @@ +package nosql + +import ( + "context" + "crypto/rand" + "encoding/json" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/acme" + nosqlDB "github.com/smallstep/nosql" +) + +// externalAccountKeyMutex for read/write locking of EAK operations. +var externalAccountKeyMutex sync.RWMutex + +// referencesByProvisionerIndexMutex for locking referencesByProvisioner index operations. +var referencesByProvisionerIndexMutex sync.Mutex + +type dbExternalAccountKey struct { + ID string `json:"id"` + ProvisionerID string `json:"provisionerID"` + Reference string `json:"reference"` + AccountID string `json:"accountID,omitempty"` + KeyBytes []byte `json:"key"` + CreatedAt time.Time `json:"createdAt"` + BoundAt time.Time `json:"boundAt"` +} + +type dbExternalAccountKeyReference struct { + Reference string `json:"reference"` + ExternalAccountKeyID string `json:"externalAccountKeyID"` +} + +// getDBExternalAccountKey retrieves and unmarshals dbExternalAccountKey. +func (db *DB) getDBExternalAccountKey(ctx context.Context, id string) (*dbExternalAccountKey, error) { + data, err := db.db.Get(externalAccountKeyTable, []byte(id)) + if err != nil { + if nosqlDB.IsErrNotFound(err) { + return nil, acme.ErrNotFound + } + return nil, errors.Wrapf(err, "error loading external account key %s", id) + } + + dbeak := new(dbExternalAccountKey) + if err = json.Unmarshal(data, dbeak); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling external account key %s into dbExternalAccountKey", id) + } + + return dbeak, nil +} + +// CreateExternalAccountKey creates a new External Account Binding key with a name +func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + + externalAccountKeyMutex.Lock() + defer externalAccountKeyMutex.Unlock() + + keyID, err := randID() + if err != nil { + return nil, err + } + + random := make([]byte, 32) + _, err = rand.Read(random) + if err != nil { + return nil, err + } + + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provisionerID, + Reference: reference, + KeyBytes: random, + CreatedAt: clock.Now(), + } + + if err := db.save(ctx, keyID, dbeak, nil, "external_account_key", externalAccountKeyTable); err != nil { + return nil, err + } + + if err := db.addEAKID(ctx, provisionerID, dbeak.ID); err != nil { + return nil, err + } + + if dbeak.Reference != "" { + dbExternalAccountKeyReference := &dbExternalAccountKeyReference{ + Reference: dbeak.Reference, + ExternalAccountKeyID: dbeak.ID, + } + if err := db.save(ctx, referenceKey(provisionerID, dbeak.Reference), dbExternalAccountKeyReference, nil, "external_account_key_reference", externalAccountKeyIDsByReferenceTable); err != nil { + return nil, err + } + } + + return &acme.ExternalAccountKey{ + ID: dbeak.ID, + ProvisionerID: dbeak.ProvisionerID, + Reference: dbeak.Reference, + AccountID: dbeak.AccountID, + KeyBytes: dbeak.KeyBytes, + CreatedAt: dbeak.CreatedAt, + BoundAt: dbeak.BoundAt, + }, nil +} + +// GetExternalAccountKey retrieves an External Account Binding key by KeyID +func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) { + externalAccountKeyMutex.RLock() + defer externalAccountKeyMutex.RUnlock() + + dbeak, err := db.getDBExternalAccountKey(ctx, keyID) + if err != nil { + return nil, err + } + + if dbeak.ProvisionerID != provisionerID { + return nil, acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created") + } + + return &acme.ExternalAccountKey{ + ID: dbeak.ID, + ProvisionerID: dbeak.ProvisionerID, + Reference: dbeak.Reference, + AccountID: dbeak.AccountID, + KeyBytes: dbeak.KeyBytes, + CreatedAt: dbeak.CreatedAt, + BoundAt: dbeak.BoundAt, + }, nil +} + +func (db *DB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { + externalAccountKeyMutex.Lock() + defer externalAccountKeyMutex.Unlock() + + dbeak, err := db.getDBExternalAccountKey(ctx, keyID) + if err != nil { + return errors.Wrapf(err, "error loading ACME EAB Key with Key ID %s", keyID) + } + + if dbeak.ProvisionerID != provisionerID { + return errors.New("provisioner does not match provisioner for which the EAB key was created") + } + + if dbeak.Reference != "" { + if err := db.db.Del(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, dbeak.Reference))); err != nil { + return errors.Wrapf(err, "error deleting ACME EAB Key reference with Key ID %s and reference %s", keyID, dbeak.Reference) + } + } + if err := db.db.Del(externalAccountKeyTable, []byte(keyID)); err != nil { + return errors.Wrapf(err, "error deleting ACME EAB Key with Key ID %s", keyID) + } + if err := db.deleteEAKID(ctx, provisionerID, keyID); err != nil { + return errors.Wrapf(err, "error removing ACME EAB Key ID %s", keyID) + } + + return nil +} + +// GetExternalAccountKeys retrieves all External Account Binding keys for a provisioner +func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) { + externalAccountKeyMutex.RLock() + defer externalAccountKeyMutex.RUnlock() + + // cursor and limit are ignored in open source, at least for now. + + var eakIDs []string + r, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) + if err != nil { + if !nosqlDB.IsErrNotFound(err) { + return nil, "", errors.Wrapf(err, "error loading ACME EAB Key IDs for provisioner %s", provisionerID) + } + // it may happen that no record is found; we'll continue with an empty slice + } else { + if err := json.Unmarshal(r, &eakIDs); err != nil { + return nil, "", errors.Wrapf(err, "error unmarshaling ACME EAB Key IDs for provisioner %s", provisionerID) + } + } + + keys := []*acme.ExternalAccountKey{} + for _, eakID := range eakIDs { + if eakID == "" { + continue // shouldn't happen; just in case + } + eak, err := db.getDBExternalAccountKey(ctx, eakID) + if err != nil { + if !nosqlDB.IsErrNotFound(err) { + return nil, "", errors.Wrapf(err, "error retrieving ACME EAB Key for provisioner %s and keyID %s", provisionerID, eakID) + } + } + keys = append(keys, &acme.ExternalAccountKey{ + ID: eak.ID, + KeyBytes: eak.KeyBytes, + ProvisionerID: eak.ProvisionerID, + Reference: eak.Reference, + AccountID: eak.AccountID, + CreatedAt: eak.CreatedAt, + BoundAt: eak.BoundAt, + }) + } + + return keys, "", nil +} + +// GetExternalAccountKeyByReference retrieves an External Account Binding key with unique reference +func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + externalAccountKeyMutex.RLock() + defer externalAccountKeyMutex.RUnlock() + + if reference == "" { + return nil, nil + } + + k, err := db.db.Get(externalAccountKeyIDsByReferenceTable, []byte(referenceKey(provisionerID, reference))) + if nosqlDB.IsErrNotFound(err) { + return nil, acme.ErrNotFound + } else if err != nil { + return nil, errors.Wrapf(err, "error loading ACME EAB key for reference %s", reference) + } + dbExternalAccountKeyReference := new(dbExternalAccountKeyReference) + if err := json.Unmarshal(k, dbExternalAccountKeyReference); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling ACME EAB key for reference %s", reference) + } + + return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) +} + +func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { + externalAccountKeyMutex.Lock() + defer externalAccountKeyMutex.Unlock() + + old, err := db.getDBExternalAccountKey(ctx, eak.ID) + if err != nil { + return err + } + + if old.ProvisionerID != provisionerID { + return errors.New("provisioner does not match provisioner for which the EAB key was created") + } + + if old.ProvisionerID != eak.ProvisionerID { + return errors.New("cannot change provisioner for an existing ACME EAB Key") + } + + if old.Reference != eak.Reference { + return errors.New("cannot change reference for an existing ACME EAB Key") + } + + nu := dbExternalAccountKey{ + ID: eak.ID, + ProvisionerID: eak.ProvisionerID, + Reference: eak.Reference, + AccountID: eak.AccountID, + KeyBytes: eak.KeyBytes, + CreatedAt: eak.CreatedAt, + BoundAt: eak.BoundAt, + } + + return db.save(ctx, nu.ID, nu, old, "external_account_key", externalAccountKeyTable) +} + +func (db *DB) addEAKID(ctx context.Context, provisionerID, eakID string) error { + referencesByProvisionerIndexMutex.Lock() + defer referencesByProvisionerIndexMutex.Unlock() + + if eakID == "" { + return errors.Errorf("can't add empty eakID for provisioner %s", provisionerID) + } + + var eakIDs []string + b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) + if err != nil { + if !nosqlDB.IsErrNotFound(err) { + return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID) + } + // it may happen that no record is found; we'll continue with an empty slice + } else { + if err := json.Unmarshal(b, &eakIDs); err != nil { + return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID) + } + } + + for _, id := range eakIDs { + if id == eakID { + // return an error when a duplicate ID is found + return errors.Errorf("eakID %s already exists for provisioner %s", eakID, provisionerID) + } + } + + var newEAKIDs []string + newEAKIDs = append(newEAKIDs, eakIDs...) + newEAKIDs = append(newEAKIDs, eakID) + + var ( + _old interface{} = eakIDs + _new interface{} = newEAKIDs + ) + + // ensure that the DB gets the expected value when the slice is empty; otherwise + // it'll return with an error that indicates that the DBs view of the data is + // different from the last read (i.e. _old is different from what the DB has). + if len(eakIDs) == 0 { + _old = nil + } + + if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil { + return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID) + } + + return nil +} + +func (db *DB) deleteEAKID(ctx context.Context, provisionerID, eakID string) error { + referencesByProvisionerIndexMutex.Lock() + defer referencesByProvisionerIndexMutex.Unlock() + + var eakIDs []string + b, err := db.db.Get(externalAccountKeyIDsByProvisionerIDTable, []byte(provisionerID)) + if err != nil { + if !nosqlDB.IsErrNotFound(err) { + return errors.Wrapf(err, "error loading eakIDs for provisioner %s", provisionerID) + } + // it may happen that no record is found; we'll continue with an empty slice + } else { + if err := json.Unmarshal(b, &eakIDs); err != nil { + return errors.Wrapf(err, "error unmarshaling eakIDs for provisioner %s", provisionerID) + } + } + + newEAKIDs := removeElement(eakIDs, eakID) + var ( + _old interface{} = eakIDs + _new interface{} = newEAKIDs + ) + + // ensure that the DB gets the expected value when the slice is empty; otherwise + // it'll return with an error that indicates that the DBs view of the data is + // different from the last read (i.e. _old is different from what the DB has). + if len(eakIDs) == 0 { + _old = nil + } + + if err = db.save(ctx, provisionerID, _new, _old, "externalAccountKeyIDsByProvisionerID", externalAccountKeyIDsByProvisionerIDTable); err != nil { + return errors.Wrapf(err, "error saving eakIDs index for provisioner %s", provisionerID) + } + + return nil +} + +// referenceKey returns a unique key for a reference per provisioner +func referenceKey(provisionerID, reference string) string { + return provisionerID + "." + reference +} + +// sliceIndex finds the index of item in slice +func sliceIndex(slice []string, item string) int { + for i := range slice { + if slice[i] == item { + return i + } + } + return -1 +} + +// removeElement deletes the item if it exists in the +// slice. It returns a new slice, keeping the old one intact. +func removeElement(slice []string, item string) []string { + + newSlice := make([]string, 0) + index := sliceIndex(slice, item) + if index < 0 { + newSlice = append(newSlice, slice...) + return newSlice + } + + newSlice = append(newSlice, slice[:index]...) + + return append(newSlice, slice[index+1:]...) +} diff --git a/acme/db/nosql/eab_test.go b/acme/db/nosql/eab_test.go new file mode 100644 index 00000000..568500e9 --- /dev/null +++ b/acme/db/nosql/eab_test.go @@ -0,0 +1,1712 @@ +package nosql + +import ( + "context" + "encoding/json" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + certdb "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBExternalAccountKey(t *testing.T) { + keyID := "keyID" + provID := "provID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + dbeak *dbExternalAccountKey + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "ref", + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + err: nil, + dbeak: dbeak, + } + }, + "fail/not-found": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return nil, nosqldb.ErrNotFound + }, + }, + err: acme.ErrNotFound, + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading external account key keyID: force"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling external account key keyID into dbExternalAccountKey"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if dbeak, err := d.getDBExternalAccountKey(context.Background(), keyID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, dbeak.ID, tc.dbeak.ID) + assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes) + assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID) + assert.Equals(t, dbeak.Reference, tc.dbeak.Reference) + assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt) + assert.Equals(t, dbeak.AccountID, tc.dbeak.AccountID) + assert.Equals(t, dbeak.BoundAt, tc.dbeak.BoundAt) + } + }) + } +} + +func TestDB_GetExternalAccountKey(t *testing.T) { + keyID := "keyID" + provID := "provID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + eak *acme.ExternalAccountKey + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "ref", + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + eak: &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "ref", + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading external account key keyID: force"), + } + }, + "fail/non-matching-provisioner": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: "aDifferentProvID", + Reference: "ref", + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + eak: &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "ref", + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, + acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if eak, err := d.GetExternalAccountKey(context.Background(), provID, keyID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, eak.ID, tc.eak.ID) + assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) + assert.Equals(t, eak.Reference, tc.eak.Reference) + assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) + assert.Equals(t, eak.AccountID, tc.eak.AccountID) + assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) + } + }) + } +} + +func TestDB_GetExternalAccountKeyByReference(t *testing.T) { + keyID := "keyID" + provID := "provID" + ref := "ref" + type test struct { + db nosql.DB + err error + ref string + acmeErr *acme.Error + eak *acme.ExternalAccountKey + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + ref: ref, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + eak: &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, + err: nil, + } + }, + "ok/no-reference": func(t *testing.T) test { + return test{ + ref: "", + eak: nil, + err: nil, + } + }, + "fail/reference-not-found": func(t *testing.T) test { + return test{ + ref: ref, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) + assert.Equals(t, string(key), provID+"."+ref) + return nil, nosqldb.ErrNotFound + }, + }, + err: errors.New("not found"), + } + }, + "fail/reference-load-error": func(t *testing.T) test { + return test{ + ref: ref, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) + assert.Equals(t, string(key), provID+"."+ref) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading ACME EAB key for reference ref: force"), + } + }, + "fail/reference-unmarshal-error": func(t *testing.T) test { + return test{ + ref: ref, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByReferenceTable)) + assert.Equals(t, string(key), provID+"."+ref) + return []byte{0}, nil + }, + }, + err: errors.New("error unmarshaling ACME EAB key for reference ref"), + } + }, + "fail/db.GetExternalAccountKey-error": func(t *testing.T) test { + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + ref: ref, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + err: errors.New("error loading external account key keyID: force"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if eak, err := d.GetExternalAccountKeyByReference(context.Background(), provID, tc.ref); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && tc.eak != nil { + assert.Equals(t, eak.ID, tc.eak.ID) + assert.Equals(t, eak.AccountID, tc.eak.AccountID) + assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) + assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) + assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) + assert.Equals(t, eak.Reference, tc.eak.Reference) + } + }) + } +} + +func TestDB_GetExternalAccountKeys(t *testing.T) { + keyID1 := "keyID1" + keyID2 := "keyID2" + keyID3 := "keyID3" + provID := "provID" + ref := "ref" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + eaks []*acme.ExternalAccountKey + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + now := clock.Now() + dbeak1 := &dbExternalAccountKey{ + ID: keyID1, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b1, err := json.Marshal(dbeak1) + assert.FatalError(t, err) + dbeak2 := &dbExternalAccountKey{ + ID: keyID2, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b2, err := json.Marshal(dbeak2) + assert.FatalError(t, err) + dbeak3 := &dbExternalAccountKey{ + ID: keyID3, + ProvisionerID: "aDifferentProvID", + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b3, err := json.Marshal(dbeak3) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByProvisionerIDTable): + keys := []string{"", keyID1, keyID2} // includes an empty keyID + b, err := json.Marshal(keys) + assert.FatalError(t, err) + return b, nil + case string(externalAccountKeyTable): + switch string(key) { + case keyID1: + return b1, nil + case keyID2: + return b2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force default") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + // TODO: remove the MList + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + switch string(bucket) { + case string(externalAccountKeyTable): + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(keyID1), + Value: b1, + }, + { + Bucket: bucket, + Key: []byte(keyID2), + Value: b2, + }, + { + Bucket: bucket, + Key: []byte(keyID3), + Value: b3, + }, + }, nil + case string(externalAccountKeyIDsByProvisionerIDTable): + keys := []string{keyID1, keyID2} + b, err := json.Marshal(keys) + assert.FatalError(t, err) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(provID), + Value: b, + }, + }, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + }, + eaks: []*acme.ExternalAccountKey{ + { + ID: keyID1, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, + { + ID: keyID2, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, + }, + } + }, + "fail/db.Get-externalAccountKeysByProvisionerIDTable": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading ACME EAB Key IDs for provisioner provID: force"), + } + }, + "fail/db.Get-externalAccountKeysByProvisionerIDTable-unmarshal": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) + b, _ := json.Marshal(1) + return b, nil + }, + }, + err: errors.New("error unmarshaling ACME EAB Key IDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), + } + }, + "fail/db.getDBExternalAccountKey": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByProvisionerIDTable): + keys := []string{keyID1, keyID2} + b, err := json.Marshal(keys) + assert.FatalError(t, err) + return b, nil + case string(externalAccountKeyTable): + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force bucket") + } + }, + }, + err: errors.New("error retrieving ACME EAB Key for provisioner provID and keyID keyID1: error loading external account key keyID1: force"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + cursor, limit := "", 0 + if eaks, nextCursor, err := d.GetExternalAccountKeys(context.Background(), provID, cursor, limit); err != nil { + assert.Equals(t, "", nextCursor) + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.Equals(t, tc.err.Error(), err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, len(eaks), len(tc.eaks)) + assert.Equals(t, "", nextCursor) + for i, eak := range eaks { + assert.Equals(t, eak.ID, tc.eaks[i].ID) + assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes) + assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID) + assert.Equals(t, eak.Reference, tc.eaks[i].Reference) + assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt) + assert.Equals(t, eak.AccountID, tc.eaks[i].AccountID) + assert.Equals(t, eak.BoundAt, tc.eaks[i].BoundAt) + } + } + }) + } +} + +func TestDB_DeleteExternalAccountKey(t *testing.T) { + keyID := "keyID" + provID := "provID" + ref := "ref" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return b, nil + case string(externalAccountKeyIDsByProvisionerIDTable): + assert.Equals(t, provID, string(key)) + b, err := json.Marshal([]string{keyID}) + assert.FatalError(t, err) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + MDel: func(bucket, key []byte) error { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return errors.New("force default") + } + }, + MCmpAndSwap: func(bucket, key, old, new []byte) ([]byte, bool, error) { + fmt.Println(string(bucket)) + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, provID+"."+ref, string(key)) + return nil, true, nil + case string(externalAccountKeyIDsByProvisionerIDTable): + assert.Equals(t, provID, string(key)) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force default") + } + }, + }, + } + }, + "fail/not-found": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyTable)) + assert.Equals(t, string(key), keyID) + return nil, nosqldb.ErrNotFound + }, + }, + err: errors.New("error loading ACME EAB Key with Key ID keyID: not found"), + } + }, + "fail/non-matching-provisioner": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: "aDifferentProvID", + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyTable)) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + err: errors.New("provisioner does not match provisioner for which the EAB key was created"), + } + }, + "fail/delete-reference": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + MDel: func(bucket, key []byte) error { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return errors.New("force") + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return errors.New("force default") + } + }, + }, + err: errors.New("error deleting ACME EAB Key reference with Key ID keyID and reference ref: force"), + } + }, + "fail/delete-eak": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return b, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + MDel: func(bucket, key []byte) error { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return errors.New("force default") + } + }, + }, + err: errors.New("error deleting ACME EAB Key with Key ID keyID: force"), + } + }, + "fail/delete-eakID": func(t *testing.T) test { + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + dbref := &dbExternalAccountKeyReference{ + Reference: ref, + ExternalAccountKeyID: keyID, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + dbrefBytes, err := json.Marshal(dbref) + assert.FatalError(t, err) + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), ref) + return dbrefBytes, nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return b, nil + case string(externalAccountKeyIDsByProvisionerIDTable): + return b, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force default") + } + }, + MDel: func(bucket, key []byte) error { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), provID+"."+ref) + return nil + case string(externalAccountKeyTable): + assert.Equals(t, string(key), keyID) + return nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return errors.New("force default") + } + }, + }, + err: errors.New("error removing ACME EAB Key ID keyID: error loading eakIDs for provisioner provID: force"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if err := d.DeleteExternalAccountKey(context.Background(), provID, keyID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.Equals(t, err.Error(), tc.err.Error()) + } + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestDB_CreateExternalAccountKey(t *testing.T) { + keyID := "keyID" + provID := "provID" + ref := "ref" + type test struct { + db nosql.DB + err error + _id *string + eak *acme.ExternalAccountKey + } + var tests = map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + var ( + id string + idPtr = &id + ) + now := clock.Now() + eak := &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "ref", + AccountID: "", + CreatedAt: now, + } + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) + assert.Equals(t, provID, string(key)) + b, _ := json.Marshal([]string{}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByProvisionerIDTable): + assert.Equals(t, provID, string(key)) + return nu, true, nil + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, provID+"."+ref, string(key)) + assert.Equals(t, nil, old) + return nu, true, nil + case string(externalAccountKeyTable): + assert.Equals(t, nil, old) + + id = string(key) + + dbeak := new(dbExternalAccountKey) + assert.FatalError(t, json.Unmarshal(nu, dbeak)) + assert.Equals(t, string(key), dbeak.ID) + assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID) + assert.Equals(t, eak.Reference, dbeak.Reference) + assert.Equals(t, 32, len(dbeak.KeyBytes)) + assert.False(t, dbeak.CreatedAt.IsZero()) + assert.Equals(t, dbeak.AccountID, eak.AccountID) + assert.True(t, dbeak.BoundAt.IsZero()) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force default") + } + }, + }, + eak: eak, + _id: idPtr, + } + }, + "fail/externalAccountKeyID-cmpAndSwap-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), ref) + assert.Equals(t, old, nil) + return nu, true, nil + case string(externalAccountKeyTable): + assert.Equals(t, old, nil) + return nu, true, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force default") + } + }, + }, + err: errors.New("error saving acme external_account_key: force"), + } + }, + "fail/addEAKID-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) + assert.Equals(t, provID, string(key)) + return nil, errors.New("force") + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, string(key), ref) + assert.Equals(t, old, nil) + return nu, true, nil + case string(externalAccountKeyTable): + assert.Equals(t, old, nil) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force default") + } + }, + }, + err: errors.New("error loading eakIDs for provisioner provID: force"), + } + }, + "fail/externalAccountKeyReference-cmpAndSwap-error": func(t *testing.T) test { + return test{ + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, string(bucket), string(externalAccountKeyIDsByProvisionerIDTable)) + assert.Equals(t, provID, string(key)) + b, _ := json.Marshal([]string{}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(externalAccountKeyIDsByProvisionerIDTable): + assert.Equals(t, provID, string(key)) + return nu, true, nil + case string(externalAccountKeyIDsByReferenceTable): + assert.Equals(t, provID+"."+ref, string(key)) + assert.Equals(t, old, nil) + return nu, true, errors.New("force") + case string(externalAccountKeyTable): + assert.Equals(t, old, nil) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force default") + } + }, + }, + err: errors.New("error saving acme external_account_key_reference: force"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + eak, err := d.CreateExternalAccountKey(context.Background(), provID, ref) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.Equals(t, err.Error(), tc.err.Error()) + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, *tc._id, eak.ID) + assert.Equals(t, provID, eak.ProvisionerID) + assert.Equals(t, ref, eak.Reference) + assert.Equals(t, "", eak.AccountID) + assert.False(t, eak.CreatedAt.IsZero()) + assert.False(t, eak.AlreadyBound()) + assert.True(t, eak.BoundAt.IsZero()) + } + }) + } +} + +func TestDB_UpdateExternalAccountKey(t *testing.T) { + keyID := "keyID" + provID := "provID" + ref := "ref" + now := clock.Now() + dbeak := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(dbeak) + assert.FatalError(t, err) + type test struct { + db nosql.DB + eak *acme.ExternalAccountKey + err error + } + var tests = map[string]func(t *testing.T) test{ + + "ok": func(t *testing.T) test { + eak := &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + return test{ + eak: eak, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, old, b) + + dbNew := new(dbExternalAccountKey) + assert.FatalError(t, json.Unmarshal(nu, dbNew)) + assert.Equals(t, dbNew.ID, dbeak.ID) + assert.Equals(t, dbNew.ProvisionerID, dbeak.ProvisionerID) + assert.Equals(t, dbNew.Reference, dbeak.Reference) + assert.Equals(t, dbNew.AccountID, dbeak.AccountID) + assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt) + assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt) + assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes) + return nu, true, nil + }, + }, + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + eak: &acme.ExternalAccountKey{ + ID: keyID, + }, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading external account key keyID: force"), + } + }, + "fail/provisioner-mismatch": func(t *testing.T) test { + newDBEAK := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: "aDifferentProvID", + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(newDBEAK) + assert.FatalError(t, err) + return test{ + eak: &acme.ExternalAccountKey{ + ID: keyID, + }, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + + return b, nil + }, + }, + err: errors.New("provisioner does not match provisioner for which the EAB key was created"), + } + }, + "fail/provisioner-change": func(t *testing.T) test { + newDBEAK := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(newDBEAK) + assert.FatalError(t, err) + return test{ + eak: &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: "aDifferentProvisionerID", + }, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + err: errors.New("cannot change provisioner for an existing ACME EAB Key"), + } + }, + "fail/reference-change": func(t *testing.T) test { + newDBEAK := &dbExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: ref, + AccountID: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + } + b, err := json.Marshal(newDBEAK) + assert.FatalError(t, err) + return test{ + eak: &acme.ExternalAccountKey{ + ID: keyID, + ProvisionerID: provID, + Reference: "aDifferentReference", + }, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyTable) + assert.Equals(t, string(key), keyID) + return b, nil + }, + }, + err: errors.New("cannot change reference for an existing ACME EAB Key"), + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if err := d.UpdateExternalAccountKey(context.Background(), provID, tc.eak); err != nil { + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, dbeak.ID, tc.eak.ID) + assert.Equals(t, dbeak.ProvisionerID, tc.eak.ProvisionerID) + assert.Equals(t, dbeak.Reference, tc.eak.Reference) + assert.Equals(t, dbeak.AccountID, tc.eak.AccountID) + assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt) + assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt) + assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes) + } + }) + } +} + +func TestDB_addEAKID(t *testing.T) { + provID := "provID" + eakID := "eakID" + type test struct { + ctx context.Context + provisionerID string + eakID string + db nosql.DB + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/empty-eakID": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: "", + err: errors.New("can't add empty eakID for provisioner provID"), + } + }, + "fail/db.Get": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading eakIDs for provisioner provID: force"), + } + }, + "fail/unmarshal": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal(1) + return b, nil + }, + }, + err: errors.New("error unmarshaling eakIDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), + } + }, + "fail/eakID-already-exists": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal([]string{eakID}) + return b, nil + }, + }, + err: errors.New("eakID eakID already exists for provisioner provID"), + } + }, + "fail/db.save": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal([]string{"id1"}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + oldB, _ := json.Marshal([]string{"id1"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"id1", eakID}) + assert.Equals(t, nu, newB) + return newB, true, errors.New("force") + }, + }, + err: errors.New("error saving eakIDs index for provisioner provID: error saving acme externalAccountKeyIDsByProvisionerID: force"), + } + }, + "ok/db.Get-not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, old, nil) + b, _ := json.Marshal([]string{eakID}) + assert.Equals(t, nu, b) + return b, true, nil + }, + }, + err: nil, + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal([]string{"id1", "id2"}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + oldB, _ := json.Marshal([]string{"id1", "id2"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"id1", "id2", eakID}) + assert.Equals(t, nu, newB) + return newB, true, nil + }, + }, + err: nil, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := &DB{ + db: tc.db, + } + wantErr := tc.err != nil + err := db.addEAKID(tc.ctx, tc.provisionerID, tc.eakID) + if (err != nil) != wantErr { + t.Errorf("DB.addEAKID() error = %v, wantErr %v", err, wantErr) + } + if err != nil { + assert.Equals(t, tc.err.Error(), err.Error()) + } + }) + } +} + +func TestDB_deleteEAKID(t *testing.T) { + provID := "provID" + eakID := "eakID" + type test struct { + ctx context.Context + provisionerID string + eakID string + db nosql.DB + err error + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading eakIDs for provisioner provID: force"), + } + }, + "fail/unmarshal": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal(1) + return b, nil + }, + }, + err: errors.New("error unmarshaling eakIDs for provisioner provID: json: cannot unmarshal number into Go value of type []string"), + } + }, + "fail/db.save": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal([]string{"id1", eakID}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + oldB, _ := json.Marshal([]string{"id1", eakID}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"id1"}) + assert.Equals(t, nu, newB) + return newB, true, errors.New("force") + }, + }, + err: errors.New("error saving eakIDs index for provisioner provID: error saving acme externalAccountKeyIDsByProvisionerID: force"), + } + }, + "ok/db.Get-not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + assert.Equals(t, old, nil) + b, _ := json.Marshal([]string{}) + assert.Equals(t, nu, b) + return b, true, nil + }, + }, + err: nil, + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.Background(), + provisionerID: provID, + eakID: eakID, + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + b, _ := json.Marshal([]string{"id1", eakID, "id2"}) + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + oldB, _ := json.Marshal([]string{"id1", eakID, "id2"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"id1", "id2"}) + assert.Equals(t, nu, newB) + return newB, true, nil + }, + }, + err: nil, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := &DB{ + db: tc.db, + } + wantErr := tc.err != nil + err := db.deleteEAKID(tc.ctx, tc.provisionerID, tc.eakID) + if (err != nil) != wantErr { + t.Errorf("DB.deleteEAKID() error = %v, wantErr %v", err, wantErr) + } + if err != nil { + assert.Equals(t, tc.err.Error(), err.Error()) + } + }) + } +} + +func TestDB_addAndDeleteEAKID(t *testing.T) { + provID := "provID" + callCounter := 0 + type test struct { + ctx context.Context + db nosql.DB + err error + } + var tests = map[string]func(t *testing.T) test{ + "ok/multi": func(t *testing.T) test { + return test{ + ctx: context.Background(), + db: &certdb.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + switch callCounter { + case 0: + return nil, nosqldb.ErrNotFound + case 1: + b, _ := json.Marshal([]string{"eakID"}) + return b, nil + case 2: + b, _ := json.Marshal([]string{}) + return b, nil + case 3: + b, _ := json.Marshal([]string{"eakID1"}) + return b, nil + case 4: + b, _ := json.Marshal([]string{"eakID1", "eakID2"}) + return b, nil + case 5: + b, _ := json.Marshal([]string{"eakID2"}) + return b, nil + default: + assert.FatalError(t, errors.New("unexpected get iteration")) + return nil, errors.New("force get default") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, externalAccountKeyIDsByProvisionerIDTable) + assert.Equals(t, string(key), provID) + switch callCounter { + case 0: + assert.Equals(t, old, nil) + newB, _ := json.Marshal([]string{"eakID"}) + assert.Equals(t, nu, newB) + return newB, true, nil + case 1: + oldB, _ := json.Marshal([]string{"eakID"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{}) + return newB, true, nil + case 2: + assert.Equals(t, old, nil) + newB, _ := json.Marshal([]string{"eakID1"}) + assert.Equals(t, nu, newB) + return newB, true, nil + case 3: + oldB, _ := json.Marshal([]string{"eakID1"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"eakID1", "eakID2"}) + assert.Equals(t, nu, newB) + return newB, true, nil + case 4: + oldB, _ := json.Marshal([]string{"eakID1", "eakID2"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{"eakID2"}) + assert.Equals(t, nu, newB) + return newB, true, nil + case 5: + oldB, _ := json.Marshal([]string{"eakID2"}) + assert.Equals(t, old, oldB) + newB, _ := json.Marshal([]string{}) + assert.Equals(t, nu, newB) + return newB, true, nil + default: + assert.FatalError(t, errors.New("unexpected get iteration")) + return nil, true, errors.New("force save default") + } + }, + }, + err: nil, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + + // goal of this test is to simulate multiple calls; no errors expected. + + db := &DB{ + db: tc.db, + } + + err := db.addEAKID(tc.ctx, provID, "eakID") + if err != nil { + t.Errorf("DB.addEAKID() error = %v", err) + } + + callCounter++ + err = db.deleteEAKID(tc.ctx, provID, "eakID") + if err != nil { + t.Errorf("DB.deleteEAKID() error = %v", err) + } + + callCounter++ + err = db.addEAKID(tc.ctx, provID, "eakID1") + if err != nil { + t.Errorf("DB.addEAKID() error = %v", err) + } + + callCounter++ + err = db.addEAKID(tc.ctx, provID, "eakID2") + if err != nil { + t.Errorf("DB.addEAKID() error = %v", err) + } + + callCounter++ + err = db.deleteEAKID(tc.ctx, provID, "eakID1") + if err != nil { + t.Errorf("DB.deleteEAKID() error = %v", err) + } + + callCounter++ + err = db.deleteEAKID(tc.ctx, provID, "eakID2") + if err != nil { + t.Errorf("DB.deleteAKID() error = %v", err) + } + }) + } +} + +func Test_removeElement(t *testing.T) { + tests := []struct { + name string + slice []string + item string + want []string + }{ + { + name: "remove-first", + slice: []string{"id1", "id2", "id3"}, + item: "id1", + want: []string{"id2", "id3"}, + }, + { + name: "remove-last", + slice: []string{"id1", "id2", "id3"}, + item: "id3", + want: []string{"id1", "id2"}, + }, + { + name: "remove-middle", + slice: []string{"id1", "id2", "id3"}, + item: "id2", + want: []string{"id1", "id3"}, + }, + { + name: "remove-non-existing", + slice: []string{"id1", "id2", "id3"}, + item: "none", + want: []string{"id1", "id2", "id3"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := removeElement(tt.slice, tt.item) + if !cmp.Equal(tt.want, got) { + t.Errorf("removeElement() diff =\n %s", cmp.Diff(tt.want, got)) + } + }) + } +} diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index 34932361..98f6a04d 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -11,15 +11,18 @@ import ( ) var ( - accountTable = []byte("acme_accounts") - accountByKeyIDTable = []byte("acme_keyID_accountID_index") - authzTable = []byte("acme_authzs") - challengeTable = []byte("acme_challenges") - nonceTable = []byte("nonces") - orderTable = []byte("acme_orders") - ordersByAccountIDTable = []byte("acme_account_orders_index") - certTable = []byte("acme_certs") - certBySerialTable = []byte("acme_serial_certs_index") + accountTable = []byte("acme_accounts") + accountByKeyIDTable = []byte("acme_keyID_accountID_index") + authzTable = []byte("acme_authzs") + challengeTable = []byte("acme_challenges") + nonceTable = []byte("nonces") + orderTable = []byte("acme_orders") + ordersByAccountIDTable = []byte("acme_account_orders_index") + certTable = []byte("acme_certs") + certBySerialTable = []byte("acme_serial_certs_index") + externalAccountKeyTable = []byte("acme_external_account_keys") + externalAccountKeyIDsByReferenceTable = []byte("acme_external_account_keyID_reference_index") + externalAccountKeyIDsByProvisionerIDTable = []byte("acme_external_account_keyID_provisionerID_index") ) // DB is a struct that implements the AcmeDB interface. @@ -30,7 +33,10 @@ type DB struct { // New configures and returns a new ACME DB backend implemented using a nosql DB. func New(db nosqlDB.DB) (*DB, error) { tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, - challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable, certBySerialTable} + challengeTable, nonceTable, orderTable, ordersByAccountIDTable, + certTable, certBySerialTable, externalAccountKeyTable, + externalAccountKeyIDsByReferenceTable, externalAccountKeyIDsByProvisionerIDTable, + } for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", diff --git a/api/api_test.go b/api/api_test.go index 5cbce8b3..c7528f9b 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -167,6 +167,208 @@ func parseCertificateRequest(data string) *x509.CertificateRequest { return csr } +type mockAuthority struct { + ret1, ret2 interface{} + err error + authorizeSign func(ott string) ([]provisioner.SignOption, error) + getTLSOptions func() *authority.TLSOptions + root func(shasum string) (*x509.Certificate, error) + sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) + loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) + loadProvisionerByName func(name string) (provisioner.Interface, error) + getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + revoke func(context.Context, *authority.RevokeOptions) error + getEncryptedKey func(kid string) (string, error) + getRoots func() ([]*x509.Certificate, error) + getFederation func() ([]*x509.Certificate, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) + checkSSHHost func(ctx context.Context, principal, token string) (bool, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) + version func() authority.Version +} + +// TODO: remove once Authorize is deprecated. +func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { + return m.AuthorizeSign(ott) +} + +func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { + if m.authorizeSign != nil { + return m.authorizeSign(ott) + } + return m.ret1.([]provisioner.SignOption), m.err +} + +func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { + if m.getTLSOptions != nil { + return m.getTLSOptions() + } + return m.ret1.(*authority.TLSOptions) +} + +func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { + if m.root != nil { + return m.root(shasum) + } + return m.ret1.(*x509.Certificate), m.err +} + +func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.sign != nil { + return m.sign(cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { + if m.renew != nil { + return m.renew(cert) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.rekey != nil { + return m.rekey(oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + +func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.getProvisioners != nil { + return m.getProvisioners(nextCursor, limit) + } + return m.ret1.(provisioner.List), m.ret2.(string), m.err +} + +func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { + if m.loadProvisionerByCertificate != nil { + return m.loadProvisionerByCertificate(cert) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.loadProvisionerByName != nil { + return m.loadProvisionerByName(name) + } + return m.ret1.(provisioner.Interface), m.err +} + +func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { + if m.revoke != nil { + return m.revoke(ctx, opts) + } + return m.err +} + +func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { + if m.getEncryptedKey != nil { + return m.getEncryptedKey(kid) + } + return m.ret1.(string), m.err +} + +func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { + if m.getRoots != nil { + return m.getRoots() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { + if m.getFederation != nil { + return m.getFederation() + } + return m.ret1.([]*x509.Certificate), m.err +} + +func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.renewSSH != nil { + return m.renewSSH(ctx, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.rekeySSH != nil { + return m.rekeySSH(ctx, cert, key, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { + if m.getSSHHosts != nil { + return m.getSSHHosts(ctx, cert) + } + return m.ret1.([]authority.Host), m.err +} + +func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHRoots != nil { + return m.getSSHRoots(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { + if m.getSSHFederation != nil { + return m.getSSHFederation(ctx) + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { + if m.getSSHConfig != nil { + return m.getSSHConfig(ctx, typ, data) + } + return m.ret1.([]templates.Output), m.err +} + +func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { + if m.checkSSHHost != nil { + return m.checkSSHHost(ctx, principal, token) + } + return m.ret1.(bool), m.err +} + +func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { + if m.getSSHBastion != nil { + return m.getSSHBastion(ctx, user, hostname) + } + return m.ret1.(*authority.Bastion), m.err +} + +func (m *mockAuthority) Version() authority.Version { + if m.version != nil { + return m.version() + } + return m.ret1.(authority.Version) +} + func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { @@ -551,208 +753,6 @@ func (m *mockProvisioner) AuthorizeSSHRekey(ctx context.Context, token string) ( return m.ret1.(*ssh.Certificate), m.ret2.([]provisioner.SignOption), m.err } -type mockAuthority struct { - ret1, ret2 interface{} - err error - authorizeSign func(ott string) ([]provisioner.SignOption, error) - getTLSOptions func() *authority.TLSOptions - root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - renew func(cert *x509.Certificate) ([]*x509.Certificate, error) - rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) - loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) - loadProvisionerByName func(name string) (provisioner.Interface, error) - getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) - revoke func(context.Context, *authority.RevokeOptions) error - getEncryptedKey func(kid string) (string, error) - getRoots func() ([]*x509.Certificate, error) - getFederation func() ([]*x509.Certificate, error) - signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) - getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) - getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) - getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) - checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) - version func() authority.Version -} - -// TODO: remove once Authorize is deprecated. -func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { - return m.AuthorizeSign(ott) -} - -func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) { - if m.authorizeSign != nil { - return m.authorizeSign(ott) - } - return m.ret1.([]provisioner.SignOption), m.err -} - -func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { - if m.getTLSOptions != nil { - return m.getTLSOptions() - } - return m.ret1.(*authority.TLSOptions) -} - -func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { - if m.root != nil { - return m.root(shasum) - } - return m.ret1.(*x509.Certificate), m.err -} - -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { - if m.renew != nil { - return m.renew(cert) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) Rekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { - if m.rekey != nil { - return m.rekey(oldcert, pk) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - -func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { - if m.getProvisioners != nil { - return m.getProvisioners(nextCursor, limit) - } - return m.ret1.(provisioner.List), m.ret2.(string), m.err -} - -func (m *mockAuthority) LoadProvisionerByCertificate(cert *x509.Certificate) (provisioner.Interface, error) { - if m.loadProvisionerByCertificate != nil { - return m.loadProvisionerByCertificate(cert) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { - if m.loadProvisionerByName != nil { - return m.loadProvisionerByName(name) - } - return m.ret1.(provisioner.Interface), m.err -} - -func (m *mockAuthority) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { - if m.revoke != nil { - return m.revoke(ctx, opts) - } - return m.err -} - -func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) { - if m.getEncryptedKey != nil { - return m.getEncryptedKey(kid) - } - return m.ret1.(string), m.err -} - -func (m *mockAuthority) GetRoots() ([]*x509.Certificate, error) { - if m.getRoots != nil { - return m.getRoots() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { - if m.getFederation != nil { - return m.getFederation() - } - return m.ret1.([]*x509.Certificate), m.err -} - -func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(ctx, key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(ctx, key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.renewSSH != nil { - return m.renewSSH(ctx, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.rekeySSH != nil { - return m.rekeySSH(ctx, cert, key, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]authority.Host, error) { - if m.getSSHHosts != nil { - return m.getSSHHosts(ctx, cert) - } - return m.ret1.([]authority.Host), m.err -} - -func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHRoots != nil { - return m.getSSHRoots(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { - if m.getSSHFederation != nil { - return m.getSSHFederation(ctx) - } - return m.ret1.(*authority.SSHKeys), m.err -} - -func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { - if m.getSSHConfig != nil { - return m.getSSHConfig(ctx, typ, data) - } - return m.ret1.([]templates.Output), m.err -} - -func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) { - if m.checkSSHHost != nil { - return m.checkSSHHost(ctx, principal, token) - } - return m.ret1.(bool), m.err -} - -func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) { - if m.getSSHBastion != nil { - return m.getSSHBastion(ctx, user, hostname) - } - return m.ret1.(*authority.Bastion), m.err -} - -func (m *mockAuthority) Version() authority.Version { - if m.version != nil { - return m.version() - } - return m.ret1.(authority.Version) -} - func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/authority/admin/api/acme.go b/authority/admin/api/acme.go new file mode 100644 index 00000000..2cd75900 --- /dev/null +++ b/authority/admin/api/acme.go @@ -0,0 +1,246 @@ +package api + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/go-chi/chi" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + // provisionerContextKey provisioner key + provisionerContextKey = ContextKey("provisioner") +) + +// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests +type CreateExternalAccountKeyRequest struct { + Reference string `json:"reference"` +} + +// Validate validates a new ACME EAB Key request body. +func (r *CreateExternalAccountKeyRequest) Validate() error { + if len(r.Reference) > 256 { // an arbitrary, but sensible (IMO), limit + return fmt.Errorf("reference length %d exceeds the maximum (256)", len(r.Reference)) + } + return nil +} + +// GetExternalAccountKeysResponse is the type for GET /admin/acme/eab responses +type GetExternalAccountKeysResponse struct { + EAKs []*linkedca.EABKey `json:"eaks"` + NextCursor string `json:"nextCursor"` +} + +// requireEABEnabled is a middleware that ensures ACME EAB is enabled +// before serving requests that act on ACME EAB credentials. +func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + provName := chi.URLParam(r, "provisionerName") + eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName) + if err != nil { + api.WriteError(w, err) + return + } + if !eabEnabled { + api.WriteError(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) + return + } + ctx = context.WithValue(ctx, provisionerContextKey, prov) + next(w, r.WithContext(ctx)) + } +} + +// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME +// provisioner is set to true and thus has EAB enabled. +func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) { + var ( + p provisioner.Interface + err error + ) + if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil { + return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName) + } + + prov, err := h.db.GetProvisioner(ctx, p.GetID()) + if err != nil { + return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID()) + } + + details := prov.GetDetails() + if details == nil { + return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID()) + } + + acmeProvisioner := details.GetACME() + if acmeProvisioner == nil { + return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID()) + } + + return acmeProvisioner.GetRequireEab(), prov, nil +} + +// provisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func provisionerFromContext(ctx context.Context) (*linkedca.Provisioner, error) { + val := ctx.Value(provisionerContextKey) + if val == nil { + return nil, admin.NewErrorISE("provisioner expected in request context") + } + pval, ok := val.(*linkedca.Provisioner) + if !ok || pval == nil { + return nil, admin.NewErrorISE("provisioner in context is not a linkedca.Provisioner") + } + return pval, nil +} + +// CreateExternalAccountKey creates a new External Account Binding key +func (h *Handler) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { + var body CreateExternalAccountKeyRequest + if err := api.ReadJSON(r.Body, &body); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) + return + } + + if err := body.Validate(); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating request body")) + return + } + + ctx := r.Context() + prov, err := provisionerFromContext(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context")) + return + } + + // check if a key with the reference does not exist (only when a reference was in the request) + reference := body.Reference + if reference != "" { + k, err := h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference) + // retrieving an EAB key from DB results in an error if it doesn't exist, which is what we're looking for, + // but other errors can also happen. Return early if that happens; continuing if it was acme.ErrNotFound. + if shouldWriteError := err != nil && !errors.Is(err, acme.ErrNotFound); shouldWriteError { + api.WriteError(w, admin.WrapErrorISE(err, "could not lookup external account key by reference")) + return + } + // if a key was found, return HTTP 409 conflict + if k != nil { + err := admin.NewError(admin.ErrorBadRequestType, "an ACME EAB key for provisioner '%s' with reference '%s' already exists", prov.GetName(), reference) + err.Status = 409 + api.WriteError(w, err) + return + } + // continue execution if no key was found for the reference + } + + eak, err := h.acmeDB.CreateExternalAccountKey(ctx, prov.GetId(), reference) + if err != nil { + msg := fmt.Sprintf("error creating ACME EAB key for provisioner '%s'", prov.GetName()) + if reference != "" { + msg += fmt.Sprintf(" and reference '%s'", reference) + } + api.WriteError(w, admin.WrapErrorISE(err, msg)) + return + } + + response := &linkedca.EABKey{ + Id: eak.ID, + HmacKey: eak.KeyBytes, + Provisioner: prov.GetName(), + Reference: eak.Reference, + } + + api.ProtoJSONStatus(w, response, http.StatusCreated) +} + +// DeleteExternalAccountKey deletes an ACME External Account Key. +func (h *Handler) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { + + keyID := chi.URLParam(r, "id") + + ctx := r.Context() + prov, err := provisionerFromContext(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context")) + return + } + + if err := h.acmeDB.DeleteExternalAccountKey(ctx, prov.GetId(), keyID); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error deleting ACME EAB Key '%s'", keyID)) + return + } + + api.JSON(w, &DeleteResponse{Status: "ok"}) +} + +// GetExternalAccountKeys returns ACME EAB Keys. If a reference is specified, +// only the ExternalAccountKey with that reference is returned. Otherwise all +// ExternalAccountKeys in the system for a specific provisioner are returned. +func (h *Handler) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { + + var ( + key *acme.ExternalAccountKey + keys []*acme.ExternalAccountKey + err error + cursor string + nextCursor string + limit int + ) + + ctx := r.Context() + prov, err := provisionerFromContext(ctx) + if err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error getting provisioner from context")) + return + } + + if cursor, limit, err = api.ParseCursor(r); err != nil { + api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, + "error parsing cursor and limit from query params")) + return + } + + reference := chi.URLParam(r, "reference") + if reference != "" { + if key, err = h.acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving external account key with reference '%s'", reference)) + return + } + if key != nil { + keys = []*acme.ExternalAccountKey{key} + } + } else { + if keys, nextCursor, err = h.acmeDB.GetExternalAccountKeys(ctx, prov.GetId(), cursor, limit); err != nil { + api.WriteError(w, admin.WrapErrorISE(err, "error retrieving external account keys")) + return + } + } + + provisionerName := prov.GetName() + eaks := make([]*linkedca.EABKey, len(keys)) + for i, k := range keys { + eaks[i] = &linkedca.EABKey{ + Id: k.ID, + HmacKey: []byte{}, + Provisioner: provisionerName, + Reference: k.Reference, + Account: k.AccountID, + CreatedAt: timestamppb.New(k.CreatedAt), + BoundAt: timestamppb.New(k.BoundAt), + } + } + + api.JSON(w, &GetExternalAccountKeysResponse{ + EAKs: eaks, + NextCursor: nextCursor, + }) +} diff --git a/authority/admin/api/acme_test.go b/authority/admin/api/acme_test.go new file mode 100644 index 00000000..50086955 --- /dev/null +++ b/authority/admin/api/acme_test.go @@ -0,0 +1,1222 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func readProtoJSON(r io.ReadCloser, m proto.Message) error { + defer r.Close() + data, err := io.ReadAll(r) + if err != nil { + return err + } + return protojson.Unmarshal(data, m) +} + +func TestHandler_requireEABEnabled(t *testing.T) { + type test struct { + ctx context.Context + db admin.DB + auth adminAuthority + next nextHTTP + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/h.provisionerHasEABEnabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + err := admin.NewErrorISE("error loading provisioner provName: force") + err.Message = "error loading provisioner provName: force" + return test{ + ctx: ctx, + auth: auth, + err: err, + statusCode: 500, + } + }, + "ok/eab-disabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") + err.Message = "ACME EAB not enabled for provisioner provName" + return test{ + ctx: ctx, + auth: auth, + db: db, + err: err, + statusCode: 400, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + ctx: ctx, + auth: auth, + db: db, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + }, + statusCode: 200, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.requireEABEnabled(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} + +func TestHandler_provisionerHasEABEnabled(t *testing.T) { + type test struct { + db admin.DB + auth adminAuthority + provisionerName string + want bool + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/prov.GetDetails": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: nil, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "fail/details.GetACME": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: nil, + }, + }, + }, nil + }, + } + return test{ + auth: auth, + db: db, + provisionerName: "provName", + want: false, + err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"), + } + }, + "ok/eab-disabled": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-disabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-disabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: false, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-disabled", + want: false, + } + }, + "ok/eab-enabled": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "eab-enabled", name) + return &provisioner.MockProvisioner{ + MgetID: func() string { + return "provID" + }, + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "eab-enabled", + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_ACME{ + ACME: &linkedca.ACMEProvisioner{ + RequireEab: true, + }, + }, + }, + }, nil + }, + } + return test{ + db: db, + auth: auth, + provisionerName: "eab-enabled", + want: true, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + db: tc.db, + auth: tc.auth, + acmeDB: nil, + } + got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName) + if (err != nil) != (tc.err != nil) { + t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err) + return + } + if tc.err != nil { + assert.Type(t, &linkedca.Provisioner{}, prov) + assert.Type(t, &admin.Error{}, err) + adminError, _ := err.(*admin.Error) + assert.Equals(t, tc.err.Type, adminError.Type) + assert.Equals(t, tc.err.Status, adminError.Status) + assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode()) + assert.Equals(t, tc.err.Message, adminError.Message) + assert.Equals(t, tc.err.Detail, adminError.Detail) + return + } + if got != tc.want { + t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want) + } + }) + } +} + +func Test_provisionerFromContext(t *testing.T) { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "acmeProv", + } + tests := []struct { + name string + ctx context.Context + want *linkedca.Provisioner + wantErr bool + }{ + { + name: "fail/no-provisioner", + ctx: context.Background(), + want: nil, + wantErr: true, + }, + { + name: "fail/wrong-type", + ctx: context.WithValue(context.Background(), provisionerContextKey, "prov"), + want: nil, + wantErr: true, + }, + { + name: "ok", + ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + want: &linkedca.Provisioner{ + Id: "provID", + Name: "acmeProv", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := provisionerFromContext(tt.ctx) + if (err != nil) != tt.wantErr { + t.Errorf("provisionerFromContext() error = %v, wantErr %v", err, tt.wantErr) + return + } + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{})} + if !cmp.Equal(tt.want, got, opts...) { + t.Errorf("provisionerFromContext() diff =\n %s", cmp.Diff(tt.want, got, opts...)) + } + }) + } +} + +func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { + type fields struct { + Reference string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "fail/reference-too-long", + fields: fields{ + Reference: strings.Repeat("A", 257), + }, + wantErr: true, + }, + { + name: "ok/empty-reference", + fields: fields{ + Reference: "", + }, + wantErr: false, + }, + { + name: "ok", + fields: fields{ + Reference: "my-eab-reference", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &CreateExternalAccountKeyRequest{ + Reference: tt.fields.Reference, + } + if err := r.Validate(); (err != nil) != tt.wantErr { + t.Errorf("CreateExternalAccountKeyRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestHandler_CreateExternalAccountKey(t *testing.T) { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + type test struct { + ctx context.Context + db acme.DB + body []byte + statusCode int + eak *linkedca.EABKey + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/ReadJSON": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + body := []byte("{!?}") + return test{ + ctx: ctx, + body: body, + statusCode: 400, + eak: nil, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string", + }, + } + }, + "fail/validate": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + req := CreateExternalAccountKeyRequest{ + Reference: strings.Repeat("A", 257), + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + return test{ + ctx: ctx, + body: body, + statusCode: 400, + eak: nil, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "error validating request body: reference length 257 exceeds the maximum (256)", + }, + } + }, + "fail/no-provisioner-in-context": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + req := CreateExternalAccountKeyRequest{ + Reference: "aRef", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + return test{ + ctx: ctx, + body: body, + statusCode: 500, + eak: nil, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error getting provisioner from context: provisioner expected in request context", + }, + } + }, + "fail/acmeDB.GetExternalAccountKeyByReference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "an-external-key-reference", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 500, + eak: nil, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "could not lookup external account key by reference: force", + }, + } + }, + "fail/reference-conflict-409": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "an-external-key-reference", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + past := time.Now().Add(-24 * time.Hour) + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "an-external-key-reference", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: past, + }, nil + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 409, + eak: nil, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 409, + Detail: "bad request", + Message: "an ACME EAB key for provisioner 'provName' with reference 'an-external-key-reference' already exists", + }, + } + }, + "fail/acmeDB.CreateExternalAccountKey-no-reference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + db := &acme.MockDB{ + MockCreateExternalAccountKey: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "", reference) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error creating ACME EAB key for provisioner 'provName': force", + }, + } + }, + "fail/acmeDB.CreateExternalAccountKey-with-reference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "an-external-key-reference", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, acme.ErrNotFound // simulating not found; skipping 409 conflict + }, + MockCreateExternalAccountKey: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error creating ACME EAB key for provisioner 'provName' and reference 'an-external-key-reference': force", + }, + } + }, + "ok/no-reference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + now := time.Now() + db := &acme.MockDB{ + MockCreateExternalAccountKey: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "", reference) + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, nil + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 201, + eak: &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provName", + Reference: "", + HmacKey: []byte{1, 3, 3, 7}, + }, + } + }, + "ok/with-reference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + req := CreateExternalAccountKeyRequest{ + Reference: "an-external-key-reference", + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + now := time.Now() + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, acme.ErrNotFound // simulating not found; skipping 409 conflict + }, + MockCreateExternalAccountKey: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "an-external-key-reference", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: now, + }, nil + }, + } + return test{ + ctx: ctx, + db: db, + body: body, + statusCode: 201, + eak: &linkedca.EABKey{ + Id: "eakID", + Provisioner: "provName", + Reference: "an-external-key-reference", + HmacKey: []byte{1, 3, 3, 7}, + }, + } + }, + } + + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + acmeDB: tc.db, + } + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.CreateExternalAccountKey(w, req) + res := w.Result() + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + eabKey := &linkedca.EABKey{} + err := readProtoJSON(res.Body, eabKey) + assert.FatalError(t, err) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{})} + if !cmp.Equal(tc.eak, eabKey, opts...) { + t.Errorf("h.CreateExternalAccountKey diff =\n%s", cmp.Diff(tc.eak, eabKey, opts...)) + } + + }) + } +} + +func TestHandler_DeleteExternalAccountKey(t *testing.T) { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + type test struct { + ctx context.Context + db acme.DB + statusCode int + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner-in-context": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + return test{ + ctx: ctx, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error getting provisioner from context: provisioner expected in request context", + }, + } + }, + "fail/acmeDB.DeleteExternalAccountKey": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + chiCtx.URLParams.Add("id", "keyID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + db := &acme.MockDB{ + MockDeleteExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) error { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "keyID", keyID) + return errors.New("force") + }, + } + return test{ + ctx: ctx, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error deleting ACME EAB Key 'keyID': force", + }, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + chiCtx.URLParams.Add("id", "keyID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + db := &acme.MockDB{ + MockDeleteExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) error { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "keyID", keyID) + return nil + }, + } + return test{ + ctx: ctx, + db: db, + statusCode: 200, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + acmeDB: tc.db, + } + req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.DeleteExternalAccountKey(w, req) + res := w.Result() + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := DeleteResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equals(t, "ok", response.Status) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestHandler_GetExternalAccountKeys(t *testing.T) { + prov := &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + } + type test struct { + ctx context.Context + db acme.DB + statusCode int + req *http.Request + resp GetExternalAccountKeysResponse + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/no-provisioner-in-context": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + req := httptest.NewRequest("GET", "/foo", nil) + return test{ + ctx: ctx, + statusCode: 500, + req: req, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error getting provisioner from context: provisioner expected in request context", + }, + } + }, + "fail/parse-cursor": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + req := httptest.NewRequest("GET", "/foo?limit=A", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 400, + req: req, + err: &admin.Error{ + Status: 400, + Type: admin.ErrorBadRequestType.String(), + Detail: "bad request", + Message: "error parsing cursor and limit from query params: limit 'A' is not an integer: strconv.Atoi: parsing \"A\": invalid syntax", + }, + } + }, + "fail/acmeDB.GetExternalAccountKeyByReference": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + chiCtx.URLParams.Add("reference", "an-external-key-reference") + req := httptest.NewRequest("GET", "/foo", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + statusCode: 500, + req: req, + db: db, + err: &admin.Error{ + Status: 500, + Type: admin.ErrorServerInternalType.String(), + Detail: "the server experienced an internal error", + Message: "error retrieving external account key with reference 'an-external-key-reference': force", + }, + } + }, + "fail/acmeDB.GetExternalAccountKeys": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + req := httptest.NewRequest("GET", "/foo", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + db := &acme.MockDB{ + MockGetExternalAccountKeys: func(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return nil, "", errors.New("force") + }, + } + return test{ + ctx: ctx, + statusCode: 500, + req: req, + db: db, + err: &admin.Error{ + Status: 500, + Type: admin.ErrorServerInternalType.String(), + Detail: "the server experienced an internal error", + Message: "error retrieving external account keys: force", + }, + } + }, + "ok/reference-not-found": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + chiCtx.URLParams.Add("reference", "an-external-key-reference") + req := httptest.NewRequest("GET", "/foo", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return nil, nil // returning nil; no key found + }, + } + return test{ + ctx: ctx, + statusCode: 200, + req: req, + resp: GetExternalAccountKeysResponse{ + EAKs: []*linkedca.EABKey{}, + }, + db: db, + err: nil, + } + }, + "ok/reference-found": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + chiCtx.URLParams.Add("reference", "an-external-key-reference") + req := httptest.NewRequest("GET", "/foo", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + createdAt := time.Now().Add(-24 * time.Hour) + var boundAt time.Time + db := &acme.MockDB{ + MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "an-external-key-reference", reference) + return &acme.ExternalAccountKey{ + ID: "eakID", + ProvisionerID: "provID", + Reference: "an-external-key-reference", + CreatedAt: createdAt, + }, nil + }, + } + return test{ + ctx: ctx, + statusCode: 200, + req: req, + resp: GetExternalAccountKeysResponse{ + EAKs: []*linkedca.EABKey{ + { + Id: "eakID", + Provisioner: "provName", + Reference: "an-external-key-reference", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + }, + }, + }, + db: db, + err: nil, + } + }, + "ok/multiple-keys": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("provisionerName", "provName") + req := httptest.NewRequest("GET", "/foo", nil) + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + createdAt := time.Now().Add(-24 * time.Hour) + var boundAt time.Time + boundAtSet := time.Now().Add(-12 * time.Hour) + db := &acme.MockDB{ + MockGetExternalAccountKeys: func(ctx context.Context, provisionerID, cursor string, limit int) ([]*acme.ExternalAccountKey, string, error) { + assert.Equals(t, "provID", provisionerID) + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return []*acme.ExternalAccountKey{ + { + ID: "eakID1", + ProvisionerID: "provID", + Reference: "some-external-key-reference", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt, + }, + { + ID: "eakID2", + ProvisionerID: "provID", + Reference: "some-other-external-key-reference", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt.Add(1 * time.Hour), + }, + { + ID: "eakID3", + ProvisionerID: "provID", + Reference: "another-external-key-reference", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt, + BoundAt: boundAtSet, + AccountID: "accountID", + }, + }, "", nil + }, + } + return test{ + ctx: ctx, + statusCode: 200, + req: req, + resp: GetExternalAccountKeysResponse{ + EAKs: []*linkedca.EABKey{ + { + Id: "eakID1", + Provisioner: "provName", + Reference: "some-external-key-reference", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAt), + }, + { + Id: "eakID2", + Provisioner: "provName", + Reference: "some-other-external-key-reference", + CreatedAt: timestamppb.New(createdAt.Add(1 * time.Hour)), + BoundAt: timestamppb.New(boundAt), + }, + { + Id: "eakID3", + Provisioner: "provName", + Reference: "another-external-key-reference", + CreatedAt: timestamppb.New(createdAt), + BoundAt: timestamppb.New(boundAtSet), + Account: "accountID", + }, + }, + }, + db: db, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + acmeDB: tc.db, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetExternalAccountKeys(w, req) + res := w.Result() + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := GetExternalAccountKeysResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.EABKey{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.resp, response, opts...) { + t.Errorf("h.GetExternalAccountKeys diff =\n%s", cmp.Diff(tc.resp, response, opts...)) + } + }) + } +} diff --git a/authority/admin/api/admin.go b/authority/admin/api/admin.go index bf79ebcf..7aa66d0f 100644 --- a/authority/admin/api/admin.go +++ b/authority/admin/api/admin.go @@ -1,14 +1,32 @@ package api import ( + "context" "net/http" "github.com/go-chi/chi" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/linkedca" ) +type adminAuthority interface { + LoadProvisionerByName(string) (provisioner.Interface, error) + GetProvisioners(cursor string, limit int) (provisioner.List, string, error) + IsAdminAPIEnabled() bool + LoadAdminByID(id string) (*linkedca.Admin, bool) + GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) + StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + RemoveAdmin(ctx context.Context, id string) error + AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) + StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error + LoadProvisionerByID(id string) (provisioner.Interface, error) + UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error + RemoveProvisioner(ctx context.Context, id string) error +} + // CreateAdminRequest represents the body for a CreateAdmin request. type CreateAdminRequest struct { Subject string `json:"subject"` diff --git a/authority/admin/api/admin_test.go b/authority/admin/api/admin_test.go new file mode 100644 index 00000000..8d223b52 --- /dev/null +++ b/authority/admin/api/admin_test.go @@ -0,0 +1,919 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type mockAdminAuthority struct { + MockLoadProvisionerByName func(name string) (provisioner.Interface, error) + MockGetProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) + MockRet1, MockRet2 interface{} // TODO: refactor the ret1/ret2 into those two + MockErr error + MockIsAdminAPIEnabled func() bool + MockLoadAdminByID func(id string) (*linkedca.Admin, bool) + MockGetAdmins func(cursor string, limit int) ([]*linkedca.Admin, string, error) + MockStoreAdmin func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error + MockUpdateAdmin func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) + MockRemoveAdmin func(ctx context.Context, id string) error + MockAuthorizeAdminToken func(r *http.Request, token string) (*linkedca.Admin, error) + MockStoreProvisioner func(ctx context.Context, prov *linkedca.Provisioner) error + MockLoadProvisionerByID func(id string) (provisioner.Interface, error) + MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error + MockRemoveProvisioner func(ctx context.Context, id string) error +} + +func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { + if m.MockIsAdminAPIEnabled != nil { + return m.MockIsAdminAPIEnabled() + } + return m.MockRet1.(bool) +} + +func (m *mockAdminAuthority) LoadProvisionerByName(name string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByName != nil { + return m.MockLoadProvisionerByName(name) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *mockAdminAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { + if m.MockGetProvisioners != nil { + return m.MockGetProvisioners(nextCursor, limit) + } + return m.MockRet1.(provisioner.List), m.MockRet2.(string), m.MockErr +} + +func (m *mockAdminAuthority) LoadAdminByID(id string) (*linkedca.Admin, bool) { + if m.MockLoadAdminByID != nil { + return m.MockLoadAdminByID(id) + } + return m.MockRet1.(*linkedca.Admin), m.MockRet2.(bool) +} + +func (m *mockAdminAuthority) GetAdmins(cursor string, limit int) ([]*linkedca.Admin, string, error) { + if m.MockGetAdmins != nil { + return m.MockGetAdmins(cursor, limit) + } + return m.MockRet1.([]*linkedca.Admin), m.MockRet2.(string), m.MockErr +} + +func (m *mockAdminAuthority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + if m.MockStoreAdmin != nil { + return m.MockStoreAdmin(ctx, adm, prov) + } + return m.MockErr +} + +func (m *mockAdminAuthority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + if m.MockUpdateAdmin != nil { + return m.MockUpdateAdmin(ctx, id, nu) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *mockAdminAuthority) RemoveAdmin(ctx context.Context, id string) error { + if m.MockRemoveAdmin != nil { + return m.MockRemoveAdmin(ctx, id) + } + return m.MockErr +} + +func (m *mockAdminAuthority) AuthorizeAdminToken(r *http.Request, token string) (*linkedca.Admin, error) { + if m.MockAuthorizeAdminToken != nil { + return m.MockAuthorizeAdminToken(r, token) + } + return m.MockRet1.(*linkedca.Admin), m.MockErr +} + +func (m *mockAdminAuthority) StoreProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { + if m.MockStoreProvisioner != nil { + return m.MockStoreProvisioner(ctx, prov) + } + return m.MockErr +} + +func (m *mockAdminAuthority) LoadProvisionerByID(id string) (provisioner.Interface, error) { + if m.MockLoadProvisionerByID != nil { + return m.MockLoadProvisionerByID(id) + } + return m.MockRet1.(provisioner.Interface), m.MockErr +} + +func (m *mockAdminAuthority) UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error { + if m.MockUpdateProvisioner != nil { + return m.MockUpdateProvisioner(ctx, nu) + } + return m.MockErr +} + +func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) error { + if m.MockRemoveProvisioner != nil { + return m.MockRemoveProvisioner(ctx, id) + } + return m.MockErr +} + +func TestCreateAdminRequest_Validate(t *testing.T) { + type fields struct { + Subject string + Provisioner string + Type linkedca.Admin_Type + } + tests := []struct { + name string + fields fields + err *admin.Error + }{ + { + name: "fail/subject-empty", + fields: fields{ + Subject: "", + Provisioner: "", + Type: 0, + }, + err: admin.NewError(admin.ErrorBadRequestType, "subject cannot be empty"), + }, + { + name: "fail/provisioner-empty", + fields: fields{ + Subject: "admin", + Provisioner: "", + Type: 0, + }, + err: admin.NewError(admin.ErrorBadRequestType, "provisioner cannot be empty"), + }, + { + name: "fail/invalid-type", + fields: fields{ + Subject: "admin", + Provisioner: "prov", + Type: -1, + }, + err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"), + }, + { + name: "ok", + fields: fields{ + Subject: "admin", + Provisioner: "prov", + Type: linkedca.Admin_SUPER_ADMIN, + }, + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + car := &CreateAdminRequest{ + Subject: tt.fields.Subject, + Provisioner: tt.fields.Provisioner, + Type: tt.fields.Type, + } + err := car.Validate() + + if (err != nil) != (tt.err != nil) { + t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil)) + return + } + + if err != nil { + assert.Type(t, &admin.Error{}, err) + adminErr, _ := err.(*admin.Error) + assert.Equals(t, tt.err.Type, adminErr.Type) + assert.Equals(t, tt.err.Detail, adminErr.Detail) + assert.Equals(t, tt.err.Status, adminErr.Status) + assert.Equals(t, tt.err.Message, adminErr.Message) + } + }) + } +} + +func TestUpdateAdminRequest_Validate(t *testing.T) { + type fields struct { + Type linkedca.Admin_Type + } + tests := []struct { + name string + fields fields + err *admin.Error + }{ + { + name: "fail/invalid-type", + fields: fields{ + Type: -1, + }, + err: admin.NewError(admin.ErrorBadRequestType, "invalid value for admin type"), + }, + { + name: "ok", + fields: fields{ + Type: linkedca.Admin_SUPER_ADMIN, + }, + err: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uar := &UpdateAdminRequest{ + Type: tt.fields.Type, + } + + err := uar.Validate() + + if (err != nil) != (tt.err != nil) { + t.Errorf("CreateAdminRequest.Validate() error = %v, wantErr %v", err, (tt.err != nil)) + return + } + + if err != nil { + assert.Type(t, &admin.Error{}, err) + adminErr, _ := err.(*admin.Error) + assert.Equals(t, tt.err.Type, adminErr.Type) + assert.Equals(t, tt.err.Detail, adminErr.Detail) + assert.Equals(t, tt.err.Status, adminErr.Status) + assert.Equals(t, tt.err.Message, adminErr.Message) + } + }) + } +} + +func TestHandler_GetAdmin(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + statusCode int + err *admin.Error + adm *linkedca.Admin + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadAdminByID-not-found": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { + assert.Equals(t, "adminID", id) + return nil, false + }, + } + return test{ + ctx: ctx, + auth: auth, + statusCode: 404, + err: &admin.Error{ + Type: admin.ErrorNotFoundType.String(), + Status: 404, + Detail: "resource not found", + Message: "admin adminID not found", + }, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + adm := &linkedca.Admin{ + Id: "adminID", + AuthorityId: "authorityID", + Subject: "admin", + ProvisionerId: "provID", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + auth := &mockAdminAuthority{ + MockLoadAdminByID: func(id string) (*linkedca.Admin, bool) { + assert.Equals(t, "adminID", id) + return adm, true + }, + } + return test{ + ctx: ctx, + auth: auth, + statusCode: 200, + err: nil, + adm: adm, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetAdmin(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + adm := &linkedca.Admin{} + err := readProtoJSON(res.Body, adm) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.adm, adm, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) + } + }) + } +} + +func TestHandler_GetAdmins(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + req *http.Request + statusCode int + err *admin.Error + resp GetAdminsResponse + } + var tests = map[string]func(t *testing.T) test{ + "fail/parse-cursor": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo?limit=A", nil) + return test{ + ctx: context.Background(), + req: req, + statusCode: 400, + err: &admin.Error{ + Status: 400, + Type: admin.ErrorBadRequestType.String(), + Detail: "bad request", + Message: "error parsing cursor and limit from query params: limit 'A' is not an integer: strconv.Atoi: parsing \"A\": invalid syntax", + }, + } + }, + "fail/auth.GetAdmins": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + auth := &mockAdminAuthority{ + MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return nil, "", errors.New("force") + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Status: 500, + Type: admin.ErrorServerInternalType.String(), + Detail: "the server experienced an internal error", + Message: "error retrieving paginated admins: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + createdAt := time.Now() + var deletedAt time.Time + adm1 := &linkedca.Admin{ + Id: "adminID1", + AuthorityId: "authorityID1", + Subject: "admin1", + ProvisionerId: "provID", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + adm2 := &linkedca.Admin{ + Id: "adminID2", + AuthorityId: "authorityID", + Subject: "admin2", + ProvisionerId: "provID", + Type: linkedca.Admin_ADMIN, + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + auth := &mockAdminAuthority{ + MockGetAdmins: func(cursor string, limit int) ([]*linkedca.Admin, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return []*linkedca.Admin{ + adm1, + adm2, + }, "nextCursorValue", nil + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 200, + err: nil, + resp: GetAdminsResponse{ + Admins: []*linkedca.Admin{ + adm1, + adm2, + }, + NextCursor: "nextCursorValue", + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetAdmins(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + response := GetAdminsResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.resp, response, opts...) { + t.Errorf("GetAdmins diff =\n%s", cmp.Diff(tc.resp, response, opts...)) + } + }) + } +} + +func TestHandler_CreateAdmin(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + body []byte + statusCode int + err *admin.Error + adm *linkedca.Admin + } + var tests = map[string]func(t *testing.T) test{ + "fail/ReadJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 400, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string", + }, + } + }, + "fail/validate": func(t *testing.T) test { + req := CreateAdminRequest{ + Subject: "", + Provisioner: "", + Type: -1, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + body: body, + statusCode: 400, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "subject cannot be empty", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + req := CreateAdminRequest{ + Subject: "admin", + Provisioner: "prov", + Type: linkedca.Admin_SUPER_ADMIN, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "prov", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner prov: force", + }, + } + }, + "fail/auth.StoreAdmin": func(t *testing.T) test { + req := CreateAdminRequest{ + Subject: "admin", + Provisioner: "prov", + Type: linkedca.Admin_SUPER_ADMIN, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "prov", name) + return &provisioner.ACME{ + ID: "provID", + Name: "prov", + }, nil + }, + MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + assert.Equals(t, "admin", adm.Subject) + assert.Equals(t, "provID", prov.GetID()) + return errors.New("force") + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error storing admin: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := CreateAdminRequest{ + Subject: "admin", + Provisioner: "prov", + Type: linkedca.Admin_SUPER_ADMIN, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "prov", name) + return &provisioner.ACME{ + ID: "provID", + Name: "prov", + }, nil + }, + MockStoreAdmin: func(ctx context.Context, adm *linkedca.Admin, prov provisioner.Interface) error { + assert.Equals(t, "admin", adm.Subject) + assert.Equals(t, "provID", prov.GetID()) + return nil + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 201, + err: nil, + adm: &linkedca.Admin{ + ProvisionerId: "provID", + Subject: "admin", + Type: linkedca.Admin_SUPER_ADMIN, + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.CreateAdmin(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + adm := &linkedca.Admin{} + err := readProtoJSON(res.Body, adm) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.adm, adm, opts...) { + t.Errorf("h.CreateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) + } + }) + } +} + +func TestHandler_DeleteAdmin(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + statusCode int + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.RemoveAdmin": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockRemoveAdmin: func(ctx context.Context, id string) error { + assert.Equals(t, "adminID", id) + return errors.New("force") + }, + } + return test{ + ctx: ctx, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error deleting admin adminID: force", + }, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockRemoveAdmin: func(ctx context.Context, id string) error { + assert.Equals(t, "adminID", id) + return nil + }, + } + return test{ + ctx: ctx, + auth: auth, + statusCode: 200, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.DeleteAdmin(w, req) + res := w.Result() + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := DeleteResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equals(t, "ok", response.Status) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + }) + } +} + +func TestHandler_UpdateAdmin(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + body []byte + statusCode int + err *admin.Error + adm *linkedca.Admin + } + var tests = map[string]func(t *testing.T) test{ + "fail/ReadJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 400, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "error reading request body: error decoding json: invalid character '!' looking for beginning of object key string", + }, + } + }, + "fail/validate": func(t *testing.T) test { + req := UpdateAdminRequest{ + Type: -1, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + body: body, + statusCode: 400, + err: &admin.Error{ + Type: admin.ErrorBadRequestType.String(), + Status: 400, + Detail: "bad request", + Message: "invalid value for admin type", + }, + } + }, + "fail/auth.UpdateAdmin": func(t *testing.T) test { + req := UpdateAdminRequest{ + Type: linkedca.Admin_ADMIN, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + assert.Equals(t, "adminID", id) + assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error updating admin adminID: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := UpdateAdminRequest{ + Type: linkedca.Admin_ADMIN, + } + body, err := json.Marshal(req) + assert.FatalError(t, err) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "adminID") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + adm := &linkedca.Admin{ + Id: "adminID", + ProvisionerId: "provID", + Subject: "admin", + Type: linkedca.Admin_SUPER_ADMIN, + } + auth := &mockAdminAuthority{ + MockUpdateAdmin: func(ctx context.Context, id string, nu *linkedca.Admin) (*linkedca.Admin, error) { + assert.Equals(t, "adminID", id) + assert.Equals(t, linkedca.Admin_ADMIN, nu.Type) + return adm, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + statusCode: 200, + err: nil, + adm: adm, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.UpdateAdmin(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + adm := &linkedca.Admin{} + err := readProtoJSON(res.Body, adm) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.adm, adm, opts...) { + t.Errorf("h.UpdateAdmin diff =\n%s", cmp.Diff(tc.adm, adm, opts...)) + } + }) + } +} diff --git a/authority/admin/api/handler.go b/authority/admin/api/handler.go index d88edfa1..51751057 100644 --- a/authority/admin/api/handler.go +++ b/authority/admin/api/handler.go @@ -1,22 +1,25 @@ package api import ( + "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/admin" ) -// Handler is the ACME API request handler. +// Handler is the Admin API request handler. type Handler struct { - db admin.DB - auth *authority.Authority + db admin.DB + auth adminAuthority + acmeDB acme.DB } // NewHandler returns a new Authority Config Handler. -func NewHandler(auth *authority.Authority) api.RouterHandler { - h := &Handler{db: auth.GetAdminDatabase(), auth: auth} - - return h +func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB) api.RouterHandler { + return &Handler{ + db: adminDB, + auth: auth, + acmeDB: acmeDB, + } } // Route traffic and implement the Router interface. @@ -25,6 +28,10 @@ func (h *Handler) Route(r api.Router) { return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) } + requireEABEnabled := func(next nextHTTP) nextHTTP { + return h.requireEABEnabled(next) + } + // Provisioners r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) @@ -38,4 +45,10 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) + + // ACME External Account Binding Keys + r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.GetExternalAccountKeys))) + r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.GetExternalAccountKeys))) + r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.CreateExternalAccountKey))) + r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.DeleteExternalAccountKey))) } diff --git a/authority/admin/api/middleware_test.go b/authority/admin/api/middleware_test.go new file mode 100644 index 00000000..7fb4671a --- /dev/null +++ b/authority/admin/api/middleware_test.go @@ -0,0 +1,225 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "go.step.sm/linkedca" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestHandler_requireAPIEnabled(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + next nextHTTP + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.IsAdminAPIEnabled": func(t *testing.T) test { + return test{ + ctx: context.Background(), + auth: &mockAdminAuthority{ + MockIsAdminAPIEnabled: func() bool { + return false + }, + }, + err: &admin.Error{ + Type: admin.ErrorNotImplementedType.String(), + Status: 501, + Detail: "not implemented", + Message: "administration API not enabled", + }, + statusCode: 501, + } + }, + "ok": func(t *testing.T) test { + auth := &mockAdminAuthority{ + MockIsAdminAPIEnabled: func() bool { + return true + }, + } + next := func(w http.ResponseWriter, r *http.Request) { + w.Write(nil) // mock response with status 200 + } + return test{ + ctx: context.Background(), + auth: auth, + next: next, + statusCode: 200, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.requireAPIEnabled(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + // nothing to test when the requireAPIEnabled middleware succeeds, currently + + }) + } +} + +func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + req *http.Request + next nextHTTP + err *admin.Error + statusCode int + } + var tests = map[string]func(t *testing.T) test{ + "fail/missing-authorization-token": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + req.Header["Authorization"] = []string{""} + return test{ + ctx: context.Background(), + req: req, + statusCode: 401, + err: &admin.Error{ + Type: admin.ErrorUnauthorizedType.String(), + Status: 401, + Detail: "unauthorized", + Message: "missing authorization header token", + }, + } + }, + "fail/auth.AuthorizeAdminToken": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + req.Header["Authorization"] = []string{"token"} + auth := &mockAdminAuthority{ + MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { + assert.Equals(t, "token", token) + return nil, admin.NewError( + admin.ErrorUnauthorizedType, + "not authorized", + ) + }, + } + return test{ + ctx: context.Background(), + auth: auth, + req: req, + statusCode: 401, + err: &admin.Error{ + Type: admin.ErrorUnauthorizedType.String(), + Status: 401, + Detail: "unauthorized", + Message: "not authorized", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + req.Header["Authorization"] = []string{"token"} + createdAt := time.Now() + var deletedAt time.Time + admin := &linkedca.Admin{ + Id: "adminID", + AuthorityId: "authorityID", + Subject: "admin", + ProvisionerId: "provID", + Type: linkedca.Admin_SUPER_ADMIN, + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + auth := &mockAdminAuthority{ + MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { + assert.Equals(t, "token", token) + return admin, nil + }, + } + next := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin + adm, ok := a.(*linkedca.Admin) + if !ok { + t.Errorf("expected *linkedca.Admin; got %T", a) + return + } + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} + if !cmp.Equal(admin, adm, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) + } + w.Write(nil) // mock response with status 200 + } + return test{ + ctx: context.Background(), + auth: auth, + req: req, + next: next, + statusCode: 200, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.extractAuthorizeTokenAdmin(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 { + err := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err)) + + assert.Equals(t, tc.err.Type, err.Type) + assert.Equals(t, tc.err.Message, err.Message) + assert.Equals(t, tc.err.StatusCode(), res.StatusCode) + assert.Equals(t, tc.err.Detail, err.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + }) + } +} diff --git a/authority/admin/api/provisioner.go b/authority/admin/api/provisioner.go index fd1a02d5..d111f1e6 100644 --- a/authority/admin/api/provisioner.go +++ b/authority/admin/api/provisioner.go @@ -54,7 +54,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := api.ParseCursor(r) if err != nil { api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, - "error parsing cursor & limit query params")) + "error parsing cursor and limit from query params")) return } diff --git a/authority/admin/api/provisioner_test.go b/authority/admin/api/provisioner_test.go new file mode 100644 index 00000000..6c463590 --- /dev/null +++ b/authority/admin/api/provisioner_test.go @@ -0,0 +1,1100 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/linkedca" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestHandler_GetProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + db admin.DB + req *http.Request + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByID": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo?id=provID", nil) + chiCtx := chi.NewRouteContext() + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provID: force", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.ACME{ + ID: "acmeID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "acmeID", id) + return nil, admin.NewErrorISE("error loading provisioner provName: force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.ACME{ + ID: "acmeID", + Name: "provName", + }, nil + }, + } + prov := &linkedca.Provisioner{ + Id: "acmeID", + Type: linkedca.Provisioner_ACME, + Name: "provName", // TODO(hs): other fields too? + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "acmeID", id) + return prov, nil + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + db: db, + statusCode: 200, + err: nil, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + db: tc.db, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("h.GetProvisioner diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +} + +func TestHandler_GetProvisioners(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + req *http.Request + statusCode int + err *admin.Error + resp GetProvisionersResponse + } + var tests = map[string]func(t *testing.T) test{ + "fail/parse-cursor": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo?limit=X", nil) + return test{ + ctx: context.Background(), + statusCode: 400, + req: req, + err: &admin.Error{ + Status: 400, + Type: admin.ErrorBadRequestType.String(), + Detail: "bad request", + Message: "error parsing cursor and limit from query params: limit 'X' is not an integer: strconv.Atoi: parsing \"X\": invalid syntax", + }, + } + }, + "fail/auth.GetProvisioners": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + auth := &mockAdminAuthority{ + MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return nil, "", errors.New("force") + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: "", + Status: 500, + Detail: "", + Message: "The certificate authority encountered an Internal Server Error. Please see the certificate authority logs for more info.", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("GET", "/foo", nil) + provisioners := provisioner.List{ + &provisioner.OIDC{ + Type: "OIDC", + Name: "oidcProv", + }, + &provisioner.ACME{ + Type: "ACME", + Name: "provName", + ForceCN: false, + RequireEAB: false, + }, + } + auth := &mockAdminAuthority{ + MockGetProvisioners: func(cursor string, limit int) (provisioner.List, string, error) { + assert.Equals(t, "", cursor) + assert.Equals(t, 0, limit) + return provisioners, "nextCursorValue", nil + }, + } + return test{ + ctx: context.Background(), + req: req, + auth: auth, + statusCode: 200, + err: nil, + resp: GetProvisionersResponse{ + Provisioners: provisioners, + NextCursor: "nextCursorValue", + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.GetProvisioners(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := GetProvisionersResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(provisioner.ACME{}, provisioner.OIDC{})} + if !cmp.Equal(tc.resp, response, opts...) { + t.Errorf("h.GetProvisioners diff =\n%s", cmp.Diff(tc.resp, response, opts...)) + } + }) + } +} + +func TestHandler_CreateProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + body []byte + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/readProtoJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 500, + err: &admin.Error{ // TODO(hs): this probably needs a better error + Type: "", + Status: 500, + Detail: "", + Message: "", + }, + } + }, + // TODO(hs): ValidateClaims can't be mocked atm + // "fail/authority.ValidateClaims": func(t *testing.T) test { + // return test{} + // }, + "fail/auth.StoreProvisioner": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { + assert.Equals(t, "provID", prov.Id) + return errors.New("force") + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error storing provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockStoreProvisioner: func(ctx context.Context, prov *linkedca.Provisioner) error { + assert.Equals(t, "provID", prov.Id) + return nil + }, + } + return test{ + ctx: context.Background(), + body: body, + auth: auth, + statusCode: 201, + err: nil, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.CreateProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Provisioner{}, timestamppb.Timestamp{})} + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +} + +func TestHandler_DeleteProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + req *http.Request + statusCode int + err *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/auth.LoadProvisionerByID": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo?id=provID", nil) + chiCtx := chi.NewRouteContext() + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByID: func(id string) (provisioner.Interface, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provID: force", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner provName: force", + }, + } + }, + "fail/auth.RemoveProvisioner": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + Type: "OIDC", + }, nil + }, + MockRemoveProvisioner: func(ctx context.Context, id string) error { + assert.Equals(t, "provID", id) + return errors.New("force") + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error removing provisioner provName: force", + }, + } + }, + "ok": func(t *testing.T) test { + req := httptest.NewRequest("DELETE", "/foo", nil) + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + Type: "OIDC", + }, nil + }, + MockRemoveProvisioner: func(ctx context.Context, id string) error { + assert.Equals(t, "provID", id) + return nil + }, + } + return test{ + ctx: ctx, + req: req, + auth: auth, + statusCode: 200, + err: nil, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + } + req := tc.req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.DeleteProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + response := DeleteResponse{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &response)) + assert.Equals(t, "ok", response.Status) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + }) + } +} + +func TestHandler_UpdateProvisioner(t *testing.T) { + type test struct { + ctx context.Context + auth adminAuthority + body []byte + db admin.DB + statusCode int + err *admin.Error + prov *linkedca.Provisioner + } + var tests = map[string]func(t *testing.T) test{ + "fail/readProtoJSON": func(t *testing.T) test { + body := []byte("{!?}") + return test{ + ctx: context.Background(), + body: body, + statusCode: 500, + err: &admin.Error{ // TODO(hs): this probably needs a better error + Type: "", + Status: 500, + Detail: "", + Message: "", + }, + } + }, + "fail/auth.LoadProvisionerByName": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner from cached configuration 'provName': force", + }, + } + }, + "fail/db.GetProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return nil, errors.New("force") + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "error loading provisioner from db 'provID': force", + }, + } + }, + "fail/change-id-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "differentProvID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner ID", + }, + } + }, + "fail/change-type-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_JWK, + Name: "provName", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner type", + }, + } + }, + "fail/change-authority-id-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "differentAuthorityID", + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner authorityID", + }, + } + }, + "fail/change-createdAt-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner createdAt", + }, + } + }, + "fail/change-deletedAt-error": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(time.Now()), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: admin.ErrorServerInternalType.String(), + Status: 500, + Detail: "the server experienced an internal error", + Message: "cannot change provisioner deletedAt", + }, + } + }, + // TODO(hs): ValidateClaims can't be mocked atm + //"fail/ValidateClaims": func(t *testing.T) test { return test{} }, + "fail/auth.UpdateProvisioner": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + assert.Equals(t, "provID", nu.Id) + assert.Equals(t, "provName", nu.Name) + return errors.New("force") + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 500, + err: &admin.Error{ + Type: "", // TODO(hs): this error can be improved + Status: 500, + Detail: "", + Message: "", + }, + } + }, + "ok": func(t *testing.T) test { + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("name", "provName") + ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) + createdAt := time.Now() + var deletedAt time.Time + prov := &linkedca.Provisioner{ + Id: "provID", + Type: linkedca.Provisioner_OIDC, + Name: "provName", + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + Details: &linkedca.ProvisionerDetails{ + Data: &linkedca.ProvisionerDetails_OIDC{ + OIDC: &linkedca.OIDCProvisioner{ + ClientId: "new-client-id", + ClientSecret: "new-client-secret", + }, + }, + }, + } + body, err := protojson.Marshal(prov) + assert.FatalError(t, err) + auth := &mockAdminAuthority{ + MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { + assert.Equals(t, "provName", name) + return &provisioner.OIDC{ + ID: "provID", + Name: "provName", + }, nil + }, + MockUpdateProvisioner: func(ctx context.Context, nu *linkedca.Provisioner) error { + assert.Equals(t, "provID", nu.Id) + assert.Equals(t, "provName", nu.Name) + return nil + }, + } + db := &admin.MockDB{ + MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { + assert.Equals(t, "provID", id) + return &linkedca.Provisioner{ + Id: "provID", + Name: "provName", + Type: linkedca.Provisioner_OIDC, + AuthorityId: "authorityID", + CreatedAt: timestamppb.New(createdAt), + DeletedAt: timestamppb.New(deletedAt), + }, nil + }, + } + return test{ + ctx: ctx, + body: body, + auth: auth, + db: db, + statusCode: 200, + prov: prov, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{ + auth: tc.auth, + db: tc.db, + } + req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.UpdateProvisioner(w, req) + res := w.Result() + + assert.Equals(t, tc.statusCode, res.StatusCode) + + if res.StatusCode >= 400 { + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + adminErr := admin.Error{} + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) + + assert.Equals(t, tc.err.Type, adminErr.Type) + assert.Equals(t, tc.err.Message, adminErr.Message) + assert.Equals(t, tc.err.Detail, adminErr.Detail) + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + return + } + + prov := &linkedca.Provisioner{} + err := readProtoJSON(res.Body, prov) + assert.FatalError(t, err) + + assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) + + opts := []cmp.Option{ + cmpopts.IgnoreUnexported( + linkedca.Provisioner{}, linkedca.ProvisionerDetails{}, linkedca.ProvisionerDetails_OIDC{}, + linkedca.OIDCProvisioner{}, timestamppb.Timestamp{}, + ), + } + if !cmp.Equal(tc.prov, prov, opts...) { + t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(tc.prov, prov, opts...)) + } + }) + } +} diff --git a/authority/admin/db/nosql/admin_test.go b/authority/admin/db/nosql/admin_test.go index 4234d526..2631b68c 100644 --- a/authority/admin/db/nosql/admin_test.go +++ b/authority/admin/db/nosql/admin_test.go @@ -11,7 +11,7 @@ import ( "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" + nosqldb "github.com/smallstep/nosql/database" "go.step.sm/linkedca" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -31,7 +31,7 @@ func TestDB_getDBAdminBytes(t *testing.T) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), @@ -105,7 +105,7 @@ func TestDB_getDBAdmin(t *testing.T) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), @@ -398,7 +398,7 @@ func TestDB_GetAdmin(t *testing.T) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), @@ -551,7 +551,7 @@ func TestDB_DeleteAdmin(t *testing.T) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), @@ -697,7 +697,7 @@ func TestDB_UpdateAdmin(t *testing.T) { assert.Equals(t, bucket, adminsTable) assert.Equals(t, string(key), adminID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), @@ -985,7 +985,7 @@ func TestDB_GetAdmins(t *testing.T) { "fail/db.List-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return nil, errors.New("force") @@ -995,14 +995,14 @@ func TestDB_GetAdmins(t *testing.T) { } }, "fail/unmarshal-error": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("zap"), Value: []byte("zap")}, } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil @@ -1012,10 +1012,10 @@ func TestDB_GetAdmins(t *testing.T) { } }, "ok/none": func(t *testing.T) test { - ret := []*database.Entry{} + ret := []*nosqldb.Entry{} return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil @@ -1027,13 +1027,13 @@ func TestDB_GetAdmins(t *testing.T) { } }, "ok/only-invalid": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil @@ -1045,7 +1045,7 @@ func TestDB_GetAdmins(t *testing.T) { } }, "ok": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: adminsTable, Key: []byte("foo"), Value: foob}, {Bucket: adminsTable, Key: []byte("bar"), Value: barb}, {Bucket: adminsTable, Key: []byte("baz"), Value: bazb}, @@ -1053,7 +1053,7 @@ func TestDB_GetAdmins(t *testing.T) { } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, adminsTable) return ret, nil diff --git a/authority/admin/db/nosql/provisioner_test.go b/authority/admin/db/nosql/provisioner_test.go index e599ea04..a399558a 100644 --- a/authority/admin/db/nosql/provisioner_test.go +++ b/authority/admin/db/nosql/provisioner_test.go @@ -11,7 +11,7 @@ import ( "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" - "github.com/smallstep/nosql/database" + nosqldb "github.com/smallstep/nosql/database" "go.step.sm/linkedca" ) @@ -30,7 +30,7 @@ func TestDB_getDBProvisionerBytes(t *testing.T) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), @@ -104,7 +104,7 @@ func TestDB_getDBProvisioner(t *testing.T) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), @@ -444,7 +444,7 @@ func TestDB_GetProvisioner(t *testing.T) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), @@ -581,7 +581,7 @@ func TestDB_DeleteProvisioner(t *testing.T) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), @@ -735,7 +735,7 @@ func TestDB_GetProvisioners(t *testing.T) { "fail/db.List-error": func(t *testing.T) test { return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return nil, errors.New("force") @@ -745,14 +745,14 @@ func TestDB_GetProvisioners(t *testing.T) { } }, "fail/unmarshal-error": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("zap"), Value: []byte("zap")}, } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil @@ -762,10 +762,10 @@ func TestDB_GetProvisioners(t *testing.T) { } }, "ok/none": func(t *testing.T) test { - ret := []*database.Entry{} + ret := []*nosqldb.Entry{} return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil @@ -777,13 +777,13 @@ func TestDB_GetProvisioners(t *testing.T) { } }, "ok/only-invalid": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil @@ -795,7 +795,7 @@ func TestDB_GetProvisioners(t *testing.T) { } }, "ok": func(t *testing.T) test { - ret := []*database.Entry{ + ret := []*nosqldb.Entry{ {Bucket: provisionersTable, Key: []byte("foo"), Value: foob}, {Bucket: provisionersTable, Key: []byte("bar"), Value: barb}, {Bucket: provisionersTable, Key: []byte("baz"), Value: bazb}, @@ -803,7 +803,7 @@ func TestDB_GetProvisioners(t *testing.T) { } return test{ db: &db.MockNoSQLDB{ - MList: func(bucket []byte) ([]*database.Entry, error) { + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { assert.Equals(t, bucket, provisionersTable) return ret, nil @@ -988,7 +988,7 @@ func TestDB_UpdateProvisioner(t *testing.T) { assert.Equals(t, bucket, provisionersTable) assert.Equals(t, string(key), provID) - return nil, database.ErrNotFound + return nil, nosqldb.ErrNotFound }, }, adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), diff --git a/authority/admin/errors.go b/authority/admin/errors.go index 607093b0..217227ca 100644 --- a/authority/admin/errors.go +++ b/authority/admin/errors.go @@ -104,7 +104,7 @@ var ( } ) -// Error represents an Admin +// Error represents an Admin error type Error struct { Type string `json:"type"` Detail string `json:"detail"` diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index c8950568..21958d36 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -13,13 +13,18 @@ import ( // provisioning flow. type ACME struct { *base - ID string `json:"-"` - Type string `json:"type"` - Name string `json:"name"` - ForceCN bool `json:"forceCN,omitempty"` - Claims *Claims `json:"claims,omitempty"` - Options *Options `json:"options,omitempty"` - claimer *Claimer + ID string `json:"-"` + Type string `json:"type"` + Name string `json:"name"` + ForceCN bool `json:"forceCN,omitempty"` + // RequireEAB makes the provisioner require ACME EAB to be provided + // by clients when creating a new Account. If set to true, the provided + // EAB will be verified. If set to false and an EAB is provided, it is + // not verified. Defaults to false. + RequireEAB bool `json:"requireEAB,omitempty"` + Claims *Claims `json:"claims,omitempty"` + Options *Options `json:"options,omitempty"` + claimer *Claimer } // GetID returns the provisioner unique identifier. diff --git a/authority/provisioners.go b/authority/provisioners.go index a98b78a6..7cf761cd 100644 --- a/authority/provisioners.go +++ b/authority/provisioners.go @@ -638,12 +638,13 @@ func ProvisionerToCertificates(p *linkedca.Provisioner) (provisioner.Interface, case *linkedca.ProvisionerDetails_ACME: cfg := d.ACME return &provisioner.ACME{ - ID: p.Id, - Type: p.Type.String(), - Name: p.Name, - ForceCN: cfg.ForceCn, - Claims: claims, - Options: options, + ID: p.Id, + Type: p.Type.String(), + Name: p.Name, + ForceCN: cfg.ForceCn, + RequireEAB: cfg.RequireEab, + Claims: claims, + Options: options, }, nil case *linkedca.ProvisionerDetails_OIDC: cfg := d.OIDC diff --git a/ca/adminClient.go b/ca/adminClient.go index 2e447f55..cfbf595a 100644 --- a/ca/adminClient.go +++ b/ca/adminClient.go @@ -560,7 +560,115 @@ retry: return nil } +// GetExternalAccountKeysPaginate returns a page from the GET /admin/acme/eab request to the CA. +func (c *AdminClient) GetExternalAccountKeysPaginate(provisionerName, reference string, opts ...AdminOption) (*adminAPI.GetExternalAccountKeysResponse, error) { + var retried bool + o := new(adminOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + p := path.Join(adminURLPrefix, "acme/eab", provisionerName) + if reference != "" { + p = path.Join(p, "/", reference) + } + u := c.endpoint.ResolveReference(&url.URL{ + Path: p, + RawQuery: o.rawQuery(), + }) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("GET", u.String(), http.NoBody) + if err != nil { + return nil, errors.Wrapf(err, "create GET %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var body = new(adminAPI.GetExternalAccountKeysResponse) + if err := readJSON(resp.Body, body); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return body, nil +} + +// CreateExternalAccountKey performs the POST /admin/acme/eab request to the CA. +func (c *AdminClient) CreateExternalAccountKey(provisionerName string, eakRequest *adminAPI.CreateExternalAccountKeyRequest) (*linkedca.EABKey, error) { + var retried bool + body, err := json.Marshal(eakRequest) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab/", provisionerName)}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("POST", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "create POST %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readAdminError(resp.Body) + } + var eabKey = new(linkedca.EABKey) + if err := readProtoJSON(resp.Body, eabKey); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return eabKey, nil +} + +// RemoveExternalAccountKey performs the DELETE /admin/acme/eab/{prov}/{key_id} request to the CA. +func (c *AdminClient) RemoveExternalAccountKey(provisionerName, keyID string) error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "acme/eab", provisionerName, "/", keyID)}) + tok, err := c.generateAdminToken(u.Path) + if err != nil { + return errors.Wrapf(err, "error generating admin token") + } + req, err := http.NewRequest("DELETE", u.String(), http.NoBody) + if err != nil { + return errors.Wrapf(err, "create DELETE %s request failed", u) + } + req.Header.Add("Authorization", tok) +retry: + resp, err := c.client.Do(req) + if err != nil { + return errors.Wrapf(err, "client DELETE %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readAdminError(resp.Body) + } + return nil +} + func readAdminError(r io.ReadCloser) error { + // TODO: not all errors can be read (i.e. 404); seems to be a bigger issue defer r.Close() adminErr := new(admin.Error) if err := json.NewDecoder(r).Decode(adminErr); err != nil { diff --git a/ca/ca.go b/ca/ca.go index da0fb874..51ca9273 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -207,7 +207,7 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { if cfg.AuthorityConfig.EnableAdmin { adminDB := auth.GetAdminDatabase() if adminDB != nil { - adminHandler := adminAPI.NewHandler(auth) + adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB) mux.Route("/admin", func(r chi.Router) { adminHandler.Route(r) }) diff --git a/docs/provisioners.md b/docs/provisioners.md index 18010f88..d45e7865 100644 --- a/docs/provisioners.md +++ b/docs/provisioners.md @@ -346,6 +346,7 @@ Below is an example of an ACME provisioner in the `ca.json`: "type": "ACME", "name": "my-acme-provisioner", "forceCN": true, + "requireEAB": false, "claims": { "maxTLSCertDuration": "8h", "defaultTLSCertDuration": "2h", @@ -361,6 +362,9 @@ Below is an example of an ACME provisioner in the `ca.json`: * `forceCN` (optional): force one of the SANs to become the Common Name, if a common name is not provided. +* `requireEAB` (optional): require clients to provide External Account Binding + credentials when creating an ACME Account. + * `claims` (optional): overwrites the default claims set in the authority, see the [top](#provisioners) section for all the options.