Few ACME fixes ...

- always URL escape linker output
- validateJWS should accept RSAPSS
- GetUpdateAccount -> GetOrUpdateAccount
This commit is contained in:
max furman 2021-04-12 19:06:07 -07:00
parent 2e0e62bc4c
commit 672e3f976e
9 changed files with 107 additions and 107 deletions

View file

@ -126,8 +126,8 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
api.JSONStatus(w, acc, httpStatus) api.JSONStatus(w, acc, httpStatus)
} }
// GetUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {

View file

@ -32,7 +32,7 @@ func newProv() acme.Provisioner {
// Initialize provisioners // Initialize provisioners
p := &provisioner.ACME{ p := &provisioner.ACME{
Type: "ACME", Type: "ACME",
Name: "test@acme-provisioner.com", Name: "test@acme-<test>provisioner.com",
} }
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil { if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err) fmt.Printf("%v", err)
@ -168,11 +168,6 @@ func TestUpdateAccountRequest_Validate(t *testing.T) {
} }
func TestHandler_GetOrdersByAccountID(t *testing.T) { func TestHandler_GetOrdersByAccountID(t *testing.T) {
oids := []string{"foo", "bar"}
oidURLs := []string{
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/foo",
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/bar",
}
accID := "account-id" accID := "account-id"
// Request with chi context // Request with chi context
@ -185,6 +180,12 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
oids := []string{"foo", "bar"}
oidURLs := []string{
fmt.Sprintf("%s/acme/%s/order/foo", baseURL.String(), provName),
fmt.Sprintf("%s/acme/%s/order/bar", baseURL.String(), provName),
}
type test struct { type test struct {
db acme.DB db acme.DB
ctx context.Context ctx context.Context
@ -287,7 +288,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
func TestHandler_NewAccount(t *testing.T) { func TestHandler_NewAccount(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
@ -424,7 +425,7 @@ func TestHandler_NewAccount(t *testing.T) {
Key: jwk, Key: jwk,
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
OrdersURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", OrdersURL: fmt.Sprintf("%s/acme/%s/account/accountID/orders", baseURL.String(), escProvName),
}, },
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
@ -486,14 +487,14 @@ func TestHandler_NewAccount(t *testing.T) {
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
provName, "accountID")}) escProvName, "accountID")})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
} }
} }
func TestHandler_GetUpdateAccount(t *testing.T) { func TestHandler_GetOrUpdateAccount(t *testing.T) {
accID := "accountID" accID := "accountID"
acc := acme.Account{ acc := acme.Account{
ID: accID, ID: accID,
@ -501,7 +502,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) {
OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
} }
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
@ -662,7 +663,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) {
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetUpdateAccount(w, req) h.GetOrUpdateAccount(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -686,7 +687,7 @@ func TestHandler_GetUpdateAccount(t *testing.T) {
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], assert.Equals(t, res.Header["Location"],
[]string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(), []string{fmt.Sprintf("%s/acme/%s/account/%s", baseURL.String(),
provName, accID)}) escProvName, accID)})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -102,7 +102,7 @@ func (h *Handler) Route(r api.Router) {
} }
r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented))
r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
@ -125,12 +125,11 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
// Directory represents an ACME directory for configuring clients. // Directory represents an ACME directory for configuring clients.
type Directory struct { type Directory struct {
NewNonce string `json:"newNonce,omitempty"` NewNonce string `json:"newNonce"`
NewAccount string `json:"newAccount,omitempty"` NewAccount string `json:"newAccount"`
NewOrder string `json:"newOrder,omitempty"` NewOrder string `json:"newOrder"`
NewAuthz string `json:"newAuthz,omitempty"` RevokeCert string `json:"revokeCert"`
RevokeCert string `json:"revokeCert,omitempty"` KeyChange string `json:"keyChange"`
KeyChange string `json:"keyChange,omitempty"`
} }
// ToLog enables response logging for the Directory type. // ToLog enables response logging for the Directory type.

View file

