Merge pull request #539 from smallstep/max/escaped-route-fix

Use different method for unescpaed paths for the router
This commit is contained in:
Max 2021-04-14 15:43:12 -07:00 committed by GitHub
commit 0ec75c98cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 140 deletions

View file

@ -121,8 +121,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID))
true, acc.ID))
api.JSONStatus(w, acc, httpStatus) api.JSONStatus(w, acc, httpStatus)
} }
@ -169,7 +168,7 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAccount(ctx, acc) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, true, acc.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID))
api.JSON(w, acc) api.JSON(w, acc)
} }

View file

@ -87,12 +87,12 @@ func NewHandler(ops HandlerOptions) api.RouterHandler {
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
getLink := h.linker.GetLinkExplicit getPath := h.linker.GetUnescapedPathSuffix
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce))))) r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.GetNonce)))))
r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
@ -101,16 +101,16 @@ func (h *Handler) Route(r api.Router) {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
} }
r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount)) r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented))
r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID)))
r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response
@ -146,11 +146,11 @@ func (d *Directory) ToLog() (interface{}, error) {
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
api.JSON(w, &Directory{ api.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true), NewNonce: h.linker.GetLink(ctx, NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true), NewAccount: h.linker.GetLink(ctx, NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true), NewOrder: h.linker.GetLink(ctx, NewOrderLinkType),
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true), RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType),
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true), KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType),
}) })
} }
@ -185,7 +185,7 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
h.linker.LinkAuthorization(ctx, az) h.linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID))
api.JSON(w, az) api.JSON(w, az)
} }
@ -235,8 +235,8 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
h.linker.LinkChallenge(ctx, ch, azID) h.linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, azID), "up")) w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID))
api.JSON(w, ch) api.JSON(w, ch)
} }

View file

@ -15,8 +15,8 @@ func NewLinker(dns, prefix string) Linker {
// Linker interface for generating links for ACME resources. // Linker interface for generating links for ACME resources.
type Linker interface { type Linker interface {
GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetLinkExplicit(typ LinkType, provName string, abs bool, baseURL *url.URL, inputs ...string) string GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order) LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account) LinkAccount(ctx context.Context, o *acme.Account)
@ -31,39 +31,40 @@ type linker struct {
dns string dns string
} }
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// GetLink is a helper for GetLinkExplicit // GetLink is a helper for GetLinkExplicit
func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string { func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var provName string var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil { if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName() provName = p.GetName()
} }
return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...)
}
// GetLinkExplicit returns an absolute or partial path to the given resource and a base
// URL dynamically obtained from the request for which the link is being
// calculated.
func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
var u = url.URL{}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351 // Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil { if baseURL != nil {
u = *baseURL u = *baseURL
} }
switch typ { u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
}
if abs {
// If no Scheme is set, then default to https. // If no Scheme is set, then default to https.
if u.Scheme == "" { if u.Scheme == "" {
u.Scheme = "https" u.Scheme = "https"
@ -76,8 +77,6 @@ func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool,
u.Path = l.prefix + u.Path u.Path = l.prefix + u.Path
return u.String() return u.String()
}
return u.EscapedPath()
} }
// LinkType captures the link type. // LinkType captures the link type.
@ -149,22 +148,22 @@ func (l LinkType) String() string {
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs { for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID) o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
} }
o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID) o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, o.ID)
if o.CertificateID != "" { if o.CertificateID != "" {
o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID) o.CertificateURL = l.GetLink(ctx, CertificateLinkType, o.CertificateID)
} }
} }
// LinkAccount sets the ACME links required by an ACME account. // LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
} }
// LinkChallenge sets the ACME links required by an ACME challenge. // LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, true, azID, ch.ID) ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
} }
// LinkAuthorization sets the ACME links required by an ACME authorization. // LinkAuthorization sets the ACME links required by an ACME authorization.
@ -177,6 +176,6 @@ func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization)
// LinkOrdersByAccountID converts each order ID to an ACME link. // LinkOrdersByAccountID converts each order ID to an ACME link.
func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) { func (l *linker) LinkOrdersByAccountID(ctx context.Context, orders []string) {
for i, id := range orders { for i, id := range orders {
orders[i] = l.GetLink(ctx, OrderLinkType, true, id) orders[i] = l.GetLink(ctx, OrderLinkType, id)
} }
} }

View file

