[acme db interface] wip

This commit is contained in:
max furman 2021-03-04 23:10:46 -08:00
parent 491c188a5e
commit 80a6640103
21 changed files with 787 additions and 975 deletions

View file

@ -22,7 +22,7 @@ type Account struct {
func (a *Account) ToLog() (interface{}, error) { func (a *Account) ToLog() (interface{}, error) {
b, err := json.Marshal(a) b, err := json.Marshal(a)
if err != nil { if err != nil {
return nil, ErrorWrap(ErrorServerInternalType, err, "error marshaling account for logging") return nil, WrapErrorISE(err, "error marshaling account for logging")
} }
return string(b), nil return string(b), nil
} }
@ -46,7 +46,7 @@ func (a *Account) IsValid() bool {
func KeyToID(jwk *jose.JSONWebKey) (string, error) { func KeyToID(jwk *jose.JSONWebKey) (string, error) {
kid, err := jwk.Thumbprint(crypto.SHA256) kid, err := jwk.Thumbprint(crypto.SHA256)
if err != nil { if err != nil {
return "", ErrorWrap(ErrorServerInternalType, err, "error generating jwk thumbprint") return "", WrapErrorISE(err, "error generating jwk thumbprint")
} }
return base64.RawURLEncoding.EncodeToString(kid), nil return base64.RawURLEncoding.EncodeToString(kid), nil
} }

View file

@ -4,8 +4,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
@ -38,13 +36,7 @@ func (n *NewAccountRequest) Validate() error {
// UpdateAccountRequest represents an update-account request. // UpdateAccountRequest represents an update-account request.
type UpdateAccountRequest struct { type UpdateAccountRequest struct {
Contact []string `json:"contact"` Contact []string `json:"contact"`
Status string `json:"status"` Status acme.Status `json:"status"`
}
// IsDeactivateRequest returns true if the update request is a deactivation
// request, false otherwise.
func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
return u.Status == string(acme.StatusDeactivated)
} }
// Validate validates a update-account request body. // Validate validates a update-account request body.
@ -59,7 +51,7 @@ func (u *UpdateAccountRequest) Validate() error {
} }
return nil return nil
case len(u.Status) > 0: case len(u.Status) > 0:
if u.Status != string(acme.StatusDeactivated) { if u.Status != acme.StatusDeactivated {
return acme.NewError(acme.ErrorMalformedType, "cannot update account "+ return acme.NewError(acme.ErrorMalformedType, "cannot update account "+
"status to %s, only deactivated", u.Status) "status to %s, only deactivated", u.Status)
} }
@ -80,7 +72,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
} }
var nar NewAccountRequest var nar NewAccountRequest
if err := json.Unmarshal(payload.value, &nar); err != nil { if err := json.Unmarshal(payload.value, &nar); err != nil {
api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
@ -90,7 +82,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
} }
httpStatus := http.StatusCreated httpStatus := http.StatusCreated
acc, err := acme.AccountFromContext(r.Context()) acc, err := accountFromContext(r.Context())
if err != nil { if err != nil {
acmeErr, ok := err.(*acme.Error) acmeErr, ok := err.(*acme.Error)
if !ok || acmeErr.Status != http.StatusBadRequest { if !ok || acmeErr.Status != http.StatusBadRequest {
@ -105,18 +97,19 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
"account does not exist")) "account does not exist"))
return return
} }
jwk, err := acme.JwkFromContext(r.Context()) jwk, err := jwkFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if acc, err = h.Auth.NewAccount(r.Context(), &acme.Account{ acc := &acme.Account{
Key: jwk, Key: jwk,
Contact: nar.Contact, Contact: nar.Contact,
Status: acme.StatusValid, Status: acme.StatusValid,
}); err != nil { }
api.WriteError(w, err) if err := h.db.CreateAccount(r.Context(), acc); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
} else { } else {
@ -124,14 +117,16 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK httpStatus = http.StatusOK
} }
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink, h.linker.LinkAccount(ctx, acc)
true, acc.GetID()))
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType,
true, acc.ID))
api.JSONStatus(w, acc, httpStatus) api.JSONStatus(w, acc, httpStatus)
} }
// GetUpdateAccount is the api for updating an ACME account. // GetUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc, err := acme.AccountFromContext(r.Context()) acc, err := accountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -147,7 +142,7 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
if !payload.isPostAsGet { if !payload.isPostAsGet {
var uar UpdateAccountRequest var uar UpdateAccountRequest
if err := json.Unmarshal(payload.value, &uar); err != nil { if err := json.Unmarshal(payload.value, &uar); err != nil {
api.WriteError(w, acme.ErrorWrap(acme.ErrorMalformedType, err, api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-account request payload")) "failed to unmarshal new-account request payload"))
return return
} }
@ -159,18 +154,18 @@ func (h *Handler) GetUpdateAccount(w http.ResponseWriter, r *http.Request) {
// If neither the status nor the contacts are being updated then ignore // If neither the status nor the contacts are being updated then ignore
// the updates and return 200. This conforms with the behavior detailed // the updates and return 200. This conforms with the behavior detailed
// in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2). // in the ACME spec (https://tools.ietf.org/html/rfc8555#section-7.3.2).
if uar.IsDeactivateRequest() { acc.Status = uar.Status
acc, err = h.Auth.DeactivateAccount(r.Context(), acc.GetID()) acc.Contact = uar.Contact
} else if len(uar.Contact) > 0 { if err = h.db.UpdateAccount(r.Context(), acc); err != nil {
acc, err = h.Auth.UpdateAccount(r.Context(), acc.GetID(), uar.Contact) api.WriteError(w, acme.WrapErrorISE(err, "error updating account"))
}
if err != nil {
api.WriteError(w, err)
return return
} }
} }
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AccountLink,
true, acc.GetID())) h.linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType,
true, acc.ID))
api.JSON(w, acc) api.JSON(w, acc)
} }
@ -185,6 +180,7 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
// GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account. // GetOrdersByAccount ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
/*
acc, err := acme.AccountFromContext(r.Context()) acc, err := acme.AccountFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
@ -192,7 +188,7 @@ func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
} }
accID := chi.URLParam(r, "accID") accID := chi.URLParam(r, "accID")
if acc.ID != accID { if acc.ID != accID {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account ID does not match url param"))) api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID does not match url param"))
return return
} }
orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID()) orders, err := h.Auth.GetOrdersByAccount(r.Context(), acc.GetID())
@ -202,4 +198,6 @@ func (h *Handler) GetOrdersByAccount(w http.ResponseWriter, r *http.Request) {
} }
api.JSON(w, orders) api.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)
*/
return
} }

View file