@ -44,27 +44,26 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...
// URL dynamically obtained from the request for which the link is being // URL dynamically obtained from the request for which the link is being
// calculated. // calculated.
func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string { func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
var link string var u = url.URL{}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
u = *baseURL
}
switch typ { switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType: case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
link = fmt.Sprintf("/%s/%s", provisionerName, typ) u.Path = fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType: case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0]) u.Path = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType: case ChallengeLinkType:
link = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1]) u.Path = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType: case OrdersByAccountLinkType:
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0]) u.Path = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType: case FinalizeLinkType:
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0]) u.Path = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
} }
if abs { if abs {
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
u := url.URL{}
if baseURL != nil {
u = *baseURL
}
// If no Scheme is set, then default to https. // If no Scheme is set, then default to https.
if u.Scheme == "" { if u.Scheme == "" {
u.Scheme = "https" u.Scheme = "https"
@ -75,10 +74,10 @@ func (l *linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool,
u.Host = l.dns u.Host = l.dns
} }
u.Path = l.prefix + link u.Path = l.prefix + u.Path
return u.String() return u.String()
} }
return link return u.EscapedPath()
} }
// LinkType captures the link type. // LinkType captures the link type.

View file

@ -51,52 +51,53 @@ func TestLinker_GetLinkExplicit(t *testing.T) {
id := "1234" id := "1234"
prov := newProv() prov := newProv()
provID := url.PathEscape(prov.GetName()) provName := prov.GetName()
escProvName := url.PathEscape(provName)
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, nil), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", provID)) assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{}), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", provID)) assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, &url.URL{Scheme: "http"}), fmt.Sprintf("%s/acme/%s/new-nonce", "http://ca.smallstep.com", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-nonce", provID)) assert.Equals(t, linker.GetLinkExplicit(NewNonceLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-nonce", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-account", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-account", provID)) assert.Equals(t, linker.GetLinkExplicit(NewAccountLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-account", escProvName))
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234", provID)) assert.Equals(t, linker.GetLinkExplicit(AccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-order", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-order", provID)) assert.Equals(t, linker.GetLinkExplicit(NewOrderLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-order", escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234", provID)) assert.Equals(t, linker.GetLinkExplicit(OrderLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/account/1234/orders", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", provID)) assert.Equals(t, linker.GetLinkExplicit(OrdersByAccountLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/account/1234/orders", escProvName))
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/order/1234/finalize", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", provID)) assert.Equals(t, linker.GetLinkExplicit(FinalizeLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/order/1234/finalize", escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/new-authz", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provID, false, baseURL), fmt.Sprintf("/%s/new-authz", provID)) assert.Equals(t, linker.GetLinkExplicit(NewAuthzLinkType, provName, false, baseURL), fmt.Sprintf("/%s/new-authz", escProvName))
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/authz/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", provID)) assert.Equals(t, linker.GetLinkExplicit(AuthzLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/authz/1234", escProvName))
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/directory", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provID, false, baseURL), fmt.Sprintf("/%s/directory", provID)) assert.Equals(t, linker.GetLinkExplicit(DirectoryLinkType, provName, false, baseURL), fmt.Sprintf("/%s/directory", escProvName))
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provID, false, baseURL), fmt.Sprintf("/%s/revoke-cert", provID)) assert.Equals(t, linker.GetLinkExplicit(RevokeCertLinkType, provName, false, baseURL), fmt.Sprintf("/%s/revoke-cert", escProvName))
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provName, false, baseURL), fmt.Sprintf("/%s/key-change", escProvName))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, escProvName, id, id))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provName, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", escProvName, id, id))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, escProvName))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provName, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", escProvName))
} }
func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) {

View file

@ -90,7 +90,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ct := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
var expected []string var expected []string
if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) { if strings.Contains(r.URL.String(), h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else { } else {
@ -170,7 +170,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
} }
hdr := sig.Protected hdr := sig.Protected
switch hdr.Algorithm { switch hdr.Algorithm {
case jose.RS256, jose.RS384, jose.RS512: case jose.RS256, jose.RS384, jose.RS512, jose.PS256, jose.PS384, jose.PS512:
if hdr.JSONWebKey != nil { if hdr.JSONWebKey != nil {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
@ -189,7 +189,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm)) api.WriteError(w, acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }

View file

@ -228,9 +228,9 @@ func TestHandler_addDirLink(t *testing.T) {
func TestHandler_verifyContentType(t *testing.T) { func TestHandler_verifyContentType(t *testing.T) {
prov := newProv() prov := newProv()
provName := prov.GetName() escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), provName) url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler h Handler
ctx context.Context ctx context.Context
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
h: Handler{ h: Handler{
linker: NewLinker("dns", "acme"), linker: NewLinker("dns", "acme"),
}, },
url: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), url: url,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
@ -1160,7 +1160,7 @@ func TestHandler_validateJWS(t *testing.T) {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: none"), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
} }
}, },
"fail/unsuitable-algorithm-mac": func(t *testing.T) test { "fail/unsuitable-algorithm-mac": func(t *testing.T) test {
@ -1172,7 +1172,7 @@ func TestHandler_validateJWS(t *testing.T) {
return test{ return test{
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", jose.HS256), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
} }
}, },
"fail/rsa-key-&-alg-mismatch": func(t *testing.T) test { "fail/rsa-key-&-alg-mismatch": func(t *testing.T) test {

View file

@ -149,6 +149,10 @@ func TestFinalizeRequestValidate(t *testing.T) {
} }
func TestHandler_GetOrder(t *testing.T) { func TestHandler_GetOrder(t *testing.T) {
prov := newProv()
escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
now := clock.Now() now := clock.Now()
nbf := now nbf := now
naf := now.Add(24 * time.Hour) naf := now.Add(24 * time.Hour)
@ -171,21 +175,18 @@ func TestHandler_GetOrder(t *testing.T) {
Status: acme.StatusInvalid, Status: acme.StatusInvalid,
Error: acme.NewError(acme.ErrorMalformedType, "order has expired"), Error: acme.NewError(acme.ErrorMalformedType, "order has expired"),
AuthorizationURLs: []string{ AuthorizationURLs: []string{
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName),
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName),
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName),
}, },
FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName),
} }
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/%s", url := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), provName, o.ID) baseURL.String(), escProvName, o.ID)
type test struct { type test struct {
db acme.DB db acme.DB
@ -285,7 +286,7 @@ func TestHandler_GetOrder(t *testing.T) {
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
return &acme.Order{ return &acme.Order{
AccountID: "accountID", AccountID: "accountID",
ProvisionerID: "acme/test@acme-provisioner.com", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
ExpiresAt: clock.Now().Add(-time.Hour), ExpiresAt: clock.Now().Add(-time.Hour),
Status: acme.StatusReady, Status: acme.StatusReady,
}, nil }, nil
@ -311,7 +312,7 @@ func TestHandler_GetOrder(t *testing.T) {
return &acme.Order{ return &acme.Order{
ID: "orderID", ID: "orderID",
AccountID: "accountID", AccountID: "accountID",
ProvisionerID: "acme/test@acme-provisioner.com", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
ExpiresAt: expiry, ExpiresAt: expiry,
Status: acme.StatusReady, Status: acme.StatusReady,
AuthorizationIDs: []string{"foo", "bar", "baz"}, AuthorizationIDs: []string{"foo", "bar", "baz"},
@ -581,10 +582,10 @@ func TestHandler_newAuthorization(t *testing.T) {
func TestHandler_NewOrder(t *testing.T) { func TestHandler_NewOrder(t *testing.T) {
// Request with chi context // Request with chi context
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/ordID", url := fmt.Sprintf("%s/acme/%s/order/ordID",
baseURL.String(), provName) baseURL.String(), escProvName)
type test struct { type test struct {
db acme.DB db acme.DB
@ -877,8 +878,8 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID", fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName),
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az2ID", fmt.Sprintf("%s/acme/%s/authz/az2ID", baseURL.String(), escProvName),
}) })
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
@ -968,7 +969,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
@ -1059,7 +1060,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
@ -1149,7 +1150,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
@ -1240,7 +1241,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers) assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/az1ID"}) assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf)) assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf)) assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf)) assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
@ -1291,6 +1292,10 @@ func TestHandler_NewOrder(t *testing.T) {
} }
func TestHandler_FinalizeOrder(t *testing.T) { func TestHandler_FinalizeOrder(t *testing.T) {
prov := newProv()
escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
now := clock.Now() now := clock.Now()
nbf := now nbf := now
naf := now.Add(24 * time.Hour) naf := now.Add(24 * time.Hour)
@ -1311,22 +1316,19 @@ func TestHandler_FinalizeOrder(t *testing.T) {
ExpiresAt: naf, ExpiresAt: naf,
Status: acme.StatusValid, Status: acme.StatusValid,
AuthorizationURLs: []string{ AuthorizationURLs: []string{
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/foo", fmt.Sprintf("%s/acme/%s/authz/foo", baseURL.String(), escProvName),
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/bar", fmt.Sprintf("%s/acme/%s/authz/bar", baseURL.String(), escProvName),
"https://test.ca.smallstep.com/acme/test@acme-provisioner.com/authz/baz", fmt.Sprintf("%s/acme/%s/authz/baz", baseURL.String(), escProvName),
}, },
FinalizeURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/order/orderID/finalize", FinalizeURL: fmt.Sprintf("%s/acme/%s/order/orderID/finalize", baseURL.String(), escProvName),
CertificateURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/certificate/certID", CertificateURL: fmt.Sprintf("%s/acme/%s/certificate/certID", baseURL.String(), escProvName),
} }
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/%s", url := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), provName, o.ID) baseURL.String(), escProvName, o.ID)
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1488,7 +1490,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
return &acme.Order{ return &acme.Order{
AccountID: "accountID", AccountID: "accountID",
ProvisionerID: "acme/test@acme-provisioner.com", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
ExpiresAt: clock.Now().Add(-time.Hour), ExpiresAt: clock.Now().Add(-time.Hour),
Status: acme.StatusReady, Status: acme.StatusReady,
}, nil }, nil
@ -1515,7 +1517,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
return &acme.Order{ return &acme.Order{
ID: "orderID", ID: "orderID",
AccountID: "accountID", AccountID: "accountID",
ProvisionerID: "acme/test@acme-provisioner.com", ProvisionerID: fmt.Sprintf("acme/%s", prov.GetName()),
ExpiresAt: naf, ExpiresAt: naf,
Status: acme.StatusValid, Status: acme.StatusValid,
AuthorizationIDs: []string{"foo", "bar", "baz"}, AuthorizationIDs: []string{"foo", "bar", "baz"},

View file

@ -35,7 +35,6 @@ func TestNewACMEClient(t *testing.T) {
NewNonce: srv.URL + "/foo", NewNonce: srv.URL + "/foo",
NewAccount: srv.URL + "/bar", NewAccount: srv.URL + "/bar",
NewOrder: srv.URL + "/baz", NewOrder: srv.URL + "/baz",
NewAuthz: srv.URL + "/zap",
RevokeCert: srv.URL + "/zip", RevokeCert: srv.URL + "/zip",
KeyChange: srv.URL + "/blorp", KeyChange: srv.URL + "/blorp",
} }
@ -146,7 +145,6 @@ func TestACMEClient_GetDirectory(t *testing.T) {
NewNonce: "/foo", NewNonce: "/foo",
NewAccount: "/bar", NewAccount: "/bar",
NewOrder: "/baz", NewOrder: "/baz",
NewAuthz: "/zap",
RevokeCert: "/zip", RevokeCert: "/zip",
KeyChange: "/blorp", KeyChange: "/blorp",
}, },