@ -10,6 +10,27 @@ import (
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
) )
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
dns := "ca.smallstep.com"
prefix := "acme"
linker := NewLinker(dns, prefix)
getPath := linker.GetUnescapedPathSuffix
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
assert.Equals(t, getPath(NewAccountLinkType, "{provisionerID}"), "/{provisionerID}/new-account")
assert.Equals(t, getPath(AccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}")
assert.Equals(t, getPath(KeyChangeLinkType, "{provisionerID}"), "/{provisionerID}/key-change")
assert.Equals(t, getPath(NewOrderLinkType, "{provisionerID}"), "/{provisionerID}/new-order")
assert.Equals(t, getPath(OrderLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}")
assert.Equals(t, getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), "/{provisionerID}/account/{accID}/orders")
assert.Equals(t, getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), "/{provisionerID}/order/{ordID}/finalize")
assert.Equals(t, getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), "/{provisionerID}/authz/{authzID}")
assert.Equals(t, getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), "/{provisionerID}/challenge/{authzID}/{chID}")
assert.Equals(t, getPath(CertificateLinkType, "{provisionerID}", "{certID}"), "/{provisionerID}/certificate/{certID}")
}
func TestLinker_GetLink(t *testing.T) { func TestLinker_GetLink(t *testing.T) {
dns := "ca.smallstep.com" dns := "ca.smallstep.com"
prefix := "acme" prefix := "acme"
@ -17,87 +38,47 @@ func TestLinker_GetLink(t *testing.T) {
id := "1234" id := "1234"
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, true), // No provisioner and no BaseURL from request
fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName)) assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName)) // Provisioner: yes, BaseURL: no
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
// No provisioner // Provisioner: no, BaseURL: yes
ctxNoProv := context.WithValue(context.Background(), baseURLContextKey, baseURL) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, true),
fmt.Sprintf("%s/acme//new-nonce", baseURL.String()))
assert.Equals(t, linker.GetLink(ctxNoProv, NewNonceLinkType, false), "//new-nonce")
// No baseURL assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
ctxNoBaseURL := context.WithValue(context.Background(), provisionerContextKey, prov) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, true),
fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provName))
assert.Equals(t, linker.GetLink(ctxNoBaseURL, NewNonceLinkType, false), fmt.Sprintf("/%s/new-nonce", provName))
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, true, id), assert.Equals(t, linker.GetLink(ctx, NewAccountLinkType), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
fmt.Sprintf("%s/acme/%s/order/1234", baseURL.String(), provName))
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName))
}
func TestLinker_GetLinkExplicit(t *testing.T) { assert.Equals(t, linker.GetLink(ctx, AccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
dns := "ca.smallstep.com"
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prefix := "acme"
linker := NewLinker(dns, prefix)
id := "1234"
prov := newProv() assert.Equals(t, linker.GetLink(ctx, NewOrderLinkType), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
provName := prov.GetName()
escProvName := url.PathEscape(provName)
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) assert.Equals(t, linker.GetLink(ctx, OrderLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, OrdersByAccountLinkType, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName))
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, FinalizeLinkType, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewAuthzLinkType), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, AuthzLinkType, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, DirectoryLinkType), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName))
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, RevokeCertLinkType, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, KeyChangeLinkType), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName))
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, ChallengeLinkType, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, CertificateLinkType, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName))
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName))
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName))
} }
func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) {

View file

@ -78,8 +78,7 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// directory index url. // directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP { func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index"))
DirectoryLinkType, true), "index"))
next(w, r) next(w, r)
} }
} }
@ -88,15 +87,23 @@ func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
// application/jose+json. // application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ct := r.Header.Get("Content-Type")
var expected []string var expected []string
if strings.Contains(r.URL.String(), h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { p, err := provisionerFromContext(r.Context())
if err != nil {
api.WriteError(w, err)
return
}
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")}
if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else { } else {
// By default every request should have content-type applictaion/jose+json. // By default every request should have content-type applictaion/jose+json.
expected = []string{"application/jose+json"} expected = []string{"application/jose+json"}
} }
ct := r.Header.Get("Content-Type")
for _, e := range expected { for _, e := range expected {
if ct == e { if ct == e {
next(w, r) next(w, r)
@ -314,7 +321,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "") kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, api.WriteError(w, acme.NewError(acme.ErrorMalformedType,

View file

@ -240,6 +240,18 @@ func TestHandler_verifyContentType(t *testing.T) {
url string url string
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/provisioner-not-set": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: url,
ctx: context.Background(),
contentType: "foo",
statusCode: 500,
err: acme.NewErrorISE("provisioner expected in request context"),
}
},
"fail/general-bad-content-type": func(t *testing.T) test { "fail/general-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{ h: Handler{

View file

@ -136,7 +136,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSONStatus(w, o, http.StatusCreated) api.JSONStatus(w, o, http.StatusCreated)
} }
@ -217,7 +217,7 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) api.JSON(w, o)
} }
@ -272,6 +272,6 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
h.linker.LinkOrder(ctx, o) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID)) w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID))
api.JSON(w, o) api.JSON(w, o)
} }

View file

@ -167,8 +167,8 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
acmeHandler.Route(r) acmeHandler.Route(r)
}) })
/*
// helpful routine for logging all routes // // helpful routine for logging all routes //
/*
walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error { walkFunc := func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error {
fmt.Printf("%s %s\n", method, route) fmt.Printf("%s %s\n", method, route)
return nil return nil