@ -1,56 +1,82 @@
package api package api
import ( import (
"context" "crypto/tls"
"crypto/x509" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"net"
"net/http" "net/http"
"time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
) )
func link(url, typ string) string { func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ)
} }
// Clock that returns time in UTC rounded to seconds.
type Clock int
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Round(time.Second)
}
var clock = new(Clock)
type payloadInfo struct { type payloadInfo struct {
value []byte value []byte
isPostAsGet bool isPostAsGet bool
isEmptyJSON bool isEmptyJSON bool
} }
// payloadFromContext searches the context for a payload. Returns the payload // Handler is the ACME API request handler.
// or an error.
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
val, ok := ctx.Value(acme.PayloadContextKey).(*payloadInfo)
if !ok || val == nil {
return nil, acme.ServerInternalErr(errors.Errorf("payload expected in request context"))
}
return val, nil
}
// New returns a new ACME API router.
func New(acmeAuth acme.Interface) api.RouterHandler {
return &Handler{acmeAuth}
}
// Handler is the ACME request handler.
type Handler struct { type Handler struct {
Auth acme.Interface db acme.DB
backdate provisioner.Duration
ca acme.CertificateAuthority
linker *Linker
}
// HandlerOptions required to create a new ACME API request handler.
type HandlerOptions struct {
Backdate provisioner.Duration
// DB storage backend that impements the acme.DB interface.
DB acme.DB
// DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if
// request.Host is empty.
DNS string
// Prefix is a URL path prefix under which the ACME api is served. This
// prefix is required to generate accurate ACME links.
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
// "acme" is the prefix from which the ACME api is accessed.
Prefix string
CA acme.CertificateAuthority
}
// NewHandler returns a new ACME API handler.
func NewHandler(ops HandlerOptions) api.RouterHandler {
return &Handler{
ca: ops.CA,
db: ops.DB,
backdate: ops.Backdate,
linker: NewLinker(ops.DNS, ops.Prefix),
}
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
getLink := h.Auth.GetLinkExplicit getLink := h.linker.GetLinkExplicit
// Standard ACME API // Standard ACME API
r.MethodFunc("GET", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) r.MethodFunc("GET", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
r.MethodFunc("HEAD", getLink(acme.NewNonceLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce)))) r.MethodFunc("HEAD", getLink(NewNonceLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetNonce))))
r.MethodFunc("GET", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("GET", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
r.MethodFunc("HEAD", getLink(acme.DirectoryLink, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory)))) r.MethodFunc("HEAD", getLink(DirectoryLinkType, "{provisionerID}", false, nil), h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.GetDirectory))))
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next)))))))))
@ -59,16 +85,16 @@ func (h *Handler) Route(r api.Router) {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))))))))
} }
r.MethodFunc("POST", getLink(acme.NewAccountLink, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount)) r.MethodFunc("POST", getLink(NewAccountLinkType, "{provisionerID}", false, nil), extractPayloadByJWK(h.NewAccount))
r.MethodFunc("POST", getLink(acme.AccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount)) r.MethodFunc("POST", getLink(AccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.GetUpdateAccount))
r.MethodFunc("POST", getLink(acme.KeyChangeLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented)) r.MethodFunc("POST", getLink(KeyChangeLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.NotImplemented))
r.MethodFunc("POST", getLink(acme.NewOrderLink, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder)) r.MethodFunc("POST", getLink(NewOrderLinkType, "{provisionerID}", false, nil), extractPayloadByKid(h.NewOrder))
r.MethodFunc("POST", getLink(acme.OrderLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) r.MethodFunc("POST", getLink(OrderLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder)))
r.MethodFunc("POST", getLink(acme.OrdersByAccountLink, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount))) r.MethodFunc("POST", getLink(OrdersByAccountLinkType, "{provisionerID}", false, nil, "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccount)))
r.MethodFunc("POST", getLink(acme.FinalizeLink, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) r.MethodFunc("POST", getLink(FinalizeLinkType, "{provisionerID}", false, nil, "{ordID}"), extractPayloadByKid(h.FinalizeOrder))
r.MethodFunc("POST", getLink(acme.AuthzLink, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz))) r.MethodFunc("POST", getLink(AuthzLinkType, "{provisionerID}", false, nil, "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthz)))
r.MethodFunc("POST", getLink(acme.ChallengeLink, "{provisionerID}", false, nil, "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getLink(ChallengeLinkType, "{provisionerID}", false, nil, "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getLink(acme.CertificateLink, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("POST", getLink(CertificateLinkType, "{provisionerID}", false, nil, "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response
@ -81,101 +107,165 @@ func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) {
} }
} }
// Directory represents an ACME directory for configuring clients.
type Directory struct {
NewNonce string `json:"newNonce,omitempty"`
NewAccount string `json:"newAccount,omitempty"`
NewOrder string `json:"newOrder,omitempty"`
NewAuthz string `json:"newAuthz,omitempty"`
RevokeCert string `json:"revokeCert,omitempty"`
KeyChange string `json:"keyChange,omitempty"`
}
// ToLog enables response logging for the Directory type.
func (d *Directory) ToLog() (interface{}, error) {
b, err := json.Marshal(d)
if err != nil {
return nil, acme.WrapErrorISE(err, "error marshaling directory for logging")
}
return string(b), nil
}
type directory struct {
prefix, dns string
}
// GetDirectory is the ACME resource for returning a directory configuration // GetDirectory is the ACME resource for returning a directory configuration
// for client configuration. // for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
dir, err := h.Auth.GetDirectory(r.Context()) ctx := r.Context()
if err != nil { api.JSON(w, &Directory{
api.WriteError(w, err) NewNonce: h.linker.GetLink(ctx, NewNonceLinkType, true),
return NewAccount: h.linker.GetLink(ctx, NewAccountLinkType, true),
} NewOrder: h.linker.GetLink(ctx, NewOrderLinkType, true),
api.JSON(w, dir) RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType, true),
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType, true),
})
} }
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, acme.NotImplemented(nil).ToACME()) api.WriteError(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthz ACME api for retrieving an Authz. // GetAuthz ACME api for retrieving an Authz.
func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAuthz(w http.ResponseWriter, r *http.Request) {
acc, err := acme.AccountFromContext(r.Context()) ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
authz, err := h.Auth.GetAuthz(r.Context(), acc.GetID(), chi.URLParam(r, "authzID")) az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
} }
if acc.ID != az.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own authorization '%s'", acc.ID, az.ID))
return
}
if err = az.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating authorization status"))
}
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.AuthzLink, true, authz.GetID())) h.linker.LinkAuthorization(ctx, az)
api.JSON(w, authz)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, true, az.ID))
api.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
acc, err := acme.AccountFromContext(r.Context()) ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
// Just verify that the payload was set, since we're not strictly adhering // Just verify that the payload was set, since we're not strictly adhering
// to ACME V2 spec for reasons specified below. // to ACME V2 spec for reasons specified below.
_, err = payloadFromContext(r.Context()) _, err = payloadFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
// NOTE: We should be checking that the request is either a POST-as-GET, or // NOTE: We should be checking ^^^ that the request is either a POST-as-GET, or
// that the payload is an empty JSON block ({}). However, older ACME clients // that the payload is an empty JSON block ({}). However, older ACME clients
// still send a vestigial body (rather than an empty JSON block) and // still send a vestigial body (rather than an empty JSON block) and
// strict enforcement would render these clients broken. For the time being // strict enforcement would render these clients broken. For the time being
// we'll just ignore the body. // we'll just ignore the body.
var (
ch *acme.Challenge ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), chi.URLParam(r, "authzID"))
chID = chi.URLParam(r, "chID") if err != nil {
) api.WriteError(w, acme.WrapErrorISE(err, "error retrieving challenge"))
ch, err = h.Auth.ValidateChallenge(r.Context(), acc.GetID(), chID, acc.GetKey()) return
}
if acc.ID != ch.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own challenge '%s'", acc.ID, ch.ID))
return
}
client := http.Client{
Timeout: time.Duration(30 * time.Second),
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, acme.ValidateOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
}); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error validating challenge"))
return
}
w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), acme.AuthzLink, true, ch.GetAuthzID()), "up")) h.linker.LinkChallenge(ctx, ch)
w.Header().Set("Location", h.Auth.GetLink(r.Context(), acme.ChallengeLink, true, ch.GetID()))
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, true, ch.AuthzID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID))
api.JSON(w, ch) api.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
acc, err := acme.AccountFromContext(r.Context()) ctx := r.Context()
acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
certID := chi.URLParam(r, "certID") certID := chi.URLParam(r, "certID")
certBytes, err := h.Auth.GetCertificate(acc.GetID(), certID)
cert, err := h.db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return
}
if cert.AccountID != acc.ID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own certificate '%s'", acc.ID, certID))
return return
} }
block, _ := pem.Decode(certBytes) certBytes, err := cert.ToACME()
if block == nil {
api.WriteError(w, acme.ServerInternalErr(errors.New("failed to decode any certificates from generated certBytes")))
return
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
api.WriteError(w, acme.Wrap(err, "failed to parse generated leaf certificate")) api.WriteError(w, acme.WrapErrorISE(err, "error converting cert to ACME representation"))
return return
} }
api.LogCertificate(w, cert) api.LogCertificate(w, cert.Leaf)
w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8") w.Header().Set("Content-Type", "application/pem-certificate-chain; charset=utf-8")
w.Write(certBytes) w.Write(certBytes)
} }

164
acme/api/linker.go Normal file
View file

@ -0,0 +1,164 @@
package api
import (
"context"
"fmt"
"net/url"
"github.com/smallstep/certificates/acme"
)
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) *Linker {
return &Linker{Prefix: prefix, DNS: dns}
}
// Linker generates ACME links.
type Linker struct {
Prefix string
DNS string
}
// GetLink is a helper for GetLinkExplicit
func (l *Linker) GetLink(ctx context.Context, typ LinkType, abs bool, inputs ...string) string {
var provName string
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
return l.GetLinkExplicit(typ, provName, abs, baseURLFromContext(ctx), inputs...)
}
// GetLinkExplicit returns an absolute or partial path to the given resource and a base
// URL dynamically obtained from the request for which the link is being
// calculated.
func (l *Linker) GetLinkExplicit(typ LinkType, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
var link string
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
link = fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
link = fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
}
if abs {
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
u := url.URL{}
if baseURL != nil {
u = *baseURL
}
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.DNS
}
u.Path = l.Prefix + link
return u.String()
}
return link
}
// LinkType captures the link type.
type LinkType int
const (
// NewNonceLinkType new-nonce
NewNonceLinkType LinkType = iota
// NewAccountLinkType new-account
NewAccountLinkType
// AccountLinkType account
AccountLinkType
// OrderLinkType order
OrderLinkType
// NewOrderLinkType new-order
NewOrderLinkType
// OrdersByAccountLinkType list of orders owned by account
OrdersByAccountLinkType
// FinalizeLinkType finalize order
FinalizeLinkType
// NewAuthzLinkType authz
NewAuthzLinkType
// AuthzLinkType new-authz
AuthzLinkType
// ChallengeLinkType challenge
ChallengeLinkType
// CertificateLinkType certificate
CertificateLinkType
// DirectoryLinkType directory
DirectoryLinkType
// RevokeCertLinkType revoke certificate
RevokeCertLinkType
// KeyChangeLinkType key rollover
KeyChangeLinkType
)
func (l LinkType) String() string {
switch l {
case NewNonceLinkType:
return "new-nonce"
case NewAccountLinkType:
return "new-account"
case AccountLinkType:
return "account"
case NewOrderLinkType:
return "new-order"
case OrderLinkType:
return "order"
case NewAuthzLinkType:
return "new-authz"
case AuthzLinkType:
return "authz"
case ChallengeLinkType:
return "challenge"
case CertificateLinkType:
return "certificate"
case DirectoryLinkType:
return "directory"
case RevokeCertLinkType:
return "revoke-cert"
case KeyChangeLinkType:
return "key-change"
default:
return fmt.Sprintf("unexpected LinkType '%d'", int(l))
}
}
// LinkOrder sets the ACME links required by an ACME order.
func (l *Linker) LinkOrder(ctx context.Context, o *acme.Order) {
o.azURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AutohrizationIDs {
o.azURLs[i] = l.GetLink(ctx, AuthzLinkType, true, azID)
}
o.FinalizeURL = l.GetLink(ctx, FinalizeLinkType, true, o.ID)
if o.CertificateID != "" {
o.CertificateURL = l.GetLink(ctx, CertificateLinkType, true, o.CertificateID)
}
}
// LinkAccount sets the ACME links required by an ACME account.
func (l *Linker) LinkAccount(ctx context.Context, acc *acme.Account) {
a.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID)
}
// LinkChallenge sets the ACME links required by an ACME account.
func (l *Linker) LinkChallenge(ctx context.Context, ch *acme.Challenge) {
a.URL = l.GetLink(ctx, ChallengeLinkType, true, ch.AuthzID, ch.ID)
}
// LinkAuthorization sets the ACME links required by an ACME account.
func (l *Linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) {
for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch)
}
}

View file

