diff --git a/acme/account.go b/acme/account.go index 3167dd09..e340bfa8 100644 --- a/acme/account.go +++ b/acme/account.go @@ -1,11 +1,11 @@ package acme import ( + "context" "encoding/json" "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/jose" "github.com/smallstep/nosql" ) @@ -79,11 +79,11 @@ func newAccount(db nosql.DB, ops AccountOptions) (*account, error) { // toACME converts the internal Account type into the public acmeAccount // type for presentation in the ACME protocol. -func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) { +func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) { return &Account{ Status: a.Status, Contact: a.Contact, - Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID), + Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), Key: a.Key, ID: a.ID, }, nil diff --git a/acme/account_test.go b/acme/account_test.go index 37af69dc..25600028 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -1,8 +1,10 @@ package acme import ( + "context" "encoding/json" "fmt" + "net/url" "testing" "time" @@ -332,6 +334,10 @@ func TestGetAccountIDsByAccount(t *testing.T) { func TestAccountToACME(t *testing.T) { dir := newDirectory("ca.smallstep.com", "acme") prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) type test struct { acc *account @@ -347,7 +353,7 @@ func TestAccountToACME(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - acmeAccount, err := tc.acc.toACME(nil, dir, prov) + acmeAccount, err := tc.acc.toACME(ctx, nil, dir) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) @@ -363,7 +369,7 @@ func TestAccountToACME(t *testing.T) { assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) assert.Equals(t, acmeAccount.Orders, - fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID)) + fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID)) } } }) diff --git a/acme/api/account.go b/acme/api/account.go index bb6c92d6..93f46651 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -65,18 +65,15 @@ func (u *UpdateAccountRequest) Validate() error { } return nil default: - return acme.MalformedErr(errors.Errorf("empty update request")) + // According to the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2) + // accountUpdate should ignore any fields not recognized by the server. + return nil } } // NewAccount is the handler resource for creating new ACME accounts. func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - payload, err := payloadFromContext(r) + payload, err := payloadFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -93,7 +90,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { } httpStatus := http.StatusCreated - acc, err := accountFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { acmeErr, ok := err.(*acme.Error) if !ok || acmeErr.Status != http.StatusBadRequest { @@ -107,13 +104,13 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { api.WriteError(w, acme.AccountDoesNotExistErr(nil)) return } - jwk, err := jwkFromContext(r) + jwk, err := acme.JwkFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{ + if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{ Key: jwk, Contact: nar.Contact, }); err != nil { @@ -125,29 +122,26 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { httpStatus = http.StatusOK } - w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, - acme.URLSafeProvisionerName(prov), true, acc.GetID())) + w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, + true, acc.GetID())) api.JSONStatus(w, acc, httpStatus) } // GetUpdateAccount is the api for updating an ACME account. func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - acc, err := accountFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - payload, err := payloadFromContext(r) + payload, err := payloadFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } + // If PostAsGet just respond with the account, otherwise process like a + // normal Post request. if !payload.isPostAsGet { var uar UpdateAccountRequest if err := json.Unmarshal(payload.value, &uar); err != nil { @@ -159,17 +153,21 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { return } var err error + // If neither the status nor the contacts are being updated then ignore + // the updates and return 200. This conforms with the behavior detailed + // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). if uar.IsDeactivateRequest() { - acc, err = h.Auth.DeactivateAccount(prov, acc.GetID()) - } else { - acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact) + acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID()) + } else if len(uar.Contact) > 0 { + acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact) } if err != nil { api.WriteError(w, err) return } } - w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID())) + w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, + true, acc.GetID())) api.JSON(w, acc) } @@ -184,23 +182,17 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) { // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - acc, err := accountFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - accID := chi.URLParam(r, "accID") if acc.ID != accID { api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) return } - orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID()) + orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) if err != nil { api.WriteError(w, err) return diff --git a/acme/api/account_test.go b/acme/api/account_test.go index a3ebf55c..0e34f980 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -7,6 +7,7 @@ import ( "fmt" "io/ioutil" "net/http/httptest" + "net/url" "testing" "time" @@ -143,6 +144,11 @@ func TestUpdateAccountRequestValidate(t *testing.T) { }, } }, + "ok/accept-empty": func(t *testing.T) test { + return test{ + uar: &UpdateAccountRequest{}, + } + }, } for name, run := range tests { tc := run(t) @@ -182,33 +188,17 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ auth: &mockAcmeAuthority{}, ctx: ctx, @@ -218,8 +208,8 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { }, "fail/account-id-mismatch": func(t *testing.T) test { acc := &acme.Account{ID: "foo"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{}, @@ -230,8 +220,8 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { }, "fail/getOrdersByAccount-error": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -244,12 +234,14 @@ func TestHandlerGetOrdersByAccount(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: accID} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ - getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) { + getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, id, acc.ID) return oids, nil @@ -304,8 +296,8 @@ func TestHandlerNewAccount(t *testing.T) { Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() - - url := "https://ca.smallstep.com/acme/new-account" + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { auth acme.Interface @@ -314,31 +306,16 @@ func TestHandlerNewAccount(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - return test{ - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 500, problem: acme.ServerInternalErr(errors.New("payload expected in request context")), } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -346,8 +323,8 @@ func TestHandlerNewAccount(t *testing.T) { } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, @@ -360,8 +337,8 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, @@ -374,8 +351,8 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, @@ -388,8 +365,8 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 500, @@ -402,9 +379,9 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, jwkContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.JwkContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -419,12 +396,14 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) return test{ auth: &mockAcmeAuthority{ - newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { + newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, ops.Contact, nar.Contact) assert.Equals(t, ops.Key, jwk) @@ -444,24 +423,27 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, jwkContextKey, jwk) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { + newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, ops.Contact, nar.Contact) assert.Equals(t, ops.Key, jwk) return &acc, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) - assert.Equals(t, in, []string{accID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID) + assert.True(t, abs) + assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) }, }, ctx: ctx, @@ -474,18 +456,19 @@ func TestHandlerNewAccount(t *testing.T) { } b, err := json.Marshal(nar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) - ctx = context.WithValue(ctx, accContextKey, &acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { assert.Equals(t, typ, acme.AccountLink) assert.True(t, abs) - assert.Equals(t, in, []string{accID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID) + assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx)) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) }, }, ctx: ctx, @@ -497,7 +480,7 @@ func TestHandlerNewAccount(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := New(tc.auth).(*Handler) - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.NewAccount(w, req) @@ -524,8 +507,8 @@ func TestHandlerNewAccount(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID)}) + []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), + provName, accID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) @@ -540,9 +523,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) { Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), } prov := newProv() - - // Request with chi context - url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { auth acme.Interface @@ -551,31 +533,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, nil) - return test{ - ctx: ctx, - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ ctx: ctx, statusCode: 400, @@ -583,8 +550,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } }, "fail/no-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) return test{ ctx: ctx, statusCode: 500, @@ -592,9 +559,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } }, "fail/nil-payload": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -602,9 +569,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } }, "fail/unmarshal-payload-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, @@ -617,9 +584,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, @@ -632,12 +599,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ auth: &mockAcmeAuthority{ - deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { + deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, id, accID) return nil, acme.ServerInternalErr(errors.New("force")) @@ -654,12 +623,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ auth: &mockAcmeAuthority{ - updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { + updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, id, accID) assert.Equals(t, contacts, uar.Contact) @@ -677,53 +648,82 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { + deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, id, accID) return &acc, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) - assert.Equals(t, in, []string{accID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) }, }, ctx: ctx, statusCode: 200, } }, - "ok/new-account": func(t *testing.T) test { + "ok/update-empty": func(t *testing.T) test { + uar := &UpdateAccountRequest{} + b, err := json.Marshal(uar) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) + return test{ + auth: &mockAcmeAuthority{ + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, typ, acme.AccountLink) + assert.True(t, abs) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) + }, + }, + ctx: ctx, + statusCode: 200, + } + }, + "ok/update-contacts": func(t *testing.T) test { uar := &UpdateAccountRequest{ Contact: []string{"foo", "bar"}, } b, err := json.Marshal(uar) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { + updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, id, accID) assert.Equals(t, contacts, uar.Contact) return &acc, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) - assert.Equals(t, in, []string{accID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL.String(), provName, accID) }, }, ctx: ctx, @@ -731,18 +731,19 @@ func TestHandlerGetUpdateAccount(t *testing.T) { } }, "ok/post-as-get": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, &acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, &acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) - assert.Equals(t, in, []string{accID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + assert.Equals(t, ins, []string{accID}) + return fmt.Sprintf("%s/acme/%s/account/%s", + baseURL, provName, accID) }, }, ctx: ctx, @@ -754,7 +755,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := New(tc.auth).(*Handler) - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetUpdateAccount(w, req) @@ -781,8 +782,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", - acme.URLSafeProvisionerName(prov), accID)}) + []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), + provName, accID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/api/handler.go b/acme/api/handler.go index b204e256..c7f05d89 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "net/http" @@ -8,65 +9,27 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/cli/jose" ) func link(url, typ string) string { return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) } -type contextKey string - -const ( - accContextKey = contextKey("acc") - jwsContextKey = contextKey("jws") - jwkContextKey = contextKey("jwk") - payloadContextKey = contextKey("payload") - provisionerContextKey = contextKey("provisioner") -) - type payloadInfo struct { value []byte isPostAsGet bool isEmptyJSON bool } -func accountFromContext(r *http.Request) (*acme.Account, error) { - val, ok := r.Context().Value(accContextKey).(*acme.Account) - if !ok || val == nil { - return nil, acme.AccountDoesNotExistErr(nil) - } - return val, nil -} -func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) { - val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context")) - } - return val, nil -} -func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) { - val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context")) - } - return val, nil -} -func payloadFromContext(r *http.Request) (*payloadInfo, error) { - val, ok := r.Context().Value(payloadContextKey).(*payloadInfo) +// payloadFromContext searches the context for a payload. Returns the payload +// or an error. +func payloadFromContext(ctx context.Context) (*payloadInfo, error) { + val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo) if !ok || val == nil { return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) } return val, nil } -func provisionerFromContext(r *http.Request) (provisioner.Interface, error) { - val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface) - if !ok || val == nil { - return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")) - } - return val, nil -} // New returns a new ACME API router. func New(acmeAuth acme.Interface) api.RouterHandler { @@ -80,29 +43,29 @@ type Handler struct { // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { - getLink := h.Auth.GetLink + getLink := h.Auth.GetLinkExplicit // Standard ACME API - r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) - r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) - r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) - r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) + r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) + r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) + r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))) + return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))) + return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) } - r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount)) - r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) - r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder)) - r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) - r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) - r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) - r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) - r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge)) - r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) + r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) + r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) + r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) + r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) + r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) + r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) + r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge)) + r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) } // GetNonce just sets the right header since a Nonce is added to each response @@ -118,34 +81,28 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { // GetDirectory is the ACME resource for returning a directory configuration // for client configuration. func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + dir, err := h.Auth.GetDirectory(r.Context()) if err != nil { api.WriteError(w, err) return } - dir := h.Auth.GetDirectory(prov) api.JSON(w, dir) } // GetAuthz ACME api for retrieving an Authz. func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - acc, err := accountFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID")) + authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID")) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID())) + w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID())) api.JSON(w, authz) } @@ -186,13 +143,7 @@ func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { // https://tools.ietf.org/html/rfc8555#section-7.5.1 // func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - - acc, err := accountFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -200,7 +151,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // Just verify that the payload was set since the client is required // to send _something_. - _, err = payloadFromContext(r) + _, err = payloadFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -210,15 +161,14 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { ch *acme.Challenge chID = chi.URLParam(r, "chID") ) - ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey()) + ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey()) if err != nil { api.WriteError(w, err) return } - getLink := h.Auth.GetLink - w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up")) - w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID())) + w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up")) + w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID())) if ch.Status == acme.StatusProcessing { w.Header().Add("Retry-After", ch.RetryAfter) @@ -231,7 +181,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { // GetCertificate ACME api for retrieving a Certificate. func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { - acc, err := accountFromContext(r) + acc, err := acme.AccountFromContext(r.Context()) if err != nil { api.WriteError(w, err) return diff --git a/acme/api/handler_test.go b/acme/api/handler_test.go index 70c366d7..8cbed7af 100644 --- a/acme/api/handler_test.go +++ b/acme/api/handler_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "net/http/httptest" + "net/url" "testing" "time" @@ -23,74 +24,79 @@ import ( ) type mockAcmeAuthority struct { - deactivateAccount func(provisioner.Interface, string) (*acme.Account, error) - finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error) - getAccount func(p provisioner.Interface, id string) (*acme.Account, error) - getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error) - getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error) - getCertificate func(accID string, id string) ([]byte, error) - getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error) - getDirectory func(provisioner.Interface) *acme.Directory - getLink func(acme.Link, string, bool, ...string) string - getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error) - getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error) + getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string + getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string + + deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error) + getAccount func(ctx context.Context, accID string) (*acme.Account, error) + getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error) + newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error) + updateAccount func(context.Context, string, []string) (*acme.Account, error) + + getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error) + validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error) + getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error) + getDirectory func(ctx context.Context) (*acme.Directory, error) + getCertificate func(string, string) ([]byte, error) + + finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error) + getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error) + getOrdersByAccount func(ctx context.Context, accID string) ([]string, error) + newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error) + loadProvisionerByID func(string) (provisioner.Interface, error) - newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error) newNonce func() (string, error) - newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error) - updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error) useNonce func(string) error - validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) ret1 interface{} err error } -func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) { +func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) { if m.deactivateAccount != nil { - return m.deactivateAccount(p, id) + return m.deactivateAccount(ctx, id) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Account), m.err } -func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { +func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { if m.finalizeOrder != nil { - return m.finalizeOrder(p, accID, id, csr) + return m.finalizeOrder(ctx, accID, id, csr) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Order), m.err } -func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) { +func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) { if m.getAccount != nil { - return m.getAccount(p, id) + return m.getAccount(ctx, id) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Account), m.err } -func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { +func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { if m.getAccountByKey != nil { - return m.getAccountByKey(p, jwk) + return m.getAccountByKey(ctx, jwk) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Account), m.err } -func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) { +func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) { if m.getAuthz != nil { - return m.getAuthz(p, accID, id) + return m.getAuthz(ctx, accID, id) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Authz), m.err } -func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) { +func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) { if m.getCertificate != nil { return m.getCertificate(accID, id) } else if m.err != nil { @@ -99,41 +105,48 @@ func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) { return m.ret1.([]byte), m.err } -func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) { +func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) { if m.getChallenge != nil { - return m.getChallenge(p, accID, id) + return m.getChallenge(ctx, accID, id) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Challenge), m.err } -func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface) *acme.Directory { +func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) { if m.getDirectory != nil { - return m.getDirectory(p) + return m.getDirectory(ctx) } - return m.ret1.(*acme.Directory) + return m.ret1.(*acme.Directory), m.err } -func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string { +func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { if m.getLink != nil { - return m.getLink(typ, provID, abs, in...) + return m.getLink(ctx, typ, abs, ins...) } return m.ret1.(string) } -func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) { +func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string { + if m.getLinkExplicit != nil { + return m.getLinkExplicit(typ, provID, abs, baseURL, ins...) + } + return m.ret1.(string) +} + +func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) { if m.getOrder != nil { - return m.getOrder(p, accID, id) + return m.getOrder(ctx, accID, id) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Order), m.err } -func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { +func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { if m.getOrdersByAccount != nil { - return m.getOrdersByAccount(p, id) + return m.getOrdersByAccount(ctx, id) } else if m.err != nil { return nil, m.err } @@ -149,9 +162,9 @@ func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Inte return m.ret1.(provisioner.Interface), m.err } -func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { +func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) { if m.newAccount != nil { - return m.newAccount(p, ops) + return m.newAccount(ctx, ops) } else if m.err != nil { return nil, m.err } @@ -167,18 +180,18 @@ func (m *mockAcmeAuthority) NewNonce() (string, error) { return m.ret1.(string), m.err } -func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { +func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { if m.newOrder != nil { - return m.newOrder(p, ops) + return m.newOrder(ctx, ops) } else if m.err != nil { return nil, m.err } return m.ret1.(*acme.Order), m.err } -func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) { +func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) { if m.updateAccount != nil { - return m.updateAccount(p, id, contact) + return m.updateAccount(ctx, id, contact) } else if m.err != nil { return nil, m.err } @@ -192,10 +205,10 @@ func (m *mockAcmeAuthority) UseNonce(nonce string) error { return m.err } -func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { +func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { switch { case m.validateChallenge != nil: - return m.validateChallenge(p, accID, id, jwk) + return m.validateChallenge(ctx, accID, id, jwk) case m.err != nil: return nil, m.err default: @@ -233,40 +246,28 @@ func TestHandlerGetNonce(t *testing.T) { func TestHandlerGetDirectory(t *testing.T) { auth, err := acme.NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil, 0) assert.FatalError(t, err) + prov := newProv() - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov)) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) expDir := acme.Directory{ - NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)), - NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)), - NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)), - RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)), - KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)), + NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), + NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName), + RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName), + KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName), } type test struct { - ctx context.Context statusCode int problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "ok": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), statusCode: 200, } }, @@ -275,8 +276,8 @@ func TestHandlerGetDirectory(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := New(auth).(*Handler) - req := httptest.NewRequest("GET", url, nil) - req = req.WithContext(tc.ctx) + req := httptest.NewRequest("GET", "/foo/bar", nil) + req = req.WithContext(ctx) w := httptest.NewRecorder() h.GetDirectory(w, req) res := w.Result() @@ -338,12 +339,14 @@ func TestHandlerGetAuthz(t *testing.T) { }, } prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("authzID", az.ID) - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s", - acme.URLSafeProvisionerName(prov), az.ID) + url := fmt.Sprintf("%s/acme/%s/challenge/%s", + baseURL.String(), provName, az.ID) type test struct { auth acme.Interface @@ -352,33 +355,17 @@ func TestHandlerGetAuthz(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ auth: &mockAcmeAuthority{}, ctx: ctx, @@ -388,8 +375,8 @@ func TestHandlerGetAuthz(t *testing.T) { }, "fail/getAuthz-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -402,20 +389,23 @@ func TestHandlerGetAuthz(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) { + getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, az.ID) return &az, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AuthzLink) + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) assert.True(t, abs) assert.Equals(t, in, []string{az.ID}) return url @@ -430,7 +420,7 @@ func TestHandlerGetAuthz(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := New(tc.auth).(*Handler) - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", "/foo/bar", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.GetAuthz(w, req) @@ -487,11 +477,13 @@ func TestHandlerGetCertificate(t *testing.T) { certID := "certID" prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} // Request with chi context chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("certID", certID) - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s", - acme.URLSafeProvisionerName(prov), certID) + url := fmt.Sprintf("%s/acme/%s/certificate/%s", + baseURL.String(), provName, certID) type test struct { auth acme.Interface @@ -503,13 +495,13 @@ func TestHandlerGetCertificate(t *testing.T) { "fail/no-account": func(t *testing.T) test { return test{ auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.AccContextKey, nil) return test{ auth: &mockAcmeAuthority{}, ctx: ctx, @@ -519,7 +511,7 @@ func TestHandlerGetCertificate(t *testing.T) { }, "fail/getCertificate-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -532,7 +524,7 @@ func TestHandlerGetCertificate(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -595,8 +587,10 @@ func ch() acme.Challenge { func TestHandlerGetChallenge(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("chID", "chID") - url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID") prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID") type test struct { auth acme.Interface @@ -607,33 +601,17 @@ func TestHandlerGetChallenge(t *testing.T) { } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ ctx: ctx, statusCode: 400, @@ -643,8 +621,8 @@ func TestHandlerGetChallenge(t *testing.T) { "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) return test{ ctx: ctx, statusCode: 500, @@ -654,9 +632,9 @@ func TestHandlerGetChallenge(t *testing.T) { "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -666,9 +644,9 @@ func TestHandlerGetChallenge(t *testing.T) { "fail/validate-challenge-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -680,39 +658,56 @@ func TestHandlerGetChallenge(t *testing.T) { } }, + "fail/get-challenge-error": func(t *testing.T) test { + acc := &acme.Account{ID: "accID"} + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + return test{ + auth: &mockAcmeAuthority{ + err: acme.UnauthorizedErr(nil), + }, + ctx: ctx, + statusCode: 401, + problem: acme.UnauthorizedErr(nil), + } + }, + "ok/validate-challenge": func(t *testing.T) test { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) ch := ch() ch.Status = "valid" ch.Validated = time.Now().UTC().Format(time.RFC3339) count := 0 return test{ auth: &mockAcmeAuthority{ - validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, ch.ID) assert.Equals(t, jwk.KeyID, key.KeyID) return &ch, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { var ret string switch count { case 0: assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) case 1: assert.Equals(t, typ, acme.ChallengeLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{ch.ID}) ret = url @@ -731,39 +726,39 @@ func TestHandlerGetChallenge(t *testing.T) { key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) acc := &acme.Account{ID: "accID", Key: key} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - chiCtxInactive := chi.NewRouteContext() - chiCtxInactive.URLParams.Add("chID", "chID") - ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtxInactive) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true}) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) ch := ch() ch.Status = "processing" ch.RetryAfter = time.Now().Add(1 * time.Minute).UTC().Format(time.RFC3339) chJSON, err := json.Marshal(ch) assert.FatalError(t, err) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: chJSON}) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: chJSON}) count := 0 return test{ auth: &mockAcmeAuthority{ - validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, ch.ID) assert.Equals(t, jwk.KeyID, key.KeyID) return &ch, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { var ret string switch count { case 0: assert.Equals(t, typ, acme.AuthzLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{ch.AuthzID}) - ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID) + ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID) case 1: assert.Equals(t, typ, acme.ChallengeLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{ch.ID}) ret = url @@ -772,7 +767,6 @@ func TestHandlerGetChallenge(t *testing.T) { return ret }, }, - ctx: ctx, statusCode: 200, ch: ch, @@ -811,14 +805,15 @@ func TestHandlerGetChallenge(t *testing.T) { expB, err := json.Marshal(tc.ch) assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) + assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)}) + assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) switch tc.ch.Status { case "processing": assert.Equals(t, res.Header["Cache-Control"], []string{"no-cache"}) assert.Equals(t, res.Header["Retry-After"], []string{tc.ch.RetryAfter}) case "valid", "invalid": - assert.Equals(t, res.Header["Location"], []string{url}) - assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf(";rel=\"up\"", tc.ch.AuthzID)}) + // } } else { assert.Fatal(t, false, "Unexpected Status Code") diff --git a/acme/api/middleware.go b/acme/api/middleware.go index af2618bf..93a85a7f 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -30,6 +30,35 @@ func logNonce(w http.ResponseWriter, nonce string) { } } +// baseURLFromRequest determines the base URL which should be used for +// constructing link URLs in e.g. the ACME directory result by taking the +// request Host into consideration. +// +// If the Request.Host is an empty string, we return an empty string, to +// indicate that the configured URL values should be used instead. If this +// function returns a non-empty result, then this should be used in +// constructing ACME link URLs. +func baseURLFromRequest(r *http.Request) *url.URL { + // NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go + // for an implementation that allows HTTP requests using the x-forwarded-proto + // header. + + if r.Host == "" { + return nil + } + return &url.URL{Scheme: "https", Host: r.Host} +} + +// baseURLFromRequest is a middleware that extracts and caches the baseURL +// from the request. +// E.g. https://ca.smallstep.com/ +func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r)) + next(w, r.WithContext(ctx)) + } +} + // addNonce is a middleware that adds a nonce to the response header. func (h *Handler) addNonce(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { @@ -49,12 +78,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP { // directory index url. func (h *Handler) addDirLink(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index")) + w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), + acme.DirectoryLink, true), "index")) next(w, r) } } @@ -63,14 +88,9 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP { // application/jose+json. func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } ct := r.Header.Get("Content-Type") var expected []string - if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) { + if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) { // GET /certificate requests allow a greater range of content types. expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} } else { @@ -101,7 +121,7 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) return } - ctx := context.WithValue(r.Context(), jwsContextKey, jws) + ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws) next(w, r.WithContext(ctx)) } } @@ -123,7 +143,7 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP { // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below func (h *Handler) validateJWS(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := jwsFromContext(r) + jws, err := acme.JwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -207,12 +227,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - jws, err := jwsFromContext(r) + jws, err := acme.JwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -226,8 +241,8 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) return } - ctx = context.WithValue(ctx, jwkContextKey, jwk) - acc, err := h.Auth.GetAccountByKey(prov, jwk) + ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) + acc, err := h.Auth.GetAccountByKey(ctx, jwk) switch { case nosql.IsErrNotFound(err): // For NewAccount requests ... @@ -240,7 +255,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) return } - ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) } next(w, r.WithContext(ctx)) } @@ -267,7 +282,7 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) return } - ctx = context.WithValue(ctx, provisionerContextKey, p) + ctx = context.WithValue(ctx, acme.ProvisionerContextKey, p) next(w, r.WithContext(ctx)) } } @@ -278,18 +293,13 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - jws, err := jwsFromContext(r) + jws, err := acme.JwsFromContext(ctx) if err != nil { api.WriteError(w, err) return } - kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "") + kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "") kid := jws.Signatures[0].Protected.KeyID if !strings.HasPrefix(kid, kidPrefix) { api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ @@ -298,7 +308,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } accID := strings.TrimPrefix(kid, kidPrefix) - acc, err := h.Auth.GetAccount(prov, accID) + acc, err := h.Auth.GetAccount(r.Context(), accID) switch { case nosql.IsErrNotFound(err): api.WriteError(w, acme.AccountDoesNotExistErr(nil)) @@ -311,8 +321,8 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) return } - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, jwkContextKey, acc.Key) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key) next(w, r.WithContext(ctx)) return } @@ -323,12 +333,12 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - jws, err := jwsFromContext(r) + jws, err := acme.JwsFromContext(r.Context()) if err != nil { api.WriteError(w, err) return } - jwk, err := jwkFromContext(r) + jwk, err := acme.JwkFromContext(r.Context()) if err != nil { api.WriteError(w, err) return @@ -342,7 +352,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) return } - ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{ + ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{ value: payload, isPostAsGet: string(payload) == "", isEmptyJSON: string(payload) == "{}", @@ -354,7 +364,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { return func(w http.ResponseWriter, r *http.Request) { - payload, err := payloadFromContext(r) + payload, err := payloadFromContext(r.Context()) if err != nil { api.WriteError(w, err) return diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index e617e5bd..916d84f0 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -11,13 +11,13 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/jose" "github.com/smallstep/nosql/database" ) @@ -28,6 +28,85 @@ func testNext(w http.ResponseWriter, r *http.Request) { w.Write(testBody) } +func Test_baseURLFromRequest(t *testing.T) { + tests := []struct { + name string + targetURL string + expectedResult *url.URL + requestPreparer func(*http.Request) + }{ + { + "HTTPS host pass-through failed.", + "https://my.dummy.host", + &url.URL{Scheme: "https", Host: "my.dummy.host"}, + nil, + }, + { + "Port pass-through failed", + "https://host.with.port:8080", + &url.URL{Scheme: "https", Host: "host.with.port:8080"}, + nil, + }, + { + "Explicit host from Request.Host was not used.", + "https://some.target.host:8080", + &url.URL{Scheme: "https", Host: "proxied.host"}, + func(r *http.Request) { + r.Host = "proxied.host" + }, + }, + { + "Missing Request.Host value did not result in empty string result.", + "https://some.host", + nil, + func(r *http.Request) { + r.Host = "" + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := httptest.NewRequest("GET", tc.targetURL, nil) + if tc.requestPreparer != nil { + tc.requestPreparer(request) + } + result := baseURLFromRequest(request) + if result == nil || tc.expectedResult == nil { + assert.Equals(t, result, tc.expectedResult) + } else if result.String() != tc.expectedResult.String() { + t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String()) + } + }) + } +} + +func TestHandlerBaseURLFromRequest(t *testing.T) { + h := New(&mockAcmeAuthority{}).(*Handler) + req := httptest.NewRequest("GET", "/foo", nil) + req.Host = "test.ca.smallstep.com:8080" + w := httptest.NewRecorder() + + next := func(w http.ResponseWriter, r *http.Request) { + bu := acme.BaseURLFromContext(r.Context()) + if assert.NotNil(t, bu) { + assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080") + assert.Equals(t, bu.Scheme, "https") + } + } + + h.baseURLFromRequest(next)(w, req) + + req = httptest.NewRequest("GET", "/foo", nil) + req.Host = "" + + next = func(w http.ResponseWriter, r *http.Request) { + assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil) + } + + h.baseURLFromRequest(next)(w, req) +} + func TestHandlerAddNonce(t *testing.T) { url := "https://ca.smallstep.com/acme/new-nonce" type test struct { @@ -93,8 +172,9 @@ func TestHandlerAddNonce(t *testing.T) { } func TestHandlerAddDirLink(t *testing.T) { - url := "https://ca.smallstep.com/acme/new-nonce" prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} type test struct { auth acme.Interface link string @@ -103,33 +183,18 @@ func TestHandlerAddDirLink(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "ok": func(t *testing.T) test { - link := "https://ca.smallstep.com/acme/directory" + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) - return link + getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string { + assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL) + return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName) }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), - link: link, + ctx: ctx, + link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), statusCode: 200, } }, @@ -138,7 +203,7 @@ func TestHandlerAddDirLink(t *testing.T) { tc := run(t) t.Run(name, func(t *testing.T) { h := New(tc.auth).(*Handler) - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest("GET", "/foo", nil) req = req.WithContext(tc.ctx) w := httptest.NewRecorder() h.addDirLink(testNext)(w, req) @@ -170,8 +235,9 @@ func TestHandlerAddDirLink(t *testing.T) { func TestHandlerVerifyContentType(t *testing.T) { prov := newProv() - url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/abc123", - acme.URLSafeProvisionerName(prov)) + provName := prov.GetName() + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName) type test struct { h Handler ctx context.Context @@ -181,38 +247,20 @@ func TestHandlerVerifyContentType(t *testing.T) { url string } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - h: Handler{Auth: &mockAcmeAuthority{}}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - h: Handler{Auth: &mockAcmeAuthority{}}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/general-bad-content-type": func(t *testing.T) test { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) - return "/certificate/" + return fmt.Sprintf("/acme/%s/certificate/", provName) }, }, }, - url: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", - acme.URLSafeProvisionerName(prov)), - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "foo", statusCode: 400, problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), @@ -222,16 +270,15 @@ func TestHandlerVerifyContentType(t *testing.T) { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) return "/certificate/" }, }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "foo", statusCode: 400, problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), @@ -241,16 +288,15 @@ func TestHandlerVerifyContentType(t *testing.T) { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) return "/certificate/" }, }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -259,16 +305,15 @@ func TestHandlerVerifyContentType(t *testing.T) { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) return "/certificate/" }, }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "application/pkix-cert", statusCode: 200, } @@ -277,16 +322,15 @@ func TestHandlerVerifyContentType(t *testing.T) { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) return "/certificate/" }, }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "application/jose+json", statusCode: 200, } @@ -295,16 +339,15 @@ func TestHandlerVerifyContentType(t *testing.T) { return test{ h: Handler{ Auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.CertificateLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, abs, false) assert.Equals(t, in, []string{""}) return "/certificate/" }, }, }, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), contentType: "application/pkcs7-mime", statusCode: 200, } @@ -364,21 +407,21 @@ func TestHandlerIsPostAsGet(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), payloadContextKey, nil), + ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil), statusCode: 500, problem: acme.ServerInternalErr(errors.New("payload expected in request context")), } }, "fail/not-post-as-get": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), + ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}), statusCode: 400, problem: acme.MalformedErr(errors.New("expected POST-as-GET")), } }, "ok": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), + ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}), statusCode: 200, } }, @@ -464,7 +507,7 @@ func TestHandlerParseJWS(t *testing.T) { return test{ body: strings.NewReader(expRaw), next: func(w http.ResponseWriter, r *http.Request) { - jws, err := jwsFromContext(r) + jws, err := acme.JwsFromContext(r.Context()) assert.FatalError(t, err) gotRaw, err := jws.CompactSerialize() assert.FatalError(t, err) @@ -542,22 +585,22 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, nil), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jws expected in request context")), } }, "fail/no-jwk": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), } }, "fail/nil-jwk": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) return test{ - ctx: context.WithValue(ctx, jwkContextKey, nil), + ctx: context.WithValue(ctx, acme.JwkContextKey, nil), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), } @@ -566,8 +609,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) _pub := _jwk.Public() - ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwkContextKey, &_pub) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub) return test{ ctx: ctx, statusCode: 400, @@ -578,8 +621,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _pub := *pub clone := &_pub clone.Algorithm = jose.HS256 - ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwkContextKey, clone) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.JwkContextKey, clone) return test{ ctx: ctx, statusCode: 400, @@ -587,13 +630,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { } }, "ok": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwkContextKey, pub) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.JwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { - p, err := payloadFromContext(r) + p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("baz")) @@ -608,13 +651,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { _pub := *pub clone := &_pub clone.Algorithm = "" - ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) - ctx = context.WithValue(ctx, jwkContextKey, pub) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.JwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { - p, err := payloadFromContext(r) + p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("baz")) @@ -632,13 +675,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) - ctx = context.WithValue(ctx, jwkContextKey, pub) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) + ctx = context.WithValue(ctx, acme.JwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { - p, err := payloadFromContext(r) + p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte{}) @@ -656,13 +699,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) - ctx = context.WithValue(ctx, jwkContextKey, pub) + ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed) + ctx = context.WithValue(ctx, acme.JwkContextKey, pub) return test{ ctx: ctx, statusCode: 200, next: func(w http.ResponseWriter, r *http.Request) { - p, err := payloadFromContext(r) + p, err := payloadFromContext(r.Context()) assert.FatalError(t, err) if assert.NotNil(t, p) { assert.Equals(t, p.value, []byte("{}")) @@ -709,13 +752,15 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { func TestHandlerLookupJWK(t *testing.T) { prov := newProv() - url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", - acme.URLSafeProvisionerName(prov)) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/%s/account/1234", + baseURL, provName) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) accID := "account-id" - prefix := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", - acme.URLSafeProvisionerName(prov)) + prefix := fmt.Sprintf("%s/acme/%s/account/", + baseURL, provName) so := new(jose.SignerOptions) so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) signer, err := jose.NewSigner(jose.SigningKey{ @@ -737,30 +782,16 @@ func TestHandlerLookupJWK(t *testing.T) { statusCode int } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jws expected in request context")), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -775,13 +806,13 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _jws, err := _signer.Sign([]byte("baz")) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) return prefix @@ -806,16 +837,16 @@ func TestHandlerLookupJWK(t *testing.T) { assert.FatalError(t, err) _parsed, err := jose.ParseJWS(_raw) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, _parsed) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) }, }, ctx: ctx, @@ -824,21 +855,23 @@ func TestHandlerLookupJWK(t *testing.T) { } }, "fail/account-not-found": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) return nil, database.ErrNotFound }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) }, }, ctx: ctx, @@ -847,21 +880,23 @@ func TestHandlerLookupJWK(t *testing.T) { } }, "fail/GetAccount-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) return nil, acme.ServerInternalErr(errors.New("force")) }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) }, }, ctx: ctx, @@ -871,21 +906,23 @@ func TestHandlerLookupJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) return acc, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) }, }, ctx: ctx, @@ -895,29 +932,31 @@ func TestHandlerLookupJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid", Key: jwk} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { + getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, accID) return acc, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.AccountLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{""}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) + return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName) }, }, ctx: ctx, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := accountFromContext(r) + _acc, err := acme.AccountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := jwkFromContext(r) + _jwk, err := acme.JwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk, jwk) w.Write(testBody) @@ -961,6 +1000,7 @@ func TestHandlerLookupJWK(t *testing.T) { func TestHandlerExtractJWK(t *testing.T) { prov := newProv() + provName := url.PathEscape(prov.GetName()) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) kid, err := jwk.Thumbprint(crypto.SHA256) @@ -982,7 +1022,7 @@ func TestHandlerExtractJWK(t *testing.T) { parsedJWS, err := jose.ParseJWS(raw) assert.FatalError(t, err) url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", - acme.URLSafeProvisionerName(prov)) + provName) type test struct { auth acme.Interface ctx context.Context @@ -991,30 +1031,16 @@ func TestHandlerExtractJWK(t *testing.T) { statusCode int } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jws expected in request context")), } }, "fail/nil-jws": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -1031,8 +1057,8 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, @@ -1049,8 +1075,8 @@ func TestHandlerExtractJWK(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, _jws) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, _jws) return test{ ctx: ctx, statusCode: 400, @@ -1058,12 +1084,14 @@ func TestHandlerExtractJWK(t *testing.T) { } }, "fail/GetAccountByKey-error": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) return test{ ctx: ctx, auth: &mockAcmeAuthority{ - getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) return nil, acme.ServerInternalErr(errors.New("force")) @@ -1075,12 +1103,14 @@ func TestHandlerExtractJWK(t *testing.T) { }, "fail/account-not-valid": func(t *testing.T) test { acc := &acme.Account{Status: "deactivated"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) return test{ ctx: ctx, auth: &mockAcmeAuthority{ - getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) return acc, nil @@ -1092,22 +1122,24 @@ func TestHandlerExtractJWK(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{Status: "valid"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) return test{ ctx: ctx, auth: &mockAcmeAuthority{ - getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) return acc, nil }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := accountFromContext(r) + _acc, err := acme.AccountFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _acc, acc) - _jwk, err := jwkFromContext(r) + _jwk, err := acme.JwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1116,22 +1148,24 @@ func TestHandlerExtractJWK(t *testing.T) { } }, "ok/no-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS) return test{ ctx: ctx, auth: &mockAcmeAuthority{ - getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { + getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, jwk.KeyID, pub.KeyID) return nil, database.ErrNotFound }, }, next: func(w http.ResponseWriter, r *http.Request) { - _acc, err := accountFromContext(r) + _acc, err := acme.AccountFromContext(r.Context()) assert.NotNil(t, err) assert.Nil(t, _acc) - _jwk, err := jwkFromContext(r) + _jwk, err := acme.JwkFromContext(r.Context()) assert.FatalError(t, err) assert.Equals(t, _jwk.KeyID, pub.KeyID) w.Write(testBody) @@ -1192,14 +1226,14 @@ func TestHandlerValidateJWS(t *testing.T) { }, "fail/nil-jws": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, nil), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil), statusCode: 500, problem: acme.ServerInternalErr(errors.New("jws expected in request context")), } }, "fail/no-signature": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}), statusCode: 400, problem: acme.MalformedErr(errors.New("request body does not contain a signature")), } @@ -1212,7 +1246,7 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.New("request body contains more than one signature")), } @@ -1224,7 +1258,7 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.New("unprotected header must not be used")), } @@ -1236,7 +1270,7 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), } @@ -1248,7 +1282,7 @@ func TestHandlerValidateJWS(t *testing.T) { }, } return test{ - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), } @@ -1276,7 +1310,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), } @@ -1304,7 +1338,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), } @@ -1321,7 +1355,7 @@ func TestHandlerValidateJWS(t *testing.T) { return acme.ServerInternalErr(errors.New("force")) }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 500, problem: acme.ServerInternalErr(errors.New("force")), } @@ -1338,7 +1372,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.New("jws missing url protected header")), } @@ -1362,7 +1396,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), } @@ -1391,7 +1425,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), } @@ -1415,7 +1449,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), statusCode: 400, problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), } @@ -1440,7 +1474,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1470,7 +1504,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, @@ -1500,7 +1534,7 @@ func TestHandlerValidateJWS(t *testing.T) { return nil }, }, - ctx: context.WithValue(context.Background(), jwsContextKey, jws), + ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws), next: func(w http.ResponseWriter, r *http.Request) { w.Write(testBody) }, diff --git a/acme/api/order.go b/acme/api/order.go index 1d491102..5c62cb52 100644 --- a/acme/api/order.go +++ b/acme/api/order.go @@ -58,17 +58,13 @@ func (f *FinalizeRequest) Validate() error { // NewOrder ACME api for creating a new order. func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + ctx := r.Context() + acc, err := acme.AccountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - acc, err := accountFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - payload, err := payloadFromContext(r) + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -84,7 +80,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - o, err := h.Auth.NewOrder(prov, acme.OrderOptions{ + o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{ AccountID: acc.GetID(), Identifiers: nor.Identifiers, NotBefore: nor.NotBefore, @@ -95,46 +91,38 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { return } - w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) + w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) api.JSONStatus(w, o, http.StatusCreated) } // GetOrder ACME api for retrieving an order. func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - acc, err := accountFromContext(r) + ctx := r.Context() + acc, err := acme.AccountFromContext(ctx) if err != nil { api.WriteError(w, err) return } oid := chi.URLParam(r, "ordID") - o, err := h.Auth.GetOrder(prov, acc.GetID(), oid) + o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) + w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) api.JSON(w, o) } // FinalizeOrder attemptst to finalize an order and create a certificate. func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { - prov, err := provisionerFromContext(r) + ctx := r.Context() + acc, err := acme.AccountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - acc, err := accountFromContext(r) - if err != nil { - api.WriteError(w, err) - return - } - payload, err := payloadFromContext(r) + payload, err := payloadFromContext(ctx) if err != nil { api.WriteError(w, err) return @@ -150,12 +138,12 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { } oid := chi.URLParam(r, "ordID") - o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr) + o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr) if err != nil { api.WriteError(w, err) return } - w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID)) + w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID)) api.JSON(w, o) } diff --git a/acme/api/order_test.go b/acme/api/order_test.go index 0931d832..487b8669 100644 --- a/acme/api/order_test.go +++ b/acme/api/order_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "net/http/httptest" + "net/url" "testing" "time" @@ -16,7 +17,6 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" - "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/pemutil" ) @@ -175,8 +175,10 @@ func TestHandlerGetOrder(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) prov := newProv() - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s", - acme.URLSafeProvisionerName(prov), o.ID) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/%s/order/%s", + baseURL.String(), provName, o.ID) type test struct { auth acme.Interface @@ -185,33 +187,17 @@ func TestHandlerGetOrder(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ auth: &mockAcmeAuthority{}, ctx: ctx, @@ -221,8 +207,8 @@ func TestHandlerGetOrder(t *testing.T) { }, "fail/getOrder-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ @@ -235,20 +221,22 @@ func TestHandlerGetOrder(t *testing.T) { }, "ok": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) { + getOrder: func(ctx context.Context, accID, id string) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, o.ID) return &o, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.OrderLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{o.ID}) return url @@ -314,8 +302,10 @@ func TestHandlerNewOrder(t *testing.T) { } prov := newProv() - url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", - acme.URLSafeProvisionerName(prov)) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/%s/new-order", + baseURL.String(), provName) type test struct { auth acme.Interface @@ -324,32 +314,16 @@ func TestHandlerNewOrder(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ ctx: ctx, statusCode: 400, @@ -358,8 +332,8 @@ func TestHandlerNewOrder(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) return test{ ctx: ctx, statusCode: 500, @@ -368,9 +342,9 @@ func TestHandlerNewOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -379,9 +353,9 @@ func TestHandlerNewOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, @@ -393,9 +367,9 @@ func TestHandlerNewOrder(t *testing.T) { nor := &NewOrderRequest{} b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, @@ -412,12 +386,14 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ auth: &mockAcmeAuthority{ - newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.Identifiers, nor.Identifiers) @@ -441,12 +417,15 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.Identifiers, nor.Identifiers) @@ -454,12 +433,11 @@ func TestHandlerNewOrder(t *testing.T) { assert.Equals(t, ops.NotAfter, naf) return &o, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.OrderLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) + return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) }, }, ctx: ctx, @@ -476,12 +454,15 @@ func TestHandlerNewOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { + newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.Identifiers, nor.Identifiers) @@ -490,12 +471,11 @@ func TestHandlerNewOrder(t *testing.T) { assert.True(t, ops.NotAfter.IsZero()) return &o, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.OrderLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) + return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID) }, }, ctx: ctx, @@ -534,7 +514,8 @@ func TestHandlerNewOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)}) + []string{fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), + provName, o.ID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) @@ -567,8 +548,10 @@ func TestHandlerFinalizeOrder(t *testing.T) { chiCtx := chi.NewRouteContext() chiCtx.URLParams.Add("ordID", o.ID) prov := newProv() - url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize", - acme.URLSafeProvisionerName(prov), o.ID) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + url := fmt.Sprintf("%s/acme/%s/order/%s/finalize", + baseURL.String(), provName, o.ID) type test struct { auth acme.Interface @@ -577,33 +560,17 @@ func TestHandlerFinalizeOrder(t *testing.T) { problem *acme.Error } var tests = map[string]func(t *testing.T) test{ - "fail/no-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.Background(), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, - "fail/nil-provisioner": func(t *testing.T) test { - return test{ - auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, nil), - statusCode: 500, - problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")), - } - }, "fail/no-account": func(t *testing.T) test { return test{ auth: &mockAcmeAuthority{}, - ctx: context.WithValue(context.Background(), provisionerContextKey, prov), + ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov), statusCode: 400, problem: acme.AccountDoesNotExistErr(nil), } }, "fail/nil-account": func(t *testing.T) test { - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, nil) return test{ auth: &mockAcmeAuthority{}, ctx: ctx, @@ -613,8 +580,8 @@ func TestHandlerFinalizeOrder(t *testing.T) { }, "fail/no-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) return test{ ctx: ctx, statusCode: 500, @@ -623,9 +590,9 @@ func TestHandlerFinalizeOrder(t *testing.T) { }, "fail/nil-payload": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, nil) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, nil) return test{ ctx: ctx, statusCode: 500, @@ -634,9 +601,9 @@ func TestHandlerFinalizeOrder(t *testing.T) { }, "fail/unmarshal-payload-error": func(t *testing.T) test { acc := &acme.Account{ID: "accID"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{}) return test{ ctx: ctx, statusCode: 400, @@ -648,9 +615,9 @@ func TestHandlerFinalizeOrder(t *testing.T) { fr := &FinalizeRequest{} b, err := json.Marshal(fr) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) return test{ ctx: ctx, statusCode: 400, @@ -664,13 +631,15 @@ func TestHandlerFinalizeOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) return test{ auth: &mockAcmeAuthority{ - finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, o.ID) @@ -690,26 +659,28 @@ func TestHandlerFinalizeOrder(t *testing.T) { } b, err := json.Marshal(nor) assert.FatalError(t, err) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) + ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, acme.AccContextKey, acc) + ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL) return test{ auth: &mockAcmeAuthority{ - finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { + p, err := acme.ProvisionerFromContext(ctx) + assert.FatalError(t, err) assert.Equals(t, p, prov) assert.Equals(t, accID, acc.ID) assert.Equals(t, id, o.ID) assert.Equals(t, incsr.Raw, csr.Raw) return &o, nil }, - getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { + getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string { assert.Equals(t, typ, acme.OrderLink) - assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.True(t, abs) assert.Equals(t, in, []string{o.ID}) - return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", - acme.URLSafeProvisionerName(prov), o.ID) + return fmt.Sprintf("%s/acme/%s/order/%s", + baseURL.String(), provName, o.ID) }, }, ctx: ctx, @@ -748,8 +719,8 @@ func TestHandlerFinalizeOrder(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, res.Header["Location"], - []string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", - acme.URLSafeProvisionerName(prov), o.ID)}) + []string{fmt.Sprintf("%s/acme/%s/order/%s", + baseURL, provName, o.ID)}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) } }) diff --git a/acme/authority.go b/acme/authority.go index 1863b48f..8677274e 100644 --- a/acme/authority.go +++ b/acme/authority.go @@ -1,6 +1,7 @@ package acme import ( + "context" "crypto" "crypto/tls" "crypto/x509" @@ -20,23 +21,29 @@ import ( // Interface is the acme authority interface. type Interface interface { - DeactivateAccount(provisioner.Interface, string) (*Account, error) - FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error) - GetAccount(provisioner.Interface, string) (*Account, error) - GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error) - GetAuthz(provisioner.Interface, string, string) (*Authz, error) - GetCertificate(string, string) ([]byte, error) - GetDirectory(provisioner.Interface) *Directory - GetLink(Link, string, bool, ...string) string - GetOrder(provisioner.Interface, string, string) (*Order, error) - GetOrdersByAccount(provisioner.Interface, string) ([]string, error) - LoadProvisionerByID(string) (provisioner.Interface, error) - NewAccount(provisioner.Interface, AccountOptions) (*Account, error) + GetDirectory(ctx context.Context) (*Directory, error) NewNonce() (string, error) - NewOrder(provisioner.Interface, OrderOptions) (*Order, error) - UpdateAccount(provisioner.Interface, string, []string) (*Account, error) UseNonce(string) error - ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error) + + DeactivateAccount(ctx context.Context, accID string) (*Account, error) + GetAccount(ctx context.Context, accID string) (*Account, error) + GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error) + NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) + UpdateAccount(context.Context, string, []string) (*Account, error) + + GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error) + ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error) + + FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error) + GetOrder(ctx context.Context, accID string, orderID string) (*Order, error) + GetOrdersByAccount(ctx context.Context, accID string) ([]string, error) + NewOrder(ctx context.Context, oo OrderOptions) (*Order, error) + + GetCertificate(string, string) ([]byte, error) + + LoadProvisionerByID(string) (provisioner.Interface, error) + GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string + GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string } // Authority is the layer that handles all ACME interactions. @@ -79,20 +86,24 @@ func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority, ordin } // GetLink returns the requested link from the directory. -func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string { - return a.dir.getLink(typ, provID, abs, inputs...) +func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { + return a.dir.getLink(ctx, typ, abs, inputs...) +} + +// GetLinkExplicit returns the requested link from the directory. +func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string { + return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...) } // GetDirectory returns the ACME directory object. -func (a *Authority) GetDirectory(p provisioner.Interface) *Directory { - name := url.PathEscape(p.GetName()) +func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) { return &Directory{ - NewNonce: a.dir.getLink(NewNonceLink, name, true), - NewAccount: a.dir.getLink(NewAccountLink, name, true), - NewOrder: a.dir.getLink(NewOrderLink, name, true), - RevokeCert: a.dir.getLink(RevokeCertLink, name, true), - KeyChange: a.dir.getLink(KeyChangeLink, name, true), - } + NewNonce: a.dir.getLink(ctx, NewNonceLink, true), + NewAccount: a.dir.getLink(ctx, NewAccountLink, true), + NewOrder: a.dir.getLink(ctx, NewOrderLink, true), + RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true), + KeyChange: a.dir.getLink(ctx, KeyChangeLink, true), + }, nil } // LoadProvisionerByID calls out to the SignAuthority interface to load a @@ -116,16 +127,16 @@ func (a *Authority) UseNonce(nonce string) error { } // NewAccount creates, stores, and returns a new ACME account. -func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) { +func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) { acc, err := newAccount(a.db, ao) if err != nil { return nil, err } - return acc.toACME(a.db, a.dir, p) + return acc.toACME(ctx, a.db, a.dir) } // UpdateAccount updates an ACME account. -func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) { +func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) { acc, err := getAccountByID(a.db, id) if err != nil { return nil, ServerInternalErr(err) @@ -133,20 +144,20 @@ func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact [] if acc, err = acc.update(a.db, contact); err != nil { return nil, err } - return acc.toACME(a.db, a.dir, p) + return acc.toACME(ctx, a.db, a.dir) } // GetAccount returns an ACME account. -func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) { +func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) { acc, err := getAccountByID(a.db, id) if err != nil { return nil, err } - return acc.toACME(a.db, a.dir, p) + return acc.toACME(ctx, a.db, a.dir) } // DeactivateAccount deactivates an ACME account. -func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) { +func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) { acc, err := getAccountByID(a.db, id) if err != nil { return nil, err @@ -154,7 +165,7 @@ func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Acco if acc, err = acc.deactivate(a.db); err != nil { return nil, err } - return acc.toACME(a.db, a.dir, p) + return acc.toACME(ctx, a.db, a.dir) } func keyToID(jwk *jose.JSONWebKey) (string, error) { @@ -166,7 +177,7 @@ func keyToID(jwk *jose.JSONWebKey) (string, error) { } // GetAccountByKey returns the ACME associated with the jwk id. -func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) { +func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { kid, err := keyToID(jwk) if err != nil { return nil, err @@ -175,11 +186,11 @@ func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKe if err != nil { return nil, err } - return acc.toACME(a.db, a.dir, p) + return acc.toACME(ctx, a.db, a.dir) } // GetOrder returns an ACME order. -func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) { +func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) { o, err := getOrder(a.db, orderID) if err != nil { return nil, err @@ -190,11 +201,11 @@ func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*O if o, err = o.updateStatus(a.db); err != nil { return nil, err } - return o.toACME(a.db, a.dir, p) + return o.toACME(ctx, a.db, a.dir) } // GetOrdersByAccount returns the list of order urls owned by the account. -func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { +func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { oids, err := getOrderIDsByAccount(a.db, id) if err != nil { return nil, err @@ -209,22 +220,26 @@ func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]st if o.Status == StatusInvalid { continue } - ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID)) + ret = append(ret, a.dir.getLink(ctx, OrderLink, true, o.ID)) } return ret, nil } // NewOrder generates, stores, and returns a new ACME order. -func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) { +func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) { order, err := newOrder(a.db, ops) if err != nil { return nil, Wrap(err, "error creating order") } - return order.toACME(a.db, a.dir, p) + return order.toACME(ctx, a.db, a.dir) } // FinalizeOrder attempts to finalize an order and generate a new certificate. -func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { +func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { + prov, err := ProvisionerFromContext(ctx) + if err != nil { + return nil, err + } o, err := getOrder(a.db, orderID) if err != nil { return nil, err @@ -232,16 +247,16 @@ func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string if accID != o.AccountID { return nil, UnauthorizedErr(errors.New("account does not own order")) } - o, err = o.finalize(a.db, csr, a.signAuth, p) + o, err = o.finalize(a.db, csr, a.signAuth, prov) if err != nil { return nil, Wrap(err, "error finalizing order") } - return o.toACME(a.db, a.dir, p) + return o.toACME(ctx, a.db, a.dir) } // GetAuthz retrieves and attempts to update the status on an ACME authz // before returning. -func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) { +func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) { az, err := getAuthz(a.db, authzID) if err != nil { return nil, err @@ -253,7 +268,7 @@ func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*A if err != nil { return nil, Wrap(err, "error updating authz status") } - return az.toACME(a.db, a.dir, p) + return az.toACME(ctx, a.db, a.dir) } // ValidateChallenge loads a challenge resource and then begins the validation process if the challenge @@ -296,7 +311,7 @@ func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*A // // Note: the default ordinal does not need to be changed unless step-ca is running in a replicated scenario. // -func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { +func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { ch, err := getChallenge(a.db, chID) if err != nil { return nil, err @@ -305,7 +320,7 @@ func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID strin case StatusPending, StatusProcessing: break case StatusInvalid, StatusValid: - return ch.toACME(a.dir, p) + return ch.toACME(ctx, a.dir) default: e:= errors.Errorf("unknown challenge state: %s", ch.getStatus()) return nil, ServerInternalErr(e) @@ -316,6 +331,11 @@ func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID strin return nil, UnauthorizedErr(errors.New("account does not own challenge")) } + p, err := ProvisionerFromContext(ctx) + if err != nil { + return nil, err + } + // Take ownership of the challenge status and retry state. The values must be reset. up := ch.clone() up.Status = StatusProcessing @@ -357,7 +377,7 @@ func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID strin e := errors.Errorf("post-validation challenge in unexpected state, %s", ch.getStatus()) return nil, ServerInternalErr(e) } - return ch.toACME(a.dir, p) + return ch.toACME(ctx, a.dir) } // The challenge validation process is specific to the type of challenge (dns-01, http-01, tls-alpn-01). @@ -434,9 +454,14 @@ func (a *Authority) RetryChallenge(chID string) { ch = up p, err := a.LoadProvisionerByID(retry.ProvisionerID) - acc, err := a.GetAccount(p, ch.getAccountID()) + if p.GetType() != provisioner.TypeACME { + log.Printf("%v", AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) + return + } + ctx := context.WithValue(context.Background(), ProvisionerContextKey, p) + acc, err := a.GetAccount(ctx, ch.getAccountID()) - v, err := a.validate(up, acc.Key) + v, err := a.validate(ch, acc.Key) if err != nil { return } diff --git a/acme/authority_test.go b/acme/authority_test.go index 9171af08..e1210acb 100644 --- a/acme/authority_test.go +++ b/acme/authority_test.go @@ -1,8 +1,10 @@ package acme import ( + "context" "encoding/json" "fmt" + "net/url" "testing" "time" @@ -16,7 +18,11 @@ import ( func TestAuthorityGetLink(t *testing.T) { auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil, 0) assert.FatalError(t, err) - provID := "acme-test-provisioner" + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) type test struct { auth *Authority typ Link @@ -30,7 +36,7 @@ func TestAuthorityGetLink(t *testing.T) { auth: auth, typ: NewAccountLink, abs: true, - res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID), + res: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), } }, "ok/new-account/no-abs": func(t *testing.T) test { @@ -38,7 +44,7 @@ func TestAuthorityGetLink(t *testing.T) { auth: auth, typ: NewAccountLink, abs: false, - res: fmt.Sprintf("/%s/new-account", provID), + res: fmt.Sprintf("/%s/new-account", provName), } }, "ok/order/abs": func(t *testing.T) test { @@ -47,7 +53,7 @@ func TestAuthorityGetLink(t *testing.T) { typ: OrderLink, abs: true, inputs: []string{"foo"}, - res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/foo", provID), + res: fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName), } }, "ok/order/no-abs": func(t *testing.T) test { @@ -56,14 +62,14 @@ func TestAuthorityGetLink(t *testing.T) { typ: OrderLink, abs: false, inputs: []string{"foo"}, - res: fmt.Sprintf("/%s/order/foo", provID), + res: fmt.Sprintf("/%s/order/foo", provName), } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - link := tc.auth.GetLink(tc.typ, provID, tc.abs, tc.inputs...) + link := tc.auth.GetLink(ctx, tc.typ, tc.abs, tc.inputs...) assert.Equals(t, tc.res, link) }) } @@ -72,14 +78,68 @@ func TestAuthorityGetLink(t *testing.T) { func TestAuthorityGetDirectory(t *testing.T) { auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil, 0) assert.FatalError(t, err) + prov := newProv() - acmeDir := auth.GetDirectory(prov) - assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", URLSafeProvisionerName(prov))) - assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", URLSafeProvisionerName(prov))) - assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", URLSafeProvisionerName(prov))) - //assert.Equals(t, acmeDir.NewOrder, "httsp://ca.smallstep.com/acme/new-authz") - assert.Equals(t, acmeDir.RevokeCert, fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", URLSafeProvisionerName(prov))) - assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", URLSafeProvisionerName(prov))) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) + + type test struct { + ctx context.Context + err *Error + } + tests := map[string]func(t *testing.T) test{ + "ok/empty-provisioner": func(t *testing.T) test { + return test{ + ctx: context.Background(), + } + }, + "ok/no-baseURL": func(t *testing.T) test { + return test{ + ctx: context.WithValue(context.Background(), ProvisionerContextKey, prov), + } + }, + "ok/baseURL": func(t *testing.T) test { + return test{ + ctx: ctx, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if dir, err := auth.GetDirectory(tc.ctx); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + bu := BaseURLFromContext(tc.ctx) + if bu == nil { + bu = &url.URL{Scheme: "https", Host: "ca.smallstep.com"} + } + + var provName string + prov, err := ProvisionerFromContext(tc.ctx) + if err != nil { + provName = "" + } else { + provName = url.PathEscape(prov.GetName()) + } + + assert.Equals(t, dir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", bu.String(), provName)) + assert.Equals(t, dir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", bu.String(), provName)) + assert.Equals(t, dir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", bu.String(), provName)) + assert.Equals(t, dir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", bu.String(), provName)) + assert.Equals(t, dir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", bu.String(), provName)) + } + } + }) + } } func TestAuthorityNewNonce(t *testing.T) { @@ -193,6 +253,8 @@ func TestAuthorityNewAccount(t *testing.T) { Key: jwk, Contact: []string{"foo", "bar"}, } prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority ops AccountOptions @@ -225,7 +287,7 @@ func TestAuthorityNewAccount(t *testing.T) { if count == 1 { var acc *account assert.FatalError(t, json.Unmarshal(newval, &acc)) - *acmeacc, err = acc.toACME(nil, dir, prov) + *acmeacc, err = acc.toACME(ctx, nil, dir) return nil, true, nil } count++ @@ -243,7 +305,7 @@ func TestAuthorityNewAccount(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAcc, err := tc.auth.NewAccount(prov, tc.ops); err != nil { + if acmeAcc, err := tc.auth.NewAccount(ctx, tc.ops); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -266,6 +328,8 @@ func TestAuthorityNewAccount(t *testing.T) { func TestAuthorityGetAccount(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id string @@ -310,7 +374,7 @@ func TestAuthorityGetAccount(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAcc, err := tc.auth.GetAccount(prov, tc.id); err != nil { + if acmeAcc, err := tc.auth.GetAccount(ctx, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -323,7 +387,7 @@ func TestAuthorityGetAccount(t *testing.T) { gotb, err := json.Marshal(acmeAcc) assert.FatalError(t, err) - acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -337,6 +401,8 @@ func TestAuthorityGetAccount(t *testing.T) { func TestAuthorityGetAccountByKey(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority jwk *jose.JSONWebKey @@ -411,7 +477,7 @@ func TestAuthorityGetAccountByKey(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAcc, err := tc.auth.GetAccountByKey(prov, tc.jwk); err != nil { + if acmeAcc, err := tc.auth.GetAccountByKey(ctx, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -424,7 +490,7 @@ func TestAuthorityGetAccountByKey(t *testing.T) { gotb, err := json.Marshal(acmeAcc) assert.FatalError(t, err) - acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -438,6 +504,8 @@ func TestAuthorityGetAccountByKey(t *testing.T) { func TestAuthorityGetOrder(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id, accID string @@ -535,7 +603,7 @@ func TestAuthorityGetOrder(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeO, err := tc.auth.GetOrder(prov, tc.accID, tc.id); err != nil { + if acmeO, err := tc.auth.GetOrder(ctx, tc.accID, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -548,7 +616,7 @@ func TestAuthorityGetOrder(t *testing.T) { gotb, err := json.Marshal(acmeO) assert.FatalError(t, err) - acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -655,6 +723,8 @@ func TestAuthorityGetCertificate(t *testing.T) { func TestAuthorityGetAuthz(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id, accID string @@ -784,7 +854,7 @@ func TestAuthorityGetAuthz(t *testing.T) { return ret, nil }, } - acmeAz, err := az.toACME(mockdb, newDirectory("ca.smallstep.com", "acme"), prov) + acmeAz, err := az.toACME(ctx, mockdb, newDirectory("ca.smallstep.com", "acme")) assert.FatalError(t, err) count = 0 @@ -825,7 +895,7 @@ func TestAuthorityGetAuthz(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAz, err := tc.auth.GetAuthz(prov, tc.accID, tc.id); err != nil { + if acmeAz, err := tc.auth.GetAuthz(ctx, tc.accID, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -850,6 +920,8 @@ func TestAuthorityGetAuthz(t *testing.T) { func TestAuthorityNewOrder(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority ops OrderOptions @@ -903,7 +975,7 @@ func TestAuthorityNewOrder(t *testing.T) { assert.Equals(t, bucket, orderTable) var o order assert.FatalError(t, json.Unmarshal(newval, &o)) - *acmeO, err = o.toACME(nil, dir, prov) + *acmeO, err = o.toACME(ctx, nil, dir) assert.FatalError(t, err) *accID = o.AccountID case 9: @@ -928,7 +1000,7 @@ func TestAuthorityNewOrder(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeO, err := tc.auth.NewOrder(prov, tc.ops); err != nil { + if acmeO, err := tc.auth.NewOrder(ctx, tc.ops); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -951,6 +1023,10 @@ func TestAuthorityNewOrder(t *testing.T) { func TestAuthorityGetOrdersByAccount(t *testing.T) { prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) type test struct { auth *Authority id string @@ -1051,8 +1127,8 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) { auth: auth, id: id, res: []string{ - fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), foo.ID), - fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), baz.ID), + fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID), + fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, baz.ID), }, } }, @@ -1060,7 +1136,7 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if orderLinks, err := tc.auth.GetOrdersByAccount(prov, tc.id); err != nil { + if orderLinks, err := tc.auth.GetOrdersByAccount(ctx, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1079,6 +1155,8 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) { func TestAuthorityFinalizeOrder(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id, accID string @@ -1174,7 +1252,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeO, err := tc.auth.FinalizeOrder(prov, tc.accID, tc.id, nil); err != nil { + if acmeO, err := tc.auth.FinalizeOrder(ctx, tc.accID, tc.id, nil); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1187,7 +1265,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) { gotb, err := json.Marshal(acmeO) assert.FatalError(t, err) - acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -1201,6 +1279,8 @@ func TestAuthorityFinalizeOrder(t *testing.T) { func TestAuthorityValidateChallenge(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id, accID string @@ -1302,7 +1382,7 @@ func TestAuthorityValidateChallenge(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeCh, err := tc.auth.ValidateChallenge(prov, tc.accID, tc.id, nil); err != nil { + if acmeCh, err := tc.auth.ValidateChallenge(ctx, tc.accID, tc.id, nil); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1314,7 +1394,7 @@ func TestAuthorityValidateChallenge(t *testing.T) { if assert.Nil(t, tc.err) { gotb, err := json.Marshal(acmeCh) assert.FatalError(t, err) - acmeExp, err := tc.ch.toACME(tc.auth.dir, prov) + acmeExp, err := tc.ch.toACME(ctx, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -1328,6 +1408,8 @@ func TestAuthorityValidateChallenge(t *testing.T) { func TestAuthorityUpdateAccount(t *testing.T) { contact := []string{"baz", "zap"} prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id string @@ -1407,7 +1489,7 @@ func TestAuthorityUpdateAccount(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAcc, err := tc.auth.UpdateAccount(prov, tc.id, tc.contact); err != nil { + if acmeAcc, err := tc.auth.UpdateAccount(ctx, tc.id, tc.contact); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1420,7 +1502,7 @@ func TestAuthorityUpdateAccount(t *testing.T) { gotb, err := json.Marshal(acmeAcc) assert.FatalError(t, err) - acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) @@ -1434,6 +1516,8 @@ func TestAuthorityUpdateAccount(t *testing.T) { func TestAuthorityDeactivateAccount(t *testing.T) { prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { auth *Authority id string @@ -1510,7 +1594,7 @@ func TestAuthorityDeactivateAccount(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if acmeAcc, err := tc.auth.DeactivateAccount(prov, tc.id); err != nil { + if acmeAcc, err := tc.auth.DeactivateAccount(ctx, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) @@ -1523,7 +1607,7 @@ func TestAuthorityDeactivateAccount(t *testing.T) { gotb, err := json.Marshal(acmeAcc) assert.FatalError(t, err) - acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) + acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir) assert.FatalError(t, err) expb, err := json.Marshal(acmeExp) assert.FatalError(t, err) diff --git a/acme/authz.go b/acme/authz.go index 789dcab2..1a118bcc 100644 --- a/acme/authz.go +++ b/acme/authz.go @@ -1,12 +1,12 @@ package acme import ( + "context" "encoding/json" "strings" "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/nosql" ) @@ -51,7 +51,7 @@ type authz interface { getChallenges() []string getCreated() time.Time updateStatus(db nosql.DB) (authz, error) - toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error) + toACME(context.Context, nosql.DB, *directory) (*Authz, error) } // baseAuthz is the base authz type that others build from. @@ -141,14 +141,14 @@ func (ba *baseAuthz) getCreated() time.Time { // toACME converts the internal Authz type into the public acmeAuthz type for // presentation in the ACME protocol. -func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) { +func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) { var chs = make([]*Challenge, len(ba.Challenges)) for i, chID := range ba.Challenges { ch, err := getChallenge(db, chID) if err != nil { return nil, err } - chs[i], err = ch.toACME(dir, p) + chs[i], err = ch.toACME(ctx, dir) if err != nil { return nil, err } diff --git a/acme/authz_test.go b/acme/authz_test.go index 1cbb939e..6fa24f25 100644 --- a/acme/authz_test.go +++ b/acme/authz_test.go @@ -1,6 +1,7 @@ package acme import ( + "context" "encoding/json" "strings" "testing" @@ -369,7 +370,10 @@ func TestAuthzToACME(t *testing.T) { } az, err := newAuthz(mockdb, "1234", iden) assert.FatalError(t, err) + prov := newProv() + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080") type test struct { db nosql.DB @@ -419,7 +423,7 @@ func TestAuthzToACME(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - acmeAz, err := az.toACME(tc.db, dir, prov) + acmeAz, err := az.toACME(ctx, tc.db, dir) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) @@ -434,9 +438,9 @@ func TestAuthzToACME(t *testing.T) { assert.Equals(t, acmeAz.Identifier, iden) assert.Equals(t, acmeAz.Status, StatusPending) - acmeCh1, err := ch1.toACME(dir, prov) + acmeCh1, err := ch1.toACME(ctx, dir) assert.FatalError(t, err) - acmeCh2, err := ch2.toACME(dir, prov) + acmeCh2, err := ch2.toACME(ctx, dir) assert.FatalError(t, err) assert.Equals(t, acmeAz.Challenges[0], acmeCh1) diff --git a/acme/challenge.go b/acme/challenge.go index 3ee27af8..07d8c4f1 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -1,6 +1,7 @@ package acme import ( + "context" "crypto" "crypto/sha256" "crypto/subtle" @@ -17,7 +18,6 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/jose" "github.com/smallstep/nosql" ) @@ -81,7 +81,7 @@ type challenge interface { getAccountID() string getValidated() time.Time getCreated() time.Time - toACME(*directory, provisioner.Interface) (*Challenge, error) + toACME(context.Context, *directory) (*Challenge, error) } // ChallengeOptions is the type used to created a new Challenge. @@ -184,12 +184,12 @@ func (bc *baseChallenge) getError() *AError { // toACME converts the internal Challenge type into the public acmeChallenge // type for presentation in the ACME protocol. -func (bc *baseChallenge) toACME(dir *directory, p provisioner.Interface) (*Challenge, error) { +func (bc *baseChallenge) toACME(ctx context.Context, dir *directory) (*Challenge, error) { ac := &Challenge{ Type: bc.getType(), Status: bc.getStatus(), Token: bc.getToken(), - URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()), + URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()), ID: bc.getID(), AuthzID: bc.getAuthzID(), } diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 1eaa4ced..b574995f 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -2,6 +2,7 @@ package acme import ( "bytes" + "context" "crypto" "crypto/rand" "crypto/rsa" @@ -20,6 +21,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -280,6 +282,10 @@ func TestChallengeToACME_Valid(t *testing.T) { } prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) tests := map[string]challenge{ "dns": chs[0], "http": chs[1], @@ -288,15 +294,15 @@ func TestChallengeToACME_Valid(t *testing.T) { for name, ch := range tests { t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(dir, prov) + ach, err := ch.toACME(ctx, dir) assert.FatalError(t, err) assert.Equals(t, ach.Type, ch.getType()) assert.Equals(t, ach.Status, ch.getStatus()) assert.Equals(t, ach.Token, ch.getToken()) assert.Equals(t, ach.URL, - fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/%s", - URLSafeProvisionerName(prov), ch.getID())) + fmt.Sprintf("%s/acme/%s/challenge/%s", + baseURL.String(), provName, ch.getID())) assert.Equals(t, ach.ID, ch.getID()) assert.Equals(t, ach.AuthzID, ch.getAuthzID()) @@ -337,6 +343,11 @@ func TestChallengeToACME_Retry(t *testing.T) { } prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "example.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) + tests := map[string]challenge{ "dns_no-retry": chs[0+0*len(fns)], "http_no-retry": chs[1+0*len(fns)], @@ -347,15 +358,15 @@ func TestChallengeToACME_Retry(t *testing.T) { } for name, ch := range tests { t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(dir, prov) + ach, err := ch.toACME(ctx, dir) assert.FatalError(t, err) assert.Equals(t, ach.Type, ch.getType()) assert.Equals(t, ach.Status, ch.getStatus()) assert.Equals(t, ach.Token, ch.getToken()) assert.Equals(t, ach.URL, - fmt.Sprintf("https://example.com/acme/%s/challenge/%s", - URLSafeProvisionerName(prov), ch.getID())) + fmt.Sprintf("%s/acme/%s/challenge/%s", + baseURL.String(), provName, ch.getID())) assert.Equals(t, ach.ID, ch.getID()) assert.Equals(t, ach.AuthzID, ch.getAuthzID()) diff --git a/acme/common.go b/acme/common.go index 936574e3..08c609d1 100644 --- a/acme/common.go +++ b/acme/common.go @@ -1,6 +1,7 @@ package acme import ( + "context" "crypto/x509" "net/url" "time" @@ -8,8 +9,75 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" ) +// ContextKey is the key type for storing and searching for ACME request +// essentials in the context of a request. +type ContextKey string + +const ( + // AccContextKey account key + AccContextKey = ContextKey("acc") + // BaseURLContextKey baseURL key + BaseURLContextKey = ContextKey("baseURL") + // JwsContextKey jws key + JwsContextKey = ContextKey("jws") + // JwkContextKey jwk key + JwkContextKey = ContextKey("jwk") + // PayloadContextKey payload key + PayloadContextKey = ContextKey("payload") + // ProvisionerContextKey provisioner key + ProvisionerContextKey = ContextKey("provisioner") +) + +// AccountFromContext searches the context for an ACME account. Returns the +// account or an error. +func AccountFromContext(ctx context.Context) (*Account, error) { + val, ok := ctx.Value(AccContextKey).(*Account) + if !ok || val == nil { + return nil, AccountDoesNotExistErr(nil) + } + return val, nil +} + +// BaseURLFromContext returns the baseURL if one is stored in the context. +func BaseURLFromContext(ctx context.Context) *url.URL { + val, ok := ctx.Value(BaseURLContextKey).(*url.URL) + if !ok || val == nil { + return nil + } + return val +} + +// JwkFromContext searches the context for a JWK. Returns the JWK or an error. +func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { + val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey) + if !ok || val == nil { + return nil, ServerInternalErr(errors.Errorf("jwk expected in request context")) + } + return val, nil +} + +// JwsFromContext searches the context for a JWS. Returns the JWS or an error. +func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) { + val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature) + if !ok || val == nil { + return nil, ServerInternalErr(errors.Errorf("jws expected in request context")) + } + return val, nil +} + +// ProvisionerFromContext searches the context for a provisioner. Returns the +// provisioner or an error. +func ProvisionerFromContext(ctx context.Context) (provisioner.Interface, error) { + val, ok := ctx.Value(ProvisionerContextKey).(provisioner.Interface) + if !ok || val == nil { + return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context")) + } + return val, nil +} + // SignAuthority is the interface implemented by a CA authority. type SignAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) @@ -59,9 +127,3 @@ func (c *Clock) Now() time.Time { } var clock = new(Clock) - -// URLSafeProvisionerName returns a path escaped version of the ACME provisioner -// ID that is safe to use in URL paths. -func URLSafeProvisionerName(p provisioner.Interface) string { - return url.PathEscape(p.GetName()) -} diff --git a/acme/directory.go b/acme/directory.go index 85819f10..d5681b73 100644 --- a/acme/directory.go +++ b/acme/directory.go @@ -1,8 +1,10 @@ package acme import ( + "context" "encoding/json" "fmt" + "net/url" "github.com/pkg/errors" ) @@ -100,8 +102,18 @@ func (l Link) String() string { } } -// getLink returns an absolute or partial path to the given resource. -func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string { +func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string { + var provName string + if p, err := ProvisionerFromContext(ctx); err == nil && p != nil { + provName = p.GetName() + } + return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...) +} + +// getLinkExplicit returns an absolute or partial path to the given resource and a base +// URL dynamically obtained from the request for which the link is being +// calculated. +func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { var link string switch typ { case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: @@ -113,8 +125,26 @@ func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs . case FinalizeLink: link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) } + if abs { - return fmt.Sprintf("https://%s/%s%s", d.dns, d.prefix, link) + // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 + u := url.URL{} + if baseURL != nil { + u = *baseURL + } + + // If no Scheme is set, then default to https. + if u.Scheme == "" { + u.Scheme = "https" + } + + // If no Host is set, then use the default (first DNS attr in the ca.json). + if u.Host == "" { + u.Host = d.dns + } + + u.Path = d.prefix + link + return u.String() } return link } diff --git a/acme/directory_test.go b/acme/directory_test.go index 2cfc8bb8..dd4c534c 100644 --- a/acme/directory_test.go +++ b/acme/directory_test.go @@ -1,7 +1,9 @@ package acme import ( + "context" "fmt" + "net/url" "testing" "github.com/smallstep/assert" @@ -14,47 +16,84 @@ func TestDirectoryGetLink(t *testing.T) { id := "1234" prov := newProv() - provID := URLSafeProvisionerName(prov) + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID)) - assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID)) + assert.Equals(t, dir.getLink(ctx, NewNonceLink, true), + fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) + assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID)) - assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID)) + // No provisioner + ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL) + assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true), + fmt.Sprintf("%s/acme//new-nonce", baseURL.String())) + assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce") - assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID)) - assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID)) + // No baseURL + ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov) + assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true), + fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName)) + assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName)) - assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID)) - assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID)) - - assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID)) - assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID)) - - assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID)) - assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID)) - - assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID)) - assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) - - assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID)) - assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID)) - - assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID)) - assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID)) - - assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID)) - assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID)) - - assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID)) - assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID)) - - assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID)) - assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID)) - - assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID)) - assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID)) - - assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID)) - assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID)) + assert.Equals(t, dir.getLink(ctx, OrderLink, true, id), + fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName)) + assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName)) +} + +func TestDirectoryGetLinkExplicit(t *testing.T) { + dns := "ca.smallstep.com" + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + prefix := "acme" + dir := newDirectory(dns, prefix) + id := "1234" + + prov := newProv() + provID := url.PathEscape(prov.GetName()) + + assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) + assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) + assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) + + assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) + + assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) + + assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) + + assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) + + assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) + + assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) + + assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) + + assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) + + assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) + + assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) + + assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) + + assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID)) + + assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) + assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) } diff --git a/acme/order.go b/acme/order.go index 27e030e9..3f02bc51 100644 --- a/acme/order.go +++ b/acme/order.go @@ -332,10 +332,10 @@ func getOrder(db nosql.DB, id string) (*order, error) { // toACME converts the internal Order type into the public acmeOrder type for // presentation in the ACME protocol. -func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) { +func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) { azs := make([]string, len(o.Authorizations)) for i, aid := range o.Authorizations { - azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid) + azs[i] = dir.getLink(ctx, AuthzLink, true, aid) } ao := &Order{ Status: o.Status, @@ -344,12 +344,12 @@ func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*O NotBefore: o.NotBefore.Format(time.RFC3339), NotAfter: o.NotAfter.Format(time.RFC3339), Authorizations: azs, - Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID), + Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), ID: o.ID, } if o.Certificate != "" { - ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate) + ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate) } return ao, nil } diff --git a/acme/order_test.go b/acme/order_test.go index b0453754..d3314a3e 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -6,6 +6,7 @@ import ( "crypto/x509/pkix" "encoding/json" "fmt" + "net/url" "testing" "time" @@ -150,6 +151,10 @@ func TestGetOrder(t *testing.T) { func TestOrderToACME(t *testing.T) { dir := newDirectory("ca.smallstep.com", "acme") prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) type test struct { o *order @@ -172,7 +177,7 @@ func TestOrderToACME(t *testing.T) { for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { - acmeOrder, err := tc.o.toACME(nil, dir, prov) + acmeOrder, err := tc.o.toACME(ctx, nil, dir) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) @@ -186,9 +191,10 @@ func TestOrderToACME(t *testing.T) { assert.Equals(t, acmeOrder.ID, tc.o.ID) assert.Equals(t, acmeOrder.Status, tc.o.Status) assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) - assert.Equals(t, acmeOrder.Finalize, fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s/finalize", URLSafeProvisionerName(prov), tc.o.ID)) + assert.Equals(t, acmeOrder.Finalize, + fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID)) if tc.o.Certificate != "" { - assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/%s", URLSafeProvisionerName(prov), tc.o.Certificate)) + assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate)) } expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) diff --git a/authority/authority.go b/authority/authority.go index 8cf4cfc1..828adf2f 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -67,7 +67,6 @@ func New(config *Config, opts ...Option) (*Authority, error) { var a = &Authority{ config: config, certificates: new(sync.Map), - provisioners: provisioner.NewCollection(config.getAudiences()), } // Apply options. @@ -85,6 +84,44 @@ func New(config *Config, opts ...Option) (*Authority, error) { return a, nil } +// NewEmbedded initializes an authority that can be embedded in a different +// project without the limitations of the config. +func NewEmbedded(opts ...Option) (*Authority, error) { + a := &Authority{ + config: &Config{}, + certificates: new(sync.Map), + } + + // Apply options. + for _, fn := range opts { + if err := fn(a); err != nil { + return nil, err + } + } + + // Validate required options + switch { + case a.config == nil: + return nil, errors.New("cannot create an authority without a configuration") + case len(a.rootX509Certs) == 0 && a.config.Root.HasEmpties(): + return nil, errors.New("cannot create an authority without a root certificate") + case a.x509Issuer == nil && a.config.IntermediateCert == "": + return nil, errors.New("cannot create an authority without an issuer certificate") + case a.x509Signer == nil && a.config.IntermediateKey == "": + return nil, errors.New("cannot create an authority without an issuer signer") + } + + // Initialize config required fields. + a.config.init() + + // Initialize authority from options or configuration. + if err := a.init(); err != nil { + return nil, err + } + + return a, nil +} + // init performs validation and initializes the fields of an Authority struct. func (a *Authority) init() error { // Check if handler has already been validated/initialized. @@ -232,9 +269,11 @@ func (a *Authority) init() error { return err } // Initialize provisioners + audiences := a.config.getAudiences() + a.provisioners = provisioner.NewCollection(audiences) config := provisioner.Config{ Claims: claimer.Claims(), - Audiences: a.config.getAudiences(), + Audiences: audiences, DB: a.db, SSHKeys: &provisioner.SSHKeys{ UserKeys: sshKeys.UserKeys, diff --git a/authority/authority_test.go b/authority/authority_test.go index 058a4c25..3ab3e142 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -1,8 +1,13 @@ package authority import ( + "crypto" + "crypto/rand" "crypto/sha256" + "crypto/x509" "encoding/hex" + "io/ioutil" + "net" "reflect" "testing" @@ -10,6 +15,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" + "github.com/smallstep/cli/crypto/pemutil" stepJOSE "github.com/smallstep/cli/jose" ) @@ -182,3 +188,123 @@ func TestAuthority_GetDatabase(t *testing.T) { }) } } + +func TestNewEmbedded(t *testing.T) { + caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + type args struct { + opts []Option + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{[]Option{WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, false}, + {"ok empty config", args{[]Option{WithConfig(&Config{}), WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, false}, + {"ok config file", args{[]Option{WithConfigFile("../ca/testdata/ca.json")}}, false}, + {"ok config", args{[]Option{WithConfig(&Config{ + Root: []string{"testdata/certs/root_ca.crt"}, + IntermediateCert: "testdata/certs/intermediate_ca.crt", + IntermediateKey: "testdata/secrets/intermediate_ca_key", + Password: "pass", + AuthorityConfig: &AuthConfig{}, + })}}, false}, + {"fail options", args{[]Option{WithX509RootBundle([]byte("bad data"))}}, true}, + {"fail missing config", args{[]Option{WithConfig(nil), WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))}}, true}, + {"fail missing root", args{[]Option{WithX509Signer(crt, key.(crypto.Signer))}}, true}, + {"fail missing signer", args{[]Option{WithX509RootBundle(caPEM)}}, true}, + {"fail missing root file", args{[]Option{WithConfig(&Config{ + IntermediateCert: "testdata/certs/intermediate_ca.crt", + IntermediateKey: "testdata/secrets/intermediate_ca_key", + Password: "pass", + AuthorityConfig: &AuthConfig{}, + })}}, true}, + {"fail missing issuer", args{[]Option{WithConfig(&Config{ + Root: []string{"testdata/certs/root_ca.crt"}, + IntermediateKey: "testdata/secrets/intermediate_ca_key", + Password: "pass", + AuthorityConfig: &AuthConfig{}, + })}}, true}, + {"fail missing signer", args{[]Option{WithConfig(&Config{ + Root: []string{"testdata/certs/root_ca.crt"}, + IntermediateCert: "testdata/certs/intermediate_ca.crt", + Password: "pass", + AuthorityConfig: &AuthConfig{}, + })}}, true}, + {"fail bad password", args{[]Option{WithConfig(&Config{ + Root: []string{"testdata/certs/root_ca.crt"}, + IntermediateCert: "testdata/certs/intermediate_ca.crt", + IntermediateKey: "testdata/secrets/intermediate_ca_key", + Password: "bad", + AuthorityConfig: &AuthConfig{}, + })}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewEmbedded(tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("NewEmbedded() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + assert.True(t, got.initOnce) + assert.NotNil(t, got.rootX509Certs) + assert.NotNil(t, got.x509Signer) + assert.NotNil(t, got.x509Issuer) + } + }) + } +} + +func TestNewEmbedded_Sign(t *testing.T) { + caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + a, err := NewEmbedded(WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))) + assert.FatalError(t, err) + + // Sign + cr, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ + DNSNames: []string{"foo.bar.zar"}, + }, key) + assert.FatalError(t, err) + csr, err := x509.ParseCertificateRequest(cr) + assert.FatalError(t, err) + + cert, err := a.Sign(csr, provisioner.Options{}) + assert.FatalError(t, err) + assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) + assert.Equals(t, crt, cert[1]) +} + +func TestNewEmbedded_GetTLSCertificate(t *testing.T) { + caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + key, err := pemutil.Read("testdata/secrets/intermediate_ca_key", pemutil.WithPassword([]byte("pass"))) + assert.FatalError(t, err) + + a, err := NewEmbedded(WithX509RootBundle(caPEM), WithX509Signer(crt, key.(crypto.Signer))) + assert.FatalError(t, err) + + // GetTLSCertificate + cert, err := a.GetTLSCertificate() + assert.FatalError(t, err) + assert.Equals(t, []string{"localhost"}, cert.Leaf.DNSNames) + assert.True(t, cert.Leaf.IPAddresses[0].Equal(net.ParseIP("127.0.0.1"))) + assert.True(t, cert.Leaf.IPAddresses[1].Equal(net.ParseIP("::1"))) +} diff --git a/authority/config.go b/authority/config.go index ceb2ea89..a26d19ad 100644 --- a/authority/config.go +++ b/authority/config.go @@ -75,12 +75,31 @@ type AuthConfig struct { Backdate *provisioner.Duration `json:"backdate,omitempty"` } +// init initializes the required fields in the AuthConfig if they are not +// provided. +func (c *AuthConfig) init() { + if c.Provisioners == nil { + c.Provisioners = provisioner.List{} + } + if c.Template == nil { + c.Template = &x509util.ASN1DN{} + } + if c.Backdate == nil { + c.Backdate = &provisioner.Duration{ + Duration: defaultBackdate, + } + } +} + // Validate validates the authority configuration. func (c *AuthConfig) Validate(audiences provisioner.Audiences) error { if c == nil { return errors.New("authority cannot be undefined") } + // Initialize required fields. + c.init() + // Check that only one K8sSA is enabled var k8sCount int for _, p := range c.Provisioners { @@ -92,18 +111,8 @@ func (c *AuthConfig) Validate(audiences provisioner.Audiences) error { return errors.New("cannot have more than one kubernetes service account provisioner") } - if c.Template == nil { - c.Template = &x509util.ASN1DN{} - } - - if c.Backdate != nil { - if c.Backdate.Duration < 0 { - return errors.New("authority.backdate cannot be less than 0") - } - } else { - c.Backdate = &provisioner.Duration{ - Duration: defaultBackdate, - } + if c.Backdate.Duration < 0 { + return errors.New("authority.backdate cannot be less than 0") } return nil @@ -126,6 +135,21 @@ func LoadConfiguration(filename string) (*Config, error) { return &c, nil } +// initializes the minimal configuration required to create an authority. This +// is mainly used on embedded authorities. +func (c *Config) init() { + if c.DNSNames == nil { + c.DNSNames = []string{"localhost", "127.0.0.1", "::1"} + } + if c.TLS == nil { + c.TLS = &DefaultTLSOptions + } + if c.AuthorityConfig == nil { + c.AuthorityConfig = &AuthConfig{} + } + c.AuthorityConfig.init() +} + // Save saves the configuration to the given filename. func (c *Config) Save(filename string) error { f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) diff --git a/authority/options.go b/authority/options.go index 04cd7bef..59566822 100644 --- a/authority/options.go +++ b/authority/options.go @@ -17,6 +17,24 @@ import ( // Option sets options to the Authority. type Option func(*Authority) error +// WithConfig replaces the current config with the given one. No validation is +// performed in the given value. +func WithConfig(config *Config) Option { + return func(a *Authority) error { + a.config = config + return nil + } +} + +// WithConfigFile reads the given filename as a configuration file and replaces +// the current one. No validation is performed in the given configuration. +func WithConfigFile(filename string) Option { + return func(a *Authority) (err error) { + a.config, err = LoadConfiguration(filename) + return + } +} + // WithDatabase sets an already initialized authority database to a new // authority. This option is intended to be use on graceful reloads. func WithDatabase(db db.AuthDB) Option { diff --git a/docs/GETTING_STARTED.md b/docs/GETTING_STARTED.md index a442c2b2..44bd8876 100644 --- a/docs/GETTING_STARTED.md +++ b/docs/GETTING_STARTED.md @@ -34,8 +34,11 @@ provisioners and its options. To initialize a PKI and configure the Step Certificate Authority run: +> **NOTE**: `step ca init` only initialize an x509 CA. If you +> would like to initialize an SSH CA as well, add the `--ssh` flag. + ``` -step ca init +step ca init [--ssh] ``` You'll be asked for a name for your PKI. This name will appear in your CA @@ -54,27 +57,40 @@ You should see: . ├── certs │   ├── intermediate_ca.crt -│   └── root_ca.crt +│   ├── root_ca.crt +│   ├── ssh_host_key.pub (--ssh only) +│   └── ssh_user_key.pub (--ssh only) ├── config │   ├── ca.json │   └── defaults.json └── secrets ├── intermediate_ca_key - └── root_ca_key + ├── root_ca_key + ├── ssh_host_key (--ssh only) + └── ssh_user_key (--ssh only) ``` The files created include: * `root_ca.crt` and `root_ca_key`: the root certificate and private key for - your PKI +your PKI. + * `intermediate_ca.crt` and `intermediate_ca_key`: the intermediate certificate - and private key that will be used to sign leaf certificates +and private key that will be used to sign leaf certificates. + +* `ssh_host_key.pub` and `ssh_host_key` (`--ssh` only): the SSH host pub/priv key +pair that will be used to sign new host SSH certificates. + +* `ssh_user_key.pub` and `ssh_user_key` (`--ssh` only): the SSH user pub/priv key +pair that will be used to sign new user SSH certificates. + * `ca.json`: the configuration file necessary for running the Step CA. + * `defaults.json`: file containing default parameters for the `step` CA cli interface. You can override these values with the appropriate flags or environment variables. -All of the files endinging in `_key` are password protected using the password +All of the files ending in `_key` are password protected using the password you chose during PKI initialization. We advise you to change these passwords (using the `step crypto change-pass` utility) if you plan to run your CA in a non-development environment. @@ -146,10 +162,34 @@ ciphersuites, min/max TLS version, etc. against token reuse. The default value is `false`. Do not change this unless you know what you are doing. - - `provisioners`: list of provisioners. Each provisioner has a `name`, - associated public/private keys, and an optional `claims` attribute that will - override any values set in the global `claims` directly underneath `authority`. + SSH CA properties + * `minUserSSHDuration`: do not allow certificates with a duration less + than this value. + + * `maxUserSSHDuration`: do not allow certificates with a duration + greater than this value. + + * `defaultUserSSHDuration`: if no certificate validity period is specified, + use this value. + + * `minHostSSHDuration`: do not allow certificates with a duration less + than this value. + + * `maxHostSSHDuration`: do not allow certificates with a duration + greater than this value. + + * `defaultHostSSHDuration`: if no certificate validity period is specified, + use this value. + + * `enableSSHCA`: enable all provisioners to generate SSH Certificates. + The deault value is `false`. You can enable this option per provisioner + by setting it to `true` in the provisioner claims. + + - `provisioners`: list of provisioners. + See the [provisioners documentation](./provisioners.md). Each provisioner + has an optional `claims` attribute that can override any attribute defined + at the level above in the `authority.claims`. `step ca init` will generate one provisioner. New provisioners can be added by running `step ca provisioner add`. @@ -445,9 +485,17 @@ Please enter the password to decrypt ~/.step/secrets/intermediate_ca_key: passwo 2019/02/21 12:09:51 Serving HTTPS on :9443 ... ``` -Please [`step ca provisioner`](https://smallstep.com/docs/cli/ca/provisioner/)'s docs for details on all available claims properties. The durations are strings which are a sequence of decimal numbers, each with optional fraction and a unit suffix, such as "300ms" or "2h45m". Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +See the [`provisioner doc`][1] for details on all available provisioner claims. +The durations are strings which are a sequence of decimal numbers, each with +optional fraction and a unit suffix, such as "300ms" or "2h45m". Valid time +units are "ns", "us" (or "µs"), "ms", "s", "m", "h". -Now certs issued by the `dev@smallstep.com` provisioner will be valid for two hours and deny renewals. Command line flags allow validity extension up to 12h, please see [`step ca certificate`](https://smallstep.com/docs/cli/ca/certificate/)'s docs for details. +Now certs issued by the `dev@smallstep.com` provisioner will be valid for two +hours and deny renewals. Command line flags allow validity extension up to 12h, +please see [`step ca certificate`][2]'s docs for details. + +[1]: ./provisioners.md +[2]: https://smallstep.com/docs/cli/ca/certificate/ ```bash # grab a cert, will also work with 'step ca token' flow @@ -610,3 +658,31 @@ are features that we plan to implement, but are not yet available. In the mean time short lived certificates are a decent alternative. * Keep your hosts secure by enforcing AuthN and AuthZ for every connection. SSH access is a big one. + + +## Notes on Running Step CA as a Highly Available Service + +**CAUTION**: `step-ca` is built to scale horizontally. However, the creators +and maintainers do not regularly test in an HA environment using mulitple +instances. You may run into issues we did not plan for. If this happens, please +[open an issue][3]. + +### Considerations + +A few things to consider / implement when running multiple instances of `step-ca`: + +* Use `MySQL` DB: The default `Badger` DB cannot be read / written by more than one +process simultaneously. The only supported DB that can support multiple instances +is `MySQL`. See the [database documentation][4] for guidance on configuring `MySQL`. + +* Synchronize `ca.json` across instances: `step-ca` reads all of it's +configuration (and all of the provisioner configuration) from the `ca.json` file +specified on the command line. If the `ca.json` of one instance is modified +(either manually or using a command like `step ca provisioner (add | remove)`) +the other instances will not pick up on this change until the `ca.json` is +copied over to the correct location for each instance and the instance itself +is `SIGHUP`'ed (or restarted). It's recommended to use a configuration management +(ansible, chef, salt, puppet, etc.) tool to synchronize `ca.json` across instances. + +[3]: https://github.com/smallstep/certificates/issues +[4]: ./database.md diff --git a/docs/provisioners.md b/docs/provisioners.md index 070b819b..5ab7c997 100644 --- a/docs/provisioners.md +++ b/docs/provisioners.md @@ -4,6 +4,70 @@ Provisioners are people or code that are registered with the CA and authorized to issue "provisioning tokens". Provisioning tokens are single-use tokens that can be used to authenticate with the CA and get a certificate. +Each provisioner can define an optional `claims` attribute. The settings in this +attribute override any settings in the global `claims` attribute in the authority +configuration. + +Example `claims`: + +``` + ... + "claims": { + "minTLSCertDuration": "5m", + "maxTLSCertDuration": "24h", + "defaultTLSCertDuration": "24h", + "disableRenewal": false + "minHostSSHCertDuration": "5m", + "maxHostSSHCertDuration": "1680h", + "minUserSSHCertDuration": "5m", + "maxUserSSHCertDuration": "24h", + "maxTLSCertDuration": "16h", + "enableSSHCA": true, + } + ... +``` + +* `claims` (optional): overwrites the default claims set in the authority. + You can set one or more of the following claims: + + * `minTLSCertDuration`: do not allow certificates with a duration less than + this value. + + * `maxTLSCertDuration`: do not allow certificates with a duration greater than + this value. + + * `defaultTLSCertDuration`: if no certificate validity period is specified, + use this value. + + * `disableIssuedAtCheck`: disable a check verifying that provisioning tokens + must be issued after the CA has booted. This claim is one prevention against + token reuse. The default value is `false`. Do not change this unless you + know what you are doing. + + SSH CA properties + + * `minUserSSHDuration`: do not allow certificates with a duration less + than this value. + + * `maxUserSSHDuration`: do not allow certificates with a duration + greater than this value. + + * `defaultUserSSHDuration`: if no certificate validity period is specified, + use this value. + + * `minHostSSHDuration`: do not allow certificates with a duration less + than this value. + + * `maxHostSSHDuration`: do not allow certificates with a duration + greater than this value. + + * `defaultHostSSHDuration`: if no certificate validity period is specified, + use this value. + + * `enableSSHCA`: enable all provisioners to generate SSH Certificates. + The deault value is `false`. You can enable this option per provisioner + by setting it to `true` in the provisioner claims. + ## JWK JWK is the default provisioner type. It uses public-key cryptography to sign and @@ -35,6 +99,12 @@ In the ca.json configuration file, a complete JWK provisioner example looks like "maxTLSCertDuration": "24h", "defaultTLSCertDuration": "24h", "disableRenewal": false + "minHostSSHCertDuration": "5m", + "maxHostSSHCertDuration": "1680h", + "minUserSSHCertDuration": "5m", + "maxUserSSHCertDuration": "24h", + "maxTLSCertDuration": "16h", + "enableSSHCA": true, } } ``` @@ -75,23 +145,6 @@ In the ca.json configuration file, a complete JWK provisioner example looks like provided using the `--key` flag of the `step ca token` to be able to sign the token. -* `claims` (optional): overwrites the default claims set in the authority. - You can set one or more of the following claims: - - * `minTLSCertDuration`: do not allow certificates with a duration less than - this value. - - * `maxTLSCertDuration`: do not allow certificates with a duration greater than - this value. - - * `defaultTLSCertDuration`: if no certificate validity period is specified, - use this value. - - * `disableIssuedAtCheck`: disable a check verifying that provisioning tokens - must be issued after the CA has booted. This claim is one prevention against - token reuse. The default value is `false`. Do not change this unless you - know what you are doing. - ## OIDC An OIDC provisioner allows a user to get a certificate after authenticating @@ -149,7 +202,7 @@ is G-Suite. port to be specified at the time of the request for loopback IP redirect URIs. * `claims` (optional): overwrites the default claims set in the authority, see - the [JWK](#jwk) section for all the options. + the [top](#provisioners) section for all the options. ## Provisioners for Cloud Identities @@ -213,7 +266,7 @@ In the ca.json, an AWS provisioner looks like: certificate. The instance age is a string using the duration format. * `claims` (optional): overwrites the default claims set in the authority, see - the [JWK](#jwk) section for all the options. + the [top](#provisioners) section for all the options. ### GCP @@ -265,7 +318,7 @@ In the ca.json, a GCP provisioner looks like: certificate. The instance age is a string using the duration format. * `claims` (optional): overwrites the default claims set in the authority, see - the [JWK](#jwk) section for all the options. + the [top](#provisioners) section for all the options. ### Azure @@ -315,4 +368,4 @@ In the ca.json, an Azure provisioner looks like: and different tokens can be used to get different certificates. * `claims` (optional): overwrites the default claims set in the authority, see - the [JWK](#jwk) section for all the options. + the [top](#provisioners) section for all the options. diff --git a/go.mod b/go.mod index 34810b17..9cbf1418 100644 --- a/go.mod +++ b/go.mod @@ -7,12 +7,14 @@ require ( github.com/Masterminds/sprig/v3 v3.0.0 github.com/go-chi/chi v4.0.2+incompatible github.com/googleapis/gax-go/v2 v2.0.5 + github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a // indirect + github.com/lunixbochs/vtclean v1.0.0 // indirect github.com/newrelic/go-agent v2.15.0+incompatible github.com/pkg/errors v0.8.1 github.com/rs/xid v1.2.1 github.com/sirupsen/logrus v1.4.2 github.com/smallstep/assert v0.0.0-20200103212524-b99dc1097b15 - github.com/smallstep/cli v0.14.3-rc.1 + github.com/smallstep/cli v0.14.3 github.com/smallstep/nosql v0.3.0 github.com/urfave/cli v1.22.2 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 diff --git a/go.sum b/go.sum index f9de86dc..8bb4b771 100644 --- a/go.sum +++ b/go.sum @@ -21,16 +21,12 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM= github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= -github.com/Masterminds/glide v0.13.2/go.mod h1:STyF5vcenH/rUqTEv+/hBXlSTo7KYwg2oc2f4tzPWic= github.com/Masterminds/goutils v1.1.0 h1:zukEsf/1JZwCMgHiK3GZftabmxiCw4apj3a28RPBiVg= github.com/Masterminds/goutils v1.1.0/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= -github.com/Masterminds/semver v1.4.2 h1:WBLTQ37jOCzSLtXNdoo8bNM8876KhNqOKvrlGITgsTc= -github.com/Masterminds/semver v1.4.2/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/semver/v3 v3.0.1 h1:2kKm5lb7dKVrt5TYUiAavE6oFc1cFT0057UVGT+JqLk= github.com/Masterminds/semver/v3 v3.0.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/Masterminds/sprig/v3 v3.0.0 h1:KSQz7Nb08/3VU9E4ns29dDxcczhOD1q7O1UfM4G3t3g= github.com/Masterminds/sprig/v3 v3.0.0/go.mod h1:NEUY/Qq8Gdm2xgYA+NwJM6wmfdRV9xkh8h/Rld20R0U= -github.com/Masterminds/vcs v1.13.0/go.mod h1:N09YCmOQr6RLxC6UNHzuVwAdodYbbnycGHSmwVJjcKA= github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= @@ -56,7 +52,6 @@ github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bombsimon/wsl/v2 v2.0.0 h1:+Vjcn+/T5lSrO8Bjzhk4v14Un/2UyCA1E3V5j9nwTkQ= github.com/bombsimon/wsl/v2 v2.0.0/go.mod h1:mf25kr/SqFEPhhcxW1+7pxzGlW+hIl/hYTKY95VwV8U= -github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/census-instrumentation/opencensus-proto v0.2.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= @@ -70,7 +65,6 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1 h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/codegangsta/cli v1.20.0/go.mod h1:/qJNoX69yVSKu5o4jLyXAENLRyk1uhi7zkbQ3slBdOA= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/bbolt v1.3.3 h1:n6AiVyVRKQFNb6mJlwESEvvLoDyiTzXX7ORAUlkeBdY= github.com/coreos/bbolt v1.3.3/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= @@ -86,7 +80,6 @@ github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e h1:Wf6HqHfScWJN9 github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbpBpLoyyu8B6e44T7hJy6potg= github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/corpix/uarand v0.1.1/go.mod h1:SFKZvkcRoLqVRFZ4u25xPmp6m9ktANfbpXZ7SJ0/FNU= github.com/cpuguy83/go-md2man v1.0.10 h1:BSKMNlYxDvnunlTymqtgONjNnaRV1sTpcovwwjF22jk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= @@ -273,7 +266,6 @@ github.com/imdario/mergo v0.3.7 h1:Y+UAYTZ7gDEuOfhxKWy+dvb5dRQ6rJjFSdX2HZY1/gI= github.com/imdario/mergo v0.3.7/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= @@ -288,7 +280,6 @@ github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a h1:FaWFmfWdAUKbSCtOU github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a/go.mod h1:UJSiEoRfvx3hP73CvoARgeLjaIOjybY9vj8PUPPFGeU= github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v0.0.0-20161130080628-0de1eaf82fa3/go.mod h1:jxZFDH7ILpTPQTk+E2s+z4CUas9lVNjIuKR4c5/zKgM= @@ -371,7 +362,6 @@ github.com/nbutton23/zxcvbn-go v0.0.0-20180912185939-ae427f1e4c1d h1:AREM5mwr4u1 github.com/nbutton23/zxcvbn-go v0.0.0-20180912185939-ae427f1e4c1d/go.mod h1:o96djdrsSGy3AWPyBgZMAGfxZNfgntdJG+11KU4QvbU= github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU= github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ= -github.com/ngdinhtoan/glide-cleanup v0.2.0/go.mod h1:UQzsmiDOb8YV3nOsCxK/c9zPpCZVNoHScRE3EO9pVMM= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/olekukonko/tablewriter v0.0.1/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.4 h1:vHD/YYe1Wolo78koG299f7V/VAS08c6IpCLn+Ejf/w8= @@ -456,14 +446,11 @@ github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5 h1:lX6ybsQW9Agn3q github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20200103212524-b99dc1097b15 h1:kSImCuenAkXtCaBeQ1UhmzzJGRhSm8sVH7I3sHE2Qdg= github.com/smallstep/assert v0.0.0-20200103212524-b99dc1097b15/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= -github.com/smallstep/certificates v0.14.2/go.mod h1:eleWnbKTXDdV9GxWNOtbdjBiitdK5SuO4FCXOvcdLEY= -github.com/smallstep/certinfo v1.2.0 h1:XJCH6fLKwGFcUndQ6ARtFdaqoujBCQnGbGRegf6PWcc= -github.com/smallstep/certinfo v1.2.0/go.mod h1:1gQJekdPwPvUwFWGTi7bZELmQT09cxC9wJ0VBkBNiwU= -github.com/smallstep/cli v0.14.2 h1:0Z1MtcgJfVS9RstNokWSNqE20xPwdiEhZgNuZxYRWRI= -github.com/smallstep/cli v0.14.2/go.mod h1:JOTzEzQ4/l863KUqs9qlAqPagWPOqu6lc3C59S1nYzU= -github.com/smallstep/cli v0.14.3-rc.1 h1:u5oUKbm6HL2lD7Xoary+DmIRJ1Ni6uov/DyA78u0CzA= -github.com/smallstep/cli v0.14.3-rc.1/go.mod h1:9dsTyViHYZRwU+YjQYMHdRVk9jONeZSioYC5rqA3LoE= -github.com/smallstep/nosql v0.2.0/go.mod h1:qyxCqeyGwkuM6bfJSY3sg+aiXEiD0GbQOPzIF8/ZD8Q= +github.com/smallstep/certificates v0.14.4/go.mod h1:Y9ug0+ZTB0k22BBV/2K+LAZIVDCMjAAtbQ0XWS+E870= +github.com/smallstep/certinfo v1.2.1/go.mod h1:1gQJekdPwPvUwFWGTi7bZELmQT09cxC9wJ0VBkBNiwU= +github.com/smallstep/certinfo v1.3.0/go.mod h1:1gQJekdPwPvUwFWGTi7bZELmQT09cxC9wJ0VBkBNiwU= +github.com/smallstep/cli v0.14.3 h1:GghXS0NBoj7psz10Cds86vvcrz6Pmww8aMGKEiD04AI= +github.com/smallstep/cli v0.14.3/go.mod h1:U3WVQBS6udgpyHg+d0FQxodNGXBwikoXbERyfgcAhs8= github.com/smallstep/nosql v0.3.0 h1:V1X5vfDsDt89499h3jZFUlR4VnnsYYs5tXaQZ0w8z5U= github.com/smallstep/nosql v0.3.0/go.mod h1:QG7gNOpidifn99MjZaiNbm7HPesIyBd97F/OfacNz8Q= github.com/smallstep/truststore v0.9.3/go.mod h1:PRSkpRIhAYBK/KLWkHNgRdYgzWMEy45bN7PSJCfKKGE= @@ -618,7 +605,6 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190522155817-f3200d17e092/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -657,7 +643,6 @@ golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190424175732-18eb32c0e2f0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -791,7 +776,6 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -howett.net/plist v0.0.0-20200225050739-77e249a2e2ba/go.mod h1:vMygbs4qMhSZSc4lCUl2OEE+rDiIIJAIdR4m7MiMcm0= mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed h1:WX1yoOaKQfddO/mLzdV4wptyWgoH/6hwLs7QHTixo0I= mvdan.cc/interfacer v0.0.0-20180901003855-c20040233aed/go.mod h1:Xkxe497xwlCKkIaQYRfC7CSLworTXY9RMqwhhCm+8Nc= mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b h1:DxJ5nJdkhDlLok9K6qO+5290kphDJbHOQO1DFFFTeBo=