Allow relative URL for all links in ACME api ...

* Pass the request context all the way down the ACME stack.
* Save baseURL in context and use when generating ACME urls.
This commit is contained in:
max furman 2020-05-06 20:18:12 -07:00
parent 639993bd09
commit e1409349f3
23 changed files with 1097 additions and 1074 deletions

View file

@ -1,11 +1,11 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -79,11 +79,11 @@ func newAccount(db nosql.DB, ops AccountOptions) (*account, error) {
// toACME converts the internal Account type into the public acmeAccount // toACME converts the internal Account type into the public acmeAccount
// type for presentation in the ACME protocol. // type for presentation in the ACME protocol.
func (a *account) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Account, error) { func (a *account) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Account, error) {
return &Account{ return &Account{
Status: a.Status, Status: a.Status,
Contact: a.Contact, Contact: a.Contact,
Orders: dir.getLink(OrdersByAccountLink, URLSafeProvisionerName(p), true, a.ID), Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID),
Key: a.Key, Key: a.Key,
ID: a.ID, ID: a.ID,
}, nil }, nil

View file

@ -1,8 +1,10 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url"
"testing" "testing"
"time" "time"
@ -332,6 +334,10 @@ func TestGetAccountIDsByAccount(t *testing.T) {
func TestAccountToACME(t *testing.T) { func TestAccountToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme") dir := newDirectory("ca.smallstep.com", "acme")
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
type test struct { type test struct {
acc *account acc *account
@ -347,7 +353,7 @@ func TestAccountToACME(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
acmeAccount, err := tc.acc.toACME(nil, dir, prov) acmeAccount, err := tc.acc.toACME(ctx, nil, dir)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
@ -363,7 +369,7 @@ func TestAccountToACME(t *testing.T) {
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact) assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID) assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
assert.Equals(t, acmeAccount.Orders, assert.Equals(t, acmeAccount.Orders,
fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s/orders", URLSafeProvisionerName(prov), tc.acc.ID)) fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID))
} }
} }
}) })

View file

@ -73,12 +73,7 @@ func (u *UpdateAccountRequest) Validate() error {
// NewAccount is the handler resource for creating new ACME accounts. // NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) payload, err := payloadFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -95,7 +90,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
} }
httpStatus := http.StatusCreated httpStatus := http.StatusCreated
acc, err := accountFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
acmeErr, ok := err.(*acme.Error) acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest { if !ok || acmeErr.Status != http.StatusBadRequest {
@ -109,13 +104,13 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.AccountDoesNotExistErr(nil)) api.WriteError(w, acme.AccountDoesNotExistErr(nil))
return return
} }
jwk, err := jwkFromContext(r) jwk, err := acme.JwkFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if acc, err = h.Auth.NewAccount(prov, acme.AccountOptions{ if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{
Key: jwk, Key: jwk,
Contact: nar.Contact, Contact: nar.Contact,
}); err != nil { }); err != nil {
@ -127,24 +122,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK httpStatus = http.StatusOK
} }
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
acme.URLSafeProvisionerName(prov), true, acc.GetID())) true, acc.GetID()))
api.JSONStatus(w, acc, httpStatus) api.JSONStatus(w, acc, httpStatus)
} }
// GetUpdateAccount is the api for updating an ACME account. // GetUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acc, err := accountFromContext(r) payload, err := payloadFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -167,16 +157,17 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
// the updates and return 200. This conforms with the behavior detailed // the updates and return 200. This conforms with the behavior detailed
// in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2).
if uar.IsDeactivateRequest() { if uar.IsDeactivateRequest() {
acc, err = h.Auth.DeactivateAccount(prov, acc.GetID()) acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID())
} else if len(uar.Contact) > 0 { } else if len(uar.Contact) > 0 {
acc, err = h.Auth.UpdateAccount(prov, acc.GetID(), uar.Contact) acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact)
} }
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
} }
w.Header().Set("Location", h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, acc.GetID())) w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
true, acc.GetID()))
api.JSON(w, acc) api.JSON(w, acc)
} }
@ -191,23 +182,17 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acc, err := accountFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param")))
return return
} }
orders, err := h.Auth.GetOrdersByAccount(prov, acc.GetID()) orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
@ -187,33 +188,17 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.Errorf("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: ctx, ctx: ctx,
@ -223,8 +208,8 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "foo"} acc := &acme.Account{ID: "foo"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
@ -235,8 +220,8 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
}, },
"fail/getOrdersByAccount-error": func(t *testing.T) test { "fail/getOrdersByAccount-error": func(t *testing.T) test {
acc := &acme.Account{ID: accID} acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -249,12 +234,14 @@ func TestHandlerGetOrdersByAccount(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID} acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getOrdersByAccount: func(p provisioner.Interface, id string) ([]string, error) { getOrdersByAccount: func(ctx context.Context, id string) ([]string, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, id, acc.ID) assert.Equals(t, id, acc.ID)
return oids, nil return oids, nil
@ -309,8 +296,8 @@ func TestHandlerNewAccount(t *testing.T) {
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
} }
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
url := "https://ca.smallstep.com/acme/new-account" baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -319,31 +306,16 @@ func TestHandlerNewAccount(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")), problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
} }
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -351,8 +323,8 @@ func TestHandlerNewAccount(t *testing.T) {
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -365,8 +337,8 @@ func TestHandlerNewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -379,8 +351,8 @@ func TestHandlerNewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -393,8 +365,8 @@ func TestHandlerNewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -407,9 +379,9 @@ func TestHandlerNewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, acme.JwkContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -424,12 +396,14 @@ func TestHandlerNewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, ops.Contact, nar.Contact) assert.Equals(t, ops.Contact, nar.Contact)
assert.Equals(t, ops.Key, jwk) assert.Equals(t, ops.Key, jwk)
@ -449,24 +423,27 @@ func TestHandlerNewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
newAccount: func(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { newAccount: func(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, ops.Contact, nar.Contact) assert.Equals(t, ops.Contact, nar.Contact)
assert.Equals(t, ops.Key, jwk) assert.Equals(t, ops.Key, jwk)
return &acc, nil return &acc, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.True(t, abs)
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL.String(), provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -479,18 +456,19 @@ func TestHandlerNewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.Equals(t, baseURL, acme.BaseURLFromContext(ctx))
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, ins, []string{accID})
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL.String(), provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -502,7 +480,7 @@ func TestHandlerNewAccount(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler) h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewAccount(w, req) h.NewAccount(w, req)
@ -529,8 +507,8 @@ func TestHandlerNewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
acme.URLSafeProvisionerName(prov), accID)}) provName, accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -545,9 +523,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
} }
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
// Request with chi context baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("http://ca.smallstep.com/acme/account/%s", accID)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -556,31 +533,16 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
return test{
ctx: ctx,
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -588,8 +550,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -597,9 +559,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -607,9 +569,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -622,9 +584,9 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -637,12 +599,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
return nil, acme.ServerInternalErr(errors.New("force")) return nil, acme.ServerInternalErr(errors.New("force"))
@ -659,12 +623,14 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
assert.Equals(t, contacts, uar.Contact) assert.Equals(t, contacts, uar.Contact)
@ -682,23 +648,26 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
deactivateAccount: func(p provisioner.Interface, id string) (*acme.Account, error) { deactivateAccount: func(ctx context.Context, id string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
return &acc, nil return &acc, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, ins, []string{accID})
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL.String(), provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -709,18 +678,19 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{} uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, ins, []string{accID})
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL.String(), provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -733,24 +703,27 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
updateAccount: func(p provisioner.Interface, id string, contacts []string) (*acme.Account, error) { updateAccount: func(ctx context.Context, id string, contacts []string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
assert.Equals(t, contacts, uar.Contact) assert.Equals(t, contacts, uar.Contact)
return &acc, nil return &acc, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, ins, []string{accID})
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL.String(), provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -758,18 +731,19 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
} }
}, },
"ok/post-as-get": func(t *testing.T) test { "ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, acme.AccContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{accID}) assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", assert.Equals(t, ins, []string{accID})
acme.URLSafeProvisionerName(prov), accID) return fmt.Sprintf("%s/acme/%s/account/%s",
baseURL, provName, accID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -781,7 +755,7 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler) h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetUpdateAccount(w, req) h.GetUpdateAccount(w, req)
@ -808,8 +782,8 @@ func TestHandlerGetUpdateAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/%s", []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
acme.URLSafeProvisionerName(prov), accID)}) provName, accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
@ -8,65 +9,27 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose"
) )
func link(url, typ string) string { func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
} }
type contextKey string
const (
accContextKey = contextKey("acc")
jwsContextKey = contextKey("jws")
jwkContextKey = contextKey("jwk")
payloadContextKey = contextKey("payload")
provisionerContextKey = contextKey("provisioner")
)
type payloadInfo struct { type payloadInfo struct {
value []byte value []byte
isPostAsGet bool isPostAsGet bool
isEmptyJSON bool isEmptyJSON bool
} }
func accountFromContext(r *http.Request) (*acme.Account, error) { // payloadFromContext searches the context for a payload. Returns the payload
val, ok := r.Context().Value(accContextKey).(*acme.Account) // or an error.
if !ok || val == nil { func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
return nil, acme.AccountDoesNotExistErr(nil) val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo)
}
return val, nil
}
func jwkFromContext(r *http.Request) (*jose.JSONWebKey, error) {
val, ok := r.Context().Value(jwkContextKey).(*jose.JSONWebKey)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("jwk expected in request context"))
}
return val, nil
}
func jwsFromContext(r *http.Request) (*jose.JSONWebSignature, error) {
val, ok := r.Context().Value(jwsContextKey).(*jose.JSONWebSignature)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("jws expected in request context"))
}
return val, nil
}
func payloadFromContext(r *http.Request) (*payloadInfo, error) {
val, ok := r.Context().Value(payloadContextKey).(*payloadInfo)
if !ok || val == nil { if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context")) return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
} }
return val, nil return val, nil
} }
func provisionerFromContext(r *http.Request) (provisioner.Interface, error) {
val, ok := r.Context().Value(provisionerContextKey).(provisioner.Interface)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("provisioner expected in request context"))
}
return val, nil
}
// New returns a new ACME API router. // New returns a new ACME API router.
func New(acmeAuth acme.Interface) api.RouterHandler { func New(acmeAuth acme.Interface) api.RouterHandler {
@ -80,29 +43,29 @@ type Handler struct {
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
getLink := h.Auth.GetLink getLink := h.Auth.GetLinkExplicit
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetNonce))) r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false), h.lookupProvisioner(h.addNonce(h.GetDirectory))) r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
} }
extractPayloadByKid := func(next nextHTTP) nextHTTP { extractPayloadByKid := func(next nextHTTP) nextHTTP {
return h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
} }
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false), extractPayloadByJWK(h.NewAccount)) r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response
@ -118,52 +81,41 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
// GetDirectory is the ACME resource for returning a directory configuration // GetDirectory is the ACME resource for returning a directory configuration
// for client configuration. // for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) dir, err := h.Auth.GetDirectory(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
dir := h.Auth.GetDirectory(prov, baseURLFromRequest(r))
api.JSON(w, dir) api.JSON(w, dir)
} }
// GetAuthz ACME api for retrieving an Authz. // GetAuthz ACME api for retrieving an Authz.
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acc, err := accountFromContext(r) authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID"))
if err != nil {
api.WriteError(w, err)
return
}
authz, err := h.Auth.GetAuthz(prov, acc.GetID(), chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
w.Header().Set("Location", h.Auth.GetLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, authz.GetID())) w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID()))
api.JSON(w, authz) api.JSON(w, authz)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
// Just verify that the payload was set, since we're not strictly adhering // Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below. // to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(r) _, err = payloadFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -178,21 +130,20 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
ch *acme.Challenge ch *acme.Challenge
chID = chi.URLParam(r, "chID") chID = chi.URLParam(r, "chID")
) )
ch, err = h.Auth.ValidateChallenge(prov, acc.GetID(), chID, acc.GetKey()) ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
getLink := h.Auth.GetLink w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up"))
w.Header().Add("Link", link(getLink(acme.AuthzLink, acme.URLSafeProvisionerName(prov), true, ch.GetAuthzID()), "up")) w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID()))
w.Header().Set("Location", getLink(acme.ChallengeLink, acme.URLSafeProvisionerName(prov), true, ch.GetID()))
api.JSON(w, ch) api.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
acc, err := accountFromContext(r) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return