@ -9,7 +9,6 @@ import (
"strings" "strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -54,7 +53,7 @@ func baseURLFromRequest(r *http.Request) *url.URL {
// E.g. https://ca.smallstep.com/ // E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP { func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), acme.BaseURLContextKey, baseURLFromRequest(r)) ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
@ -62,14 +61,14 @@ func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
// addNonce is a middleware that adds a nonce to the response header. // addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP { func (h *Handler) addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.Auth.NewNonce() nonce, err := h.db.CreateNonce(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
w.Header().Set("Replay-Nonce", nonce) w.Header().Set("Replay-Nonce", string(nonce))
w.Header().Set("Cache-Control", "no-store") w.Header().Set("Cache-Control", "no-store")
logNonce(w, nonce) logNonce(w, string(nonce))
next(w, r) next(w, r)
} }
} }
@ -78,8 +77,8 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// directory index url. // directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP { func (h *Handler) addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.Auth.GetLink(r.Context(), w.Header().Add("Link", link(h.linker.GetLink(r.Context(),
acme.DirectoryLink, true), "index")) DirectoryLinkType, true), "index"))
next(w, r) next(w, r)
} }
} }
@ -90,7 +89,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ct := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
var expected []string var expected []string
if strings.Contains(r.URL.Path, h.Auth.GetLink(r.Context(), acme.CertificateLink, false, "")) { if strings.Contains(r.URL.Path, h.linker.GetLink(r.Context(), CertificateLinkType, false, "")) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
} else { } else {
@ -103,8 +102,8 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
return return
} }
} }
api.WriteError(w, acme.MalformedErr(errors.Errorf( api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
"expected content-type to be in %s, but got %s", expected, ct))) "expected content-type to be in %s, but got %s", expected, ct))
} }
} }
@ -113,15 +112,15 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.ServerInternalErr(errors.Wrap(err, "failed to read request body"))) api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
} }
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
if err != nil { if err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to parse JWS from request body"))) api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "failed to parse JWS from request body"))
return return
} }
ctx := context.WithValue(r.Context(), acme.JwsContextKey, jws) ctx := context.WithValue(r.Context(), jwsContextKey, jws)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
@ -143,17 +142,18 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste> // * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
jws, err := acme.JwsFromContext(r.Context()) ctx := r.Context()
jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if len(jws.Signatures) == 0 { if len(jws.Signatures) == 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body does not contain a signature"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"))
return return
} }
if len(jws.Signatures) > 1 { if len(jws.Signatures) > 1 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("request body contains more than one signature"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"))
return return
} }
@ -164,7 +164,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
len(uh.Algorithm) > 0 || len(uh.Algorithm) > 0 ||
len(uh.Nonce) > 0 || len(uh.Nonce) > 0 ||
len(uh.ExtraHeaders) > 0 { len(uh.ExtraHeaders) > 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("unprotected header must not be used"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"))
return return
} }
hdr := sig.Protected hdr := sig.Protected
@ -174,25 +174,26 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
switch k := hdr.JSONWebKey.Key.(type) { switch k := hdr.JSONWebKey.Key.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < keyutil.MinRSAKeyBytes { if k.Size() < keyutil.MinRSAKeyBytes {
api.WriteError(w, acme.MalformedErr(errors.Errorf("rsa "+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
"keys must be at least %d bits (%d bytes) in size", "rsa keys must be at least %d bits (%d bytes) in size",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes))
return return
} }
default: default:
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws key type and algorithm do not match"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
"jws key type and algorithm do not match"))
return return
} }
} }
case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA: case jose.ES256, jose.ES384, jose.ES512, jose.EdDSA:
// we good // we good
default: default:
api.WriteError(w, acme.MalformedErr(errors.Errorf("unsuitable algorithm: %s", hdr.Algorithm))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "unsuitable algorithm: %s", hdr.Algorithm))
return return
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.Auth.UseNonce(hdr.Nonce); err != nil { if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
@ -200,21 +201,22 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// Check that the JWS url matches the requested url. // Check that the JWS url matches the requested url.
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
if !ok { if !ok {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jws missing url protected header"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jws missing url protected header"))
return return
} }
reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path} reqURL := &url.URL{Scheme: "https", Host: r.Host, Path: r.URL.Path}
if jwsURL != reqURL.String() { if jwsURL != reqURL.String() {
api.WriteError(w, acme.MalformedErr(errors.Errorf("url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
"url header in JWS (%s) does not match request url (%s)", jwsURL, reqURL))
return return
} }
if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 { if hdr.JSONWebKey != nil && len(hdr.KeyID) > 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk and kid are mutually exclusive"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
api.WriteError(w, acme.MalformedErr(errors.Errorf("either jwk or kid must be defined in jws protected header"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
next(w, r) next(w, r)
@ -227,22 +229,27 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
func (h *Handler) extractJWK(next nextHTTP) nextHTTP { func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := acme.JwsFromContext(r.Context()) jws, err := jwsFromContext(r.Context())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
jwk := jws.Signatures[0].Protected.JSONWebKey jwk := jws.Signatures[0].Protected.JSONWebKey
if jwk == nil { if jwk == nil {
api.WriteError(w, acme.MalformedErr(errors.Errorf("jwk expected in protected header"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"))
return return
} }
if !jwk.Valid() { if !jwk.Valid() {
api.WriteError(w, acme.MalformedErr(errors.Errorf("invalid jwk in protected header"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"))
return return
} }
ctx = context.WithValue(ctx, acme.JwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
acc, err := h.Auth.GetAccountByKey(ctx, jwk) kid, err := acme.KeyToID(jwk)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error getting KeyID from JWK"))
return
}
acc, err := h.db.GetAccountByKeyID(ctx, kid)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
// For NewAccount requests ... // For NewAccount requests ...
@ -252,10 +259,10 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
} }
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
@ -270,20 +277,20 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
name := chi.URLParam(r, "provisionerID") name := chi.URLParam(r, "provisionerID")
provID, err := url.PathUnescape(name) provID, err := url.PathUnescape(name)
if err != nil { if err != nil {
api.WriteError(w, acme.ServerInternalErr(errors.Wrapf(err, "error url unescaping provisioner id '%s'", name))) api.WriteError(w, acme.WrapErrorISE(err, "error url unescaping provisioner id '%s'", name))
return return
} }
p, err := h.Auth.LoadProvisionerByID("acme/" + provID) p, err := h.ca.LoadProvisionerByID("acme/" + provID)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
acmeProv, ok := p.(*provisioner.ACME) acmeProv, ok := p.(*provisioner.ACME)
if !ok { if !ok {
api.WriteError(w, acme.AccountDoesNotExistErr(errors.New("provisioner must be of type ACME"))) api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return return
} }
ctx = context.WithValue(ctx, acme.ProvisionerContextKey, acme.Provisioner(acmeProv)) ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
@ -294,36 +301,37 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := acme.JwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
kidPrefix := h.Auth.GetLink(ctx, acme.AccountLink, true, "") kidPrefix := h.linker.GetLink(ctx, AccountLinkType, true, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
api.WriteError(w, acme.MalformedErr(errors.Errorf("kid does not have "+ api.WriteError(w, acme.NewError(acme.ErrorMalformedType,
"required prefix; expected %s, but got %s", kidPrefix, kid))) "kid does not have required prefix; expected %s, but got %s",
kidPrefix, kid))
return return
} }
accID := strings.TrimPrefix(kid, kidPrefix) accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.Auth.GetAccount(r.Context(), accID) acc, err := h.db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
api.WriteError(w, acme.AccountDoesNotExistErr(nil)) api.WriteError(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
return return
case err != nil: case err != nil:
api.WriteError(w, err) api.WriteError(w, err)
return return
default: default:
if !acc.IsValid() { if !acc.IsValid() {
api.WriteError(w, acme.UnauthorizedErr(errors.New("account is not active"))) api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
return return
} }
ctx = context.WithValue(ctx, acme.AccContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, acme.JwkContextKey, acc.Key) ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
return return
} }
@ -334,26 +342,27 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
jws, err := acme.JwsFromContext(r.Context()) ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
jwk, err := acme.JwkFromContext(r.Context()) jwk, err := jwkFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.MalformedErr(errors.New("verifier and signature algorithm do not match"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }
payload, err := jws.Verify(jwk) payload, err := jws.Verify(jwk)
if err != nil { if err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "error verifying jws"))) api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error verifying jws"))
return return
} }
ctx := context.WithValue(r.Context(), acme.PayloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
value: payload, value: payload,
isPostAsGet: string(payload) == "", isPostAsGet: string(payload) == "",
isEmptyJSON: string(payload) == "{}", isEmptyJSON: string(payload) == "{}",
@ -371,9 +380,89 @@ func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP {
return return
} }
if !payload.isPostAsGet { if !payload.isPostAsGet {
api.WriteError(w, acme.MalformedErr(errors.Errorf("expected POST-as-GET"))) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "expected POST-as-GET"))
return return
} }
next(w, r) next(w, r)
} }
} }
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string
const (
// accContextKey account key
accContextKey = ContextKey("acc")
// baseURLContextKey baseURL key
baseURLContextKey = ContextKey("baseURL")
// jwsContextKey jws key
jwsContextKey = ContextKey("jws")
// jwkContextKey jwk key
jwkContextKey = ContextKey("jwk")
// payloadContextKey payload key
payloadContextKey = ContextKey("payload")
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
)
// accountFromContext searches the context for an ACME account. Returns the
// account or an error.
func accountFromContext(ctx context.Context) (*acme.Account, error) {
val, ok := ctx.Value(accContextKey).(*acme.Account)
if !ok || val == nil {
return nil, acme.NewErrorISE("account not in context")
}
return val, nil
}
// baseURLFromContext returns the baseURL if one is stored in the context.
func baseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// jwkFromContext searches the context for a JWK. Returns the JWK or an error.
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
if !ok || val == nil {
return nil, acme.NewErrorISE("jwk expected in request context")
}
return val, nil
}
// jwsFromContext searches the context for a JWS. Returns the JWS or an error.
func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
val, ok := ctx.Value(jwsContextKey).(*jose.JSONWebSignature)
if !ok || val == nil {
return nil, acme.NewErrorISE("jws expected in request context")
}
return val, nil
}
// provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey)
if val == nil {
return nil, acme.NewErrorISE("provisioner expected in request context")
}
pval, ok := val.(acme.Provisioner)
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
}
// payloadFromContext searches the context for a payload. Returns the payload
// or an error.
func payloadFromContext(ctx context.Context) (*payloadInfo, error) {
val, ok := ctx.Value(payloadContextKey).(*payloadInfo)
if !ok || val == nil {
return nil, acme.NewErrorISE("payload expected in request context")
}
return val, nil
}

View file

