[acme db interface] wip errors

This commit is contained in:
max furman 2021-02-28 22:49:20 -08:00
parent 121cc34cca
commit 2ae43ef2dc
18 changed files with 564 additions and 715 deletions

View file

@ -1,9 +1,10 @@
package acme package acme
import ( import (
"crypto"
"encoding/base64"
"encoding/json" "encoding/json"
"github.com/pkg/errors"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
) )
@ -11,7 +12,7 @@ import (
// attributes required for responses in the ACME protocol. // attributes required for responses in the ACME protocol.
type Account struct { type Account struct {
Contact []string `json:"contact,omitempty"` Contact []string `json:"contact,omitempty"`
Status string `json:"status"` Status Status `json:"status"`
Orders string `json:"orders"` Orders string `json:"orders"`
ID string `json:"-"` ID string `json:"-"`
Key *jose.JSONWebKey `json:"-"` Key *jose.JSONWebKey `json:"-"`
@ -21,7 +22,7 @@ type Account struct {
func (a *Account) ToLog() (interface{}, error) { func (a *Account) ToLog() (interface{}, error) {
b, err := json.Marshal(a) b, err := json.Marshal(a)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling account for logging")) return nil, ErrorWrap(ErrorServerInternalType, err, "error marshaling account for logging")
} }
return string(b), nil return string(b), nil
} }
@ -40,3 +41,12 @@ func (a *Account) GetKey() *jose.JSONWebKey {
func (a *Account) IsValid() bool { func (a *Account) IsValid() bool {
return Status(a.Status) == StatusValid return Status(a.Status) == StatusValid
} }
// KeyToID converts a JWK to a thumbprint.
func KeyToID(jwk *jose.JSONWebKey) (string, error) {
kid, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", ErrorWrap(ErrorServerInternalType, err, "error generating jwk thumbprint")
}
return base64.RawURLEncoding.EncodeToString(kid), nil
}

View file

@ -44,7 +44,7 @@ type UpdateAccountRequest struct {
// IsDeactivateRequest returns true if the update request is a deactivation // IsDeactivateRequest returns true if the update request is a deactivation
// request, false otherwise. // request, false otherwise.
func (u *UpdateAccountRequest) IsDeactivateRequest() bool { func (u *UpdateAccountRequest) IsDeactivateRequest() bool {
return u.Status == acme.StatusDeactivated return u.Status == string(acme.StatusDeactivated)
} }
// Validate validates a update-account request body. // Validate validates a update-account request body.
@ -59,7 +59,7 @@ func (u *UpdateAccountRequest) Validate() error {
} }
return nil return nil
case len(u.Status) > 0: case len(u.Status) > 0:
if u.Status != acme.StatusDeactivated { if u.Status != string(acme.StatusDeactivated) {
return acme.MalformedErr(errors.Errorf("cannot update account "+ return acme.MalformedErr(errors.Errorf("cannot update account "+
"status to %s, only deactivated", u.Status)) "status to %s, only deactivated", u.Status))
} }
@ -110,9 +110,10 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
if acc, err = h.Auth.NewAccount(r.Context(), acme.AccountOptions{ if acc, err = h.Auth.NewAccount(r.Context(), &acme.Account{
Key: jwk, Key: jwk,
Contact: nar.Contact, Contact: nar.Contact,
Status: acme.StatusValid,
}); err != nil { }); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return

View file

@ -2,10 +2,8 @@ package acme
import ( import (
"context" "context"
"crypto"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/base64"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -14,8 +12,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
database "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
) )
@ -49,7 +45,7 @@ type Interface interface {
// Authority is the layer that handles all ACME interactions. // Authority is the layer that handles all ACME interactions.
type Authority struct { type Authority struct {
backdate provisioner.Duration backdate provisioner.Duration
db nosql.DB db DB
dir *directory dir *directory
signAuth SignAuthority signAuth SignAuthority
} }
@ -57,8 +53,8 @@ type Authority struct {
// AuthorityOptions required to create a new ACME Authority. // AuthorityOptions required to create a new ACME Authority.
type AuthorityOptions struct { type AuthorityOptions struct {
Backdate provisioner.Duration Backdate provisioner.Duration
// DB is the database used by nosql. // DB storage backend that impements the acme.DB interface.
DB nosql.DB DB DB
// DNS the host used to generate accurate ACME links. By default the authority // DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if // will use the Host from the request, so this value will only be used if
// request.Host is empty. // request.Host is empty.
@ -74,7 +70,7 @@ type AuthorityOptions struct {
// //
// Deprecated: NewAuthority exists for hitorical compatibility and should not // Deprecated: NewAuthority exists for hitorical compatibility and should not
// be used. Use acme.New() instead. // be used. Use acme.New() instead.
func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) { func NewAuthority(db DB, dns, prefix string, signAuth SignAuthority) (*Authority, error) {
return New(signAuth, AuthorityOptions{ return New(signAuth, AuthorityOptions{
DB: db, DB: db,
DNS: dns, DNS: dns,
@ -84,19 +80,6 @@ func NewAuthority(db nosql.DB, dns, prefix string, signAuth SignAuthority) (*Aut
// New returns a new Authority that implements the ACME interface. // New returns a new Authority that implements the ACME interface.
func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) { func New(signAuth SignAuthority, ops AuthorityOptions) (*Authority, error) {
if _, ok := ops.DB.(*database.SimpleDB); !ok {
// If it's not a SimpleDB then go ahead and bootstrap the DB with the
// necessary ACME tables. SimpleDB should ONLY be used for testing.
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
challengeTable, nonceTable, orderTable, ordersByAccountIDTable,
certTable}
for _, b := range tables {
if err := ops.DB.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",
string(b))
}
}
}
return &Authority{ return &Authority{
backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth, backdate: ops.Backdate, db: ops.DB, dir: newDirectory(ops.DNS, ops.Prefix), signAuth: signAuth,
}, nil }, nil
@ -130,21 +113,21 @@ func (a *Authority) LoadProvisionerByID(id string) (provisioner.Interface, error
} }
// NewNonce generates, stores, and returns a new ACME nonce. // NewNonce generates, stores, and returns a new ACME nonce.
func (a *Authority) NewNonce(ctx context.Context) (string, error) { func (a *Authority) NewNonce(ctx context.Context) (Nonce, error) {
return a.db.CreateNonce(ctx) return a.db.CreateNonce(ctx)
} }
// UseNonce consumes the given nonce if it is valid, returns error otherwise. // UseNonce consumes the given nonce if it is valid, returns error otherwise.
func (a *Authority) UseNonce(ctx context.Context, nonce string) error { func (a *Authority) UseNonce(ctx context.Context, nonce string) error {
return a.db.DeleteNonce(ctx, nonce) return a.db.DeleteNonce(ctx, Nonce(nonce))
} }
// NewAccount creates, stores, and returns a new ACME account. // NewAccount creates, stores, and returns a new ACME account.
func (a *Authority) NewAccount(ctx context.Context, acc *Account) (*Account, error) { func (a *Authority) NewAccount(ctx context.Context, acc *Account) error {
if err := a.db.CreateAccount(ctx, acc); err != nil { if err := a.db.CreateAccount(ctx, acc); err != nil {
return ServerInternalErr(err) return ErrorWrap(ErrorServerInternalType, err, "newAccount: error creating account")
} }
return a, nil return nil
} }
// UpdateAccount updates an ACME account. // UpdateAccount updates an ACME account.
@ -153,8 +136,8 @@ func (a *Authority) UpdateAccount(ctx context.Context, acc *Account) (*Account,
acc.Contact = auo.Contact acc.Contact = auo.Contact
acc.Status = auo.Status acc.Status = auo.Status
*/ */
if err = a.db.UpdateAccount(ctx, acc); err != nil { if err := a.db.UpdateAccount(ctx, acc); err != nil {
return ServerInternalErr(err) return nil, ErrorWrap(ErrorServerInternalType, err, "authority.UpdateAccount - database update failed"
} }
return acc, nil return acc, nil
} }
@ -164,17 +147,9 @@ func (a *Authority) GetAccount(ctx context.Context, id string) (*Account, error)
return a.db.GetAccount(ctx, id) return a.db.GetAccount(ctx, id)
} }
func keyToID(jwk *jose.JSONWebKey) (string, error) {
kid, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", ServerInternalErr(errors.Wrap(err, "error generating jwk thumbprint"))
}
return base64.RawURLEncoding.EncodeToString(kid), nil
}
// GetAccountByKey returns the ACME associated with the jwk id. // GetAccountByKey returns the ACME associated with the jwk id.
func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) { func (a *Authority) GetAccountByKey(ctx context.Context, jwk *jose.JSONWebKey) (*Account, error) {
kid, err := keyToID(jwk) kid, err := KeyToID(jwk)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -200,12 +175,13 @@ func (a *Authority) GetOrder(ctx context.Context, accID, orderID string) (*Order
log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID)
return nil, UnauthorizedErr(errors.New("provisioner does not own order")) return nil, UnauthorizedErr(errors.New("provisioner does not own order"))
} }
if err = a.updateOrderStatus(ctx, o); err != nil { if err = o.UpdateStatus(ctx, a.db); err != nil {
return nil, err return nil, err
} }
return o.toACME(ctx, a.db, a.dir) return o, nil
} }
/*
// GetOrdersByAccount returns the list of order urls owned by the account. // GetOrdersByAccount returns the list of order urls owned by the account.
func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) { func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string, error) {
ordersByAccountMux.Lock() ordersByAccountMux.Lock()
@ -223,6 +199,7 @@ func (a *Authority) GetOrdersByAccount(ctx context.Context, id string) ([]string
} }
return ret, nil return ret, nil
} }
*/
// NewOrder generates, stores, and returns a new ACME order. // NewOrder generates, stores, and returns a new ACME order.
func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) { func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) {
@ -234,7 +211,7 @@ func (a *Authority) NewOrder(ctx context.Context, o *Order) (*Order, error) {
o.Backdate = a.backdate.Duration o.Backdate = a.backdate.Duration
o.ProvisionerID = prov.GetID() o.ProvisionerID = prov.GetID()
if err = db.CreateOrder(ctx, o); err != nil { if err = a.db.CreateOrder(ctx, o); err != nil {
return nil, ServerInternalErr(err) return nil, ServerInternalErr(err)
} }
return o, nil return o, nil
@ -258,8 +235,7 @@ func (a *Authority) FinalizeOrder(ctx context.Context, accID, orderID string, cs
log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID) log.Printf("provisioner-id from request ('%s') does not match order provisioner-id ('%s')", prov.GetID(), o.ProvisionerID)
return nil, UnauthorizedErr(errors.New("provisioner does not own order")) return nil, UnauthorizedErr(errors.New("provisioner does not own order"))
} }
o, err = o.Finalize(ctx, a.db, csr, a.signAuth, prov) if err = o.Finalize(ctx, a.db, csr, a.signAuth, prov); err != nil {
if err != nil {
return nil, Wrap(err, "error finalizing order") return nil, Wrap(err, "error finalizing order")
} }
return o, nil return o, nil
@ -276,8 +252,7 @@ func (a *Authority) GetAuthz(ctx context.Context, accID, authzID string) (*Autho
log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID) log.Printf("account-id from request ('%s') does not match authz account-id ('%s')", accID, az.AccountID)
return nil, UnauthorizedErr(errors.New("account does not own authz")) return nil, UnauthorizedErr(errors.New("account does not own authz"))
} }
az, err = az.UpdateStatus(ctx, a.db) if err = az.UpdateStatus(ctx, a.db); err != nil {
if err != nil {
return nil, Wrap(err, "error updating authz status") return nil, Wrap(err, "error updating authz status")
} }
return az, nil return az, nil
@ -313,7 +288,7 @@ func (a *Authority) ValidateChallenge(ctx context.Context, accID, chID string, j
// GetCertificate retrieves the Certificate by ID. // GetCertificate retrieves the Certificate by ID.
func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) { func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([]byte, error) {
cert, err := a.db.GetCertificate(a.db, certID) cert, err := a.db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -321,5 +296,5 @@ func (a *Authority) GetCertificate(ctx context.Context, accID, certID string) ([
log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID) log.Printf("account-id from request ('%s') does not match challenge account-id ('%s')", accID, cert.AccountID)
return nil, UnauthorizedErr(errors.New("account does not own challenge")) return nil, UnauthorizedErr(errors.New("account does not own challenge"))
} }
return cert.toACME(a.db, a.dir) return cert.ToACME(ctx)
} }

View file

@ -11,7 +11,7 @@ import (
// Authorization representst an ACME Authorization. // Authorization representst an ACME Authorization.
type Authorization struct { type Authorization struct {
Identifier *Identifier `json:"identifier"` Identifier *Identifier `json:"identifier"`
Status string `json:"status"` Status Status `json:"status"`
Expires string `json:"expires"` Expires string `json:"expires"`
Challenges []*Challenge `json:"challenges"` Challenges []*Challenge `json:"challenges"`
Wildcard bool `json:"wildcard"` Wildcard bool `json:"wildcard"`
@ -34,7 +34,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
now := time.Now().UTC() now := time.Now().UTC()
expiry, err := time.Parse(time.RFC3339, az.Expires) expiry, err := time.Parse(time.RFC3339, az.Expires)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrap("error converting expiry string to time")) return ServerInternalErr(errors.Wrap(err, "error converting expiry string to time"))
} }
switch az.Status { switch az.Status {
@ -46,16 +46,11 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
// check expiry // check expiry
if now.After(expiry) { if now.After(expiry) {
az.Status = StatusInvalid az.Status = StatusInvalid
az.Error = MalformedErr(errors.New("authz has expired"))
break break
} }
var isValid = false var isValid = false
for _, chID := range ba.Challenges { for _, ch := range az.Challenges {
ch, err := db.GetChallenge(ctx, chID, az.ID)
if err != nil {
return ServerInternalErr(err)
}
if ch.Status == StatusValid { if ch.Status == StatusValid {
isValid = true isValid = true
break break
@ -66,10 +61,12 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
return nil return nil
} }
az.Status = StatusValid az.Status = StatusValid
az.Error = nil
default: default:
return nil, ServerInternalErr(errors.Errorf("unrecognized authz status: %s", ba.Status)) return ServerInternalErr(errors.Errorf("unrecognized authorization status: %s", az.Status))
} }
return ServerInternalErr(db.UpdateAuthorization(ctx, az)) if err = db.UpdateAuthorization(ctx, az); err != nil {
return ServerInternalErr(err)
}
return nil
} }

View file

@ -1,10 +1,9 @@
package acme package acme
import ( import (
"context"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"github.com/smallstep/nosql"
) )
// Certificate options with which to create and store a cert object. // Certificate options with which to create and store a cert object.
@ -17,7 +16,7 @@ type Certificate struct {
} }
// ToACME encodes the entire X509 chain into a PEM list. // ToACME encodes the entire X509 chain into a PEM list.
func (cert *Certificate) ToACME(db nosql.DB, dir *directory) ([]byte, error) { func (cert *Certificate) ToACME(ctx context.Context) ([]byte, error) {
var ret []byte var ret []byte
for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) { for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
ret = append(ret, pem.EncodeToMemory(&pem.Block{ ret = append(ret, pem.EncodeToMemory(&pem.Block{

View file

@ -18,14 +18,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/nosql"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
) )
// Challenge represents an ACME response Challenge type. // Challenge represents an ACME response Challenge type.
type Challenge struct { type Challenge struct {
Type string `json:"type"` Type string `json:"type"`
Status string `json:"status"` Status Status `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Validated string `json:"validated,omitempty"` Validated string `json:"validated,omitempty"`
URL string `json:"url"` URL string `json:"url"`
@ -99,7 +98,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
// Update and store the challenge. // Update and store the challenge.
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = clock.Now() ch.Validated = clock.Now().Format(time.RFC3339)
return ServerInternalErr(db.UpdateChallenge(ctx, ch)) return ServerInternalErr(db.UpdateChallenge(ctx, ch))
} }
@ -107,11 +106,11 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
ServerName: tc.Value, ServerName: ch.Value,
InsecureSkipVerify: true, // we expect a self-signed challenge certificate InsecureSkipVerify: true, // we expect a self-signed challenge certificate
} }
hostPort := net.JoinHostPort(tc.Value, "443") hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.tlsDial("tcp", hostPort, config) conn, err := vo.tlsDial("tcp", hostPort, config)
if err != nil { if err != nil {
@ -125,7 +124,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
if len(certs) == 0 { if len(certs) == 0 {
return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("%s "+ return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("%s "+
"challenge for %s resulted in no certificates", tc.Type, tc.Value))) "challenge for %s resulted in no certificates", ch.Type, ch.Value)))
} }
if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" { if !cs.NegotiatedProtocolIsMutual || cs.NegotiatedProtocol != "acme-tls/1" {
@ -135,18 +134,18 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
leafCert := certs[0] leafCert := certs[0]
if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], tc.Value) { if len(leafCert.DNSNames) != 1 || !strings.EqualFold(leafCert.DNSNames[0], ch.Value) {
return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+ return storeError(ctx, ch, db, RejectedIdentifierErr(errors.Errorf("incorrect certificate for tls-alpn-01 challenge: "+
"leaf certificate must contain a single DNS name, %v", tc.Value))) "leaf certificate must contain a single DNS name, %v", ch.Value)))
} }
idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} idPeAcmeIdentifier := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} idPeAcmeIdentifierV1Obsolete := asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
foundIDPeAcmeIdentifierV1Obsolete := false foundIDPeAcmeIdentifierV1Obsolete := false
keyAuth, err := KeyAuthorization(tc.Token, jwk) keyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil { if err != nil {
return nil, err return err
} }
hashedKeyAuth := sha256.Sum256([]byte(keyAuth)) hashedKeyAuth := sha256.Sum256([]byte(keyAuth))
@ -173,9 +172,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = clock.Now() ch.Validated = clock.Now().Format(time.RFC3339)
return ServerInternalErr(db.UpdateChallenge(ctx, ch)) if err = db.UpdateChallenge(ctx, ch); err != nil {
return ServerInternalErr(errors.Wrap(err, "tlsalpn01ValidateChallenge - error updating challenge"))
}
return nil
} }
if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) { if idPeAcmeIdentifierV1Obsolete.Equal(ext.Id) {
@ -192,12 +194,12 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))) "certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")))
} }
func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JSONWebKey, vo validateOptions) error { func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo validateOptions) error {
// Normalize domain for wildcard DNS names // Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like // This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com // _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com // Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(dc.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.lookupTxt("_acme-challenge." + domain) txtRecords, err := vo.lookupTxt("_acme-challenge." + domain)
if err != nil { if err != nil {
@ -205,9 +207,9 @@ func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JS
"records for domain %s", domain))) "records for domain %s", domain)))
} }
expectedKeyAuth, err := KeyAuthorization(dc.Token, jwk) expectedKeyAuth, err := KeyAuthorization(ch.Token, jwk)
if err != nil { if err != nil {
return nil, err return err
} }
h := sha256.Sum256([]byte(expectedKeyAuth)) h := sha256.Sum256([]byte(expectedKeyAuth))
expected := base64.RawURLEncoding.EncodeToString(h[:]) expected := base64.RawURLEncoding.EncodeToString(h[:])
@ -226,7 +228,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db nosql.DB, jwk *jose.JS
// Update and store the challenge. // Update and store the challenge.
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = time.Now().UTC() ch.Validated = clock.Now().UTC().Format(time.RFC3339)
return ServerInternalErr(db.UpdateChallenge(ctx, ch)) return ServerInternalErr(db.UpdateChallenge(ctx, ch))
} }
@ -243,7 +245,7 @@ func KeyAuthorization(token string, jwk *jose.JSONWebKey) (string, error) {
} }
// storeError the given error to an ACME error and saves using the DB interface. // storeError the given error to an ACME error and saves using the DB interface.
func (bc *baseChallenge) storeError(ctx context.Context, ch Challenge, db nosql.DB, err *Error) error { func storeError(ctx context.Context, ch *Challenge, db DB, err *Error) error {
ch.Error = err.ToACME() ch.Error = err.ToACME()
if err := db.UpdateChallenge(ctx, ch); err != nil { if err := db.UpdateChallenge(ctx, ch); err != nil {
return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge")) return ServerInternalErr(errors.Wrap(err, "failure saving error to acme challenge"))

View file

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
) )
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
@ -149,38 +148,6 @@ type SignAuthority interface {
LoadProvisionerByID(string) (provisioner.Interface, error) LoadProvisionerByID(string) (provisioner.Interface, error)
} }
// Identifier encodes the type that an order pertains to.
type Identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
var (
// StatusValid -- valid
StatusValid = "valid"
// StatusInvalid -- invalid
StatusInvalid = "invalid"
// StatusPending -- pending; e.g. an Order that is not ready to be finalized.
StatusPending = "pending"
// StatusDeactivated -- deactivated; e.g. for an Account that is not longer valid.
StatusDeactivated = "deactivated"
// StatusReady -- ready; e.g. for an Order that is ready to be finalized.
StatusReady = "ready"
//statusExpired = "expired"
//statusActive = "active"
//statusProcessing = "processing"
)
var idLen = 32
func randID() (val string, err error) {
val, err = randutil.Alphanumeric(idLen)
if err != nil {
return "", ServerInternalErr(errors.Wrap(err, "error generating random alphanumeric ID"))
}
return val, nil
}
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.
type Clock int type Clock int

View file

@ -4,7 +4,7 @@ import "context"
// DB is the DB interface expected by the step-ca ACME API. // DB is the DB interface expected by the step-ca ACME API.
type DB interface { type DB interface {
CreateAccount(ctx context.Context, acc *Account) (*Account, error) CreateAccount(ctx context.Context, acc *Account) error
GetAccount(ctx context.Context, id string) (*Account, error) GetAccount(ctx context.Context, id string) (*Account, error)
GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error)
UpdateAccount(ctx context.Context, acc *Account) error UpdateAccount(ctx context.Context, acc *Account) error

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql" nosqlDB "github.com/smallstep/nosql"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
) )
@ -17,7 +18,7 @@ type dbAccount struct {
Deactivated time.Time `json:"deactivated"` Deactivated time.Time `json:"deactivated"`
Key *jose.JSONWebKey `json:"key"` Key *jose.JSONWebKey `json:"key"`
Contact []string `json:"contact,omitempty"` Contact []string `json:"contact,omitempty"`
Status string `json:"status"` Status acme.Status `json:"status"`
} }
func (dba *dbAccount) clone() *dbAccount { func (dba *dbAccount) clone() *dbAccount {
@ -26,33 +27,34 @@ func (dba *dbAccount) clone() *dbAccount {
} }
// CreateAccount imlements the AcmeDB.CreateAccount interface. // CreateAccount imlements the AcmeDB.CreateAccount interface.
func (db *DB) CreateAccount(ctx context.Context, acc *Account) error { func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
var err error
acc.ID, err = randID() acc.ID, err = randID()
if err != nil { if err != nil {
return nil, err return err
} }
dba := &dbAccount{ dba := &dbAccount{
ID: acc.ID, ID: acc.ID,
Key: acc.Key, Key: acc.Key,
Contact: acc.Contact, Contact: acc.Contact,
Status: acc.Valid, Status: acc.Status,
Created: clock.Now(), Created: clock.Now(),
} }
kid, err := keyToID(dba.Key) kid, err := acme.KeyToID(dba.Key)
if err != nil { if err != nil {
return err return err
} }
kidB := []byte(kid) kidB := []byte(kid)
// Set the jwkID -> acme account ID index // Set the jwkID -> acme account ID index
_, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(a.ID)) _, swapped, err := db.db.CmpAndSwap(accountByKeyIDTable, kidB, nil, []byte(acc.ID))
switch { switch {
case err != nil: case err != nil:
return ServerInternalErr(errors.Wrap(err, "error setting key-id to account-id index")) return errors.Wrap(err, "error storing keyID to accountID index")
case !swapped: case !swapped:
return ServerInternalErr(errors.Errorf("key-id to account-id index already exists")) return errors.Errorf("key-id to account-id index already exists")
default: default:
if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil { if err = db.save(ctx, acc.ID, dba, nil, "account", accountTable); err != nil {
db.db.Del(accountByKeyIDTable, kidB) db.db.Del(accountByKeyIDTable, kidB)
@ -63,24 +65,24 @@ func (db *DB) CreateAccount(ctx context.Context, acc *Account) error {
} }
// GetAccount retrieves an ACME account by ID. // GetAccount retrieves an ACME account by ID.
func (db *DB) GetAccount(ctx context.Context, id string) (*Account, error) { func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error) {
acc, err := db.getDBAccount(ctx, id) dbacc, err := db.getDBAccount(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Account{ return &acme.Account{
Status: dbacc.Status, Status: dbacc.Status,
Contact: dbacc.Contact, Contact: dbacc.Contact,
Orders: dir.getLink(ctx, OrdersByAccountLink, true, a.ID), Orders: dir.getLink(ctx, OrdersByAccountLink, true, dbacc.ID),
Key: dbacc.Key, Key: dbacc.Key,
ID: dbacc.ID, ID: dbacc.ID,
}, nil }, nil
} }
// GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK). // GetAccountByKeyID retrieves an ACME account by KeyID (thumbprint of the Account Key -- JWK).
func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, error) { func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*acme.Account, error) {
id, err := db.getAccountIDByKeyID(kid) id, err := db.getAccountIDByKeyID(ctx, kid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,9 +90,9 @@ func (db *DB) GetAccountByKeyID(ctx context.Context, kid string) (*Account, erro
} }
// UpdateAccount imlements the AcmeDB.UpdateAccount interface. // UpdateAccount imlements the AcmeDB.UpdateAccount interface.
func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error { func (db *DB) UpdateAccount(ctx context.Context, acc *acme.Account) error {
if len(acc.ID) == 0 { if len(acc.ID) == 0 {
return ServerInternalErr(errors.New("id cannot be empty")) return errors.New("id cannot be empty")
} }
old, err := db.getDBAccount(ctx, acc.ID) old, err := db.getDBAccount(ctx, acc.ID)
@ -99,24 +101,24 @@ func (db *DB) UpdateAccount(ctx context.Context, acc *Account) error {
} }
nu := old.clone() nu := old.clone()
nu.Contact = acc.contact nu.Contact = acc.Contact
nu.Status = acc.Status nu.Status = acc.Status
// If the status has changed to 'deactivated', then set deactivatedAt timestamp. // If the status has changed to 'deactivated', then set deactivatedAt timestamp.
if acc.Status == StatusDeactivated && old.Status != Status.Deactivated { if acc.Status == acme.StatusDeactivated && old.Status != acme.StatusDeactivated {
nu.Deactivated = clock.Now() nu.Deactivated = clock.Now()
} }
return db.save(ctx, old.ID, newdba, dba, "account", accountTable) return db.save(ctx, old.ID, nu, old, "account", accountTable)
} }
func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) { func (db *DB) getAccountIDByKeyID(ctx context.Context, kid string) (string, error) {
id, err := db.db.Get(accountByKeyIDTable, []byte(kid)) id, err := db.db.Get(accountByKeyIDTable, []byte(kid))
if err != nil { if err != nil {
if nosqlDB.IsErrNotFound(err) { if nosqlDB.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "account with key id %s not found", kid)) return "", errors.Wrapf(err, "account with key id %s not found", kid)
} }
return nil, ServerInternalErr(errors.Wrapf(err, "error loading key-account index")) return "", errors.Wrapf(err, "error loading key-account index")
} }
return string(id), nil return string(id), nil
} }
@ -126,14 +128,14 @@ func (db *DB) getDBAccount(ctx context.Context, id string) (*dbAccount, error) {
data, err := db.db.Get(accountTable, []byte(id)) data, err := db.db.Get(accountTable, []byte(id))
if err != nil { if err != nil {
if nosqlDB.IsErrNotFound(err) { if nosqlDB.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "account %s not found", id)) return nil, errors.Wrapf(err, "account %s not found", id)
} }
return nil, ServerInternalErr(errors.Wrapf(err, "error loading account %s", id)) return nil, errors.Wrapf(err, "error loading account %s", id)
} }
dbacc := new(account) dbacc := new(dbAccount)
if err = json.Unmarshal(data, dbacc); err != nil { if err = json.Unmarshal(data, dbacc); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling account")) return nil, errors.Wrap(err, "error unmarshaling account")
} }
return dbacc, nil return dbacc, nil
} }

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -16,13 +17,13 @@ var defaultExpiryDuration = time.Hour * 24
type dbAuthz struct { type dbAuthz struct {
ID string `json:"id"` ID string `json:"id"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
Identifier *Identifier `json:"identifier"` Identifier *acme.Identifier `json:"identifier"`
Status string `json:"status"` Status acme.Status `json:"status"`
Expires time.Time `json:"expires"` Expires time.Time `json:"expires"`
Challenges []string `json:"challenges"` Challenges []string `json:"challenges"`
Wildcard bool `json:"wildcard"` Wildcard bool `json:"wildcard"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Error *Error `json:"error"` Error *acme.Error `json:"error"`
} }
func (ba *dbAuthz) clone() *dbAuthz { func (ba *dbAuthz) clone() *dbAuthz {
@ -35,33 +36,33 @@ func (ba *dbAuthz) clone() *dbAuthz {
func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) { func (db *DB) getDBAuthz(ctx context.Context, id string) (*dbAuthz, error) {
data, err := db.db.Get(authzTable, []byte(id)) data, err := db.db.Get(authzTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "authz %s not found", id)) return nil, errors.Wrapf(err, "authz %s not found", id)
} else if err != nil { } else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading authz %s", id)) return nil, errors.Wrapf(err, "error loading authz %s", id)
} }
var dbaz dbAuthz var dbaz dbAuthz
if err = json.Unmarshal(data, &dbaz); err != nil { if err = json.Unmarshal(data, &dbaz); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling authz type into dbAuthz")) return nil, errors.Wrap(err, "error unmarshaling authz type into dbAuthz")
} }
return &dbaz return &dbaz, nil
} }
// GetAuthorization retrieves and unmarshals an ACME authz type from the database. // GetAuthorization retrieves and unmarshals an ACME authz type from the database.
// Implements acme.DB GetAuthorization interface. // Implements acme.DB GetAuthorization interface.
func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authorization, error) { func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorization, error) {
dbaz, err := getDBAuthz(id) dbaz, err := db.getDBAuthz(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var chs = make([]*Challenge, len(ba.Challenges)) var chs = make([]*acme.Challenge, len(dbaz.Challenges))
for i, chID := range dbaz.Challenges { for i, chID := range dbaz.Challenges {
chs[i], err = db.GetChallenge(ctx, chID) chs[i], err = db.GetChallenge(ctx, chID, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
return &types.Authorization{ return &acme.Authorization{
Identifier: dbaz.Identifier, Identifier: dbaz.Identifier,
Status: dbaz.Status, Status: dbaz.Status,
Challenges: chs, Challenges: chs,
@ -73,23 +74,24 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*types.Authoriza
// CreateAuthorization creates an entry in the database for the Authorization. // CreateAuthorization creates an entry in the database for the Authorization.
// Implements the acme.DB.CreateAuthorization interface. // Implements the acme.DB.CreateAuthorization interface.
func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization) error { func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) error {
if len(az.AccountID) == 0 { if len(az.AccountID) == 0 {
return ServerInternalErr(errors.New("account-id cannot be empty")) return errors.New("account-id cannot be empty")
} }
if az.Identifier == nil { if az.Identifier == nil {
return ServerInternalErr(errors.New("identifier cannot be nil")) return errors.New("identifier cannot be nil")
} }
var err error
az.ID, err = randID() az.ID, err = randID()
if err != nil { if err != nil {
return nil, err return err
} }
now := clock.Now() now := clock.Now()
dbaz := &dbAuthz{ dbaz := &dbAuthz{
ID: az.ID, ID: az.ID,
AccountID: az.AccountID, AccountID: az.AccountID,
Status: types.StatusPending, Status: acme.StatusPending,
Created: now, Created: now,
Expires: now.Add(defaultExpiryDuration), Expires: now.Add(defaultExpiryDuration),
Identifier: az.Identifier, Identifier: az.Identifier,
@ -97,9 +99,9 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization)
if strings.HasPrefix(az.Identifier.Value, "*.") { if strings.HasPrefix(az.Identifier.Value, "*.") {
dbaz.Wildcard = true dbaz.Wildcard = true
dbaz.Identifier = Identifier{ dbaz.Identifier = &acme.Identifier{
Value: strings.TrimPrefix(identifier.Value, "*."), Value: strings.TrimPrefix(az.Identifier.Value, "*."),
Type: identifier.Type, Type: az.Identifier.Type,
} }
} }
@ -111,14 +113,14 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization)
} }
for _, typ := range chTypes { for _, typ := range chTypes {
ch, err := db.CreateChallenge(ctx, &types.Challenge{ ch := &acme.Challenge{
AccountID: az.AccountID, AccountID: az.AccountID,
AuthzID: az.ID, AuthzID: az.ID,
Value: az.Identifier.Value, Value: az.Identifier.Value,
Type: typ, Type: typ,
}) }
if err != nil { if err = db.CreateChallenge(ctx, ch); err != nil {
return nil, Wrapf(err, "error creating '%s' challenge", typ) return errors.Wrapf(err, "error creating challenge")
} }
chIDs = append(chIDs, ch.ID) chIDs = append(chIDs, ch.ID)
@ -129,9 +131,9 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *types.Authorization)
} }
// UpdateAuthorization saves an updated ACME Authorization to the database. // UpdateAuthorization saves an updated ACME Authorization to the database.
func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization) error { func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) error {
if len(az.ID) == 0 { if len(az.ID) == 0 {
return ServerInternalErr(errors.New("id cannot be empty")) return errors.New("id cannot be empty")
} }
old, err := db.getDBAuthz(ctx, az.ID) old, err := db.getDBAuthz(ctx, az.ID)
if err != nil { if err != nil {
@ -141,6 +143,5 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *types.Authorization)
nu := old.clone() nu := old.clone()
nu.Status = az.Status nu.Status = az.Status
nu.Error = az.Error
return db.save(ctx, old.ID, nu, old, "authz", authzTable) return db.save(ctx, old.ID, nu, old, "authz", authzTable)
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -21,25 +22,26 @@ type dbCert struct {
} }
// CreateCertificate creates and stores an ACME certificate type. // CreateCertificate creates and stores an ACME certificate type.
func (db *DB) CreateCertificate(ctx context.Context, cert *Certificate) error { func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
cert.id, err = randID() var err error
cert.ID, err = randID()
if err != nil { if err != nil {
return err return err
} }
leaf := pem.EncodeToMemory(&pem.Block{ leaf := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: ops.Leaf.Raw, Bytes: cert.Leaf.Raw,
}) })
var intermediates []byte var intermediates []byte
for _, cert := range ops.Intermediates { for _, cert := range cert.Intermediates {
intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{ intermediates = append(intermediates, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
})...) })...)
} }
cert := &dbCert{ dbch := &dbCert{
ID: cert.ID, ID: cert.ID,
AccountID: cert.AccountID, AccountID: cert.AccountID,
OrderID: cert.OrderID, OrderID: cert.OrderID,
@ -47,74 +49,80 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *Certificate) error {
Intermediates: intermediates, Intermediates: intermediates,
Created: time.Now().UTC(), Created: time.Now().UTC(),
} }
return db.save(ctx, cert.ID, cert, nil, "certificate", certTable) return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
} }
// GetCertificate retrieves and unmarshals an ACME certificate type from the // GetCertificate retrieves and unmarshals an ACME certificate type from the
// datastore. // datastore.
func (db *DB) GetCertificate(ctx context.Context, id string) (*Certificate, error) { func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
b, err := db.db.Get(certTable, []byte(id)) b, err := db.db.Get(certTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "certificate %s not found", id)) return nil, errors.Wrapf(err, "certificate %s not found", id)
} else if err != nil { } else if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error loading certificate")) return nil, errors.Wrap(err, "error loading certificate")
} }
var dbCert certificate dbC := new(dbCert)
if err := json.Unmarshal(b, &dbCert); err != nil { if err := json.Unmarshal(b, dbC); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling certificate")) return nil, errors.Wrap(err, "error unmarshaling certificate")
} }
leaf, err := parseCert(dbCert.Leaf) leaf, err := parseCert(dbC.Leaf)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf("error parsing leaf of ACME Certificate with ID '%s'", id)) return nil, errors.Wrapf(err, "error parsing leaf of ACME Certificate with ID '%s'", id)
} }
intermediates, err := parseBundle(dbCert.Intermediates) intermediates, err := parseBundle(dbC.Intermediates)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf("error parsing intermediate bundle of ACME Certificate with ID '%s'", id)) return nil, errors.Wrapf(err, "error parsing intermediate bundle of ACME Certificate with ID '%s'", id)
} }
return &Certificate{ return &acme.Certificate{
ID: dbCert.ID, ID: dbC.ID,
AccountID: dbCert.AccountID, AccountID: dbC.AccountID,
OrderID: dbCert.OrderID, OrderID: dbC.OrderID,
Leaf: leaf, Leaf: leaf,
Intermediates: intermediate, Intermediates: intermediates,
} }, nil
} }
func parseCert(b []byte) (*x509.Certificate, error) { func parseCert(b []byte) (*x509.Certificate, error) {
block, rest := pem.Decode(dbCert.Leaf) block, rest := pem.Decode(b)
if block == nil || len(rest) > 0 { if block == nil || len(rest) > 0 {
return nil, errors.New("error decoding PEM block: contains unexpected data") return nil, errors.New("error decoding PEM block: contains unexpected data")
} }
if block.Type != "CERTIFICATE" { if block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding PEM: block is not a certificate bundle") return nil, errors.New("error decoding PEM: block is not a certificate bundle")
} }
var crt *x509.Certificate cert, err := x509.ParseCertificate(block.Bytes)
crt, err = x509.ParseCertificate(block.Bytes) if err != nil {
return nil, errors.Wrap(err, "error parsing x509 certificate")
}
return cert, nil
} }
func parseBundle(b []byte) ([]*x509.Certificate, error) { func parseBundle(b []byte) ([]*x509.Certificate, error) {
var block *pem.Block var (
var bundle []*x509.Certificate err error
block *pem.Block
bundle []*x509.Certificate
)
for len(b) > 0 { for len(b) > 0 {
block, b = pem.Decode(b) block, b = pem.Decode(b)
if block == nil { if block == nil {
break break
} }
if block.Type != "CERTIFICATE" { if block.Type != "CERTIFICATE" {
return nil, errors.Errorf("error decoding PEM: file '%s' is not a certificate bundle", filename) return nil, errors.New("error decoding PEM: data contains block that is not a certificate")
} }
var crt *x509.Certificate var crt *x509.Certificate
crt, err = x509.ParseCertificate(block.Bytes) crt, err = x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", filename) return nil, errors.Wrapf(err, "error parsing x509 certificate")
} }
bundle = append(bundle, crt) bundle = append(bundle, crt)
} }
if len(b) > 0 { if len(b) > 0 {
return nil, errors.Errorf("error decoding PEM: file '%s' contains unexpected data", filename) return nil, errors.New("error decoding PEM: unexpected data")
} }
return bundle, nil return bundle, nil

View file

@ -6,75 +6,69 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
// ChallengeOptions is the type used to created a new Challenge.
type ChallengeOptions struct {
AccountID string
AuthzID string
Identifier Identifier
}
// dbChallenge is the base Challenge type that others build from. // dbChallenge is the base Challenge type that others build from.
type dbChallenge struct { type dbChallenge struct {
ID string `json:"id"` ID string `json:"id"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
AuthzID string `json:"authzID"` AuthzID string `json:"authzID"`
Type string `json:"type"` Type string `json:"type"`
Status string `json:"status"` Status acme.Status `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Value string `json:"value"` Value string `json:"value"`
Validated time.Time `json:"validated"` Validated string `json:"validated"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Error *AError `json:"error"` Error *AError `json:"error"`
} }
func (dbc *dbChallenge) clone() *dbChallenge { func (dbc *dbChallenge) clone() *dbChallenge {
u := *bc u := *dbc
return &u return &u
} }
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
data, err := db.db.Get(challengeTable, []byte(id)) data, err := db.db.Get(challengeTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "challenge %s not found", id)) return nil, errors.Wrapf(err, "challenge %s not found", id)
} else if err != nil { } else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading challenge %s", id)) return nil, errors.Wrapf(err, "error loading challenge %s", id)
} }
dbch := new(baseChallenge) dbch := new(dbChallenge)
if err := json.Unmarshal(data, dbch); err != nil { if err := json.Unmarshal(data, dbch); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling "+ return nil, errors.Wrap(err, "error unmarshaling dbChallenge")
"challenge type into dbChallenge"))
} }
return dbch return dbch, nil
} }
// CreateChallenge creates a new ACME challenge data structure in the database. // CreateChallenge creates a new ACME challenge data structure in the database.
// Implements acme.DB.CreateChallenge interface. // Implements acme.DB.CreateChallenge interface.
func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error { func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
if len(ch.AuthzID) == 0 { if len(ch.AuthzID) == 0 {
return ServerInternalError(errors.New("AuthzID cannot be empty")) return errors.New("AuthzID cannot be empty")
} }
if len(ch.AccountID) == 0 { if len(ch.AccountID) == 0 {
return ServerInternalError(errors.New("AccountID cannot be empty")) return errors.New("AccountID cannot be empty")
} }
if len(ch.Value) == 0 { if len(ch.Value) == 0 {
return ServerInternalError(errors.New("AccountID cannot be empty")) return errors.New("AccountID cannot be empty")
} }
// TODO: verify that challenge type is set and is one of expected types. // TODO: verify that challenge type is set and is one of expected types.
if len(ch.Type) == 0 { if len(ch.Type) == 0 {
return ServerInternalError(errors.New("Type cannot be empty")) return errors.New("Type cannot be empty")
} }
var err error
ch.ID, err = randID() ch.ID, err = randID()
if err != nil { if err != nil {
return nil, Wrap(err, "error generating random id for ACME challenge") return errors.Wrap(err, "error generating random id for ACME challenge")
} }
ch.Token, err = randID() ch.Token, err = randID()
if err != nil { if err != nil {
return nil, Wrap(err, "error generating token for ACME challenge") return errors.Wrap(err, "error generating token for ACME challenge")
} }
dbch := &dbChallenge{ dbch := &dbChallenge{
@ -82,42 +76,40 @@ func (db *DB) CreateChallenge(ctx context.context, ch *types.Challenge) error {
AuthzID: ch.AuthzID, AuthzID: ch.AuthzID,
AccountID: ch.AccountID, AccountID: ch.AccountID,
Value: ch.Value, Value: ch.Value,
Status: types.StatusPending, Status: acme.StatusPending,
Token: ch.Token, Token: ch.Token,
Created: clock.Now(), Created: clock.Now(),
Type: ch.Type, Type: ch.Type,
} }
return dbch.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable) return db.save(ctx, ch.ID, dbch, nil, "challenge", challengeTable)
} }
// GetChallenge retrieves and unmarshals an ACME challenge type from the database. // GetChallenge retrieves and unmarshals an ACME challenge type from the database.
// Implements the acme.DB GetChallenge interface. // Implements the acme.DB GetChallenge interface.
func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*types.Challenge, error) { func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Challenge, error) {
dbch, err := db.getDBChallenge(ctx, id) dbch, err := db.getDBChallenge(ctx, id)
if err != nil { if err != nil {
return err return nil, err
} }
ch := &Challenge{ ch := &acme.Challenge{
Type: dbch.Type, Type: dbch.Type,
Status: dbch.Status, Status: dbch.Status,
Token: dbch.Token, Token: dbch.Token,
URL: dir.getLink(ctx, ChallengeLink, true, dbch.getID()), URL: dir.getLink(ctx, ChallengeLink, true, dbch.ID),
ID: dbch.ID, ID: dbch.ID,
AuthzID: dbch.AuthzID(), AuthzID: dbch.AuthzID,
Error: dbch.Error, Error: dbch.Error,
} Validated: dbch.Validated,
if !dbch.Validated.IsZero() {
ac.Validated = dbch.Validated.Format(time.RFC3339)
} }
return ch, nil return ch, nil
} }
// UpdateChallenge updates an ACME challenge type in the database. // UpdateChallenge updates an ACME challenge type in the database.
func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error { func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
if len(ch.ID) == 0 { if len(ch.ID) == 0 {
return ServerInternalErr(errors.New("id cannot be empty")) return errors.New("id cannot be empty")
} }
old, err := db.getDBChallenge(ctx, ch.ID) old, err := db.getDBChallenge(ctx, ch.ID)
if err != nil { if err != nil {
@ -129,9 +121,7 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *types.Challenge) error {
// These should be the only values chaning in an Update request. // These should be the only values chaning in an Update request.
nu.Status = ch.Status nu.Status = ch.Status
nu.Error = ch.Error nu.Error = ch.Error
if nu.Status == types.StatusValid { nu.Validated = ch.Validated
nu.Validated = clock.Now()
}
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
} }

View file

@ -1,11 +1,13 @@
package nosql package nosql
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql" nosqlDB "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
) )
@ -18,10 +20,10 @@ type dbNonce struct {
// CreateNonce creates, stores, and returns an ACME replay-nonce. // CreateNonce creates, stores, and returns an ACME replay-nonce.
// Implements the acme.DB interface. // Implements the acme.DB interface.
func (db *DB) CreateNonce() (Nonce, error) { func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
_id, err := randID() _id, err := randID()
if err != nil { if err != nil {
return nil, err return "", err
} }
id := base64.RawURLEncoding.EncodeToString([]byte(_id)) id := base64.RawURLEncoding.EncodeToString([]byte(_id))
@ -31,12 +33,12 @@ func (db *DB) CreateNonce() (Nonce, error) {
} }
b, err := json.Marshal(n) b, err := json.Marshal(n)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error marshaling nonce")) return "", errors.Wrap(err, "error marshaling nonce")
} }
if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil { if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil {
return "", err return "", err
} }
return Nonce(id), nil return acme.Nonce(id), nil
} }
// DeleteNonce verifies that the nonce is valid (by checking if it exists), // DeleteNonce verifies that the nonce is valid (by checking if it exists),
@ -59,9 +61,9 @@ func (db *DB) DeleteNonce(nonce string) error {
switch { switch {
case nosqlDB.IsErrNotFound(err): case nosqlDB.IsErrNotFound(err):
return BadNonceErr(nil) return errors.New("not found")
case err != nil: case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error deleting nonce %s", nonce)) return errors.Wrapf(err, "error deleting nonce %s", nonce)
default: default:
return nil return nil
} }

View file

@ -3,9 +3,11 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
nosqlDB "github.com/smallstep/nosql" nosqlDB "github.com/smallstep/nosql"
"go.step.sm/crypto/randutil"
) )
var ( var (
@ -24,13 +26,26 @@ type DB struct {
db nosqlDB.DB db nosqlDB.DB
} }
// New configures and returns a new ACME DB backend implemented using a nosql DB.
func New(db nosqlDB.DB) (*DB, error) {
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",
string(b))
}
}
return &DB{db}, nil
}
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
newB, err := json.Marshal(nu) newB, err := json.Marshal(nu)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrapf(err, return errors.Wrapf(err,
"error marshaling new acme %s", typ)) "error marshaling new acme %s", typ)
} }
var oldB []byte var oldB []byte
if old == nil { if old == nil {
@ -38,19 +53,39 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
} else { } else {
oldB, err = json.Marshal(old) oldB, err = json.Marshal(old)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrapf(err, return errors.Wrapf(err,
"error marshaling old acme %s", typ)) "error marshaling old acme %s", typ)
} }
} }
_, swapped, err := db.CmpAndSwap(table, []byte(id), oldB, newB) _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
switch { switch {
case err != nil: case err != nil:
return ServerInternalErr(errors.Wrapf(err, "error saving acme %s", typ)) return errors.Wrapf(err, "error saving acme %s", typ)
case !swapped: case !swapped:
return ServerInternalErr(errors.Errorf("error saving acme %s; "+ return errors.Errorf("error saving acme %s; "+
"changed since last read", typ)) "changed since last read", typ)
default: default:
return nil return nil
} }
} }
var idLen = 32
func randID() (val string, err error) {
val, err = randutil.Alphanumeric(idLen)
if err != nil {
return "", errors.Wrap(err, "error generating random alphanumeric ID")
}
return val, nil
}
// Clock that returns time in UTC rounded to seconds.
type Clock int
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Round(time.Second)
}
var clock = new(Clock)

View file

@ -22,8 +22,8 @@ type dbOrder struct {
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Created time.Time `json:"created"` Created time.Time `json:"created"`
Expires time.Time `json:"expires,omitempty"` Expires time.Time `json:"expires,omitempty"`
Status string `json:"status"` Status acme.Status `json:"status"`
Identifiers []Identifier `json:"identifiers"` Identifiers []acme.Identifier `json:"identifiers"`
NotBefore time.Time `json:"notBefore,omitempty"` NotBefore time.Time `json:"notBefore,omitempty"`
NotAfter time.Time `json:"notAfter,omitempty"` NotAfter time.Time `json:"notAfter,omitempty"`
Error *Error `json:"error,omitempty"` Error *Error `json:"error,omitempty"`
@ -35,33 +35,33 @@ type dbOrder struct {
func (db *DB) getDBOrder(id string) (*dbOrder, error) { func (db *DB) getDBOrder(id string) (*dbOrder, error) {
b, err := db.db.Get(orderTable, []byte(id)) b, err := db.db.Get(orderTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, MalformedErr(errors.Wrapf(err, "order %s not found", id)) return nil, errors.Wrapf(err, "order %s not found", id)
} else if err != nil { } else if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error loading order %s", id)) return nil, errors.Wrapf(err, "error loading order %s", id)
} }
o := new(dbOrder) o := new(dbOrder)
if err := json.Unmarshal(b, &o); err != nil { if err := json.Unmarshal(b, &o); err != nil {
return nil, ServerInternalErr(errors.Wrap(err, "error unmarshaling order")) return nil, errors.Wrap(err, "error unmarshaling order")
} }
return o, nil return o, nil
} }
// GetOrder retrieves an ACME Order from the database. // GetOrder retrieves an ACME Order from the database.
func (db *DB) GetOrder(id string) (*acme.Order, error) { func (db *DB) GetOrder(ctx context.Context, id string) (*acme.Order, error) {
dbo, err := db.getDBOrder(id) dbo, err := db.getDBOrder(id)
azs := make([]string, len(dbo.Authorizations)) azs := make([]string, len(dbo.Authorizations))
for i, aid := range dbo.Authorizations { for i, aid := range dbo.Authorizations {
azs[i] = dir.getLink(ctx, AuthzLink, true, aid) azs[i] = dir.getLink(ctx, AuthzLink, true, aid)
} }
o := &Order{ o := &acme.Order{
Status: dbo.Status, Status: dbo.Status,
Expires: dbo.Expires.Format(time.RFC3339), Expires: dbo.Expires.Format(time.RFC3339),
Identifiers: dbo.Identifiers, Identifiers: dbo.Identifiers,
NotBefore: dbo.NotBefore.Format(time.RFC3339), NotBefore: dbo.NotBefore.Format(time.RFC3339),
NotAfter: dbo.NotAfter.Format(time.RFC3339), NotAfter: dbo.NotAfter.Format(time.RFC3339),
Authorizations: azs, Authorizations: azs,
Finalize: dir.getLink(ctx, FinalizeLink, true, o.ID), FinalizeURL: dir.getLink(ctx, FinalizeLink, true, o.ID),
ID: dbo.ID, ID: dbo.ID,
ProvisionerID: dbo.ProvisionerID, ProvisionerID: dbo.ProvisionerID,
} }

View file

@ -1,410 +1,324 @@
// Error represents an ACME
package acme package acme
import ( import (
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
// AccountDoesNotExistErr returns a new acme error. // ProblemType is the type of the ACME problem.
func AccountDoesNotExistErr(err error) *Error { type ProblemType int
return &Error{
Type: accountDoesNotExistErr,
Detail: "Account does not exist",
Status: 400,
Err: err,
}
}
// AlreadyRevokedErr returns a new acme error.
func AlreadyRevokedErr(err error) *Error {
return &Error{
Type: alreadyRevokedErr,
Detail: "Certificate already revoked",
Status: 400,
Err: err,
}
}
// BadCSRErr returns a new acme error.
func BadCSRErr(err error) *Error {
return &Error{
Type: badCSRErr,
Detail: "The CSR is unacceptable",
Status: 400,
Err: err,
}
}
// BadNonceErr returns a new acme error.
func BadNonceErr(err error) *Error {
return &Error{
Type: badNonceErr,
Detail: "Unacceptable anti-replay nonce",
Status: 400,
Err: err,
}
}
// BadPublicKeyErr returns a new acme error.
func BadPublicKeyErr(err error) *Error {
return &Error{
Type: badPublicKeyErr,
Detail: "The jws was signed by a public key the server does not support",
Status: 400,
Err: err,
}
}
// BadRevocationReasonErr returns a new acme error.
func BadRevocationReasonErr(err error) *Error {
return &Error{
Type: badRevocationReasonErr,
Detail: "The revocation reason provided is not allowed by the server",
Status: 400,
Err: err,
}
}
// BadSignatureAlgorithmErr returns a new acme error.
func BadSignatureAlgorithmErr(err error) *Error {
return &Error{
Type: badSignatureAlgorithmErr,
Detail: "The JWS was signed with an algorithm the server does not support",
Status: 400,
Err: err,
}
}
// CaaErr returns a new acme error.
func CaaErr(err error) *Error {
return &Error{
Type: caaErr,
Detail: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
Status: 400,
Err: err,
}
}
// CompoundErr returns a new acme error.
func CompoundErr(err error) *Error {
return &Error{
Type: compoundErr,
Detail: "Specific error conditions are indicated in the “subproblems” array",
Status: 400,
Err: err,
}
}
// ConnectionErr returns a new acme error.
func ConnectionErr(err error) *Error {
return &Error{
Type: connectionErr,
Detail: "The server could not connect to validation target",
Status: 400,
Err: err,
}
}
// DNSErr returns a new acme error.
func DNSErr(err error) *Error {
return &Error{
Type: dnsErr,
Detail: "There was a problem with a DNS query during identifier validation",
Status: 400,
Err: err,
}
}
// ExternalAccountRequiredErr returns a new acme error.
func ExternalAccountRequiredErr(err error) *Error {
return &Error{
Type: externalAccountRequiredErr,
Detail: "The request must include a value for the \"externalAccountBinding\" field",
Status: 400,
Err: err,
}
}
// IncorrectResponseErr returns a new acme error.
func IncorrectResponseErr(err error) *Error {
return &Error{
Type: incorrectResponseErr,
Detail: "Response received didn't match the challenge's requirements",
Status: 400,
Err: err,
}
}
// InvalidContactErr returns a new acme error.
func InvalidContactErr(err error) *Error {
return &Error{
Type: invalidContactErr,
Detail: "A contact URL for an account was invalid",
Status: 400,
Err: err,
}
}
// MalformedErr returns a new acme error.
func MalformedErr(err error) *Error {
return &Error{
Type: malformedErr,
Detail: "The request message was malformed",
Status: 400,
Err: err,
}
}
// OrderNotReadyErr returns a new acme error.
func OrderNotReadyErr(err error) *Error {
return &Error{
Type: orderNotReadyErr,
Detail: "The request attempted to finalize an order that is not ready to be finalized",
Status: 400,
Err: err,
}
}
// RateLimitedErr returns a new acme error.
func RateLimitedErr(err error) *Error {
return &Error{
Type: rateLimitedErr,
Detail: "The request exceeds a rate limit",
Status: 400,
Err: err,
}
}
// RejectedIdentifierErr returns a new acme error.
func RejectedIdentifierErr(err error) *Error {
return &Error{
Type: rejectedIdentifierErr,
Detail: "The server will not issue certificates for the identifier",
Status: 400,
Err: err,
}
}
// ServerInternalErr returns a new acme error.
func ServerInternalErr(err error) *Error {
if err == nil {
return nil
}
return &Error{
Type: serverInternalErr,
Detail: "The server experienced an internal error",
Status: 500,
Err: err,
}
}
// NotImplemented returns a new acme error.
func NotImplemented(err error) *Error {
return &Error{
Type: notImplemented,
Detail: "The requested operation is not implemented",
Status: 501,
Err: err,
}
}
// TLSErr returns a new acme error.
func TLSErr(err error) *Error {
return &Error{
Type: tlsErr,
Detail: "The server received a TLS error during validation",
Status: 400,
Err: err,
}
}
// UnauthorizedErr returns a new acme error.
func UnauthorizedErr(err error) *Error {
return &Error{
Type: unauthorizedErr,
Detail: "The client lacks sufficient authorization",
Status: 401,
Err: err,
}
}
// UnsupportedContactErr returns a new acme error.
func UnsupportedContactErr(err error) *Error {
return &Error{
Type: unsupportedContactErr,
Detail: "A contact URL for an account used an unsupported protocol scheme",
Status: 400,
Err: err,
}
}
// UnsupportedIdentifierErr returns a new acme error.
func UnsupportedIdentifierErr(err error) *Error {
return &Error{
Type: unsupportedIdentifierErr,
Detail: "An identifier is of an unsupported type",
Status: 400,
Err: err,
}
}
// UserActionRequiredErr returns a new acme error.
func UserActionRequiredErr(err error) *Error {
return &Error{
Type: userActionRequiredErr,
Detail: "Visit the “instance” URL and take actions specified there",
Status: 400,
Err: err,
}
}
// ProbType is the type of the ACME problem.
type ProbType int
const ( const (
// The request specified an account that does not exist // The request specified an account that does not exist
accountDoesNotExistErr ProbType = iota ErrorAccountDoesNotExistType ProblemType = iota
// The request specified a certificate to be revoked that has already been revoked // The request specified a certificate to be revoked that has already been revoked
alreadyRevokedErr ErrorAlreadyRevokedType
// The CSR is unacceptable (e.g., due to a short key) // The CSR is unacceptable (e.g., due to a short key)
badCSRErr ErrorBadCSRType
// The client sent an unacceptable anti-replay nonce // The client sent an unacceptable anti-replay nonce
badNonceErr ErrorBadNonceType
// The JWS was signed by a public key the server does not support // The JWS was signed by a public key the server does not support
badPublicKeyErr ErrorBadPublicKeyType
// The revocation reason provided is not allowed by the server // The revocation reason provided is not allowed by the server
badRevocationReasonErr ErrorBadRevocationReasonType
// The JWS was signed with an algorithm the server does not support // The JWS was signed with an algorithm the server does not support
badSignatureAlgorithmErr ErrorBadSignatureAlgorithmType
// Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate // Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate
caaErr ErrorCaaType
// Specific error conditions are indicated in the “subproblems” array. // Specific error conditions are indicated in the “subproblems” array.
compoundErr ErrorCompoundType
// The server could not connect to validation target // The server could not connect to validation target
connectionErr ErrorConnectionType
// There was a problem with a DNS query during identifier validation // There was a problem with a DNS query during identifier validation
dnsErr ErrorDNSType
// The request must include a value for the “externalAccountBinding” field // The request must include a value for the “externalAccountBinding” field
externalAccountRequiredErr ErrorExternalAccountRequiredType
// Response received didnt match the challenges requirements // Response received didnt match the challenges requirements
incorrectResponseErr ErrorIncorrectResponseType
// A contact URL for an account was invalid // A contact URL for an account was invalid
invalidContactErr ErrorInvalidContactType
// The request message was malformed // The request message was malformed
malformedErr ErrorMalformedType
// The request attempted to finalize an order that is not ready to be finalized // The request attempted to finalize an order that is not ready to be finalized
orderNotReadyErr ErrorOrderNotReadyType
// The request exceeds a rate limit // The request exceeds a rate limit
rateLimitedErr ErrorRateLimitedType
// The server will not issue certificates for the identifier // The server will not issue certificates for the identifier
rejectedIdentifierErr ErrorRejectedIdentifierType
// The server experienced an internal error // The server experienced an internal error
serverInternalErr ErrorServerInternalType
// The server received a TLS error during validation // The server received a TLS error during validation
tlsErr ErrorTLSType
// The client lacks sufficient authorization // The client lacks sufficient authorization
unauthorizedErr ErrorUnauthorizedType
// A contact URL for an account used an unsupported protocol scheme // A contact URL for an account used an unsupported protocol scheme
unsupportedContactErr ErrorUnsupportedContactType
// An identifier is of an unsupported type // An identifier is of an unsupported type
unsupportedIdentifierErr ErrorUnsupportedIdentifierType
// Visit the “instance” URL and take actions specified there // Visit the “instance” URL and take actions specified there
userActionRequiredErr ErrorUserActionRequiredType
// The operation is not implemented // The operation is not implemented
notImplemented ErrorNotImplementedType
) )
// String returns the string representation of the acme problem type, // String returns the string representation of the acme problem type,
// fulfilling the Stringer interface. // fulfilling the Stringer interface.
func (ap ProbType) String() string { func (ap ProblemType) String() string {
switch ap { switch ap {
case accountDoesNotExistErr: case ErrorAccountDoesNotExistType:
return "accountDoesNotExist" return "accountDoesNotExist"
case alreadyRevokedErr: case ErrorAlreadyRevokedType:
return "alreadyRevoked" return "alreadyRevoked"
case badCSRErr: case ErrorBadCSRType:
return "badCSR" return "badCSR"
case badNonceErr: case ErrorBadNonceType:
return "badNonce" return "badNonce"
case badPublicKeyErr: case ErrorBadPublicKeyType:
return "badPublicKey" return "badPublicKey"
case badRevocationReasonErr: case ErrorBadRevocationReasonType:
return "badRevocationReason" return "badRevocationReason"
case badSignatureAlgorithmErr: case ErrorBadSignatureAlgorithmType:
return "badSignatureAlgorithm" return "badSignatureAlgorithm"
case caaErr: case ErrorCaaType:
return "caa" return "caa"
case compoundErr: case ErrorCompoundType:
return "compound" return "compound"
case connectionErr: case ErrorConnectionType:
return "connection" return "connection"
case dnsErr: case ErrorDNSType:
return "dns" return "dns"
case externalAccountRequiredErr: case ErrorExternalAccountRequiredType:
return "externalAccountRequired" return "externalAccountRequired"
case incorrectResponseErr: case ErrorInvalidContactType:
return "incorrectResponse" return "incorrectResponse"
case invalidContactErr: case ErrorInvalidContactType:
return "invalidContact" return "invalidContact"
case malformedErr: case ErrorMalformedType:
return "malformed" return "malformed"
case orderNotReadyErr: case ErrorOrderNotReadyType:
return "orderNotReady" return "orderNotReady"
case rateLimitedErr: case ErrorRateLimitedType:
return "rateLimited" return "rateLimited"
case rejectedIdentifierErr: case ErrorRejectedIdentifierType:
return "rejectedIdentifier" return "rejectedIdentifier"
case serverInternalErr: case ErrorServerInternalType:
return "serverInternal" return "serverInternal"
case tlsErr: case ErrorTLSType:
return "tls" return "tls"
case unauthorizedErr: case ErrorUnauthorizedType:
return "unauthorized" return "unauthorized"
case unsupportedContactErr: case ErrorUnsupportedContactType:
return "unsupportedContact" return "unsupportedContact"
case unsupportedIdentifierErr: case ErrorUnsupportedIdentifierType:
return "unsupportedIdentifier" return "unsupportedIdentifier"
case userActionRequiredErr: case ErrorUserActionRequiredType:
return "userActionRequired" return "userActionRequired"
case notImplemented: case ErrorNotImplementedType:
return "notImplemented" return "notImplemented"
default: default:
return "unsupported type" return fmt.Sprintf("unsupported type ACME error type %v", ap)
} }
} }
// Error is an ACME error type complete with problem document. type errorMetadata struct {
type Error struct { details string
Type ProbType status int
Detail string typ string
Err error String string
Status int
Sub []*Error
Identifier *Identifier
} }
// Wrap attempts to wrap the internal error. var (
func Wrap(err error, wrap string) *Error { officialACMEPrefix = "urn:ietf:params:acme:error:"
stepACMEPrefix = "urn:step:acme:error:"
errorServerInternalMetadata = errorMetadata{
ErrorAccountDoesNotExistType: {
typ: officialACMEPrefix + ErrorServerInternalType.String(),
details: "The server experienced an internal error",
status: 500,
},
}
errorMap = [ProblemType]errorMetadata{
ErrorAccountDoesNotExistType: {
typ: officialACMEPrefix + ErrorAccountDoesNotExistType.String(),
details: "Account does not exist",
status: 400,
},
ErrorAlreadyRevokedType: {
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
details: "Certificate already Revoked",
status: 400,
},
ErrorBadCSRType: {
typ: officialACMEPrefix + ErrorBadCSRType.String(),
details: "The CSR is unacceptable",
status: 400,
},
ErrorBadNonceType: {
typ: officialACMEPrefix + ErrorBadNonceType.String(),
details: "Unacceptable anti-replay nonce",
status: 400,
},
ErrorBadPublicKeyType: {
typ: officialACMEPrefix + ErrorBadPublicKeyType.String(),
details: "The jws was signed by a public key the server does not support",
status: 400,
},
ErrorBadRevocationReasonType: {
typ: officialACMEPrefix + ErrorBadRevocationReasonType.String(),
details: "The revocation reason provided is not allowed by the server",
status: 400,
},
ErrorBadSignatureAlgorithmType: {
typ: officialACMEPrefix + ErrorBadSignatureAlgorithmType.String(),
details: "The JWS was signed with an algorithm the server does not support",
status: 400,
},
ErrorCaaType: {
typ: officialACMEPrefix + ErrorCaaType.String(),
details: "Certification Authority Authorization (CAA) records forbid the CA from issuing a certificate",
status: 400,
},
ErrorCompoundType: {
typ: officialACMEPrefix + ErrorCompoundType.String(),
details: "Specific error conditions are indicated in the “subproblems” array",
status: 400,
},
ErrorConnectionType: {
typ: officialACMEPrefix + ErrorConnectionType.String(),
details: "The server could not connect to validation target",
status: 400,
},
ErrorDNSType: {
typ: officialACMEPrefix + ErrorDNSType.String(),
details: "There was a problem with a DNS query during identifier validation",
status: 400,
},
ErrorExternalAccountRequiredType: {
typ: officialACMEPrefix + ErrorExternalAccountRequiredType.String(),
details: "The request must include a value for the \"externalAccountBinding\" field",
status: 400,
},
ErrorIncorrectResponseType: {
typ: officialACMEPrefix + ErrorIncorrectResponseType.String(),
details: "Response received didn't match the challenge's requirements",
status: 400,
},
ErrorInvalidContactType: {
typ: officialACMEPrefix + ErrorInvalidContactType.String(),
details: "A contact URL for an account was invalid",
status: 400,
},
ErrorMalformedType: {
typ: officialACMEPrefix + ErrorMalformedType.String(),
details: "The request message was malformed",
status: 400,
},
ErrorOrderNotReadyType: {
typ: officialACMEPrefix + ErrorOrderNotReadyType.String(),
details: "The request attempted to finalize an order that is not ready to be finalized",
status: 400,
},
ErrorRateLimitedType: {
typ: officialACMEPrefix + ErrorRateLimitedType.String(),
details: "The request exceeds a rate limit",
status: 400,
},
ErrorRejectedIdentifierType: {
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
details: "The server will not issue certificates for the identifier",
status: 400,
},
ErrorNotImplementedType: {
typ: officialACMEPrefix + ErrorRejectedIdentifierType.String(),
details: "The requested operation is not implemented",
status: 501,
},
ErrorTLSType: {
typ: officialACMEPrefix + ErrorTLSType.String(),
details: "The server received a TLS error during validation",
status: 400,
},
ErrorUnauthorizedType: {
typ: officialACMEPrefix + ErrorUnauthorizedType.String(),
details: "The client lacks sufficient authorization",
status: 401,
},
ErrorUnsupportedContactType: {
typ: officialACMEPrefix + ErrorUnsupportedContactType.String(),
details: "A contact URL for an account used an unsupported protocol scheme",
status: 400,
},
ErrorUnsupportedIdentifierType: {
typ: officialACMEPrefix + ErrorUnsupportedIdentifierType.String(),
details: "An identifier is of an unsupported type",
status: 400,
},
ErrorUserActionRequiredType: {
typ: officialACMEPrefix + ErrorUserActionRequiredType.String(),
details: "Visit the “instance” URL and take actions specified there",
status: 400,
},
ErrorServerInternalType: errorServerInternalMetadata,
}
)
// Error represents an ACME
type Error struct {
Type string `json:"type"`
Detail string `json:"detail"`
Subproblems []interface{} `json:"subproblems,omitempty"`
Identifier interface{} `json:"identifier,omitempty"`
Err error `json:"-"`
Status int `json:"-"`
}
func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
meta, ok := errorMetadata[typ]
if !ok {
meta = errorServerInternalMetadata
return &Error{
Type: meta.typ,
Details: meta.details,
Status: meta.Status,
Err: errors.Errorf("unrecognized problemType %v", pt),
}
}
return &Error{
Type: meta.typ,
Details: meta.details,
Status: meta.status,
Err: errors.Errorf(msg, args...),
}
}
// ErrorWrap attempts to wrap the internal error.
func ErrorWrap(typ ProblemType, err error, msg string, args ...interface{}) *Error {
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:
return nil return nil
case *Error: case *Error:
if e.Err == nil { if e.Err == nil {
e.Err = errors.New(wrap + "; " + e.Detail) e.Err = errors.Errorf(msg+"; "+e.Detail, args...)
} else { } else {
e.Err = errors.Wrap(e.Err, wrap) e.Err = errors.Wrapf(e.Err, msg, args...)
} }
return e return e
default: default:
return ServerInternalErr(errors.Wrap(err, wrap)) return NewError(ErrorServerInternalType, msg, args...)
} }
} }
// Error implements the error interface. // StatusCode returns the status code and implements the StatusCoder interface.
func (e *Error) StatusCode() int {
return e.Status
}
// Error allows AError to implement the error interface.
func (e *Error) Error() string { func (e *Error) Error() string {
if e.Err == nil {
return e.Detail return e.Detail
}
return e.Err.Error()
} }
// Cause returns the internal error and implements the Causer interface. // Cause returns the internal error and implements the Causer interface.
@ -414,71 +328,3 @@ func (e *Error) Cause() error {
} }
return e.Err return e.Err
} }
// Official returns true if this error's type is listed in §6.7 of RFC 8555.
// Error types in §6.7 are registered under IETF urn namespace:
//
// "urn:ietf:params:acme:error:"
//
// and should include the namespace as a prefix when appearing as a problem
// document.
//
// RFC 8555 also says:
//
// This list is not exhaustive. The server MAY return errors whose
// "type" field is set to a URI other than those defined above. Servers
// MUST NOT use the ACME URN namespace for errors not listed in the
// appropriate IANA registry (see Section 9.6). Clients SHOULD display
// the "detail" field of all errors.
//
// In this case Official returns `false` so that a different namespace can
// be used.
func (e *Error) Official() bool {
return e.Type != notImplemented
}
// ToACME returns an acme representation of the problem type.
// For official errors, the IETF ACME namespace is prepended to the error type.
// For our own errors, we use an (yet) unregistered smallstep acme namespace.
func (e *Error) ToACME() *AError {
prefix := "urn:step:acme:error"
if e.Official() {
prefix = "urn:ietf:params:acme:error:"
}
ae := &AError{
Type: prefix + e.Type.String(),
Detail: e.Error(),
Status: e.Status,
}
if e.Identifier != nil {
ae.Identifier = *e.Identifier
}
for _, p := range e.Sub {
ae.Subproblems = append(ae.Subproblems, p.ToACME())
}
return ae
}
// StatusCode returns the status code and implements the StatusCode interface.
func (e *Error) StatusCode() int {
return e.Status
}
// AError is the error type as seen in acme request/responses.
type AError struct {
Type string `json:"type"`
Detail string `json:"detail"`
Identifier interface{} `json:"identifier,omitempty"`
Subproblems []interface{} `json:"subproblems,omitempty"`
Status int `json:"-"`
}
// Error allows AError to implement the error interface.
func (ae *AError) Error() string {
return ae.Detail
}
// StatusCode returns the status code and implements the StatusCode interface.
func (ae *AError) StatusCode() int {
return ae.Status
}

View file

@ -13,18 +13,25 @@ import (
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
// Identifier encodes the type that an order pertains to.
type Identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
// Order contains order metadata for the ACME protocol order type. // Order contains order metadata for the ACME protocol order type.
type Order struct { type Order struct {
Status string `json:"status"` Status Status `json:"status"`
Expires string `json:"expires,omitempty"` Expires string `json:"expires,omitempty"`
Identifiers []Identifier `json:"identifiers"` Identifiers []Identifier `json:"identifiers"`
NotBefore string `json:"notBefore,omitempty"` NotBefore string `json:"notBefore,omitempty"`
NotAfter string `json:"notAfter,omitempty"` NotAfter string `json:"notAfter,omitempty"`
Error interface{} `json:"error,omitempty"` Error interface{} `json:"error,omitempty"`
Authorizations []string `json:"authorizations"` Authorizations []string `json:"authorizations"`
Finalize string `json:"finalize"` FinalizeURL string `json:"finalize"`
Certificate string `json:"certificate,omitempty"` Certificate string `json:"certificate,omitempty"`
ID string `json:"-"` ID string `json:"-"`
AccountID string `json:"-"`
ProvisionerID string `json:"-"` ProvisionerID string `json:"-"`
DefaultDuration time.Duration `json:"-"` DefaultDuration time.Duration `json:"-"`
Backdate time.Duration `json:"-"` Backdate time.Duration `json:"-"`
@ -45,7 +52,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
now := time.Now().UTC() now := time.Now().UTC()
expiry, err := time.Parse(time.RFC3339, o.Expires) expiry, err := time.Parse(time.RFC3339, o.Expires)
if err != nil { if err != nil {
return ServerInternalErr(errors.Wrap("error converting expiry string to time")) return ServerInternalErr(errors.Wrap(err, "order.UpdateStatus - error converting expiry string to time"))
} }
switch o.Status { switch o.Status {
@ -69,7 +76,7 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
break break
} }
var count = map[string]int{ var count = map[Status]int{
StatusValid: 0, StatusValid: 0,
StatusInvalid: 0, StatusInvalid: 0,
StatusPending: 0, StatusPending: 0,
@ -77,10 +84,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
for _, azID := range o.Authorizations { for _, azID := range o.Authorizations {
az, err := db.GetAuthorization(ctx, azID) az, err := db.GetAuthorization(ctx, azID)
if err != nil { if err != nil {
return false, err return err
} }
if az, err = az.UpdateStatus(db); err != nil { if err = az.UpdateStatus(ctx, db); err != nil {
return false, err return err
} }
st := az.Status st := az.Status
count[st]++ count[st]++
@ -98,20 +105,19 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
o.Status = StatusReady o.Status = StatusReady
default: default:
return nil, ServerInternalErr(errors.New("unexpected authz status")) return ServerInternalErr(errors.New("unexpected authz status"))
} }
default: default:
return nil, ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status)) return ServerInternalErr(errors.Errorf("unrecognized order status: %s", o.Status))
} }
return db.UpdateOrder(ctx, o) return db.UpdateOrder(ctx, o)
} }
// finalize signs a certificate if the necessary conditions for Order completion // Finalize signs a certificate if the necessary conditions for Order completion
// have been met. // have been met.
func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error { func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateRequest, auth SignAuthority, p Provisioner) error {
var err error if err := o.UpdateStatus(ctx, db); err != nil {
if o, err = o.UpdateStatus(db); err != nil { return err
return nil, err
} }
switch o.Status { switch o.Status {
@ -124,7 +130,7 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth
case StatusReady: case StatusReady:
break break
default: default:
return nil, ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID)) return ServerInternalErr(errors.Errorf("unexpected status %s for order %s", o.Status, o.ID))
} }
// RFC8555: The CSR MUST indicate the exact same set of requested // RFC8555: The CSR MUST indicate the exact same set of requested
@ -135,7 +141,7 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth
if csr.Subject.CommonName != "" { if csr.Subject.CommonName != "" {
csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) csr.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
} }
csr.DNSNames = uniqueLowerNames(csr.DNSNames) csr.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
orderNames := make([]string, len(o.Identifiers)) orderNames := make([]string, len(o.Identifiers))
for i, n := range o.Identifiers { for i, n := range o.Identifiers {
orderNames[i] = n.Value orderNames[i] = n.Value
@ -148,13 +154,13 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth
// absence of other SANs as they will only be set if the templates allows // absence of other SANs as they will only be set if the templates allows
// them. // them.
if len(csr.DNSNames) != len(orderNames) { if len(csr.DNSNames) != len(orderNames) {
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
} }
sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames)) sans := make([]x509util.SubjectAlternativeName, len(csr.DNSNames))
for i := range csr.DNSNames { for i := range csr.DNSNames {
if csr.DNSNames[i] != orderNames[i] { if csr.DNSNames[i] != orderNames[i] {
return nil, BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames)) return BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly: CSR names = %v, Order names = %v", csr.DNSNames, orderNames))
} }
sans[i] = x509util.SubjectAlternativeName{ sans[i] = x509util.SubjectAlternativeName{
Type: x509util.DNSType, Type: x509util.DNSType,
@ -163,10 +169,10 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth
} }
// Get authorizations from the ACME provisioner. // Get authorizations from the ACME provisioner.
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOps, err := p.AuthorizeSign(ctx, "") signOps, err := p.AuthorizeSign(ctx, "")
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner")) return ServerInternalErr(errors.Wrapf(err, "error retrieving authorization options from ACME provisioner"))
} }
// Template data // Template data
@ -176,27 +182,36 @@ func (o *order) Finalize(ctx, db DB, csr *x509.CertificateRequest, auth SignAuth
templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data) templateOptions, err := provisioner.TemplateOptions(p.GetOptions(), data)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner")) return ServerInternalErr(errors.Wrapf(err, "error creating template options from ACME provisioner"))
} }
signOps = append(signOps, templateOptions) signOps = append(signOps, templateOptions)
// Create and store a new certificate. nbf, err := time.Parse(time.RFC3339, o.NotBefore)
certChain, err := auth.Sign(csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(o.NotBefore),
NotAfter: provisioner.NewTimeDuration(o.NotAfter),
}, signOps...)
if err != nil { if err != nil {
return nil, ServerInternalErr(errors.Wrapf(err, "error generating certificate for order %s", o.ID)) return ServerInternalErr(errors.Wrap(err, "error parsing order NotBefore"))
}
naf, err := time.Parse(time.RFC3339, o.NotAfter)
if err != nil {
return ServerInternalErr(errors.Wrap(err, "error parsing order NotAfter"))
} }
cert, err := db.CreateCertificate(ctx, &Certificate{ // Sign a new certificate.
certChain, err := auth.Sign(csr, provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nbf),
NotAfter: provisioner.NewTimeDuration(naf),
}, signOps...)
if err != nil {
return ServerInternalErr(errors.Wrapf(err, "error signing certificate for order %s", o.ID))
}
cert := &Certificate{
AccountID: o.AccountID, AccountID: o.AccountID,
OrderID: o.ID, OrderID: o.ID,
Leaf: certChain[0], Leaf: certChain[0],
Intermediates: certChain[1:], Intermediates: certChain[1:],
}) }
if err != nil { if err := db.CreateCertificate(ctx, cert); err != nil {
return nil, err return err
} }
o.Certificate = cert.ID o.Certificate = cert.ID

View file

@ -56,8 +56,7 @@ func NewContextWithMethod(ctx context.Context, method Method) context.Context {
return context.WithValue(ctx, methodKey{}, method) return context.WithValue(ctx, methodKey{}, method)
} }
// MethodFromContext returns the Method saved in ctx. Returns Sign if the given // MethodFromContext returns the Method saved in ctx.
// context has no Method associated with it.
func MethodFromContext(ctx context.Context) Method { func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method) m, _ := ctx.Value(methodKey{}).(Method)
return m return m