View file

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
@ -23,74 +24,79 @@ import (
) )
type mockAcmeAuthority struct { type mockAcmeAuthority struct {
deactivateAccount func(provisioner.Interface, string) (*acme.Account, error) getLink func(ctx context.Context, link acme.Link, absPath bool, ins ...string) string
finalizeOrder func(p provisioner.Interface, accID string, id string, csr *x509.CertificateRequest) (*acme.Order, error) getLinkExplicit func(acme.Link, string, bool, *url.URL, ...string) string
getAccount func(p provisioner.Interface, id string) (*acme.Account, error)
getAccountByKey func(provisioner.Interface, *jose.JSONWebKey) (*acme.Account, error) deactivateAccount func(ctx context.Context, accID string) (*acme.Account, error)
getAuthz func(p provisioner.Interface, accID string, id string) (*acme.Authz, error) getAccount func(ctx context.Context, accID string) (*acme.Account, error)
getCertificate func(accID string, id string) ([]byte, error) getAccountByKey func(ctx context.Context, key *jose.JSONWebKey) (*acme.Account, error)
getChallenge func(p provisioner.Interface, accID string, id string) (*acme.Challenge, error) newAccount func(ctx context.Context, ao acme.AccountOptions) (*acme.Account, error)
getDirectory func(provisioner.Interface, string) *acme.Directory updateAccount func(context.Context, string, []string) (*acme.Account, error)
getLink func(acme.Link, string, bool, ...string) string
getOrder func(p provisioner.Interface, accID string, id string) (*acme.Order, error) getChallenge func(ctx context.Context, accID string, chID string) (*acme.Challenge, error)
getOrdersByAccount func(p provisioner.Interface, id string) ([]string, error) validateChallenge func(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*acme.Challenge, error)
getAuthz func(ctx context.Context, accID string, authzID string) (*acme.Authz, error)
getDirectory func(ctx context.Context) (*acme.Directory, error)
getCertificate func(string, string) ([]byte, error)
finalizeOrder func(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*acme.Order, error)
getOrder func(ctx context.Context, accID string, orderID string) (*acme.Order, error)
getOrdersByAccount func(ctx context.Context, accID string) ([]string, error)
newOrder func(ctx context.Context, oo acme.OrderOptions) (*acme.Order, error)
loadProvisionerByID func(string) (provisioner.Interface, error) loadProvisionerByID func(string) (provisioner.Interface, error)
newAccount func(provisioner.Interface, acme.AccountOptions) (*acme.Account, error)
newNonce func() (string, error) newNonce func() (string, error)
newOrder func(provisioner.Interface, acme.OrderOptions) (*acme.Order, error)
updateAccount func(provisioner.Interface, string, []string) (*acme.Account, error)
useNonce func(string) error useNonce func(string) error
validateChallenge func(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error)
ret1 interface{} ret1 interface{}
err error err error
} }
func (m *mockAcmeAuthority) DeactivateAccount(p provisioner.Interface, id string) (*acme.Account, error) { func (m *mockAcmeAuthority) DeactivateAccount(ctx context.Context, id string) (*acme.Account, error) {
if m.deactivateAccount != nil { if m.deactivateAccount != nil {
return m.deactivateAccount(p, id) return m.deactivateAccount(ctx, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Account), m.err return m.ret1.(*acme.Account), m.err
} }
func (m *mockAcmeAuthority) FinalizeOrder(p provisioner.Interface, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) { func (m *mockAcmeAuthority) FinalizeOrder(ctx context.Context, accID, id string, csr *x509.CertificateRequest) (*acme.Order, error) {
if m.finalizeOrder != nil { if m.finalizeOrder != nil {
return m.finalizeOrder(p, accID, id, csr) return m.finalizeOrder(ctx, accID, id, csr)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Order), m.err return m.ret1.(*acme.Order), m.err
} }
func (m *mockAcmeAuthority) GetAccount(p provisioner.Interface, id string) (*acme.Account, error) { func (m *mockAcmeAuthority) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
if m.getAccount != nil { if m.getAccount != nil {
return m.getAccount(p, id) return m.getAccount(ctx, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Account), m.err return m.ret1.(*acme.Account), m.err
} }
func (m *mockAcmeAuthority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { func (m *mockAcmeAuthority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
if m.getAccountByKey != nil { if m.getAccountByKey != nil {
return m.getAccountByKey(p, jwk) return m.getAccountByKey(ctx, jwk)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Account), m.err return m.ret1.(*acme.Account), m.err
} }
func (m *mockAcmeAuthority) GetAuthz(p provisioner.Interface, accID, id string) (*acme.Authz, error) { func (m *mockAcmeAuthority) GetAuthz(ctx context.Context, accID, id string) (*acme.Authz, error) {
if m.getAuthz != nil { if m.getAuthz != nil {
return m.getAuthz(p, accID, id) return m.getAuthz(ctx, accID, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Authz), m.err return m.ret1.(*acme.Authz), m.err
} }
func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) { func (m *mockAcmeAuthority) GetCertificate(accID string, id string) ([]byte, error) {
if m.getCertificate != nil { if m.getCertificate != nil {
return m.getCertificate(accID, id) return m.getCertificate(accID, id)
} else if m.err != nil { } else if m.err != nil {
@ -99,41 +105,48 @@ func (m *mockAcmeAuthority) GetCertificate(accID, id string) ([]byte, error) {
return m.ret1.([]byte), m.err return m.ret1.([]byte), m.err
} }
func (m *mockAcmeAuthority) GetChallenge(p provisioner.Interface, accID, id string) (*acme.Challenge, error) { func (m *mockAcmeAuthority) GetChallenge(ctx context.Context, accID, id string) (*acme.Challenge, error) {
if m.getChallenge != nil { if m.getChallenge != nil {
return m.getChallenge(p, accID, id) return m.getChallenge(ctx, accID, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Challenge), m.err return m.ret1.(*acme.Challenge), m.err
} }
func (m *mockAcmeAuthority) GetDirectory(p provisioner.Interface, baseURLFromRequest string) *acme.Directory { func (m *mockAcmeAuthority) GetDirectory(ctx context.Context) (*acme.Directory, error) {
if m.getDirectory != nil { if m.getDirectory != nil {
return m.getDirectory(p, baseURLFromRequest) return m.getDirectory(ctx)
} }
return m.ret1.(*acme.Directory) return m.ret1.(*acme.Directory), m.err
} }
func (m *mockAcmeAuthority) GetLink(typ acme.Link, provID string, abs bool, in ...string) string { func (m *mockAcmeAuthority) GetLink(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
if m.getLink != nil { if m.getLink != nil {
return m.getLink(typ, provID, abs, in...) return m.getLink(ctx, typ, abs, ins...)
} }
return m.ret1.(string) return m.ret1.(string)
} }
func (m *mockAcmeAuthority) GetOrder(p provisioner.Interface, accID, id string) (*acme.Order, error) { func (m *mockAcmeAuthority) GetLinkExplicit(typ acme.Link, provID string, abs bool, baseURL *url.URL, ins ...string) string {
if m.getLinkExplicit != nil {
return m.getLinkExplicit(typ, provID, abs, baseURL, ins...)
}
return m.ret1.(string)
}
func (m *mockAcmeAuthority) GetOrder(ctx context.Context, accID, id string) (*acme.Order, error) {
if m.getOrder != nil { if m.getOrder != nil {
return m.getOrder(p, accID, id) return m.getOrder(ctx, accID, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Order), m.err return m.ret1.(*acme.Order), m.err
} }
func (m *mockAcmeAuthority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { func (m *mockAcmeAuthority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
if m.getOrdersByAccount != nil { if m.getOrdersByAccount != nil {
return m.getOrdersByAccount(p, id) return m.getOrdersByAccount(ctx, id)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
@ -149,9 +162,9 @@ func (m *mockAcmeAuthority) LoadProvisionerByID(provID string) (provisioner.Inte
return m.ret1.(provisioner.Interface), m.err return m.ret1.(provisioner.Interface), m.err
} }
func (m *mockAcmeAuthority) NewAccount(p provisioner.Interface, ops acme.AccountOptions) (*acme.Account, error) { func (m *mockAcmeAuthority) NewAccount(ctx context.Context, ops acme.AccountOptions) (*acme.Account, error) {
if m.newAccount != nil { if m.newAccount != nil {
return m.newAccount(p, ops) return m.newAccount(ctx, ops)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
@ -167,18 +180,18 @@ func (m *mockAcmeAuthority) NewNonce() (string, error) {
return m.ret1.(string), m.err return m.ret1.(string), m.err
} }
func (m *mockAcmeAuthority) NewOrder(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { func (m *mockAcmeAuthority) NewOrder(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
if m.newOrder != nil { if m.newOrder != nil {
return m.newOrder(p, ops) return m.newOrder(ctx, ops)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
return m.ret1.(*acme.Order), m.err return m.ret1.(*acme.Order), m.err
} }
func (m *mockAcmeAuthority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*acme.Account, error) { func (m *mockAcmeAuthority) UpdateAccount(ctx context.Context, id string, contact []string) (*acme.Account, error) {
if m.updateAccount != nil { if m.updateAccount != nil {
return m.updateAccount(p, id, contact) return m.updateAccount(ctx, id, contact)
} else if m.err != nil { } else if m.err != nil {
return nil, m.err return nil, m.err
} }
@ -192,10 +205,10 @@ func (m *mockAcmeAuthority) UseNonce(nonce string) error {
return m.err return m.err
} }
func (m *mockAcmeAuthority) ValidateChallenge(p provisioner.Interface, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { func (m *mockAcmeAuthority) ValidateChallenge(ctx context.Context, accID string, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
switch { switch {
case m.validateChallenge != nil: case m.validateChallenge != nil:
return m.validateChallenge(p, accID, id, jwk) return m.validateChallenge(ctx, accID, id, jwk)
case m.err != nil: case m.err != nil:
return nil, m.err return nil, m.err
default: default:
@ -233,40 +246,28 @@ func TestHandlerGetNonce(t *testing.T) {
func TestHandlerGetDirectory(t *testing.T) { func TestHandlerGetDirectory(t *testing.T) {
auth, err := acme.NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) auth, err := acme.NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newProv() prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/directory", acme.URLSafeProvisionerName(prov)) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
expDir := acme.Directory{ expDir := acme.Directory{
NewNonce: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", acme.URLSafeProvisionerName(prov)), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", acme.URLSafeProvisionerName(prov)), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
NewOrder: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", acme.URLSafeProvisionerName(prov)), NewOrder: fmt.Sprintf("%s/acme/%s/new-order", baseURL.String(), provName),
RevokeCert: fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", acme.URLSafeProvisionerName(prov)), RevokeCert: fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), provName),
KeyChange: fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", acme.URLSafeProvisionerName(prov)), KeyChange: fmt.Sprintf("%s/acme/%s/key-change", baseURL.String(), provName),
} }
type test struct { type test struct {
ctx context.Context
statusCode int statusCode int
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
statusCode: 200, statusCode: 200,
} }
}, },
@ -275,9 +276,8 @@ func TestHandlerGetDirectory(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(auth).(*Handler) h := New(auth).(*Handler)
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req.Header.Add("X-Forwarded-Proto", "https") req = req.WithContext(ctx)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetDirectory(w, req) h.GetDirectory(w, req)
res := w.Result() res := w.Result()
@ -339,12 +339,14 @@ func TestHandlerGetAuthz(t *testing.T) {
}, },
} }
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("authzID", az.ID) chiCtx.URLParams.Add("authzID", az.ID)
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/challenge/%s", url := fmt.Sprintf("%s/acme/%s/challenge/%s",
acme.URLSafeProvisionerName(prov), az.ID) baseURL.String(), provName, az.ID)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -353,33 +355,17 @@ func TestHandlerGetAuthz(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: ctx, ctx: ctx,
@ -389,8 +375,8 @@ func TestHandlerGetAuthz(t *testing.T) {
}, },
"fail/getAuthz-error": func(t *testing.T) test { "fail/getAuthz-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -403,20 +389,23 @@ func TestHandlerGetAuthz(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAuthz: func(p provisioner.Interface, accID, id string) (*acme.Authz, error) { getAuthz: func(ctx context.Context, accID, id string) (*acme.Authz, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, az.ID) assert.Equals(t, id, az.ID)
return &az, nil return &az, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, typ, acme.AuthzLink) assert.Equals(t, typ, acme.AuthzLink)
assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{az.ID}) assert.Equals(t, in, []string{az.ID})
return url return url
@ -431,7 +420,7 @@ func TestHandlerGetAuthz(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler) h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAuthz(w, req) h.GetAuthz(w, req)
@ -488,11 +477,13 @@ func TestHandlerGetCertificate(t *testing.T) {
certID := "certID" certID := "certID"
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("certID", certID) chiCtx.URLParams.Add("certID", certID)
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/certificate/%s", url := fmt.Sprintf("%s/acme/%s/certificate/%s",
acme.URLSafeProvisionerName(prov), certID) baseURL.String(), provName, certID)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -504,13 +495,13 @@ func TestHandlerGetCertificate(t *testing.T) {
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil) ctx := context.WithValue(context.Background(), acme.AccContextKey, nil)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: ctx, ctx: ctx,
@ -520,7 +511,7 @@ func TestHandlerGetCertificate(t *testing.T) {
}, },
"fail/getCertificate-error": func(t *testing.T) test { "fail/getCertificate-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -533,7 +524,7 @@ func TestHandlerGetCertificate(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -596,8 +587,10 @@ func ch() acme.Challenge {
func TestHandlerGetChallenge(t *testing.T) { func TestHandlerGetChallenge(t *testing.T) {
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("chID", "chID") chiCtx.URLParams.Add("chID", "chID")
url := fmt.Sprintf("http://ca.smallstep.com/acme/challenge/%s", "chID")
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/challenge/%s", baseURL, "chID")
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -607,30 +600,16 @@ func TestHandlerGetChallenge(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -639,8 +618,8 @@ func TestHandlerGetChallenge(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -649,9 +628,9 @@ func TestHandlerGetChallenge(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -660,9 +639,9 @@ func TestHandlerGetChallenge(t *testing.T) {
}, },
"fail/validate-challenge-error": func(t *testing.T) test { "fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -675,9 +654,9 @@ func TestHandlerGetChallenge(t *testing.T) {
}, },
"fail/get-challenge-error": func(t *testing.T) test { "fail/get-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -692,35 +671,36 @@ func TestHandlerGetChallenge(t *testing.T) {
key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) key, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: key} acc := &acme.Account{ID: "accID", Key: key}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
ch := ch() ch := ch()
ch.Status = "valid" ch.Status = "valid"
ch.Validated = time.Now().UTC().Format(time.RFC3339) ch.Validated = time.Now().UTC().Format(time.RFC3339)
count := 0 count := 0
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
validateChallenge: func(p provisioner.Interface, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) { validateChallenge: func(ctx context.Context, accID, id string, jwk *jose.JSONWebKey) (*acme.Challenge, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, ch.ID) assert.Equals(t, id, ch.ID)
assert.Equals(t, jwk.KeyID, key.KeyID) assert.Equals(t, jwk.KeyID, key.KeyID)
return &ch, nil return &ch, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
var ret string var ret string
switch count { switch count {
case 0: case 0:
assert.Equals(t, typ, acme.AuthzLink) assert.Equals(t, typ, acme.AuthzLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{ch.AuthzID}) assert.Equals(t, in, []string{ch.AuthzID})
ret = fmt.Sprintf("https://ca.smallstep.com/acme/authz/%s", ch.AuthzID) ret = fmt.Sprintf("%s/acme/%s/authz/%s", baseURL.String(), provName, ch.AuthzID)
case 1: case 1:
assert.Equals(t, typ, acme.ChallengeLink) assert.Equals(t, typ, acme.ChallengeLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{ch.ID}) assert.Equals(t, in, []string{ch.ID})
ret = url ret = url
@ -765,7 +745,7 @@ func TestHandlerGetChallenge(t *testing.T) {
expB, err := json.Marshal(tc.ch) expB, err := json.Marshal(tc.ch)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<https://ca.smallstep.com/acme/authz/%s>;rel=\"up\"", tc.ch.AuthzID)}) assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, tc.ch.AuthzID)})
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{url})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }

View file

@ -1,28 +0,0 @@
package api
import (
"net/http"
)
// baseURLFromRequest determines the base URL which should be used for constructing link URLs in e.g. the ACME directory
// result by taking the request Host, TLS and Header[X-Forwarded-Proto] values into consideration.
// If the Request.Host is an empty string, we return an empty string, to indicate that the configured
// URL values should be used instead.
// If this function returns a non-empty result, then this should be used in constructing ACME link URLs.
func baseURLFromRequest(r *http.Request) string {
// TODO: I semantically copied the functionality of determining the protol from boulder web/relative.go
// which allows HTTP. Previously this was always forced to be HTTPS for absolute URLs. Should this be
// changed to also always force HTTPS protocol?
proto := "http"
if specifiedProto := r.Header.Get("X-Forwarded-Proto"); specifiedProto != "" {
proto = specifiedProto
} else if r.TLS != nil {
proto += "s"
}
host := r.Host
if host == "" {
return ""
}
return proto + "://" + host
}

View file

@ -1,70 +0,0 @@
package api
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestGetBaseUrl(t *testing.T) {
tests := []struct {
testFailedDescription string
targetURL string
expectedResult string
requestPreparer func(*http.Request)
}{
{
"HTTP host pass-through failed.",
"http://my.dummy.host",
"http://my.dummy.host",
nil,
},
{
"HTTPS host pass-through failed.",
"https://my.dummy.host",
"https://my.dummy.host",
nil,
},
{
"Port pass-through failed",
"http://host.with.port:8080",
"http://host.with.port:8080",
nil,
},
{
"Explicit host from Request.Host was not used.",
"http://some.target.host:8080",
"http://proxied.host",
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Explicit forwarded protocol from request header X-Forwarded-Proto was not used.",
"http://some.host",
"ssl://some.host",
func(r *http.Request) {
r.Header.Add("X-Forwarded-Proto", "ssl")
},
},
{
"Missing Request.Host value did not result in empty string result.",
"http://some.host",
"",
func(r *http.Request) {
r.Host = ""
},
},
}
for _, test := range tests {
request := httptest.NewRequest("GET", test.targetURL, nil)
if test.requestPreparer != nil {
test.requestPreparer(request)
}
result := baseURLFromRequest(request)
if result != test.expectedResult {
t.Errorf("Expected %q, but got %q", test.expectedResult, result)
}
}
}

View file

@ -30,6 +30,35 @@ func logNonce(w http.ResponseWriter, nonce string) {
} }
} }
// baseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in
// constructing ACME link URLs.
func baseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header. // addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP { func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
@ -49,12 +78,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// directory index url. // directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP { func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) w.Header().Add("Link", link(h.Auth.GetLink(r.Context(),
if err != nil { acme.DirectoryLink, true), "index"))
api.WriteError(w, err)
return
}
w.Header().Add("Link", link(h.Auth.GetLink(acme.DirectoryLink, acme.URLSafeProvisionerName(prov), true), "index"))
next(w, r) next(w, r)
} }
} }
@ -63,14 +88,9 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
// application/jose+json. // application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r)
if err != nil {
api.WriteError(w, err)
return
}
ct := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
var expected []string var expected []string
if strings.Contains(r.URL.Path, h.Auth.GetLink(acme.CertificateLink, acme.URLSafeProvisionerName(prov), false, "")) { if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else { } else {
@ -101,7 +121,7 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body")))
return return
} }
ctx := context.WithValue(r.Context(), jwsContextKey, jws) ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
@ -123,7 +143,7 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste> // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(r) jws, err := acme.JwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -207,12 +227,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
func (h *Handler) extractJWK(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
prov, err := provisionerFromContext(r) jws, err := acme.JwsFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
jws, err := jwsFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -226,8 +241,8 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header")))
return return
} }
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, acme.JwkContextKey, jwk)
acc, err := h.Auth.GetAccountByKey(prov, jwk) acc, err := h.Auth.GetAccountByKey(ctx, jwk)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
// For NewAccount requests ... // For NewAccount requests ...
@ -240,7 +255,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
} }
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
@ -267,7 +282,7 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME")))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, p) ctx = context.WithValue(ctx, acme.ProvisionerContextKey, p)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
@ -278,18 +293,13 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
prov, err := provisionerFromContext(r) jws, err := acme.JwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
jws, err := jwsFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
kidPrefix := h.Auth.GetLink(acme.AccountLink, acme.URLSafeProvisionerName(prov), true, "") kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+
@ -298,7 +308,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
} }
accID := strings.TrimPrefix(kid, kidPrefix) accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.Auth.GetAccount(prov, accID) acc, err := h.Auth.GetAccount(r.Context(), accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
api.WriteError(w, acme.AccountDoesNotExistErr(nil)) api.WriteError(w, acme.AccountDoesNotExistErr(nil))
@ -311,8 +321,8 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active")))
return return
} }
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, jwkContextKey, acc.Key) ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
return return
} }
@ -323,12 +333,12 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(r) jws, err := acme.JwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
jwk, err := jwkFromContext(r) jwk, err := acme.JwkFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -342,7 +352,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws")))
return return
} }
ctx := context.WithValue(r.Context(), payloadContextKey, &payloadInfo{ ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{
value: payload, value: payload,
isPostAsGet: string(payload) == "", isPostAsGet: string(payload) == "",
isEmptyJSON: string(payload) == "{}", isEmptyJSON: string(payload) == "{}",
@ -354,7 +364,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return

View file

@ -11,13 +11,13 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"strings" "strings"
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
) )
@ -28,6 +28,85 @@ func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
} }
func Test_baseURLFromRequest(t *testing.T) {
tests := []struct {
name string
targetURL string
expectedResult *url.URL
requestPreparer func(*http.Request)
}{
{
"HTTPS host pass-through failed.",
"https://my.dummy.host",
&url.URL{Scheme: "https", Host: "my.dummy.host"},
nil,
},
{
"Port pass-through failed",
"https://host.with.port:8080",
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
nil,
},
{
"Explicit host from Request.Host was not used.",
"https://some.target.host:8080",
&url.URL{Scheme: "https", Host: "proxied.host"},
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Missing Request.Host value did not result in empty string result.",
"https://some.host",
nil,
func(r *http.Request) {
r.Host = ""
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest("GET", tc.targetURL, nil)
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := baseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
}
})
}
}
func TestHandlerBaseURLFromRequest(t *testing.T) {
h := New(&mockAcmeAuthority{}).(*Handler)
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
bu := acme.BaseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
}
}
h.baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, acme.BaseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
}
func TestHandlerAddNonce(t *testing.T) { func TestHandlerAddNonce(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce" url := "https://ca.smallstep.com/acme/new-nonce"
type test struct { type test struct {
@ -93,8 +172,9 @@ func TestHandlerAddNonce(t *testing.T) {
} }
func TestHandlerAddDirLink(t *testing.T) { func TestHandlerAddDirLink(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce"
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
auth acme.Interface auth acme.Interface
link string link string
@ -103,33 +183,18 @@ func TestHandlerAddDirLink(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
link := "https://ca.smallstep.com/acme/directory" ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, ins ...string) string {
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov)) assert.Equals(t, acme.BaseURLFromContext(ctx), baseURL)
return link return fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName)
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: ctx,
link: link, link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200, statusCode: 200,
} }
}, },
@ -138,7 +203,7 @@ func TestHandlerAddDirLink(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*Handler) h := New(tc.auth).(*Handler)
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addDirLink(testNext)(w, req) h.addDirLink(testNext)(w, req)
@ -170,8 +235,9 @@ func TestHandlerAddDirLink(t *testing.T) {
func TestHandlerVerifyContentType(t *testing.T) { func TestHandlerVerifyContentType(t *testing.T) {
prov := newProv() prov := newProv()
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/abc123", provName := prov.GetName()
acme.URLSafeProvisionerName(prov)) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName)
type test struct { type test struct {
h Handler h Handler
ctx context.Context ctx context.Context
@ -181,38 +247,20 @@ func TestHandlerVerifyContentType(t *testing.T) {
url string url string
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
h: Handler{Auth: &mockAcmeAuthority{}},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
h: Handler{Auth: &mockAcmeAuthority{}},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/general-bad-content-type": func(t *testing.T) test { "fail/general-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return fmt.Sprintf("/acme/%s/certificate/", provName)
}, },
}, },
}, },
url: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
acme.URLSafeProvisionerName(prov)), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")), problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json], but got foo")),
@ -222,16 +270,15 @@ func TestHandlerVerifyContentType(t *testing.T) {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return "/certificate/"
}, },
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")), problem: acme.MalformedErr(errors.New("expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo")),
@ -241,16 +288,15 @@ func TestHandlerVerifyContentType(t *testing.T) {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return "/certificate/"
}, },
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
@ -259,16 +305,15 @@ func TestHandlerVerifyContentType(t *testing.T) {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return "/certificate/"
}, },
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
contentType: "application/pkix-cert", contentType: "application/pkix-cert",
statusCode: 200, statusCode: 200,
} }
@ -277,16 +322,15 @@ func TestHandlerVerifyContentType(t *testing.T) {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return "/certificate/"
}, },
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
@ -295,16 +339,15 @@ func TestHandlerVerifyContentType(t *testing.T) {
return test{ return test{
h: Handler{ h: Handler{
Auth: &mockAcmeAuthority{ Auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.CertificateLink) assert.Equals(t, typ, acme.CertificateLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.Equals(t, abs, false) assert.Equals(t, abs, false)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return "/certificate/" return "/certificate/"
}, },
}, },
}, },
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
contentType: "application/pkcs7-mime", contentType: "application/pkcs7-mime",
statusCode: 200, statusCode: 200,
} }
@ -364,21 +407,21 @@ func TestHandlerIsPostAsGet(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), payloadContextKey, nil), ctx: context.WithValue(context.Background(), acme.PayloadContextKey, nil),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("payload expected in request context")), problem: acme.ServerInternalErr(errors.New("payload expected in request context")),
} }
}, },
"fail/not-post-as-get": func(t *testing.T) test { "fail/not-post-as-get": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}), ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{}),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("expected POST-as-GET")), problem: acme.MalformedErr(errors.New("expected POST-as-GET")),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), payloadContextKey, &payloadInfo{isPostAsGet: true}), ctx: context.WithValue(context.Background(), acme.PayloadContextKey, &payloadInfo{isPostAsGet: true}),
statusCode: 200, statusCode: 200,
} }
}, },
@ -464,7 +507,7 @@ func TestHandlerParseJWS(t *testing.T) {
return test{ return test{
body: strings.NewReader(expRaw), body: strings.NewReader(expRaw),
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
jws, err := jwsFromContext(r) jws, err := acme.JwsFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
gotRaw, err := jws.CompactSerialize() gotRaw, err := jws.CompactSerialize()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -542,22 +585,22 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, nil), ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jws expected in request context")), problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
} }
}, },
"fail/no-jwk": func(t *testing.T) test { "fail/no-jwk": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), problem: acme.ServerInternalErr(errors.New("jwk expected in request context")),
} }
}, },
"fail/nil-jwk": func(t *testing.T) test { "fail/nil-jwk": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
return test{ return test{
ctx: context.WithValue(ctx, jwkContextKey, nil), ctx: context.WithValue(ctx, acme.JwkContextKey, nil),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jwk expected in request context")), problem: acme.ServerInternalErr(errors.New("jwk expected in request context")),
} }
@ -566,8 +609,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := _jwk.Public() _pub := _jwk.Public()
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, acme.JwkContextKey, &_pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -578,8 +621,8 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_pub := *pub _pub := *pub
clone := &_pub clone := &_pub
clone.Algorithm = jose.HS256 clone.Algorithm = jose.HS256
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwkContextKey, clone) ctx = context.WithValue(ctx, acme.JwkContextKey, clone)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -587,13 +630,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwkContextKey, pub) ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
p, err := payloadFromContext(r) p, err := payloadFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
if assert.NotNil(t, p) { if assert.NotNil(t, p) {
assert.Equals(t, p.value, []byte("baz")) assert.Equals(t, p.value, []byte("baz"))
@ -608,13 +651,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_pub := *pub _pub := *pub
clone := &_pub clone := &_pub
clone.Algorithm = "" clone.Algorithm = ""
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwkContextKey, pub) ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
p, err := payloadFromContext(r) p, err := payloadFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
if assert.NotNil(t, p) { if assert.NotNil(t, p) {
assert.Equals(t, p.value, []byte("baz")) assert.Equals(t, p.value, []byte("baz"))
@ -632,13 +675,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed)
ctx = context.WithValue(ctx, jwkContextKey, pub) ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
p, err := payloadFromContext(r) p, err := payloadFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
if assert.NotNil(t, p) { if assert.NotNil(t, p) {
assert.Equals(t, p.value, []byte{}) assert.Equals(t, p.value, []byte{})
@ -656,13 +699,13 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx := context.WithValue(context.Background(), acme.JwsContextKey, _parsed)
ctx = context.WithValue(ctx, jwkContextKey, pub) ctx = context.WithValue(ctx, acme.JwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
p, err := payloadFromContext(r) p, err := payloadFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
if assert.NotNil(t, p) { if assert.NotNil(t, p) {
assert.Equals(t, p.value, []byte("{}")) assert.Equals(t, p.value, []byte("{}"))
@ -709,13 +752,15 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
func TestHandlerLookupJWK(t *testing.T) { func TestHandlerLookupJWK(t *testing.T) {
prov := newProv() prov := newProv()
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provName := url.PathEscape(prov.GetName())
acme.URLSafeProvisionerName(prov)) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
accID := "account-id" accID := "account-id"
prefix := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", prefix := fmt.Sprintf("%s/acme/%s/account/",
acme.URLSafeProvisionerName(prov)) baseURL, provName)
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID))
signer, err := jose.NewSigner(jose.SigningKey{ signer, err := jose.NewSigner(jose.SigningKey{
@ -737,30 +782,16 @@ func TestHandlerLookupJWK(t *testing.T) {
statusCode int statusCode int
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jws expected in request context")), problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, acme.JwsContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -775,13 +806,13 @@ func TestHandlerLookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz")) _jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return prefix return prefix
@ -806,16 +837,16 @@ func TestHandlerLookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed) ctx = context.WithValue(ctx, acme.JwsContextKey, _parsed)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -824,21 +855,23 @@ func TestHandlerLookupJWK(t *testing.T) {
} }
}, },
"fail/account-not-found": func(t *testing.T) test { "fail/account-not-found": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
return nil, database.ErrNotFound return nil, database.ErrNotFound
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -847,21 +880,23 @@ func TestHandlerLookupJWK(t *testing.T) {
} }
}, },
"fail/GetAccount-error": func(t *testing.T) test { "fail/GetAccount-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
return nil, acme.ServerInternalErr(errors.New("force")) return nil, acme.ServerInternalErr(errors.New("force"))
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -871,21 +906,23 @@ func TestHandlerLookupJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
return acc, nil return acc, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -895,29 +932,31 @@ func TestHandlerLookupJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk} acc := &acme.Account{Status: "valid", Key: jwk}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccount: func(p provisioner.Interface, _accID string) (*acme.Account, error) { getAccount: func(ctx context.Context, _accID string) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
return acc, nil return acc, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.AccountLink) assert.Equals(t, typ, acme.AccountLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{""}) assert.Equals(t, in, []string{""})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/", acme.URLSafeProvisionerName(prov)) return fmt.Sprintf("%s/acme/%s/account/", baseURL.String(), provName)
}, },
}, },
ctx: ctx, ctx: ctx,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
_acc, err := accountFromContext(r) _acc, err := acme.AccountFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, _acc, acc) assert.Equals(t, _acc, acc)
_jwk, err := jwkFromContext(r) _jwk, err := acme.JwkFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, _jwk, jwk) assert.Equals(t, _jwk, jwk)
w.Write(testBody) w.Write(testBody)
@ -961,6 +1000,7 @@ func TestHandlerLookupJWK(t *testing.T) {
func TestHandlerExtractJWK(t *testing.T) { func TestHandlerExtractJWK(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
kid, err := jwk.Thumbprint(crypto.SHA256) kid, err := jwk.Thumbprint(crypto.SHA256)
@ -982,7 +1022,7 @@ func TestHandlerExtractJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
acme.URLSafeProvisionerName(prov)) provName)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
ctx context.Context ctx context.Context
@ -991,30 +1031,16 @@ func TestHandlerExtractJWK(t *testing.T) {
statusCode int statusCode int
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jws expected in request context")), problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, acme.JwsContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -1031,8 +1057,8 @@ func TestHandlerExtractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -1049,8 +1075,8 @@ func TestHandlerExtractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, acme.JwsContextKey, _jws)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -1058,12 +1084,14 @@ func TestHandlerExtractJWK(t *testing.T) {
} }
}, },
"fail/GetAccountByKey-error": func(t *testing.T) test { "fail/GetAccountByKey-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, jwk.KeyID, pub.KeyID) assert.Equals(t, jwk.KeyID, pub.KeyID)
return nil, acme.ServerInternalErr(errors.New("force")) return nil, acme.ServerInternalErr(errors.New("force"))
@ -1075,12 +1103,14 @@ func TestHandlerExtractJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, jwk.KeyID, pub.KeyID) assert.Equals(t, jwk.KeyID, pub.KeyID)
return acc, nil return acc, nil
@ -1092,22 +1122,24 @@ func TestHandlerExtractJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"} acc := &acme.Account{Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, jwk.KeyID, pub.KeyID) assert.Equals(t, jwk.KeyID, pub.KeyID)
return acc, nil return acc, nil
}, },
}, },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
_acc, err := accountFromContext(r) _acc, err := acme.AccountFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, _acc, acc) assert.Equals(t, _acc, acc)
_jwk, err := jwkFromContext(r) _jwk, err := acme.JwkFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, _jwk.KeyID, pub.KeyID) assert.Equals(t, _jwk.KeyID, pub.KeyID)
w.Write(testBody) w.Write(testBody)
@ -1116,22 +1148,24 @@ func TestHandlerExtractJWK(t *testing.T) {
} }
}, },
"ok/no-account": func(t *testing.T) test { "ok/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, acme.JwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getAccountByKey: func(p provisioner.Interface, jwk *jose.JSONWebKey) (*acme.Account, error) { getAccountByKey: func(ctx context.Context, jwk *jose.JSONWebKey) (*acme.Account, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, jwk.KeyID, pub.KeyID) assert.Equals(t, jwk.KeyID, pub.KeyID)
return nil, database.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
_acc, err := accountFromContext(r) _acc, err := acme.AccountFromContext(r.Context())
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Nil(t, _acc) assert.Nil(t, _acc)
_jwk, err := jwkFromContext(r) _jwk, err := acme.JwkFromContext(r.Context())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, _jwk.KeyID, pub.KeyID) assert.Equals(t, _jwk.KeyID, pub.KeyID)
w.Write(testBody) w.Write(testBody)
@ -1192,14 +1226,14 @@ func TestHandlerValidateJWS(t *testing.T) {
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, nil), ctx: context.WithValue(context.Background(), acme.JwsContextKey, nil),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("jws expected in request context")), problem: acme.ServerInternalErr(errors.New("jws expected in request context")),
} }
}, },
"fail/no-signature": func(t *testing.T) test { "fail/no-signature": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), ctx: context.WithValue(context.Background(), acme.JwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("request body does not contain a signature")), problem: acme.MalformedErr(errors.New("request body does not contain a signature")),
} }
@ -1212,7 +1246,7 @@ func TestHandlerValidateJWS(t *testing.T) {
}, },
} }
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("request body contains more than one signature")), problem: acme.MalformedErr(errors.New("request body contains more than one signature")),
} }
@ -1224,7 +1258,7 @@ func TestHandlerValidateJWS(t *testing.T) {
}, },
} }
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("unprotected header must not be used")), problem: acme.MalformedErr(errors.New("unprotected header must not be used")),
} }
@ -1236,7 +1270,7 @@ func TestHandlerValidateJWS(t *testing.T) {
}, },
} }
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")), problem: acme.MalformedErr(errors.New("unsuitable algorithm: none")),
} }
@ -1248,7 +1282,7 @@ func TestHandlerValidateJWS(t *testing.T) {
}, },
} }
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)), problem: acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", jose.HS256)),
} }
@ -1276,7 +1310,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")), problem: acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match")),
} }
@ -1304,7 +1338,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")), problem: acme.MalformedErr(errors.Errorf("rsa keys must be at least 2048 bits (256 bytes) in size")),
} }
@ -1321,7 +1355,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return acme.ServerInternalErr(errors.New("force")) return acme.ServerInternalErr(errors.New("force"))
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 500, statusCode: 500,
problem: acme.ServerInternalErr(errors.New("force")), problem: acme.ServerInternalErr(errors.New("force")),
} }
@ -1338,7 +1372,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.New("jws missing url protected header")), problem: acme.MalformedErr(errors.New("jws missing url protected header")),
} }
@ -1362,7 +1396,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)), problem: acme.MalformedErr(errors.Errorf("url header in JWS (foo) does not match request url (%s)", url)),
} }
@ -1391,7 +1425,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")), problem: acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive")),
} }
@ -1415,7 +1449,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
statusCode: 400, statusCode: 400,
problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")), problem: acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header")),
} }
@ -1440,7 +1474,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
}, },
@ -1470,7 +1504,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
}, },
@ -1500,7 +1534,7 @@ func TestHandlerValidateJWS(t *testing.T) {
return nil return nil
}, },
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), acme.JwsContextKey, jws),
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
}, },