@ -1,16 +1,18 @@
package api package api
import ( import (
"context"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/crypto/randutil"
) )
// NewOrderRequest represents the body for a NewOrder request. // NewOrderRequest represents the body for a NewOrder request.
@ -23,11 +25,11 @@ type NewOrderRequest struct {
// Validate validates a new-order request body. // Validate validates a new-order request body.
func (n *NewOrderRequest) Validate() error { func (n *NewOrderRequest) Validate() error {
if len(n.Identifiers) == 0 { if len(n.Identifiers) == 0 {
return acme.NewError(ErrorMalformedType, "identifiers list cannot be empty") return acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty")
} }
for _, id := range n.Identifiers { for _, id := range n.Identifiers {
if id.Type != "dns" { if id.Type != "dns" {
return acme.NewError(ErrorMalformedType, "identifier type unsupported: %s", id.Type) return acme.NewError(acme.ErrorMalformedType, "identifier type unsupported: %s", id.Type)
} }
} }
return nil return nil
@ -44,22 +46,29 @@ func (f *FinalizeRequest) Validate() error {
var err error var err error
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR)
if err != nil { if err != nil {
return acme.MalformedErr(errors.Wrap(err, "error base64url decoding csr")) return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
} }
f.csr, err = x509.ParseCertificateRequest(csrBytes) f.csr, err = x509.ParseCertificateRequest(csrBytes)
if err != nil { if err != nil {
return acme.MalformedErr(errors.Wrap(err, "unable to parse csr")) return acme.WrapError(acme.ErrorMalformedType, err, "unable to parse csr")
} }
if err = f.csr.CheckSignature(); err != nil { if err = f.csr.CheckSignature(); err != nil {
return acme.MalformedErr(errors.Wrap(err, "csr failed signature check")) return acme.WrapError(acme.ErrorMalformedType, err, "csr failed signature check")
} }
return nil return nil
} }
var defaultOrderExpiry = time.Hour * 24
// NewOrder ACME api for creating a new order. // NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := acme.AccountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -71,8 +80,8 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal new-order request payload"))) "failed to unmarshal new-order request payload"))
return return
} }
if err := nor.Validate(); err != nil { if err := nor.Validate(); err != nil {
@ -80,44 +89,133 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
o, err := h.Auth.NewOrder(ctx, acme.OrderOptions{ // New order.
AccountID: acc.GetID(), o := &acme.Order{Identifiers: nor.Identifiers}
Identifiers: nor.Identifiers,
NotBefore: nor.NotBefore, o.AuthorizationIDs = make([]string, len(o.Identifiers))
NotAfter: nor.NotAfter, for i, identifier := range o.Identifiers {
}) az := &acme.Authorization{
if err != nil { AccountID: acc.ID,
Identifier: identifier,
}
if err := h.newAuthorization(ctx, az); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
o.AuthorizationIDs[i] = az.ID
}
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) now := clock.Now()
if o.NotBefore.IsZero() {
o.NotBefore = now
}
if o.NotAfter.IsZero() {
o.NotAfter = o.NotBefore.Add(prov.DefaultTLSCertDuration())
}
o.Expires = now.Add(defaultOrderExpiry)
if err := h.db.CreateOrder(ctx, o); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error creating order"))
return
}
h.linker.Link(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID))
api.JSONStatus(w, o, http.StatusCreated) api.JSONStatus(w, o, http.StatusCreated)
} }
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error {
if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true
az.Identifier = acme.Identifier{
Value: strings.TrimPrefix(az.Identifier.Value, "*."),
Type: az.Identifier.Type,
}
}
var (
err error
chTypes = []string{"dns-01"}
)
// HTTP and TLS challenges can only be used for identifiers without wildcards.
if !az.Wildcard {
chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...)
}
az.Token, err = randutil.Alphanumeric(32)
if err != nil {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
}
az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes {
ch := &acme.Challenge{
AccountID: az.AccountID,
AuthzID: az.ID,
Value: az.Identifier.Value,
Type: typ,
Token: az.Token,
}
if err := h.db.CreateChallenge(ctx, ch); err != nil {
return err
}
az.Challenges[i] = ch
}
if err = h.db.CreateAuthorization(ctx, az); err != nil {
return acme.WrapErrorISE(err, "error creating authorization")
}
return nil
}
// GetOrder ACME api for retrieving an order. // GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := acme.AccountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
oid := chi.URLParam(r, "ordID") prov, err := provisionerFromContext(ctx)
o, err := h.Auth.GetOrder(ctx, acc.GetID(), oid)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.UpdateStatus(ctx, h.db); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error updating order status"))
return
}
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.GetID())) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID))
api.JSON(w, o) api.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attemptst to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acc, err := acme.AccountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -129,7 +227,8 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
} }
var fr FinalizeRequest var fr FinalizeRequest
if err := json.Unmarshal(payload.value, &fr); err != nil { if err := json.Unmarshal(payload.value, &fr); err != nil {
api.WriteError(w, acme.MalformedErr(errors.Wrap(err, "failed to unmarshal finalize-order request payload"))) api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err,
"failed to unmarshal finalize-order request payload"))
return return
} }
if err := fr.Validate(); err != nil { if err := fr.Validate(); err != nil {
@ -137,13 +236,28 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
oid := chi.URLParam(r, "ordID") o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
o, err := h.Auth.FinalizeOrder(ctx, acc.GetID(), oid, fr.csr)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, acme.WrapErrorISE(err, "error retrieving order"))
return
}
if acc.ID != o.AccountID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"account '%s' does not own order '%s'", acc.ID, o.ID))
return
}
if prov.GetID() != o.ProvisionerID {
api.WriteError(w, acme.NewError(acme.ErrorUnauthorizedType,
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return
}
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
w.Header().Set("Location", h.Auth.GetLink(ctx, acme.OrderLink, true, o.ID)) h.linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, true, o.ID))
api.JSON(w, o) api.JSON(w, o)
} }

View file

@ -1,420 +0,0 @@
package acme
import (
"context"
"crypto/tls"
"crypto/x509"
"log"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
)
// Interface is the acme authority interface.
type Interface interface {
GetDirectory(ctx context.Context) (*Directory, error)
NewNonce() (string, error)
UseNonce(string) error
DeactivateAccount(ctx context.Context, accID string) (*Account, error)
GetAccount(ctx context.Context, accID string) (*Account, error)
GetAccountByKey(ctx context.Context, key *jose.JSONWebKey) (*Account, error)
NewAccount(ctx context.Context, acc *Account) (*Account, error)
UpdateAccount(ctx context.Context, acc *Account) (*Account, error)
GetAuthz(ctx context.Context, accID string, authzID string) (*Authorization, error)
ValidateChallenge(ctx context.Context, accID string, chID string, key *jose.JSONWebKey) (*Challenge, error)
FinalizeOrder(ctx context.Context, accID string, orderID string, csr *x509.CertificateRequest) (*Order, error)
GetOrder(ctx context.Context, accID string, orderID string) (*Order, error)
GetOrdersByAccount(ctx context.Context, accID string) ([]string, error)
NewOrder(ctx context.Context, o *Order) (*Order, error)
GetCertificate(string, string) ([]byte, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
GetLink(ctx context.Context, linkType Link, absoluteLink bool, inputs ...string) string
GetLinkExplicit(linkType Link, provName string, absoluteLink bool, baseURL *url.URL, inputs ...string) string
}
// Authority is the layer that handles all ACME interactions.
type Authority struct {
backdate provisioner.Duration
db DB
dir *directory
signAuth SignAuthority
}
// AuthorityOptions required to create a new ACME Authority.
type AuthorityOptions struct {
Backdate provisioner.Duration
// DB storage backend that impements the acme.DB interface.
DB DB
// DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if
// request.Host is empty.
DNS string
// Prefix is a URL path prefix under which the ACME api is served. This
// prefix is required to generate accurate ACME links.
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
// "acme" is the prefix from which the ACME api is accessed.
Prefix string
}
// NewAuthority returns a new Authority that implements the ACME interface.
//
// Deprecated: NewAuthority exists for hitorical compatibility and should not
// be used. Use acme.New() instead.
func NewAuthority(db DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
return New(signAuth, AuthorityOptions{
DB: db,
DNS: dns,
Prefix: prefix,
})
}
// New returns a new Authority that implements the ACME interface.
func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) {
return &Authority{
backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth,
}, nil
}
// GetLink returns the requested link from the directory.
func (a *Authority) GetLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
return a.dir.getLink(ctx, typ, abs, inputs...)
}
// GetLinkExplicit returns the requested link from the directory.
func (a *Authority) GetLinkExplicit(typ Link, provName string, abs bool, baseURL *url.URL, inputs ...string) string {
return a.dir.getLinkExplicit(typ, provName, abs, baseURL, inputs...)
}
// GetDirectory returns the ACME directory object.
func (a *Authority) GetDirectory(ctx context.Context) (*Directory, error) {
return &Directory{
NewNonce: a.dir.getLink(ctx, NewNonceLink, true),
NewAccount: a.dir.getLink(ctx, NewAccountLink, true),
NewOrder: a.dir.getLink(ctx, NewOrderLink, true),
RevokeCert: a.dir.getLink(ctx, RevokeCertLink, true),
KeyChange: a.dir.getLink(ctx, KeyChangeLink, true),
}, nil
}
// LoadProvisionerByID calls out to the SignAuthority interface to load a
// provisioner by ID.
func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error) {
return a.signAuth.LoadProvisionerByID(id)
}
// NewNonce generates, stores, and returns a new ACME nonce.
func (a *Authority) NewNonce(ctx context.Context) (Nonce, error) {
return a.db.CreateNonce(ctx)
}
// UseNonce consumes the given nonce if it is valid, returns error otherwise.
func (a *Authority) UseNonce(ctx context.Context, nonce string) error {
return a.db.DeleteNonce(ctx, Nonce(nonce))
}
// NewAccount creates, stores, and returns a new ACME account.
func (a *Authority) NewAccount(ctx context.Context, acc *Account) error {
if err := a.db.CreateAccount(ctx, acc); err != nil {
return ErrorISEWrap(err, "error creating account")
}
return nil
}
// UpdateAccount updates an ACME account.
func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account, error) {
/*
acc.Contact = auo.Contact
acc.Status = auo.Status
*/
if err := a.db.UpdateAccount(ctx, acc); err != nil {
return nil, ErrorISEWrap(err, "error updating account")
}
return acc, nil
}
// GetAccount returns an ACME account.
func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error) {
acc, err := a.db.GetAccount(ctx, id)
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving account")
}
return acc, nil
}
// GetAccountByKey returns the ACME associated with the jwk id.
func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
kid, err := KeyToID(jwk)
if err != nil {
return nil, err
}
acc, err := a.db.GetAccountByKeyID(ctx, kid)
return acc, err
}
// GetOrder returns an ACME order.
func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order, error) {
prov, err := ProvisionerFromContext(ctx)
if err != nil {
return nil, err
}
o, err := a.db.GetOrder(ctx, orderID)
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving order")
}
if accID != o.AccountID {
log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID)
return nil, NewError(ErrorUnauthorizedType, "account does not own order")
}
if prov.GetID() != o.ProvisionerID {
log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID)
return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order")
}
if err = o.UpdateStatus(ctx, a.db); err != nil {
return nil, ErrorISEWrap(err, "error updating order")
}
return o, nil
}
/*
// GetOrdersByAccount returns the list of order urls owned by the account.
func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
ordersByAccountMux.Lock()
defer ordersByAccountMux.Unlock()
var oiba = orderIDsByAccount{}
oids, err := oiba.unsafeGetOrderIDsByAccount(a.db, id)
if err != nil {
return nil, err
}
var ret = []string{}
for _, oid := range oids {
ret = append(ret, a.dir.getLink(ctx, OrderLink, true, oid))
}
return ret, nil
}
*/
// NewOrder generates, stores, and returns a new ACME order.
func (a *Authority) NewOrder(ctx context.Context, o *Order) error {
if len(o.AccountID) == 0 {
return NewErrorISE("account-id cannot be empty")
}
if len(o.ProvisionerID) == 0 {
return NewErrorISE("provisioner-id cannot be empty")
}
if len(o.Identifiers) == 0 {
return NewErrorISE("identifiers cannot be empty")
}
if o.DefaultDuration == 0 {
return NewErrorISE("default-duration cannot be empty")
}
o.AuthorizationIDs = make([]string, len(o.Identifiers))
for i, identifier := range o.Identifiers {
az := &Authorization{
AccountID: o.AccountID,
Identifier: identifier,
}
if err := a.NewAuthorization(ctx, az); err != nil {
return err
}
o.AuthorizationIDs[i] = az.ID
}
now := clock.Now()
if o.NotBefore.IsZero() {
o.NotBefore = now
}
if o.NotAfter.IsZero() {
o.NotAfter = o.NotBefore.Add(o.DefaultDuration)
}
if err := a.db.CreateOrder(ctx, o); err != nil {
return ErrorISEWrap(err, "error creating order")
}
return nil
/*
o.DefaultDuration = prov.DefaultTLSCertDuration()
o.Backdate = a.backdate.Duration
o.ProvisionerID = prov.GetID()
if err = a.db.CreateOrder(ctx, o); err != nil {
return nil, ErrorWrap(ErrorServerInternalType, err, "error creating order")
}
return o, nil
*/
}
// FinalizeOrder attempts to finalize an order and generate a new certificate.
func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, csr *x509.CertificateRequest) (*Order, error) {
prov, err := ProvisionerFromContext(ctx)
if err != nil {
return nil, err
}
o, err := a.db.GetOrder(ctx, orderID)
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving order")
}
if accID != o.AccountID {
log.Printf("account-id from request ('%s') does not match order account-id ('%s')", accID, o.AccountID)
return nil, NewError(ErrorUnauthorizedType, "account does not own order")
}
if prov.GetID() != o.ProvisionerID {
log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID)
return nil, NewError(ErrorUnauthorizedType, "provisioner does not own order")
}
if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil {
return nil, ErrorISEWrap(err, "error finalizing order")
}
return o, nil
}
// NewAuthorization generates and stores an ACME Authorization type along with
// any associated resources.
func (a *Authority) NewAuthorization(ctx context.Context, az *Authorization) error {
if len(az.AccountID) == 0 {
return NewErrorISE("account-id cannot be empty")
}
if len(az.Identifier.Value) == 0 {
return NewErrorISE("identifier cannot be empty")
}
if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true
az.Identifier = Identifier{
Value: strings.TrimPrefix(az.Identifier.Value, "*."),
Type: az.Identifier.Type,
}
}
var (
err error
chTypes = []string{"dns-01"}
)
// HTTP and TLS challenges can only be used for identifiers without wildcards.
if !az.Wildcard {
chTypes = append(chTypes, []string{"http-01", "tls-alpn-01"}...)
}
az.Token, err = randutil.Alphanumeric(32)
if err != nil {
return ErrorISEWrap(err, "error generating random alphanumeric ID")
}
az.Challenges = make([]*Challenge, len(chTypes))
for i, typ := range chTypes {
ch := &Challenge{
AccountID: az.AccountID,
AuthzID: az.ID,
Value: az.Identifier.Value,
Type: typ,
Token: az.Token,
}
if err := a.NewChallenge(ctx, ch); err != nil {
return err
}
az.Challenges[i] = ch
}
if err = a.db.CreateAuthorization(ctx, az); err != nil {
return ErrorISEWrap(err, "error creating authorization")
}
return nil
}
// GetAuthorization retrieves and attempts to update the status on an ACME authz
// before returning.
func (a *Authority) GetAuthorization(ctx context.Context, accID, authzID string) (*Authorization, error) {
az, err := a.db.GetAuthorization(ctx, authzID)
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving authorization")
}
if accID != az.AccountID {
log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID)
return nil, NewError(ErrorUnauthorizedType, "account does not own order")
}
if err = az.UpdateStatus(ctx, a.db); err != nil {
return nil, ErrorISEWrap(err, "error updating authorization status")
}
return az, nil
}
// NewChallenge generates and stores an ACME challenge and associated resources.
func (a *Authority) NewChallenge(ctx context.Context, ch *Challenge) error {
if len(ch.AccountID) == 0 {
return NewErrorISE("account-id cannot be empty")
}
if len(ch.AuthzID) == 0 {
return NewErrorISE("authz-id cannot be empty")
}
if len(ch.Token) == 0 {
return NewErrorISE("token cannot be empty")
}
if len(ch.Value) == 0 {
return NewErrorISE("value cannot be empty")
}
switch ch.Type {
case "dns-01", "http-01", "tls-alpn-01":
break
default:
return NewErrorISE("unexpected error type '%s'", ch.Type)
}
if err := a.db.CreateChallenge(ctx, ch); err != nil {
return ErrorISEWrap(err, "error creating challenge")
}
return nil
}
// GetValidateChallenge attempts to validate the challenge.
func (a *Authority) GetValidateChallenge(ctx context.Context, accID, chID, azID string, jwk *jose.JSONWebKey) (*Challenge, error) {
ch, err := a.db.GetChallenge(ctx, chID, "todo")
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving challenge")
}
if accID != ch.AccountID {
log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, ch.AccountID)
return nil, NewError(ErrorUnauthorizedType, "account does not own order")
}
client := http.Client{
Timeout: time.Duration(30 * time.Second),
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
if err = ch.Validate(ctx, a.db, jwk, validateOptions{
httpGet: client.Get,
lookupTxt: net.LookupTXT,
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
}); err != nil {
return nil, ErrorISEWrap(err, "error validating challenge")
}
return ch, nil
}
// GetCertificate retrieves the Certificate by ID.
func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) {
cert, err := a.db.GetCertificate(ctx, certID)
if err != nil {
return nil, ErrorISEWrap(err, "error retrieving certificate")
}
if cert.AccountID != accID {
log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID)
return nil, NewError(ErrorUnauthorizedType, "account does not own order")
}
return cert.ToACME(ctx)
}

View file

@ -22,7 +22,7 @@ type Authorization struct {
func (az *Authorization) ToLog() (interface{}, error) { func (az *Authorization) ToLog() (interface{}, error) {
b, err := json.Marshal(az) b, err := json.Marshal(az)
if err != nil { if err != nil {
return nil, ErrorISEWrap(err, "error marshaling authz for logging") return nil, WrapErrorISE(err, "error marshaling authz for logging")
} }
return string(b), nil return string(b), nil
} }
@ -30,11 +30,7 @@ func (az *Authorization) ToLog() (interface{}, error) {
// UpdateStatus updates the ACME Authorization Status if necessary. // UpdateStatus updates the ACME Authorization Status if necessary.
// Changes to the Authorization are saved using the database interface. // Changes to the Authorization are saved using the database interface.
func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error { func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
now := time.Now().UTC() now := clock.Now()
expiry, err := time.Parse(time.RFC3339, az.Expires)
if err != nil {
return ErrorISEWrap(err, "error converting expiry string to time")
}
switch az.Status { switch az.Status {
case StatusInvalid: case StatusInvalid:
@ -43,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
return nil return nil
case StatusPending: case StatusPending:
// check expiry // check expiry
if now.After(expiry) { if now.After(az.Expires) {
az.Status = StatusInvalid az.Status = StatusInvalid
break break
} }
@ -61,11 +57,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
} }
az.Status = StatusValid az.Status = StatusValid
default: default:
return NewError(ErrorServerInternalType, "unrecognized authorization status: %s", az.Status) return NewErrorISE("unrecognized authorization status: %s", az.Status)
} }
if err = db.UpdateAuthorization(ctx, az); err != nil { if err := db.UpdateAuthorization(ctx, az); err != nil {
return ErrorISEWrap(err, "error updating authorization") return WrapErrorISE(err, "error updating authorization")
} }
return nil return nil
} }

View file

@ -1,7 +1,6 @@
package acme package acme
import ( import (
"context"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
) )
@ -16,7 +15,7 @@ type Certificate struct {
} }
// ToACME encodes the entire X509 chain into a PEM list. // ToACME encodes the entire X509 chain into a PEM list.
func (cert *Certificate) ToACME(ctx context.Context) ([]byte, error) { func (cert *Certificate) ToACME() ([]byte, error) {
var ret []byte var ret []byte
for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
ret = append(ret, pem.EncodeToMemory(&pem.Block{ ret = append(ret, pem.EncodeToMemory(&pem.Block{

View file

@ -38,7 +38,7 @@ type Challenge struct {
func (ch *Challenge) ToLog() (interface{}, error) { func (ch *Challenge) ToLog() (interface{}, error) {
b, err := json.Marshal(ch) b, err := json.Marshal(ch)
if err != nil { if err != nil {
return nil, ErrorISEWrap(err, "error marshaling challenge for logging") return nil, WrapErrorISE(err, "error marshaling challenge for logging")
} }
return string(b), nil return string(b), nil
} }
@ -47,7 +47,7 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface. // type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are // satisfactorily validated, the 'status' and 'validated' attributes are
// updated. // updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error {
// If already valid or invalid then return without performing validation. // If already valid or invalid then return without performing validation.
if ch.Status == StatusValid || ch.Status == StatusInvalid { if ch.Status == StatusValid || ch.Status == StatusInvalid {
return nil return nil
@ -60,16 +60,16 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
case "tls-alpn-01": case "tls-alpn-01":
return tlsalpn01Validate(ctx, ch, db, jwk, vo) return tlsalpn01Validate(ctx, ch, db, jwk, vo)
default: default:
return NewError(ErrorServerInternalType, "unexpected challenge type '%s'", ch.Type) return NewErrorISE("unexpected challenge type '%s'", ch.Type)
} }
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error {
url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token) url := fmt.Sprintf("http://%s/.well-known/acme-challenge/%s", ch.Value, ch.Token)
resp, err := vo.httpGet(url) resp, err := vo.HTTPGet(url)
if err != nil { if err != nil {
return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", url)) "error doing http GET for url %s", url))
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
@ -80,7 +80,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return ErrorISEWrap(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", url) "response body for url %s", url)
} }
keyAuth := strings.Trim(string(body), "\r\n") keyAuth := strings.Trim(string(body), "\r\n")
@ -100,12 +100,12 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
ch.Validated = clock.Now().Format(time.RFC3339) ch.Validated = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return ErrorISEWrap(err, "error updating challenge") return WrapErrorISE(err, "error updating challenge")
} }
return nil return nil
} }
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
ServerName: ch.Value, ServerName: ch.Value,
@ -114,9 +114,9 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443") hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.tlsDial("tcp", hostPort, config) conn, err := vo.TLSDial("tcp", hostPort, config)
if err != nil { if err != nil {
return storeError(ctx, ch, db, ErrorWrap(ErrorConnectionType, err, return storeError(ctx, ch, db, WrapError(ErrorConnectionType, err,
"error doing TLS dial for %s", hostPort)) "error doing TLS dial for %s", hostPort))
} }
defer conn.Close() defer conn.Close()
@ -178,7 +178,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
ch.Validated = clock.Now().Format(time.RFC3339) ch.Validated = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return ErrorISEWrap(err, "tlsalpn01ValidateChallenge - error updating challenge") return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge")
} }
return nil return nil
} }
@ -197,16 +197,16 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
} }
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo ValidateOptions) error {
// Normalize domain for wildcard DNS names // Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like // This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com // _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com // Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) txtRecords, err := vo.LookupTxt("_acme-challenge." + domain)
if err != nil { if err != nil {
return storeError(ctx, ch, db, ErrorWrap(ErrorDNSType, err, return storeError(ctx, ch, db, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain)) "error looking up TXT records for domain %s", domain))
} }
@ -234,7 +234,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
ch.Validated = clock.Now().UTC().Format(time.RFC3339) ch.Validated = clock.Now().UTC().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return ErrorISEWrap(err, "error updating challenge") return WrapErrorISE(err, "error updating challenge")
} }
return nil return nil
} }
@ -244,7 +244,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) { func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
thumbprint, err := jwk.Thumbprint(crypto.SHA256) thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil { if err != nil {
return "", ErrorISEWrap(err, "error generating JWK thumbprint") return "", WrapErrorISE(err, "error generating JWK thumbprint")
} }
encPrint := base64.RawURLEncoding.EncodeToString(thumbprint) encPrint := base64.RawURLEncoding.EncodeToString(thumbprint)
return fmt.Sprintf("%s.%s", token, encPrint), nil return fmt.Sprintf("%s.%s", token, encPrint), nil
@ -254,7 +254,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error { func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error {
ch.Error = err ch.Error = err
if err := db.UpdateChallenge(ctx, ch); err != nil { if err := db.UpdateChallenge(ctx, ch); err != nil {
return ErrorISEWrap(err, "failure saving error to acme challenge") return WrapErrorISE(err, "failure saving error to acme challenge")
} }
return nil return nil
} }
@ -263,8 +263,9 @@ type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error) type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error) type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
type validateOptions struct { // ValidateOptions are ACME challenge validator functions.
httpGet httpGetter type ValidateOptions struct {
lookupTxt lookupTxt HTTPGet httpGetter
tlsDial tlsDialer LookupTxt lookupTxt
TLSDial tlsDialer
} }

View file

@ -3,13 +3,27 @@ package acme
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"net/url"
"time" "time"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
) )
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds.
type Clock int
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Round(time.Second)
}
var clock = new(Clock)
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
@ -70,89 +84,3 @@ func (m *MockProvisioner) GetID() string {
} }
return m.Mret1.(string) return m.Mret1.(string)
} }
// ContextKey is the key type for storing and searching for ACME request
// essentials in the context of a request.
type ContextKey string
const (
// AccContextKey account key
AccContextKey = ContextKey("acc")
// BaseURLContextKey baseURL key
BaseURLContextKey = ContextKey("baseURL")
// JwsContextKey jws key
JwsContextKey = ContextKey("jws")
// JwkContextKey jwk key
JwkContextKey = ContextKey("jwk")
// PayloadContextKey payload key
PayloadContextKey = ContextKey("payload")
// ProvisionerContextKey provisioner key
ProvisionerContextKey = ContextKey("provisioner")
)
// AccountFromContext searches the context for an ACME account. Returns the
// account or an error.
func AccountFromContext(ctx context.Context) (*Account, error) {
val, ok := ctx.Value(AccContextKey).(*Account)
if !ok || val == nil {
return nil, NewError(ErrorServerInternalType, "account not in context")
}
return val, nil
}
// BaseURLFromContext returns the baseURL if one is stored in the context.
func BaseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(BaseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// JwkFromContext searches the context for a JWK. Returns the JWK or an error.
func JwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(JwkContextKey).(*jose.JSONWebKey)
if !ok || val == nil {
return nil, NewError(ErrorServerInternalType, "jwk expected in request context")
}
return val, nil
}
// JwsFromContext searches the context for a JWS. Returns the JWS or an error.
func JwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
val, ok := ctx.Value(JwsContextKey).(*jose.JSONWebSignature)
if !ok || val == nil {
return nil, NewError(ErrorServerInternalType, "jws expected in request context")
}
return val, nil
}
// ProvisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error.
func ProvisionerFromContext(ctx context.Context) (Provisioner, error) {
val := ctx.Value(ProvisionerContextKey)
if val == nil {
return nil, NewError(ErrorServerInternalType, "provisioner expected in request context")
}
pval, ok := val.(Provisioner)
if !ok || pval == nil {
return nil, NewError(ErrorServerInternalType, "provisioner in context is not an ACME provisioner")
}
return pval, nil
}
// SignAuthority is the interface implemented by a CA authority.
type SignAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
LoadProvisionerByID(string) (provisioner.Interface, error)
}
// Clock that returns time in UTC rounded to seconds.
type Clock int
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Round(time.Second)
}
var clock = new(Clock)

View file

@ -74,7 +74,6 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error)
return &acme.Account{ return &acme.Account{
Status: dbacc.Status, Status: dbacc.Status,
Contact: dbacc.Contact, Contact: dbacc.Contact,
Orders: dir.getLink(ctx, OrdersByAccountLink, true, dbacc.ID),
Key: dbacc.Key, Key: dbacc.Key,
ID: dbacc.ID, ID: dbacc.ID,
}, nil }, nil

View file

@ -16,7 +16,7 @@ var defaultExpiryDuration = time.Hour * 24
type dbAuthz struct { type dbAuthz struct {
ID string `json:"id"` ID string `json:"id"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
Identifier *acme.Identifier `json:"identifier"` Identifier acme.Identifier `json:"identifier"`
Status acme.Status `json:"status"` Status acme.Status `json:"status"`
Expires time.Time `json:"expires"` Expires time.Time `json:"expires"`
Challenges []string `json:"challenges"` Challenges []string `json:"challenges"`
@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat
Status: dbaz.Status, Status: dbaz.Status,
Challenges: chs, Challenges: chs,
Wildcard: dbaz.Wildcard, Wildcard: dbaz.Wildcard,
Expires: dbaz.Expires.Format(time.RFC3339), Expires: dbaz.Expires,
ID: dbaz.ID, ID: dbaz.ID,
}, nil }, nil
} }

View file

@ -21,7 +21,7 @@ type dbChallenge struct {
Value string `json:"value"` Value string `json:"value"`
Validated string `json:"validated"` Validated string `json:"validated"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Error *AError `json:"error"` Error *acme.Error `json:"error"`
} }
func (dbc *dbChallenge) clone() *dbChallenge { func (dbc *dbChallenge) clone() *dbChallenge {
@ -79,7 +79,6 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall
Type: dbch.Type, Type: dbch.Type,
Status: dbch.Status, Status: dbch.Status,
Token: dbch.Token, Token: dbch.Token,
URL: dir.getLink(ctx, ChallengeLink, true, dbch.ID),
ID: dbch.ID, ID: dbch.ID,
AuthzID: dbch.AuthzID, AuthzID: dbch.AuthzID,
Error: dbch.Error, Error: dbch.Error,

View file

@ -11,8 +11,6 @@ import (
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
var defaultOrderExpiry = time.Hour * 24
// Mutex for locking ordersByAccount index operations. // Mutex for locking ordersByAccount index operations.
var ordersByAccountMux sync.Mutex var ordersByAccountMux sync.Mutex
@ -26,16 +24,16 @@ type dbOrder struct {
Identifiers []acme.Identifier `json:"identifiers"` Identifiers []acme.Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"` NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"`
Error *Error `json:"error,omitempty"` Error *acme.Error `json:"error,omitempty"`
Authorizations []string `json:"authorizations"` Authorizations []string `json:"authorizations"`
Certificate string `json:"certificate,omitempty"` CertificateID string `json:"certificate,omitempty"`
} }
// getDBOrder retrieves and unmarshals an ACME Order type from the database. // getDBOrder retrieves and unmarshals an ACME Order type from the database.
func (db *DB) getDBOrder(id string) (*dbOrder, error) { func (db *DB) getDBOrder(id string) (*dbOrder, error) {
b, err := db.db.Get(orderTable, []byte(id)) b, err := db.db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "order %s not found", id) return nil, acme.WrapError(acme.ErrorMalformedType, err, "order %s not found", id)
} else if err != nil { } else if err != nil {
return nil, errors.Wrapf(err, "error loading order %s", id) return nil, errors.Wrapf(err, "error loading order %s", id)
} }
@ -49,34 +47,31 @@ func (db *DB) getDBOrder(id string) (*dbOrder, error) {
// GetOrder retrieves an ACME Order from the database. // GetOrder retrieves an ACME Order from the database.
func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) { func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
dbo, err := db.getDBOrder(id) dbo, err := db.getDBOrder(id)
if err != nil {
azs := make([]string, len(dbo.Authorizations)) return nil, err
for i, aid := range dbo.Authorizations {
azs[i] = dir.getLink(ctx, AuthzLink, true, aid)
} }
o := &acme.Order{ o := &acme.Order{
Status: dbo.Status, Status: dbo.Status,
Expires: dbo.Expires.Format(time.RFC3339), Expires: dbo.Expires,
Identifiers: dbo.Identifiers, Identifiers: dbo.Identifiers,
NotBefore: dbo.NotBefore.Format(time.RFC3339), NotBefore: dbo.NotBefore,
NotAfter: dbo.NotAfter.Format(time.RFC3339), NotAfter: dbo.NotAfter,
Authorizations: azs, AuthorizationIDs: dbo.Authorizations,
FinalizeURL: dir.getLink(ctx, FinalizeLink, true, o.ID),
ID: dbo.ID, ID: dbo.ID,
ProvisionerID: dbo.ProvisionerID, ProvisionerID: dbo.ProvisionerID,
CertificateID: dbo.CertificateID,
} }
if dbo.Certificate != "" {
o.Certificate = dir.getLink(ctx, CertificateLink, true, o.Certificate)
}
return o, nil return o, nil
} }
// CreateOrder creates ACME Order resources and saves them to the DB. // CreateOrder creates ACME Order resources and saves them to the DB.
func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error { func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
var err error
o.ID, err = randID() o.ID, err = randID()
if err != nil { if err != nil {
return nil, err return err
} }
now := clock.Now() now := clock.Now()
@ -85,23 +80,23 @@ func (db *DB) CreateOrder(ctx context.Context, o *acme.Order) error {
AccountID: o.AccountID, AccountID: o.AccountID,
ProvisionerID: o.ProvisionerID, ProvisionerID: o.ProvisionerID,
Created: now, Created: now,
Status: StatusPending, Status: acme.StatusPending,
Expires: now.Add(defaultOrderExpiry), Expires: o.Expires,
Identifiers: o.Identifiers, Identifiers: o.Identifiers,
NotBefore: o.NotBefore, NotBefore: o.NotBefore,
NotAfter: o.NotBefore, NotAfter: o.NotBefore,
Authorizations: o.AuthorizationIDs, Authorizations: o.AuthorizationIDs,
} }
if err := db.save(ctx, o.ID, dbo, nil, orderTable); err != nil { if err := db.save(ctx, o.ID, dbo, nil, "order", orderTable); err != nil {
return nil, err return err
} }
var oidHelper = orderIDsByAccount{} var oidHelper = orderIDsByAccount{}
_, err = oidHelper.addOrderID(db, o.AccountID, o.ID) _, err = oidHelper.addOrderID(db, o.AccountID, o.ID)
if err != nil { if err != nil {
return nil, err return err
} }
return o, nil return nil
} }
type orderIDsByAccount struct{} type orderIDsByAccount struct{}
@ -135,11 +130,11 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return []string{}, nil return []string{}, nil
} }
return nil, ServerInternalErr(errors.Wrapf(err, "error loading orderIDs for account %s", accID)) return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID)
} }
var oids []string var oids []string
if err := json.Unmarshal(b, &oids); err != nil { if err := json.Unmarshal(b, &oids); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)) return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID)
} }
// Remove any order that is not in PENDING state and update the stored list // Remove any order that is not in PENDING state and update the stored list
@ -152,21 +147,21 @@ func (oiba orderIDsByAccount) unsafeGetOrderIDsByAccount(db nosql.DB, accID stri
for _, oid := range oids { for _, oid := range oids {
o, err := getOrder(db, oid) o, err := getOrder(db, oid)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s for account %s", oid, accID)) return nil, errors.Wrapf(err, "error loading order %s for account %s", oid, accID)
} }
if o, err = o.UpdateStatus(db); err != nil { if o, err = o.UpdateStatus(db); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error updating order %s for account %s", oid, accID)) return nil, errors.Wrapf(err, "error updating order %s for account %s", oid, accID)
} }
if o.Status == StatusPending { if o.Status == acme.StatusPending {
pendOids = append(pendOids, oid) pendOids = append(pendOids, oid)
} }
} }
// If the number of pending orders is less than the number of orders in the // If the number of pending orders is less than the number of orders in the
// list, then update the pending order list. // list, then update the pending order list.
if len(pendOids) != len(oids) { if len(pendOids) != len(oids) {
if err = orderIDs(pendOiUs).save(db, oids, accID); err != nil { if err = orderIDs(pendOids).save(db, oids, accID); err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+ return nil, errors.Wrapf(err, "error storing orderIDs as part of getOrderIDsByAccount logic: "+
"len(orderIDs) = %d", len(pendOids))) "len(orderIDs) = %d", len(pendOids))
} }
} }
@ -192,7 +187,7 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
} else { } else {
oldb, err = json.Marshal(old) oldb, err = json.Marshal(old)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling old order IDs slice")) return errors.Wrap(err, "error marshaling old order IDs slice")
} }
} }
if len(oids) == 0 { if len(oids) == 0 {
@ -200,13 +195,13 @@ func (oids orderIDs) save(db nosql.DB, old orderIDs, accID string) error {
} else { } else {
newb, err = json.Marshal(oids) newb, err = json.Marshal(oids)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrap(err, "error marshaling new order IDs slice")) return errors.Wrap(err, "error marshaling new order IDs slice")
} }
} }
_, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb) _, swapped, err := db.CmpAndSwap(ordersByAccountIDTable, []byte(accID), oldb, newb)
switch { switch {
case err != nil: case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error storing order IDs for account %s", accID)) return errors.Wrapf(err, "error storing order IDs for account %s", accID)
case !swapped: case !swapped:
return ServerInternalErr(errors.Errorf("error storing order IDs "+ return ServerInternalErr(errors.Errorf("error storing order IDs "+
"for account %s; order IDs changed since last read", accID)) "for account %s; order IDs changed since last read", accID))

View file

@ -1,148 +0,0 @@
package acme
import (
"context"
"encoding/json"
"fmt"
"net/url"
)
// Directory represents an ACME directory for configuring clients.
type Directory struct {
NewNonce string `json:"newNonce,omitempty"`
NewAccount string `json:"newAccount,omitempty"`
NewOrder string `json:"newOrder,omitempty"`
NewAuthz string `json:"newAuthz,omitempty"`
RevokeCert string `json:"revokeCert,omitempty"`
KeyChange string `json:"keyChange,omitempty"`
}
// ToLog enables response logging for the Directory type.
func (d *Directory) ToLog() (interface{}, error) {
b, err := json.Marshal(d)
if err != nil {
return nil, ErrorISEWrap(err, "error marshaling directory for logging")
}
return string(b), nil
}
type directory struct {
prefix, dns string
}
// newDirectory returns a new Directory type.
func newDirectory(dns, prefix string) *directory {
return &directory{prefix: prefix, dns: dns}
}
// Link captures the link type.
type Link int
const (
// NewNonceLink new-nonce
NewNonceLink Link = iota
// NewAccountLink new-account
NewAccountLink
// AccountLink account
AccountLink
// OrderLink order
OrderLink
// NewOrderLink new-order
NewOrderLink
// OrdersByAccountLink list of orders owned by account
OrdersByAccountLink
// FinalizeLink finalize order
FinalizeLink
// NewAuthzLink authz
NewAuthzLink
// AuthzLink new-authz
AuthzLink
// ChallengeLink challenge
ChallengeLink
// CertificateLink certificate
CertificateLink
// DirectoryLink directory
DirectoryLink
// RevokeCertLink revoke certificate
RevokeCertLink
// KeyChangeLink key rollover
KeyChangeLink
)
func (l Link) String() string {
switch l {
case NewNonceLink:
return "new-nonce"
case NewAccountLink:
return "new-account"
case AccountLink:
return "account"
case NewOrderLink:
return "new-order"
case OrderLink:
return "order"
case NewAuthzLink:
return "new-authz"
case AuthzLink:
return "authz"
case ChallengeLink:
return "challenge"
case CertificateLink:
return "certificate"
case DirectoryLink:
return "directory"
case RevokeCertLink:
return "revoke-cert"
case KeyChangeLink:
return "key-change"
default:
return "unexpected"
}
}
func (d *directory) getLink(ctx context.Context, typ Link, abs bool, inputs ...string) string {
var provName string
if p, err := ProvisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
return d.getLinkExplicit(typ, provName, abs, BaseURLFromContext(ctx), inputs...)
}
// getLinkExplicit returns an absolute or partial path to the given resource and a base
// URL dynamically obtained from the request for which the link is being
// calculated.
func (d *directory) getLinkExplicit(typ Link, provisionerName string, abs bool, baseURL *url.URL, inputs ...string) string {
var link string
switch typ {
case NewNonceLink, NewAccountLink, NewOrderLink, NewAuthzLink, DirectoryLink, KeyChangeLink, RevokeCertLink:
link = fmt.Sprintf("/%s/%s", provisionerName, typ.String())
case AccountLink, OrderLink, AuthzLink, ChallengeLink, CertificateLink:
link = fmt.Sprintf("/%s/%s/%s", provisionerName, typ.String(), inputs[0])
case OrdersByAccountLink:
link = fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLink.String(), inputs[0])
case FinalizeLink:
link = fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLink.String(), inputs[0])
}
if abs {
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
u := url.URL{}
if baseURL != nil {
u = *baseURL
}
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = d.dns
}
u.Path = d.prefix + link
return u.String()
}
return link
}

View file

@ -262,7 +262,7 @@ var (
// Error represents an ACME // Error represents an ACME
type Error struct { type Error struct {
Type string `json:"type"` Type string `json:"type"`
Details string `json:"detail"` Detail string `json:"detail"`
Subproblems []interface{} `json:"subproblems,omitempty"` Subproblems []interface{} `json:"subproblems,omitempty"`
Identifier interface{} `json:"identifier,omitempty"` Identifier interface{} `json:"identifier,omitempty"`
Err error `json:"-"` Err error `json:"-"`
@ -276,7 +276,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
meta = errorServerInternalMetadata meta = errorServerInternalMetadata
return &Error{ return &Error{
Type: meta.typ, Type: meta.typ,
Details: meta.details, Detail: meta.details,
Status: meta.status, Status: meta.status,
Err: errors.Errorf("unrecognized problemType %v", pt), Err: errors.Errorf("unrecognized problemType %v", pt),
} }
@ -284,7 +284,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
return &Error{ return &Error{
Type: meta.typ, Type: meta.typ,
Details: meta.details, Detail: meta.details,
Status: meta.status, Status: meta.status,
Err: errors.Errorf(msg, args...), Err: errors.Errorf(msg, args...),
} }
@ -295,14 +295,14 @@ func NewErrorISE(msg string, args ...interface{}) *Error {
return NewError(ErrorServerInternalType, msg, args...) return NewError(ErrorServerInternalType, msg, args...)
} }
// ErrorWrap attempts to wrap the internal error. // WrapError attempts to wrap the internal error.
func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error { func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Error {
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
return nil return nil
case *Error: case *Error:
if e.Err == nil { if e.Err == nil {
e.Err = errors.Errorf(msg+"; "+e.Details, args...) e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
} else { } else {
e.Err = errors.Wrapf(e.Err, msg, args...) e.Err = errors.Wrapf(e.Err, msg, args...)
} }
@ -312,9 +312,9 @@ func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Err
} }
} }
// ErrorISEWrap shortcut to wrap an internal server error type. // WrapErrorISE shortcut to wrap an internal server error type.
func ErrorISEWrap(err error, msg string, args ...interface{}) *Error { func WrapErrorISE(err error, msg string, args ...interface{}) *Error {
return ErrorWrap(ErrorServerInternalType, err, msg, args...) return WrapError(ErrorServerInternalType, err, msg, args...)
} }
// StatusCode returns the status code and implements the StatusCoder interface. // StatusCode returns the status code and implements the StatusCoder interface.
@ -324,13 +324,13 @@ func (e *Error) StatusCode() int {
// Error allows AError to implement the error interface. // Error allows AError to implement the error interface.
func (e *Error) Error() string { func (e *Error) Error() string {
return e.Details return e.Detail
} }
// Cause returns the internal error and implements the Causer interface. // Cause returns the internal error and implements the Causer interface.
func (e *Error) Cause() error { func (e *Error) Cause() error {
if e.Err == nil { if e.Err == nil {
return errors.New(e.Details) return errors.New(e.Detail)
} }
return e.Err return e.Err
} }

View file

@ -1,3 +1,9 @@
package acme package acme
// Nonce represents an ACME nonce type.
type Nonce string type Nonce string
// String implements the ToString interface.
func (n Nonce) String() string {
return string(n)
}

View file

@ -26,10 +26,11 @@ type Order struct {
NotBefore time.Time `json:"notBefore,omitempty"` NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"`
Error interface{} `json:"error,omitempty"` Error interface{} `json:"error,omitempty"`
AuthorizationURLs []string `json:"authorizations"`
AuthorizationIDs []string `json:"-"` AuthorizationIDs []string `json:"-"`
AuthorizationURLs []string `json:"authorizations"`
FinalizeURL string `json:"finalize"` FinalizeURL string `json:"finalize"`
Certificate string `json:"certificate,omitempty"` CertificateID string `json:"-"`
CertificateURL string `json:"certificate,omitempty"`
ID string `json:"-"` ID string `json:"-"`
AccountID string `json:"-"` AccountID string `json:"-"`
ProvisionerID string `json:"-"` ProvisionerID string `json:"-"`
@ -41,7 +42,7 @@ type Order struct {
func (o *Order) ToLog() (interface{}, error) { func (o *Order) ToLog() (interface{}, error) {
b, err := json.Marshal(o) b, err := json.Marshal(o)
if err != nil { if err != nil {
return nil, ErrorISEWrap(err, "error marshaling order for logging") return nil, WrapErrorISE(err, "error marshaling order for logging")
} }
return string(b), nil return string(b), nil
} }
@ -111,7 +112,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
// Finalize signs a certificate if the necessary conditions for Order completion // Finalize signs a certificate if the necessary conditions for Order completion
// have been met. // have been met.
func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth CertificateAuthority, p Provisioner) error {
if err := o.UpdateStatus(ctx, db); err != nil { if err := o.UpdateStatus(ctx, db); err != nil {
return err return err
} }
@ -170,7 +171,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOps, err := p.AuthorizeSign(ctx, "") signOps, err := p.AuthorizeSign(ctx, "")
if err != nil { if err != nil {
return ErrorISEWrap(err, "error retrieving authorization options from ACME provisioner") return WrapErrorISE(err, "error retrieving authorization options from ACME provisioner")
} }
// Template data // Template data
@ -180,7 +181,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
if err != nil { if err != nil {
return ErrorISEWrap(err, "error creating template options from ACME provisioner") return WrapErrorISE(err, "error creating template options from ACME provisioner")
} }
signOps = append(signOps, templateOptions) signOps = append(signOps, templateOptions)
@ -190,7 +191,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
NotAfter: provisioner.NewTimeDuration(o.NotAfter), NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...) }, signOps...)
if err != nil { if err != nil {
return ErrorISEWrap(err, "error signing certificate for order %s", o.ID) return WrapErrorISE(err, "error signing certificate for order %s", o.ID)
} }
cert := &Certificate{ cert := &Certificate{
@ -203,7 +204,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques
return err return err
} }
o.Certificate = cert.ID o.CertificateID = cert.ID
o.Status = StatusValid o.Status = StatusValid
return db.UpdateOrder(ctx, o) return db.UpdateOrder(ctx, o)
} }

View file

@ -21,7 +21,7 @@ import (
type ACMEClient struct { type ACMEClient struct {
client *http.Client client *http.Client
dirLoc string dirLoc string
dir *acme.Directory dir *acmeAPI.Directory
acc *acme.Account acc *acme.Account
Key *jose.JSONWebKey Key *jose.JSONWebKey
kid string kid string
@ -53,7 +53,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, readACMEError(resp.Body) return nil, readACMEError(resp.Body)
} }
var dir acme.Directory var dir acmeAPI.Directory
if err := readJSON(resp.Body, &dir); err != nil { if err := readJSON(resp.Body, &dir); err != nil {
return nil, errors.Wrapf(err, "error reading %s", endpoint) return nil, errors.Wrapf(err, "error reading %s", endpoint)
} }
@ -93,7 +93,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC
// GetDirectory makes a directory request to the ACME api and returns an // GetDirectory makes a directory request to the ACME api and returns an
// ACME directory object. // ACME directory object.
func (c *ACMEClient) GetDirectory() (*acme.Directory, error) { func (c *ACMEClient) GetDirectory() (*acmeAPI.Directory, error) {
return c.dir, nil return c.dir, nil
} }
@ -231,7 +231,7 @@ func (c *ACMEClient) ValidateChallenge(url string) error {
} }
// GetAuthz returns the Authz at the given path. // GetAuthz returns the Authz at the given path.
func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) { func (c *ACMEClient) GetAuthz(url string) (*acme.Authorization, error) {
resp, err := c.post(nil, url, withKid(c)) resp, err := c.post(nil, url, withKid(c))
if err != nil { if err != nil {
return nil, err return nil, err
@ -240,7 +240,7 @@ func (c *ACMEClient) GetAuthz(url string) (*acme.Authz, error) {
return nil, readACMEError(resp.Body) return nil, readACMEError(resp.Body)
} }
var az acme.Authz var az acme.Authorization
if err := readJSON(resp.Body, &az); err != nil { if err := readJSON(resp.Body, &az); err != nil {
return nil, errors.Wrapf(err, "error reading %s", url) return nil, errors.Wrapf(err, "error reading %s", url)
} }
@ -342,7 +342,7 @@ func readACMEError(r io.ReadCloser) error {
if err != nil { if err != nil {
return errors.Wrap(err, "error reading from body") return errors.Wrap(err, "error reading from body")
} }
ae := new(acme.AError) ae := new(acme.Error)
err = json.Unmarshal(b, &ae) err = json.Unmarshal(b, &ae)
// If we successfully marshaled to an ACMEError then return the ACMEError. // If we successfully marshaled to an ACMEError then return the ACMEError.
if err != nil || len(ae.Error()) == 0 { if err != nil || len(ae.Error()) == 0 {

View file

@ -11,8 +11,8 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api" acmeAPI "github.com/smallstep/certificates/acme/api"
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -124,11 +124,12 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) {
} }
prefix := "acme" prefix := "acme"
acmeAuth, err := acme.New(auth, acme.AuthorityOptions{ acmeAuth, err := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *config.AuthorityConfig.Backdate, Backdate: *config.AuthorityConfig.Backdate,
DB: auth.GetDatabase().(nosql.DB), DB: acmeNoSQL.New(auth.GetDatabase().(nosql.DB)),
DNS: dns, DNS: dns,
Prefix: prefix, Prefix: prefix,
CA: auth,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error creating ACME authority") return nil, errors.Wrap(err, "error creating ACME authority")