View file

@ -58,17 +58,13 @@ func (f *FinalizeRequest) Validate() error {
// NewOrder ACME api for creating a new order. // NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) ctx := r.Context()
acc, err := acme.AccountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acc, err := accountFromContext(r) payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -84,7 +80,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
o, err := h.Auth.NewOrder(prov, acme.OrderOptions{ o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{
AccountID: acc.GetID(), AccountID: acc.GetID(),
Identifiers: nor.Identifiers, Identifiers: nor.Identifiers,
NotBefore: nor.NotBefore, NotBefore: nor.NotBefore,
@ -95,46 +91,38 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
api.JSONStatus(w, o, http.StatusCreated) api.JSONStatus(w, o, http.StatusCreated)
} }
// GetOrder ACME api for retrieving an order. // GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) ctx := r.Context()
if err != nil { acc, err := acme.AccountFromContext(ctx)
api.WriteError(w, err)
return
}
acc, err := accountFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
oid := chi.URLParam(r, "ordID") oid := chi.URLParam(r, "ordID")
o, err := h.Auth.GetOrder(prov, acc.GetID(), oid) o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.GetID())) w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID()))
api.JSON(w, o) api.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attemptst to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
prov, err := provisionerFromContext(r) ctx := r.Context()
acc, err := acme.AccountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acc, err := accountFromContext(r) payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(r)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -150,12 +138,12 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
} }
oid := chi.URLParam(r, "ordID") oid := chi.URLParam(r, "ordID")
o, err := h.Auth.FinalizeOrder(prov, acc.GetID(), oid, fr.csr) o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
w.Header().Set("Location", h.Auth.GetLink(acme.OrderLink, acme.URLSafeProvisionerName(prov), true, o.ID)) w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID))
api.JSON(w, o) api.JSON(w, o)
} }

View file

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
@ -16,7 +17,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil" "github.com/smallstep/cli/crypto/pemutil"
) )
@ -175,8 +175,10 @@ func TestHandlerGetOrder(t *testing.T) {
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv() prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s", provName := url.PathEscape(prov.GetName())
acme.URLSafeProvisionerName(prov), o.ID) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), provName, o.ID)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -185,33 +187,17 @@ func TestHandlerGetOrder(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: ctx, ctx: ctx,
@ -221,8 +207,8 @@ func TestHandlerGetOrder(t *testing.T) {
}, },
"fail/getOrder-error": func(t *testing.T) test { "fail/getOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
@ -235,20 +221,22 @@ func TestHandlerGetOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
getOrder: func(p provisioner.Interface, accID, id string) (*acme.Order, error) { getOrder: func(ctx context.Context, accID, id string) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID) assert.Equals(t, id, o.ID)
return &o, nil return &o, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink) assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{o.ID}) assert.Equals(t, in, []string{o.ID})
return url return url
@ -314,8 +302,10 @@ func TestHandlerNewOrder(t *testing.T) {
} }
prov := newProv() prov := newProv()
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provName := url.PathEscape(prov.GetName())
acme.URLSafeProvisionerName(prov)) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/new-order",
baseURL.String(), provName)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -324,32 +314,16 @@ func TestHandlerNewOrder(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -358,8 +332,8 @@ func TestHandlerNewOrder(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -368,9 +342,9 @@ func TestHandlerNewOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -379,9 +353,9 @@ func TestHandlerNewOrder(t *testing.T) {
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -393,9 +367,9 @@ func TestHandlerNewOrder(t *testing.T) {
nor := &NewOrderRequest{} nor := &NewOrderRequest{}
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -412,12 +386,14 @@ func TestHandlerNewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers) assert.Equals(t, ops.Identifiers, nor.Identifiers)
@ -441,12 +417,15 @@ func TestHandlerNewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers) assert.Equals(t, ops.Identifiers, nor.Identifiers)
@ -454,12 +433,11 @@ func TestHandlerNewOrder(t *testing.T) {
assert.Equals(t, ops.NotAfter, naf) assert.Equals(t, ops.NotAfter, naf)
return &o, nil return &o, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink) assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{o.ID}) assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -476,12 +454,15 @@ func TestHandlerNewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
newOrder: func(p provisioner.Interface, ops acme.OrderOptions) (*acme.Order, error) { newOrder: func(ctx context.Context, ops acme.OrderOptions) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, ops.AccountID, acc.ID) assert.Equals(t, ops.AccountID, acc.ID)
assert.Equals(t, ops.Identifiers, nor.Identifiers) assert.Equals(t, ops.Identifiers, nor.Identifiers)
@ -490,12 +471,11 @@ func TestHandlerNewOrder(t *testing.T) {
assert.True(t, ops.NotAfter.IsZero()) assert.True(t, ops.NotAfter.IsZero())
return &o, nil return &o, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink) assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{o.ID}) assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID) return fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, o.ID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -534,7 +514,8 @@ func TestHandlerNewOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/order/%s", o.ID)}) []string{fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(),
provName, o.ID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -567,8 +548,10 @@ func TestHandlerFinalizeOrder(t *testing.T) {
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv() prov := newProv()
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/order/%s/finalize", provName := url.PathEscape(prov.GetName())
acme.URLSafeProvisionerName(prov), o.ID) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/%s/finalize",
baseURL.String(), provName, o.ID)
type test struct { type test struct {
auth acme.Interface auth acme.Interface
@ -577,33 +560,17 @@ func TestHandlerFinalizeOrder(t *testing.T) {
problem *acme.Error problem *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.Background(),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/nil-provisioner": func(t *testing.T) test {
return test{
auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, nil),
statusCode: 500,
problem: acme.ServerInternalErr(errors.New("provisioner expected in request context")),
}
},
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), acme.ProvisionerContextKey, prov),
statusCode: 400, statusCode: 400,
problem: acme.AccountDoesNotExistErr(nil), problem: acme.AccountDoesNotExistErr(nil),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, acme.AccContextKey, nil)
return test{ return test{
auth: &mockAcmeAuthority{}, auth: &mockAcmeAuthority{},
ctx: ctx, ctx: ctx,
@ -613,8 +580,8 @@ func TestHandlerFinalizeOrder(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -623,9 +590,9 @@ func TestHandlerFinalizeOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, acme.PayloadContextKey, nil)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -634,9 +601,9 @@ func TestHandlerFinalizeOrder(t *testing.T) {
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -648,9 +615,9 @@ func TestHandlerFinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{} fr := &FinalizeRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -664,13 +631,15 @@ func TestHandlerFinalizeOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID) assert.Equals(t, id, o.ID)
@ -690,26 +659,28 @@ func TestHandlerFinalizeOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), acme.ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, acme.AccContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, acme.PayloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, acme.BaseURLContextKey, baseURL)
return test{ return test{
auth: &mockAcmeAuthority{ auth: &mockAcmeAuthority{
finalizeOrder: func(p provisioner.Interface, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) { finalizeOrder: func(ctx context.Context, accID, id string, incsr *x509.CertificateRequest) (*acme.Order, error) {
p, err := acme.ProvisionerFromContext(ctx)
assert.FatalError(t, err)
assert.Equals(t, p, prov) assert.Equals(t, p, prov)
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
assert.Equals(t, id, o.ID) assert.Equals(t, id, o.ID)
assert.Equals(t, incsr.Raw, csr.Raw) assert.Equals(t, incsr.Raw, csr.Raw)
return &o, nil return &o, nil
}, },
getLink: func(typ acme.Link, provID string, abs bool, in ...string) string { getLink: func(ctx context.Context, typ acme.Link, abs bool, in ...string) string {
assert.Equals(t, typ, acme.OrderLink) assert.Equals(t, typ, acme.OrderLink)
assert.Equals(t, provID, acme.URLSafeProvisionerName(prov))
assert.True(t, abs) assert.True(t, abs)
assert.Equals(t, in, []string{o.ID}) assert.Equals(t, in, []string{o.ID})
return fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", return fmt.Sprintf("%s/acme/%s/order/%s",
acme.URLSafeProvisionerName(prov), o.ID) baseURL.String(), provName, o.ID)
}, },
}, },
ctx: ctx, ctx: ctx,
@ -748,8 +719,8 @@ func TestHandlerFinalizeOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", []string{fmt.Sprintf("%s/acme/%s/order/%s",
acme.URLSafeProvisionerName(prov), o.ID)}) baseURL, provName, o.ID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -1,6 +1,7 @@
package acme package acme
import ( import (
"context"
"crypto" "crypto"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@ -19,23 +20,29 @@ import (
// Interface is the acme authority interface. // Interface is the acme authority interface.
type Interface interface { type Interface interface {
DeactivateAccount(provisioner.Interface, string) (*Account, error) GetDirectory(ctx context.Context) (*Directory, error)
FinalizeOrder(provisioner.Interface, string, string, *x509.CertificateRequest) (*Order, error)
GetAccount(provisioner.Interface, string) (*Account, error)
GetAccountByKey(provisioner.Interface, *jose.JSONWebKey) (*Account, error)
GetAuthz(provisioner.Interface, string, string) (*Authz, error)
GetCertificate(string, string) ([]byte, error)
GetDirectory(provisioner.Interface, string) *Directory
GetLink(Link, string, bool, ...string) string
GetOrder(provisioner.Interface, string, string) (*Order, error)
GetOrdersByAccount(provisioner.Interface, string) ([]string, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
NewAccount(provisioner.Interface, AccountOptions) (*Account, error)
NewNonce() (string, error) NewNonce() (string, error)
NewOrder(provisioner.Interface, OrderOptions) (*Order, error)
UpdateAccount(provisioner.Interface, string, []string) (*Account, error)
UseNonce(string) error UseNonce(string) error
ValidateChallenge(provisioner.Interface, string, string, *jose.JSONWebKey) (*Challenge, error)
DeactivateAccount(ctx context.Context, accID string) (*Account, error)
GetAccount(ctx context.Context, accID string) (*Account, error)
GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error)
NewAccount(ctx context.Context, ao AccountOptions) (*Account, error)
UpdateAccount(context.Context, string, []string) (*Account, error)
GetAuthz(ctx context.Context, accID string, authzID string) (*Authz, error)
ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error)
FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error)
GetOrder(ctx context.Context, accID string, orderID string) (*Order, error)
GetOrdersByAccount(ctx context.Context, accID string) ([]string, error)
NewOrder(ctx context.Context, oo OrderOptions) (*Order, error)
GetCertificate(string, string) ([]byte, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string
GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string
} }
// Authority is the layer that handles all ACME interactions. // Authority is the layer that handles all ACME interactions.
@ -77,20 +84,24 @@ func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Aut
} }
// GetLink returns the requested link from the directory. // GetLink returns the requested link from the directory.
func (a *Authority) GetLink(typ Link, provID string, abs bool, inputs ...string) string { func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
return a.dir.getLink(typ, provID, abs, inputs...) return a.dir.getLink(ctx, typ, abs, inputs...)
}
// GetLinkExplicit returns the requested link from the directory.
func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string {
return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...)
} }
// GetDirectory returns the ACME directory object. // GetDirectory returns the ACME directory object.
func (a *Authority) GetDirectory(p provisioner.Interface, baseURLFromRequest string) *Directory { func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) {
name := url.PathEscape(p.GetName())
return &Directory{ return &Directory{
NewNonce: a.dir.getLinkFromBaseURL(NewNonceLink, name, true, baseURLFromRequest), NewNonce: a.dir.getLink(ctx, NewNonceLink, true),
NewAccount: a.dir.getLinkFromBaseURL(NewAccountLink, name, true, baseURLFromRequest), NewAccount: a.dir.getLink(ctx, NewAccountLink, true),
NewOrder: a.dir.getLinkFromBaseURL(NewOrderLink, name, true, baseURLFromRequest), NewOrder: a.dir.getLink(ctx, NewOrderLink, true),
RevokeCert: a.dir.getLinkFromBaseURL(RevokeCertLink, name, true, baseURLFromRequest), RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true),
KeyChange: a.dir.getLinkFromBaseURL(KeyChangeLink, name, true, baseURLFromRequest), KeyChange: a.dir.getLink(ctx, KeyChangeLink, true),
} }, nil
} }
// LoadProvisionerByID calls out to the SignAuthority interface to load a // LoadProvisionerByID calls out to the SignAuthority interface to load a
@ -114,16 +125,16 @@ func (a *Authority) UseNonce(nonce string) error {
} }
// NewAccount creates, stores, and returns a new ACME account. // NewAccount creates, stores, and returns a new ACME account.
func (a *Authority) NewAccount(p provisioner.Interface, ao AccountOptions) (*Account, error) { func (a *Authority) NewAccount(ctx context.Context, ao AccountOptions) (*Account, error) {
acc, err := newAccount(a.db, ao) acc, err := newAccount(a.db, ao)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return acc.toACME(a.db, a.dir, p) return acc.toACME(ctx, a.db, a.dir)
} }
// UpdateAccount updates an ACME account. // UpdateAccount updates an ACME account.
func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []string) (*Account, error) { func (a *Authority) UpdateAccount(ctx context.Context, id string, contact []string) (*Account, error) {
acc, err := getAccountByID(a.db, id) acc, err := getAccountByID(a.db, id)
if err != nil { if err != nil {
return nil, ServerInternalErr(err) return nil, ServerInternalErr(err)
@ -131,20 +142,20 @@ func (a *Authority) UpdateAccount(p provisioner.Interface, id string, contact []
if acc, err = acc.update(a.db, contact); err != nil { if acc, err = acc.update(a.db, contact); err != nil {
return nil, err return nil, err
} }
return acc.toACME(a.db, a.dir, p) return acc.toACME(ctx, a.db, a.dir)
} }
// GetAccount returns an ACME account. // GetAccount returns an ACME account.
func (a *Authority) GetAccount(p provisioner.Interface, id string) (*Account, error) { func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) {
acc, err := getAccountByID(a.db, id) acc, err := getAccountByID(a.db, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return acc.toACME(a.db, a.dir, p) return acc.toACME(ctx, a.db, a.dir)
} }
// DeactivateAccount deactivates an ACME account. // DeactivateAccount deactivates an ACME account.
func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Account, error) { func (a *Authority) DeactivateAccount(ctx context.Context, id string) (*Account, error) {
acc, err := getAccountByID(a.db, id) acc, err := getAccountByID(a.db, id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -152,7 +163,7 @@ func (a *Authority) DeactivateAccount(p provisioner.Interface, id string) (*Acco
if acc, err = acc.deactivate(a.db); err != nil { if acc, err = acc.deactivate(a.db); err != nil {
return nil, err return nil, err
} }
return acc.toACME(a.db, a.dir, p) return acc.toACME(ctx, a.db, a.dir)
} }
func keyToID(jwk *jose.JSONWebKey) (string, error) { func keyToID(jwk *jose.JSONWebKey) (string, error) {
@ -164,7 +175,7 @@ func keyToID(jwk *jose.JSONWebKey) (string, error) {
} }
// GetAccountByKey returns the ACME associated with the jwk id. // GetAccountByKey returns the ACME associated with the jwk id.
func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKey) (*Account, error) { func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
kid, err := keyToID(jwk) kid, err := keyToID(jwk)
if err != nil { if err != nil {
return nil, err return nil, err
@ -173,11 +184,11 @@ func (a *Authority) GetAccountByKey(p provisioner.Interface, jwk *jose.JSONWebKe
if err != nil { if err != nil {
return nil, err return nil, err
} }
return acc.toACME(a.db, a.dir, p) return acc.toACME(ctx, a.db, a.dir)
} }
// GetOrder returns an ACME order. // GetOrder returns an ACME order.
func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*Order, error) { func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) {
o, err := getOrder(a.db, orderID) o, err := getOrder(a.db, orderID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -188,11 +199,11 @@ func (a *Authority) GetOrder(p provisioner.Interface, accID, orderID string) (*O
if o, err = o.updateStatus(a.db); err != nil { if o, err = o.updateStatus(a.db); err != nil {
return nil, err return nil, err
} }
return o.toACME(a.db, a.dir, p) return o.toACME(ctx, a.db, a.dir)
} }
// GetOrdersByAccount returns the list of order urls owned by the account. // GetOrdersByAccount returns the list of order urls owned by the account.
func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]string, error) { func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
oids, err := getOrderIDsByAccount(a.db, id) oids, err := getOrderIDsByAccount(a.db, id)
if err != nil { if err != nil {
return nil, err return nil, err
@ -207,22 +218,26 @@ func (a *Authority) GetOrdersByAccount(p provisioner.Interface, id string) ([]st
if o.Status == StatusInvalid { if o.Status == StatusInvalid {
continue continue
} }
ret = append(ret, a.dir.getLink(OrderLink, URLSafeProvisionerName(p), true, o.ID)) ret = append(ret, a.dir.getLink(ctx, OrderLink, true, o.ID))
} }
return ret, nil return ret, nil
} }
// NewOrder generates, stores, and returns a new ACME order. // NewOrder generates, stores, and returns a new ACME order.
func (a *Authority) NewOrder(p provisioner.Interface, ops OrderOptions) (*Order, error) { func (a *Authority) NewOrder(ctx context.Context, ops OrderOptions) (*Order, error) {
order, err := newOrder(a.db, ops) order, err := newOrder(a.db, ops)
if err != nil { if err != nil {
return nil, Wrap(err, "error creating order") return nil, Wrap(err, "error creating order")
} }
return order.toACME(a.db, a.dir, p) return order.toACME(ctx, a.db, a.dir)
} }
// FinalizeOrder attempts to finalize an order and generate a new certificate. // FinalizeOrder attempts to finalize an order and generate a new certificate.
func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) { func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
prov, err := ProvisionerFromContext(ctx)
if err != nil {
return nil, err
}
o, err := getOrder(a.db, orderID) o, err := getOrder(a.db, orderID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -230,16 +245,16 @@ func (a *Authority) FinalizeOrder(p provisioner.Interface, accID, orderID string
if accID != o.AccountID { if accID != o.AccountID {
return nil, UnauthorizedErr(errors.New("account does not own order")) return nil, UnauthorizedErr(errors.New("account does not own order"))
} }
o, err = o.finalize(a.db, csr, a.signAuth, p) o, err = o.finalize(a.db, csr, a.signAuth, prov)
if err != nil { if err != nil {
return nil, Wrap(err, "error finalizing order") return nil, Wrap(err, "error finalizing order")
} }
return o.toACME(a.db, a.dir, p) return o.toACME(ctx, a.db, a.dir)
} }
// GetAuthz retrieves and attempts to update the status on an ACME authz // GetAuthz retrieves and attempts to update the status on an ACME authz
// before returning. // before returning.
func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*Authz, error) { func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Authz, error) {
az, err := getAuthz(a.db, authzID) az, err := getAuthz(a.db, authzID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -251,11 +266,11 @@ func (a *Authority) GetAuthz(p provisioner.Interface, accID, authzID string) (*A
if err != nil { if err != nil {
return nil, Wrap(err, "error updating authz status") return nil, Wrap(err, "error updating authz status")
} }
return az.toACME(a.db, a.dir, p) return az.toACME(ctx, a.db, a.dir)
} }
// ValidateChallenge attempts to validate the challenge. // ValidateChallenge attempts to validate the challenge.
func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) { func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, jwk *jose.JSONWebKey) (*Challenge, error) {
ch, err := getChallenge(a.db, chID) ch, err := getChallenge(a.db, chID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -279,7 +294,7 @@ func (a *Authority) ValidateChallenge(p provisioner.Interface, accID, chID strin
if err != nil { if err != nil {
return nil, Wrap(err, "error attempting challenge validation") return nil, Wrap(err, "error attempting challenge validation")
} }
return ch.toACME(a.db, a.dir, p) return ch.toACME(ctx, a.db, a.dir)
} }
// GetCertificate retrieves the Certificate by ID. // GetCertificate retrieves the Certificate by ID.

View file

@ -1,8 +1,10 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url"
"testing" "testing"
"time" "time"
@ -16,7 +18,11 @@ import (
func TestAuthorityGetLink(t *testing.T) { func TestAuthorityGetLink(t *testing.T) {
auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
assert.FatalError(t, err) assert.FatalError(t, err)
provID := "acme-test-provisioner" prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
type test struct { type test struct {
auth *Authority auth *Authority
typ Link typ Link
@ -30,7 +36,7 @@ func TestAuthorityGetLink(t *testing.T) {
auth: auth, auth: auth,
typ: NewAccountLink, typ: NewAccountLink,
abs: true, abs: true,
res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID), res: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
} }
}, },
"ok/new-account/no-abs": func(t *testing.T) test { "ok/new-account/no-abs": func(t *testing.T) test {
@ -38,7 +44,7 @@ func TestAuthorityGetLink(t *testing.T) {
auth: auth, auth: auth,
typ: NewAccountLink, typ: NewAccountLink,
abs: false, abs: false,
res: fmt.Sprintf("/%s/new-account", provID), res: fmt.Sprintf("/%s/new-account", provName),
} }
}, },
"ok/order/abs": func(t *testing.T) test { "ok/order/abs": func(t *testing.T) test {
@ -47,7 +53,7 @@ func TestAuthorityGetLink(t *testing.T) {
typ: OrderLink, typ: OrderLink,
abs: true, abs: true,
inputs: []string{"foo"}, inputs: []string{"foo"},
res: fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/foo", provID), res: fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName),
} }
}, },
"ok/order/no-abs": func(t *testing.T) test { "ok/order/no-abs": func(t *testing.T) test {
@ -56,14 +62,14 @@ func TestAuthorityGetLink(t *testing.T) {
typ: OrderLink, typ: OrderLink,
abs: false, abs: false,
inputs: []string{"foo"}, inputs: []string{"foo"},
res: fmt.Sprintf("/%s/order/foo", provID), res: fmt.Sprintf("/%s/order/foo", provName),
} }
}, },
} }
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
link := tc.auth.GetLink(tc.typ, provID, tc.abs, tc.inputs...) link := tc.auth.GetLink(ctx, tc.typ, tc.abs, tc.inputs...)
assert.Equals(t, tc.res, link) assert.Equals(t, tc.res, link)
}) })
} }
@ -72,28 +78,68 @@ func TestAuthorityGetLink(t *testing.T) {
func TestAuthorityGetDirectory(t *testing.T) { func TestAuthorityGetDirectory(t *testing.T) {
auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil) auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newProv()
acmeDir := auth.GetDirectory(prov, "")
assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", URLSafeProvisionerName(prov)))
assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", URLSafeProvisionerName(prov)))
assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", URLSafeProvisionerName(prov)))
//assert.Equals(t, acmeDir.NewOrder, "httsp://ca.smallstep.com/acme/new-authz")
assert.Equals(t, acmeDir.RevokeCert, fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", URLSafeProvisionerName(prov)))
assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", URLSafeProvisionerName(prov)))
}
func TestAuthorityGetDirectoryWithBaseURL(t *testing.T) {
baseURL := "http://my.proxied.host"
auth, err := NewAuthority(new(db.MockNoSQLDB), "ca.smallstep.com", "acme", nil)
assert.FatalError(t, err)
prov := newProv() prov := newProv()
acmeDir := auth.GetDirectory(prov, baseURL) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
assert.Equals(t, acmeDir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, URLSafeProvisionerName(prov))) ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
assert.Equals(t, acmeDir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", baseURL, URLSafeProvisionerName(prov))) ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
assert.Equals(t, acmeDir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", baseURL, URLSafeProvisionerName(prov)))
//assert.Equals(t, acmeDir.NewOrder, "%s/acme/new-authz") type test struct {
assert.Equals(t, acmeDir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, URLSafeProvisionerName(prov))) ctx context.Context
assert.Equals(t, acmeDir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", baseURL, URLSafeProvisionerName(prov))) err *Error
}
tests := map[string]func(t *testing.T) test{
"ok/empty-provisioner": func(t *testing.T) test {
return test{
ctx: context.Background(),
}
},
"ok/no-baseURL": func(t *testing.T) test {
return test{
ctx: context.WithValue(context.Background(), ProvisionerContextKey, prov),
}
},
"ok/baseURL": func(t *testing.T) test {
return test{
ctx: ctx,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if dir, err := auth.GetDirectory(tc.ctx); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
bu := BaseURLFromContext(tc.ctx)
if bu == nil {
bu = &url.URL{Scheme: "https", Host: "ca.smallstep.com"}
}
var provName string
prov, err := ProvisionerFromContext(tc.ctx)
if err != nil {
provName = ""
} else {
provName = url.PathEscape(prov.GetName())
}
assert.Equals(t, dir.NewNonce, fmt.Sprintf("%s/acme/%s/new-nonce", bu.String(), provName))
assert.Equals(t, dir.NewAccount, fmt.Sprintf("%s/acme/%s/new-account", bu.String(), provName))
assert.Equals(t, dir.NewOrder, fmt.Sprintf("%s/acme/%s/new-order", bu.String(), provName))
assert.Equals(t, dir.RevokeCert, fmt.Sprintf("%s/acme/%s/revoke-cert", bu.String(), provName))
assert.Equals(t, dir.KeyChange, fmt.Sprintf("%s/acme/%s/key-change", bu.String(), provName))
}
}
})
}
} }
func TestAuthorityNewNonce(t *testing.T) { func TestAuthorityNewNonce(t *testing.T) {
@ -207,6 +253,8 @@ func TestAuthorityNewAccount(t *testing.T) {
Key: jwk, Contact: []string{"foo", "bar"}, Key: jwk, Contact: []string{"foo", "bar"},
} }
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
ops AccountOptions ops AccountOptions
@ -239,7 +287,7 @@ func TestAuthorityNewAccount(t *testing.T) {
if count == 1 { if count == 1 {
var acc *account var acc *account
assert.FatalError(t, json.Unmarshal(newval, &acc)) assert.FatalError(t, json.Unmarshal(newval, &acc))
*acmeacc, err = acc.toACME(nil, dir, prov) *acmeacc, err = acc.toACME(ctx, nil, dir)
return nil, true, nil return nil, true, nil
} }
count++ count++
@ -257,7 +305,7 @@ func TestAuthorityNewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAcc, err := tc.auth.NewAccount(prov, tc.ops); err != nil { if acmeAcc, err := tc.auth.NewAccount(ctx, tc.ops); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -280,6 +328,8 @@ func TestAuthorityNewAccount(t *testing.T) {
func TestAuthorityGetAccount(t *testing.T) { func TestAuthorityGetAccount(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id string id string
@ -324,7 +374,7 @@ func TestAuthorityGetAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAcc, err := tc.auth.GetAccount(prov, tc.id); err != nil { if acmeAcc, err := tc.auth.GetAccount(ctx, tc.id); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -337,7 +387,7 @@ func TestAuthorityGetAccount(t *testing.T) {
gotb, err := json.Marshal(acmeAcc) gotb, err := json.Marshal(acmeAcc)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -351,6 +401,8 @@ func TestAuthorityGetAccount(t *testing.T) {
func TestAuthorityGetAccountByKey(t *testing.T) { func TestAuthorityGetAccountByKey(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
@ -425,7 +477,7 @@ func TestAuthorityGetAccountByKey(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAcc, err := tc.auth.GetAccountByKey(prov, tc.jwk); err != nil { if acmeAcc, err := tc.auth.GetAccountByKey(ctx, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -438,7 +490,7 @@ func TestAuthorityGetAccountByKey(t *testing.T) {
gotb, err := json.Marshal(acmeAcc) gotb, err := json.Marshal(acmeAcc)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -452,6 +504,8 @@ func TestAuthorityGetAccountByKey(t *testing.T) {
func TestAuthorityGetOrder(t *testing.T) { func TestAuthorityGetOrder(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id, accID string id, accID string
@ -549,7 +603,7 @@ func TestAuthorityGetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeO, err := tc.auth.GetOrder(prov, tc.accID, tc.id); err != nil { if acmeO, err := tc.auth.GetOrder(ctx, tc.accID, tc.id); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -562,7 +616,7 @@ func TestAuthorityGetOrder(t *testing.T) {
gotb, err := json.Marshal(acmeO) gotb, err := json.Marshal(acmeO)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -669,6 +723,8 @@ func TestAuthorityGetCertificate(t *testing.T) {
func TestAuthorityGetAuthz(t *testing.T) { func TestAuthorityGetAuthz(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id, accID string id, accID string
@ -798,7 +854,7 @@ func TestAuthorityGetAuthz(t *testing.T) {
return ret, nil return ret, nil
}, },
} }
acmeAz, err := az.toACME(mockdb, newDirectory("ca.smallstep.com", "acme"), prov) acmeAz, err := az.toACME(ctx, mockdb, newDirectory("ca.smallstep.com", "acme"))
assert.FatalError(t, err) assert.FatalError(t, err)
count = 0 count = 0
@ -839,7 +895,7 @@ func TestAuthorityGetAuthz(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAz, err := tc.auth.GetAuthz(prov, tc.accID, tc.id); err != nil { if acmeAz, err := tc.auth.GetAuthz(ctx, tc.accID, tc.id); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -864,6 +920,8 @@ func TestAuthorityGetAuthz(t *testing.T) {
func TestAuthorityNewOrder(t *testing.T) { func TestAuthorityNewOrder(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
ops OrderOptions ops OrderOptions
@ -917,7 +975,7 @@ func TestAuthorityNewOrder(t *testing.T) {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
var o order var o order
assert.FatalError(t, json.Unmarshal(newval, &o)) assert.FatalError(t, json.Unmarshal(newval, &o))
*acmeO, err = o.toACME(nil, dir, prov) *acmeO, err = o.toACME(ctx, nil, dir)
assert.FatalError(t, err) assert.FatalError(t, err)
*accID = o.AccountID *accID = o.AccountID
case 9: case 9:
@ -942,7 +1000,7 @@ func TestAuthorityNewOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeO, err := tc.auth.NewOrder(prov, tc.ops); err != nil { if acmeO, err := tc.auth.NewOrder(ctx, tc.ops); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -965,6 +1023,10 @@ func TestAuthorityNewOrder(t *testing.T) {
func TestAuthorityGetOrdersByAccount(t *testing.T) { func TestAuthorityGetOrdersByAccount(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
type test struct { type test struct {
auth *Authority auth *Authority
id string id string
@ -1065,8 +1127,8 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) {
auth: auth, auth: auth,
id: id, id: id,
res: []string{ res: []string{
fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), foo.ID), fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, foo.ID),
fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s", URLSafeProvisionerName(prov), baz.ID), fmt.Sprintf("%s/acme/%s/order/%s", baseURL.String(), provName, baz.ID),
}, },
} }
}, },
@ -1074,7 +1136,7 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if orderLinks, err := tc.auth.GetOrdersByAccount(prov, tc.id); err != nil { if orderLinks, err := tc.auth.GetOrdersByAccount(ctx, tc.id); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -1093,6 +1155,8 @@ func TestAuthorityGetOrdersByAccount(t *testing.T) {
func TestAuthorityFinalizeOrder(t *testing.T) { func TestAuthorityFinalizeOrder(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id, accID string id, accID string
@ -1188,7 +1252,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeO, err := tc.auth.FinalizeOrder(prov, tc.accID, tc.id, nil); err != nil { if acmeO, err := tc.auth.FinalizeOrder(ctx, tc.accID, tc.id, nil); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -1201,7 +1265,7 @@ func TestAuthorityFinalizeOrder(t *testing.T) {
gotb, err := json.Marshal(acmeO) gotb, err := json.Marshal(acmeO)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.o.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.o.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1215,6 +1279,8 @@ func TestAuthorityFinalizeOrder(t *testing.T) {
func TestAuthorityValidateChallenge(t *testing.T) { func TestAuthorityValidateChallenge(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id, accID string id, accID string
@ -1311,7 +1377,7 @@ func TestAuthorityValidateChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeCh, err := tc.auth.ValidateChallenge(prov, tc.accID, tc.id, nil); err != nil { if acmeCh, err := tc.auth.ValidateChallenge(ctx, tc.accID, tc.id, nil); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -1324,7 +1390,7 @@ func TestAuthorityValidateChallenge(t *testing.T) {
gotb, err := json.Marshal(acmeCh) gotb, err := json.Marshal(acmeCh)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.ch.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.ch.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1339,6 +1405,8 @@ func TestAuthorityValidateChallenge(t *testing.T) {
func TestAuthorityUpdateAccount(t *testing.T) { func TestAuthorityUpdateAccount(t *testing.T) {
contact := []string{"baz", "zap"} contact := []string{"baz", "zap"}
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id string id string
@ -1418,7 +1486,7 @@ func TestAuthorityUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAcc, err := tc.auth.UpdateAccount(prov, tc.id, tc.contact); err != nil { if acmeAcc, err := tc.auth.UpdateAccount(ctx, tc.id, tc.contact); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -1431,7 +1499,7 @@ func TestAuthorityUpdateAccount(t *testing.T) {
gotb, err := json.Marshal(acmeAcc) gotb, err := json.Marshal(acmeAcc)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1445,6 +1513,8 @@ func TestAuthorityUpdateAccount(t *testing.T) {
func TestAuthorityDeactivateAccount(t *testing.T) { func TestAuthorityDeactivateAccount(t *testing.T) {
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
auth *Authority auth *Authority
id string id string
@ -1521,7 +1591,7 @@ func TestAuthorityDeactivateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if acmeAcc, err := tc.auth.DeactivateAccount(prov, tc.id); err != nil { if acmeAcc, err := tc.auth.DeactivateAccount(ctx, tc.id); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
assert.True(t, ok) assert.True(t, ok)
@ -1534,7 +1604,7 @@ func TestAuthorityDeactivateAccount(t *testing.T) {
gotb, err := json.Marshal(acmeAcc) gotb, err := json.Marshal(acmeAcc)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeExp, err := tc.acc.toACME(nil, tc.auth.dir, prov) acmeExp, err := tc.acc.toACME(ctx, nil, tc.auth.dir)
assert.FatalError(t, err) assert.FatalError(t, err)
expb, err := json.Marshal(acmeExp) expb, err := json.Marshal(acmeExp)
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -1,12 +1,12 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"strings" "strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -51,7 +51,7 @@ type authz interface {
getChallenges() []string getChallenges() []string
getCreated() time.Time getCreated() time.Time
updateStatus(db nosql.DB) (authz, error) updateStatus(db nosql.DB) (authz, error)
toACME(nosql.DB, *directory, provisioner.Interface) (*Authz, error) toACME(context.Context, nosql.DB, *directory) (*Authz, error)
} }
// baseAuthz is the base authz type that others build from. // baseAuthz is the base authz type that others build from.
@ -141,14 +141,14 @@ func (ba *baseAuthz) getCreated() time.Time {
// toACME converts the internal Authz type into the public acmeAuthz type for // toACME converts the internal Authz type into the public acmeAuthz type for
// presentation in the ACME protocol. // presentation in the ACME protocol.
func (ba *baseAuthz) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Authz, error) { func (ba *baseAuthz) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Authz, error) {
var chs = make([]*Challenge, len(ba.Challenges)) var chs = make([]*Challenge, len(ba.Challenges))
for i, chID := range ba.Challenges { for i, chID := range ba.Challenges {
ch, err := getChallenge(db, chID) ch, err := getChallenge(db, chID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
chs[i], err = ch.toACME(db, dir, p) chs[i], err = ch.toACME(ctx, db, dir)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"strings" "strings"
"testing" "testing"
@ -369,7 +370,10 @@ func TestAuthzToACME(t *testing.T) {
} }
az, err := newAuthz(mockdb, "1234", iden) az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newProv() prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct { type test struct {
db nosql.DB db nosql.DB
@ -419,7 +423,7 @@ func TestAuthzToACME(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
acmeAz, err := az.toACME(tc.db, dir, prov) acmeAz, err := az.toACME(ctx, tc.db, dir)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
@ -434,9 +438,9 @@ func TestAuthzToACME(t *testing.T) {
assert.Equals(t, acmeAz.Identifier, iden) assert.Equals(t, acmeAz.Identifier, iden)
assert.Equals(t, acmeAz.Status, StatusPending) assert.Equals(t, acmeAz.Status, StatusPending)
acmeCh1, err := ch1.toACME(nil, dir, prov) acmeCh1, err := ch1.toACME(ctx, nil, dir)
assert.FatalError(t, err) assert.FatalError(t, err)
acmeCh2, err := ch2.toACME(nil, dir, prov) acmeCh2, err := ch2.toACME(ctx, nil, dir)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, acmeAz.Challenges[0], acmeCh1) assert.Equals(t, acmeAz.Challenges[0], acmeCh1)

View file

@ -1,6 +1,7 @@
package acme package acme
import ( import (
"context"
"crypto" "crypto"
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
@ -17,7 +18,6 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -79,7 +79,7 @@ type challenge interface {
getAccountID() string getAccountID() string
getValidated() time.Time getValidated() time.Time
getCreated() time.Time getCreated() time.Time
toACME(nosql.DB, *directory, provisioner.Interface) (*Challenge, error) toACME(context.Context, nosql.DB, *directory) (*Challenge, error)
} }
// ChallengeOptions is the type used to created a new Challenge. // ChallengeOptions is the type used to created a new Challenge.
@ -175,12 +175,12 @@ func (bc *baseChallenge) getError() *AError {
// toACME converts the internal Challenge type into the public acmeChallenge // toACME converts the internal Challenge type into the public acmeChallenge
// type for presentation in the ACME protocol. // type for presentation in the ACME protocol.
func (bc *baseChallenge) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Challenge, error) { func (bc *baseChallenge) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Challenge, error) {
ac := &Challenge{ ac := &Challenge{
Type: bc.getType(), Type: bc.getType(),
Status: bc.getStatus(), Status: bc.getStatus(),
Token: bc.getToken(), Token: bc.getToken(),
URL: dir.getLink(ChallengeLink, URLSafeProvisionerName(p), true, bc.getID()), URL: dir.getLink(ctx, ChallengeLink, true, bc.getID()),
ID: bc.getID(), ID: bc.getID(),
AuthzID: bc.getAuthzID(), AuthzID: bc.getAuthzID(),
} }

View file

@ -2,6 +2,7 @@ package acme
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -20,6 +21,7 @@ import (
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"time" "time"
@ -273,6 +275,10 @@ func TestChallengeToACME(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
tests := map[string]challenge{ tests := map[string]challenge{
"dns": dnsCh, "dns": dnsCh,
"http": httpCh, "http": httpCh,
@ -280,15 +286,15 @@ func TestChallengeToACME(t *testing.T) {
} }
for name, ch := range tests { for name, ch := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ach, err := ch.toACME(nil, dir, prov) ach, err := ch.toACME(ctx, nil, dir)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, ach.Type, ch.getType()) assert.Equals(t, ach.Type, ch.getType())
assert.Equals(t, ach.Status, ch.getStatus()) assert.Equals(t, ach.Status, ch.getStatus())
assert.Equals(t, ach.Token, ch.getToken()) assert.Equals(t, ach.Token, ch.getToken())
assert.Equals(t, ach.URL, assert.Equals(t, ach.URL,
fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/%s", fmt.Sprintf("%s/acme/%s/challenge/%s",
URLSafeProvisionerName(prov), ch.getID())) baseURL.String(), provName, ch.getID()))
assert.Equals(t, ach.ID, ch.getID()) assert.Equals(t, ach.ID, ch.getID())
assert.Equals(t, ach.AuthzID, ch.getAuthzID()) assert.Equals(t, ach.AuthzID, ch.getAuthzID())

View file

@ -1,6 +1,7 @@
package acme package acme
import ( import (
"context"
"crypto/x509" "crypto/x509"
"net/url" "net/url"
"time" "time"
@ -8,8 +9,75 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
) )
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string
const (
// AccContextKey account key
AccContextKey = ContextKey("acc")
// BaseURLContextKey baseURL key
BaseURLContextKey = ContextKey("baseURL")
// JwsContextKey jws key
JwsContextKey = ContextKey("jws")
// JwkContextKey jwk key
JwkContextKey = ContextKey("jwk")
// PayloadContextKey payload key
PayloadContextKey = ContextKey("payload")
// ProvisionerContextKey provisioner key
ProvisionerContextKey = ContextKey("provisioner")
)
// AccountFromContext searches the context for an ACME account. Returns the
// account or an error.
func AccountFromContext(ctx context.Context) (*Account, error) {
val, ok := ctx.Value(AccContextKey).(*Account)
if !ok || val == nil {
return nil, AccountDoesNotExistErr(nil)
}
return val, nil
}
// BaseURLFromContext returns the baseURL if one is stored in the context.
func BaseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(BaseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// JwkFromContext searches the context for a JWK. Returns the JWK or an error.
func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey)
if !ok || val == nil {
return nil, ServerInternalErr(errors.Errorf("jwk expected in request context"))
}
return val, nil
}
// JwsFromContext searches the context for a JWS. Returns the JWS or an error.
func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature)
if !ok || val == nil {
return nil, ServerInternalErr(errors.Errorf("jws expected in request context"))
}
return val, nil
}
// ProvisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func ProvisionerFromContext(ctx context.Context) (provisioner.Interface, error) {
val, ok := ctx.Value(ProvisionerContextKey).(provisioner.Interface)
if !ok || val == nil {
return nil, ServerInternalErr(errors.Errorf("provisioner expected in request context"))
}
return val, nil
}
// SignAuthority is the interface implemented by a CA authority. // SignAuthority is the interface implemented by a CA authority.
type SignAuthority interface { type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -57,9 +125,3 @@ func (c *Clock) Now() time.Time {
} }
var clock = new(Clock) var clock = new(Clock)
// URLSafeProvisionerName returns a path escaped version of the ACME provisioner
// ID that is safe to use in URL paths.
func URLSafeProvisionerName(p provisioner.Interface) string {
return url.PathEscape(p.GetName())
}

View file

@ -1,8 +1,10 @@
package acme package acme
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -100,14 +102,18 @@ func (l Link) String() string {
} }
} }
// getLink returns an absolute or partial path to the given resource. func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
func (d *directory) getLink(typ Link, provisionerName string, abs bool, inputs ...string) string { var provName string
return d.getLinkFromBaseURL(typ, provisionerName, abs, "", inputs...) if p, err := ProvisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...)
} }
// getLinkFromBaseURL returns an absolute or partial path to the given resource and a base URL dynamically obtained from the request for which // getLinkExplicit returns an absolute or partial path to the given resource and a base
// the link is being calculated. // URL dynamically obtained from the request for which the link is being
func (d *directory) getLinkFromBaseURL(typ Link, provisionerName string, abs bool, baseURLFromRequest string, inputs ...string) string { // calculated.
func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
var link string var link string
switch typ { switch typ {
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink: case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
@ -119,12 +125,26 @@ func (d *directory) getLinkFromBaseURL(typ Link, provisionerName string, abs boo
case FinalizeLink: case FinalizeLink:
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0]) link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
} }
if abs { if abs {
baseURL := baseURLFromRequest // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL == "" { u := url.URL{}
baseURL = "https://" + d.dns if baseURL != nil {
u = *baseURL
} }
return fmt.Sprintf("%s/%s%s", baseURL, d.prefix, link)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = d.dns
}
u.Path = d.prefix + link
return u.String()
} }
return link return link
} }

View file

@ -1,7 +1,9 @@
package acme package acme
import ( import (
"context"
"fmt" "fmt"
"net/url"
"testing" "testing"
"github.com/smallstep/assert" "github.com/smallstep/assert"
@ -14,47 +16,84 @@ func TestDirectoryGetLink(t *testing.T) {
id := "1234" id := "1234"
prov := newProv() prov := newProv()
provID := URLSafeProvisionerName(prov) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
assert.Equals(t, dir.getLink(NewNonceLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-nonce", provID)) assert.Equals(t, dir.getLink(ctx, NewNonceLink, true),
assert.Equals(t, dir.getLink(NewNonceLink, provID, false), fmt.Sprintf("/%s/new-nonce", provID)) fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName))
assert.Equals(t, dir.getLink(ctx, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
assert.Equals(t, dir.getLink(NewAccountLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-account", provID)) // No provisioner
assert.Equals(t, dir.getLink(NewAccountLink, provID, false), fmt.Sprintf("/%s/new-account", provID)) ctxNoProv := context.WithValue(context.Background(), BaseURLContextKey, baseURL)
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, true),
fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
assert.Equals(t, dir.getLink(ctxNoProv, NewNonceLink, false), "//new-nonce")
assert.Equals(t, dir.getLink(AccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", provID)) // No baseURL
assert.Equals(t, dir.getLink(AccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234", provID)) ctxNoBaseURL := context.WithValue(context.Background(), ProvisionerContextKey, prov)
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, true),
fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
assert.Equals(t, dir.getLink(ctxNoBaseURL, NewNonceLink, false), fmt.Sprintf("/%s/new-nonce", provName))
assert.Equals(t, dir.getLink(NewOrderLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-order", provID)) assert.Equals(t, dir.getLink(ctx, OrderLink, true, id),
assert.Equals(t, dir.getLink(NewOrderLink, provID, false), fmt.Sprintf("/%s/new-order", provID)) fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
assert.Equals(t, dir.getLink(ctx, OrderLink, false, id), fmt.Sprintf("/%s/order/1234", provName))
assert.Equals(t, dir.getLink(OrderLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234", provID)) }
assert.Equals(t, dir.getLink(OrderLink, provID, false, id), fmt.Sprintf("/%s/order/1234", provID))
func TestDirectoryGetLinkExplicit(t *testing.T) {
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234/orders", provID)) dns := "ca.smallstep.com"
assert.Equals(t, dir.getLink(OrdersByAccountLink, provID, false, id), fmt.Sprintf("/%s/account/1234/orders", provID)) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prefix := "acme"
assert.Equals(t, dir.getLink(FinalizeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/1234/finalize", provID)) dir := newDirectory(dns, prefix)
assert.Equals(t, dir.getLink(FinalizeLink, provID, false, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) id := "1234"
assert.Equals(t, dir.getLink(NewAuthzLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/new-authz", provID)) prov := newProv()
assert.Equals(t, dir.getLink(NewAuthzLink, provID, false), fmt.Sprintf("/%s/new-authz", provID)) provID := url.PathEscape(prov.GetName())
assert.Equals(t, dir.getLink(AuthzLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/authz/1234", provID)) assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
assert.Equals(t, dir.getLink(AuthzLink, provID, false, id), fmt.Sprintf("/%s/authz/1234", provID)) assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID))
assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID))
assert.Equals(t, dir.getLink(DirectoryLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/directory", provID)) assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID))
assert.Equals(t, dir.getLink(DirectoryLink, provID, false), fmt.Sprintf("/%s/directory", provID)) assert.Equals(t, dir.getLinkExplicit(NewNonceLink, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID))
assert.Equals(t, dir.getLink(RevokeCertLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/revoke-cert", provID)) assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID))
assert.Equals(t, dir.getLink(RevokeCertLink, provID, false), fmt.Sprintf("/%s/revoke-cert", provID)) assert.Equals(t, dir.getLinkExplicit(NewAccountLink, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID))
assert.Equals(t, dir.getLink(KeyChangeLink, provID, true), fmt.Sprintf("https://ca.smallstep.com/acme/%s/key-change", provID)) assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID))
assert.Equals(t, dir.getLink(KeyChangeLink, provID, false), fmt.Sprintf("/%s/key-change", provID)) assert.Equals(t, dir.getLinkExplicit(AccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID))
assert.Equals(t, dir.getLink(ChallengeLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/challenge/1234", provID)) assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID))
assert.Equals(t, dir.getLink(ChallengeLink, provID, false, id), fmt.Sprintf("/%s/challenge/1234", provID)) assert.Equals(t, dir.getLinkExplicit(NewOrderLink, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID))
assert.Equals(t, dir.getLink(CertificateLink, provID, true, id), fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/1234", provID)) assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID))
assert.Equals(t, dir.getLink(CertificateLink, provID, false, id), fmt.Sprintf("/%s/certificate/1234", provID)) assert.Equals(t, dir.getLinkExplicit(OrderLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID))
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(OrdersByAccountLink, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID))
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(FinalizeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID))
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(NewAuthzLink, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID))
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(AuthzLink, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID))
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(DirectoryLink, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID))
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(RevokeCertLink, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID))
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(KeyChangeLink, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID))
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/1234", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(ChallengeLink, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/1234", provID))
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID))
assert.Equals(t, dir.getLinkExplicit(CertificateLink, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID))
} }

View file

@ -332,10 +332,10 @@ func getOrder(db nosql.DB, id string) (*order, error) {
// toACME converts the internal Order type into the public acmeOrder type for // toACME converts the internal Order type into the public acmeOrder type for
// presentation in the ACME protocol. // presentation in the ACME protocol.
func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*Order, error) { func (o *order) toACME(ctx context.Context, db nosql.DB, dir *directory) (*Order, error) {
azs := make([]string, len(o.Authorizations)) azs := make([]string, len(o.Authorizations))
for i, aid := range o.Authorizations { for i, aid := range o.Authorizations {
azs[i] = dir.getLink(AuthzLink, URLSafeProvisionerName(p), true, aid) azs[i] = dir.getLink(ctx, AuthzLink, true, aid)
} }
ao := &Order{ ao := &Order{
Status: o.Status, Status: o.Status,
@ -344,12 +344,12 @@ func (o *order) toACME(db nosql.DB, dir *directory, p provisioner.Interface) (*O
NotBefore: o.NotBefore.Format(time.RFC3339), NotBefore: o.NotBefore.Format(time.RFC3339),
NotAfter: o.NotAfter.Format(time.RFC3339), NotAfter: o.NotAfter.Format(time.RFC3339),
Authorizations: azs, Authorizations: azs,
Finalize: dir.getLink(FinalizeLink, URLSafeProvisionerName(p), true, o.ID), Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID),
ID: o.ID, ID: o.ID,
} }
if o.Certificate != "" { if o.Certificate != "" {
ao.Certificate = dir.getLink(CertificateLink, URLSafeProvisionerName(p), true, o.Certificate) ao.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate)
} }
return ao, nil return ao, nil
} }

View file

@ -6,6 +6,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url"
"testing" "testing"
"time" "time"
@ -150,6 +151,10 @@ func TestGetOrder(t *testing.T) {
func TestOrderToACME(t *testing.T) { func TestOrderToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme") dir := newDirectory("ca.smallstep.com", "acme")
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
type test struct { type test struct {
o *order o *order
@ -172,7 +177,7 @@ func TestOrderToACME(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
acmeOrder, err := tc.o.toACME(nil, dir, prov) acmeOrder, err := tc.o.toACME(ctx, nil, dir)
if err != nil { if err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error) ae, ok := err.(*Error)
@ -186,9 +191,10 @@ func TestOrderToACME(t *testing.T) {
assert.Equals(t, acmeOrder.ID, tc.o.ID) assert.Equals(t, acmeOrder.ID, tc.o.ID)
assert.Equals(t, acmeOrder.Status, tc.o.Status) assert.Equals(t, acmeOrder.Status, tc.o.Status)
assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers)
assert.Equals(t, acmeOrder.Finalize, fmt.Sprintf("https://ca.smallstep.com/acme/%s/order/%s/finalize", URLSafeProvisionerName(prov), tc.o.ID)) assert.Equals(t, acmeOrder.Finalize,
fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID))
if tc.o.Certificate != "" { if tc.o.Certificate != "" {
assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("https://ca.smallstep.com/acme/%s/certificate/%s", URLSafeProvisionerName(prov), tc.o.Certificate)) assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate))
} }
expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires)