Merge pull request #625 from hslatman/hs/acme-revocation

ACME Certificate Revocation
This commit is contained in:
Herman Slatman 2021-12-09 09:48:02 +01:00 committed by GitHub
commit fbd3fd2145
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 2334 additions and 44 deletions

View file

@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased - 0.18.1] - DATE ## [Unreleased - 0.18.1] - DATE
### Added ### Added
- Support for ACME revocation.
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed

View file

@ -100,11 +100,17 @@ func (h *Handler) Route(r api.Router) {
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory)))
validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))
}
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKid := func(next nextHTTP) nextHTTP { extractPayloadByKid := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next)))
}
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next)))
} }
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount))
@ -117,6 +123,7 @@ func (h *Handler) Route(r api.Router) {
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization)))
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate)))
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response

View file

@ -262,11 +262,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
// Store the JWK in the context. // Store the JWK in the context.
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
// Get Account or continue to generate a new one. // Get Account OR continue to generate a new one OR continue Revoke with certificate private key
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID)
switch { switch {
case errors.Is(err, acme.ErrNotFound): case errors.Is(err, acme.ErrNotFound):
// For NewAccount requests ... // For NewAccount and Revoke requests ...
break break
case err != nil: case err != nil:
api.WriteError(w, err) api.WriteError(w, err)
@ -352,6 +352,42 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
} }
} }
// extractOrLookupJWK forwards handling to either extractJWK or
// lookupJWK based on the presence of a JWK or a KID, respectively.
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
// at this point the JWS has already been verified (if correctly configured in middleware),
// and it can be used to check if a JWK exists. This flow is used when the ACME client
// signed the payload with a certificate private key.
if canExtractJWKFrom(jws) {
h.extractJWK(next)(w, r)
return
}
// default to looking up the JWK based on KeyID. This flow is used when the ACME client
// signed the payload with an account private key.
h.lookupJWK(next)(w, r)
}
}
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
if jws == nil {
return false
}
if len(jws.Signatures) == 0 {
return false
}
return jws.Signatures[0].Protected.JSONWebKey != nil
}
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {

View file

@ -1472,3 +1472,187 @@ func TestHandler_validateJWS(t *testing.T) {
}) })
} }
} }
func Test_canExtractJWKFrom(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
type args struct {
jws *jose.JSONWebSignature
}
tests := []struct {
name string
args args
want bool
}{
{
name: "no-jws",
args: args{
jws: nil,
},
want: false,
},
{
name: "no-signatures",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{},
},
},
want: false,
},
{
name: "no-jwk",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{},
},
},
},
},
want: false,
},
{
name: "ok",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
JSONWebKey: jwk,
},
},
},
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
}
})
}
}
func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account"
type test struct {
db acme.DB
linker Linker
statusCode int
ctx context.Context
err *acme.Error
next func(w http.ResponseWriter, r *http.Request)
}
var tests = map[string]func(t *testing.T) test{
"ok/extract": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
kid, err := jwk.Thumbprint(crypto.SHA256)
assert.FatalError(t, err)
pub := jwk.Public()
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
so := new(jose.SignerOptions)
so.WithHeader("jwk", pub) // JWK for certificate private key flow
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
signed, err := signer.Sign([]byte("foo"))
assert.FatalError(t, err)
raw, err := signed.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
return test{
linker: NewLinker("dns", "acme"),
db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID)
return nil, acme.ErrNotFound
},
},
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
}
},
"ok/lookup": func(t *testing.T) test {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
accID := "accID"
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
so := new(jose.SignerOptions)
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) // KID for account private key flow
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign([]byte("baz"))
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID)
return acc, nil
},
},
ctx: ctx,
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
}
})
}
}

287
acme/api/revoke.go Normal file
View file

@ -0,0 +1,287 @@
package api
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ocsp"
)
type revokePayload struct {
Certificate string `json:"certificate"`
ReasonCode *int `json:"reason,omitempty"`
}
// RevokeCert attempts to revoke a certificate.
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
jws, err := jwsFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
prov, err := provisionerFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
payload, err := payloadFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
var p revokePayload
err = json.Unmarshal(payload.value, &p)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload"))
return
}
certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate)
if err != nil {
// in this case the most likely cause is a client that didn't properly encode the certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property"))
return
}
certToBeRevoked, err := x509.ParseCertificate(certBytes)
if err != nil {
// in this case a client may have encoded something different than a certificate
api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate"))
return
}
serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return
}
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return
}
if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx)
if err != nil {
api.WriteError(w, err)
return
}
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil {
api.WriteError(w, acmeErr)
return
}
} else {
// if account doesn't need to be checked, the JWS should be verified to be signed by the
// private key that belongs to the public key in the certificate to be revoked.
_, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return
}
}
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial)
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return
}
if hasBeenRevokedBefore {
api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked"))
return
}
reasonCode := p.ReasonCode
acmeErr := validateReasonCode(reasonCode)
if acmeErr != nil {
api.WriteError(w, acmeErr)
return
}
// Authorize revocation by ACME provisioner
ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod)
err = prov.AuthorizeRevoke(ctx, "")
if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner"))
return
}
options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options)
if err != nil {
api.WriteError(w, wrapRevokeErr(err))
return
}
logRevoke(w, options)
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index"))
w.Write(nil)
}
// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized
// to revoke the certificate. An Account must always be valid in order to revoke a certificate.
// In case the certificate retrieved from the database belongs to the Account, the Account is
// authorized. If the certificate retrieved from the database doesn't belong to the Account,
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
}
certificateBelongsToAccount := dbCert.AccountID == account.ID
if certificateBelongsToAccount {
return nil // return early
}
// TODO(hs): according to RFC8555: 7.6, a server MUST consider the following accounts authorized
// to revoke a certificate:
//
// o the account that issued the certificate.
// o an account that holds authorizations for all of the identifiers in the certificate.
//
// We currently only support the first case. The second might result in step going OOM when
// large numbers of Authorizations are involved when the current nosql interface is in use.
// We want to protect users from this failure scenario, so that's why it hasn't been added yet.
// This issue is tracked in https://github.com/smallstep/certificates/issues/767
// not authorized; fail closed.
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' is not authorized", account.ID), nil)
}
// wrapRevokeErr is a best effort implementation to transform an error during
// revocation into an ACME error, so that clients can understand the error.
func wrapRevokeErr(err error) *acme.Error {
t := err.Error()
if strings.Contains(t, "is already revoked") {
return acme.NewError(acme.ErrorAlreadyRevokedType, t)
}
return acme.WrapErrorISE(err, "error when revoking certificate")
}
// unauthorizedError returns an ACME error indicating the request was
// not authorized to revoke the certificate.
func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
var acmeErr *acme.Error
if err == nil {
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
} else {
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
}
acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403
switch {
case len(unauthorizedIdentifiers) > 0:
identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too?
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value)
case cert.Subject.String() != "":
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName)
default:
acmeErr.Detail = "No authorization provided"
}
return acmeErr
}
// logRevoke logs successful revocation of certificate
func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) {
if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"serial": ri.Serial,
"reasonCode": ri.ReasonCode,
"reason": ri.Reason,
"passiveOnly": ri.PassiveOnly,
"ACME": ri.ACME,
})
}
}
// validateReasonCode validates the revocation reason
func validateReasonCode(reasonCode *int) *acme.Error {
if reasonCode != nil && ((*reasonCode < ocsp.Unspecified || *reasonCode > ocsp.AACompromise) || *reasonCode == 7) {
return acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds")
}
// NOTE: it's possible to add additional requirements to the reason code:
// The server MAY disallow a subset of reasonCodes from being
// used by the user. If a request contains a disallowed reasonCode,
// then the server MUST reject it with the error type
// "urn:ietf:params:acme:error:badRevocationReason"
// No additional checks have been implemented so far.
return nil
}
// revokeOptions determines the RevokeOptions for the Authority to use in revocation
func revokeOptions(serial string, certToBeRevoked *x509.Certificate, reasonCode *int) *authority.RevokeOptions {
opts := &authority.RevokeOptions{
Serial: serial,
ACME: true,
Crt: certToBeRevoked,
}
if reasonCode != nil { // NOTE: when implementing CRL and/or OCSP, and reason code is missing, CRL entry extension should be omitted
opts.Reason = reason(*reasonCode)
opts.ReasonCode = *reasonCode
}
return opts
}
// reason transforms an integer reason code to a
// textual description of the revocation reason.
func reason(reasonCode int) string {
switch reasonCode {
case ocsp.Unspecified:
return "unspecified reason"
case ocsp.KeyCompromise:
return "key compromised"
case ocsp.CACompromise:
return "ca compromised"
case ocsp.AffiliationChanged:
return "affiliation changed"
case ocsp.Superseded:
return "superseded"
case ocsp.CessationOfOperation:
return "cessation of operation"
case ocsp.CertificateHold:
return "certificate hold"
case ocsp.RemoveFromCRL:
return "remove from crl"
case ocsp.PrivilegeWithdrawn:
return "privilege withdrawn"
case ocsp.AACompromise:
return "aa compromised"
default:
return "unspecified reason"
}
}
// shouldCheckAccountFrom indicates whether an account should be
// retrieved from the context, so that it can be used for
// additional checks. This should only be done when no JWK
// can be extracted from the request, as that would indicate
// that the revocation request was signed with a certificate
// key pair (and not an account key pair). Looking up such
// a JWK would result in no Account being found.
func shouldCheckAccountFrom(jws *jose.JSONWebSignature) bool {
return !canExtractJWKFrom(jws)
}

1316
acme/api/revoke_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -26,8 +26,11 @@ import (
type ChallengeType string type ChallengeType string
const ( const (
HTTP01 ChallengeType = "http-01" // HTTP01 is the http-01 ACME challenge type
DNS01 ChallengeType = "dns-01" HTTP01 ChallengeType = "http-01"
// DNS01 is the dns-01 ACME challenge type
DNS01 ChallengeType = "dns-01"
// TLSALPN01 is the tls-alpn-01 ACME challenge type
TLSALPN01 ChallengeType = "tls-alpn-01" TLSALPN01 ChallengeType = "tls-alpn-01"
) )

View file

@ -5,12 +5,15 @@ import (
"crypto/x509" "crypto/x509"
"time" "time"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
// CertificateAuthority is the interface implemented by a CA authority. // CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface { type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error)
} }
@ -28,6 +31,7 @@ var clock Clock
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error
GetID() string GetID() string
GetName() string GetName() string
DefaultTLSCertDuration() time.Duration DefaultTLSCertDuration() time.Duration
@ -41,6 +45,7 @@ type MockProvisioner struct {
MgetID func() string MgetID func() string
MgetName func() string MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options MgetOptions func() *provisioner.Options
} }
@ -61,6 +66,14 @@ func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]prov
return m.Mret1.([]provisioner.SignOption), m.Merr return m.Mret1.([]provisioner.SignOption), m.Merr
} }
// AuthorizeRevoke mock
func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error {
if m.MauthorizeRevoke != nil {
return m.MauthorizeRevoke(ctx, token)
}
return m.Merr
}
// DefaultTLSCertDuration mock // DefaultTLSCertDuration mock
func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration {
if m.MdefaultTLSCertDuration != nil { if m.MdefaultTLSCertDuration != nil {

View file

@ -25,9 +25,11 @@ type DB interface {
CreateAuthorization(ctx context.Context, az *Authorization) error CreateAuthorization(ctx context.Context, az *Authorization) error
GetAuthorization(ctx context.Context, id string) (*Authorization, error) GetAuthorization(ctx context.Context, id string) (*Authorization, error)
UpdateAuthorization(ctx context.Context, az *Authorization) error UpdateAuthorization(ctx context.Context, az *Authorization) error
GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error)
CreateCertificate(ctx context.Context, cert *Certificate) error CreateCertificate(ctx context.Context, cert *Certificate) error
GetCertificate(ctx context.Context, id string) (*Certificate, error) GetCertificate(ctx context.Context, id string) (*Certificate, error)
GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error)
CreateChallenge(ctx context.Context, ch *Challenge) error CreateChallenge(ctx context.Context, ch *Challenge) error
GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error) GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error)
@ -50,12 +52,14 @@ type MockDB struct {
MockCreateNonce func(ctx context.Context) (Nonce, error) MockCreateNonce func(ctx context.Context) (Nonce, error)
MockDeleteNonce func(ctx context.Context, nonce Nonce) error MockDeleteNonce func(ctx context.Context, nonce Nonce) error
MockCreateAuthorization func(ctx context.Context, az *Authorization) error MockCreateAuthorization func(ctx context.Context, az *Authorization) error
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error)
MockCreateCertificate func(ctx context.Context, cert *Certificate) error MockCreateCertificate func(ctx context.Context, cert *Certificate) error
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
MockGetCertificateBySerial func(ctx context.Context, serial string) (*Certificate, error)
MockCreateChallenge func(ctx context.Context, ch *Challenge) error MockCreateChallenge func(ctx context.Context, ch *Challenge) error
MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error)
@ -160,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err
return m.MockError return m.MockError
} }
// GetAuthorizationsByAccountID mock
func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) {
if m.MockGetAuthorizationsByAccountID != nil {
return m.MockGetAuthorizationsByAccountID(ctx, accountID)
} else if m.MockError != nil {
return nil, m.MockError
}
return nil, m.MockError
}
// CreateCertificate mock // CreateCertificate mock
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
if m.MockCreateCertificate != nil { if m.MockCreateCertificate != nil {
@ -180,6 +194,16 @@ func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, e
return m.MockRet1.(*Certificate), m.MockError return m.MockRet1.(*Certificate), m.MockError
} }
// GetCertificateBySerial mock
func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) {
if m.MockGetCertificateBySerial != nil {
return m.MockGetCertificateBySerial(ctx, serial)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*Certificate), m.MockError
}
// CreateChallenge mock // CreateChallenge mock
func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error {
if m.MockCreateChallenge != nil { if m.MockCreateChallenge != nil {

View file

@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
nu.Error = az.Error nu.Error = az.Error
return db.save(ctx, old.ID, nu, old, "authz", authzTable) return db.save(ctx, old.ID, nu, old, "authz", authzTable)
} }
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
entries, err := db.db.List(authzTable)
if err != nil {
return nil, errors.Wrapf(err, "error listing authz")
}
authzs := []*acme.Authorization{}
for _, entry := range entries {
dbaz := new(dbAuthz)
if err = json.Unmarshal(entry.Value, dbaz); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key))
}
// Filter out all dbAuthzs that don't belong to the accountID. This
// could be made more efficient with additional data structures mapping the
// Account ID to authorizations. Not trivial to do, though.
if dbaz.AccountID != accountID {
continue
}
authzs = append(authzs, &acme.Authorization{
ID: dbaz.ID,
AccountID: dbaz.AccountID,
Identifier: dbaz.Identifier,
Status: dbaz.Status,
Challenges: nil, // challenges not required for current use case
Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt,
Token: dbaz.Token,
Error: dbaz.Error,
})
}
return authzs, nil
}

View file

@ -3,9 +3,11 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) {
}) })
} }
} }
func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
azID := "azID"
accountID := "accountID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
authzs []*acme.Authorization
}
var tests = map[string]func(t *testing.T) test{
"fail/db.List-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return nil, errors.New("force")
},
},
err: errors.New("error listing authz: force"),
}
},
"fail/unmarshal": func(t *testing.T) test {
b := []byte(`{malformed}`)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: nil,
err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID),
}
},
"ok": func(t *testing.T) test {
now := clock.Now()
dbaz := &dbAuthz{
ID: azID,
AccountID: accountID,
Identifier: acme.Identifier{
Type: "dns",
Value: "test.ca.smallstep.com",
},
Status: acme.StatusValid,
Token: "token",
CreatedAt: now,
ExpiresAt: now.Add(5 * time.Minute),
ChallengeIDs: []string{"foo", "bar"},
Wildcard: true,
}
b, err := json.Marshal(dbaz)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: []*acme.Authorization{
{
ID: dbaz.ID,
AccountID: dbaz.AccountID,
Token: dbaz.Token,
Identifier: dbaz.Identifier,
Status: dbaz.Status,
Challenges: nil,
Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt,
Error: dbaz.Error,
},
},
}
},
"ok/skip-different-account": func(t *testing.T) test {
now := clock.Now()
dbaz := &dbAuthz{
ID: azID,
AccountID: "differentAccountID",
Identifier: acme.Identifier{
Type: "dns",
Value: "test.ca.smallstep.com",
},
Status: acme.StatusValid,
Token: "token",
CreatedAt: now,
ExpiresAt: now.Add(5 * time.Minute),
ChallengeIDs: []string{"foo", "bar"},
Wildcard: true,
}
b, err := json.Marshal(dbaz)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: []*acme.Authorization{},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
if !cmp.Equal(azs, tc.authzs) {
t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs))
}
}
})
}
}

View file

@ -21,6 +21,11 @@ type dbCert struct {
Intermediates []byte `json:"intermediates"` Intermediates []byte `json:"intermediates"`
} }
type dbSerial struct {
Serial string `json:"serial"`
CertificateID string `json:"certificateID"`
}
// CreateCertificate creates and stores an ACME certificate type. // CreateCertificate creates and stores an ACME certificate type.
func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error { func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error {
var err error var err error
@ -49,7 +54,17 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
Intermediates: intermediates, Intermediates: intermediates,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) err = db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
if err != nil {
return err
}
serial := cert.Leaf.SerialNumber.String()
dbSerial := &dbSerial{
Serial: serial,
CertificateID: cert.ID,
}
return db.save(ctx, serial, dbSerial, nil, "serial", certBySerialTable)
} }
// GetCertificate retrieves and unmarshals an ACME certificate type from the // GetCertificate retrieves and unmarshals an ACME certificate type from the
@ -80,6 +95,24 @@ func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate,
}, nil }, nil
} }
// GetCertificateBySerial retrieves and unmarshals an ACME certificate type from the
// datastore based on a certificate serial number.
func (db *DB) GetCertificateBySerial(ctx context.Context, serial string) (*acme.Certificate, error) {
b, err := db.db.Get(certBySerialTable, []byte(serial))
if nosql.IsErrNotFound(err) {
return nil, acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading certificate ID for serial %s", serial)
}
dbSerial := new(dbSerial)
if err := json.Unmarshal(b, dbSerial); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling certificate with serial %s", serial)
}
return db.GetCertificate(ctx, dbSerial.CertificateID)
}
func parseBundle(b []byte) ([]*x509.Certificate, error) { func parseBundle(b []byte) ([]*x509.Certificate, error) {
var ( var (
err error err error

View file

@ -1,10 +1,12 @@
package nosql package nosql
import ( import (
"bytes"
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt"
"testing" "testing"
"time" "time"
@ -14,7 +16,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database" nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -75,18 +76,36 @@ func TestDB_CreateCertificate(t *testing.T) {
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
*idPtr = string(key) if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) {
assert.Equals(t, bucket, certTable) t.Fail()
assert.Equals(t, key, []byte(cert.ID)) }
assert.Equals(t, old, nil) if bytes.Equal(bucket, certTable) {
*idPtr = string(key)
assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID))
assert.Equals(t, old, nil)
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
}
if bytes.Equal(bucket, certBySerialTable) {
assert.Equals(t, bucket, certBySerialTable)
assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String()))
assert.Equals(t, old, nil)
dbs := new(dbSerial)
assert.FatalError(t, json.Unmarshal(nu, dbs))
assert.Equals(t, dbs.Serial, string(key))
assert.Equals(t, dbs.CertificateID, cert.ID)
*idPtr = cert.ID
}
dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, true, nil return nil, true, nil
}, },
}, },
@ -317,3 +336,135 @@ func Test_parseBundle(t *testing.T) {
}) })
} }
} }
func TestDB_GetCertificateBySerial(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
certID := "certID"
serial := ""
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, nosqldb.ErrNotFound
}
return nil, errors.New("wrong table")
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial),
}
},
"fail/db-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, errors.New("force")
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error loading certificate ID for serial %s", serial),
}
},
"fail/unmarshal-dbSerial": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return []byte(`{"serial":malformed!}`), nil
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error unmarshaling certificate with serial %s", serial),
}
},
"ok": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
certSerial := dbSerial{
Serial: serial,
CertificateID: certID,
}
b, err := json.Marshal(certSerial)
assert.FatalError(t, err)
return b, nil
}
if bytes.Equal(bucket, certTable) {
cert := dbCert{
ID: certID,
AccountID: "accountID",
OrderID: "orderID",
Leaf: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}),
Intermediates: append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: inter.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})...),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return b, nil
}
return nil, errors.New("wrong table")
},
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
cert, err := d.GetCertificateBySerial(context.Background(), serial)
if err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
})
}
}

View file

@ -19,6 +19,7 @@ var (
orderTable = []byte("acme_orders") orderTable = []byte("acme_orders")
ordersByAccountIDTable = []byte("acme_account_orders_index") ordersByAccountIDTable = []byte("acme_account_orders_index")
certTable = []byte("acme_certs") certTable = []byte("acme_certs")
certBySerialTable = []byte("acme_serial_certs_index")
) )
// DB is a struct that implements the AcmeDB interface. // DB is a struct that implements the AcmeDB interface.
@ -29,7 +30,7 @@ type DB struct {
// New configures and returns a new ACME DB backend implemented using a nosql DB. // New configures and returns a new ACME DB backend implemented using a nosql DB.
func New(db nosqlDB.DB) (*DB, error) { func New(db nosqlDB.DB) (*DB, error) {
tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, tables := [][]byte{accountTable, accountByKeyIDTable, authzTable,
challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable} challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable, certBySerialTable}
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s", return nil, errors.Wrapf(err, "error creating table %s",

View file

@ -147,7 +147,7 @@ var (
}, },
ErrorAlreadyRevokedType: { ErrorAlreadyRevokedType: {
typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(), typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(),
details: "Certificate already Revoked", details: "Certificate already revoked",
status: 400, status: 400,
}, },
ErrorBadCSRType: { ErrorBadCSRType: {

View file

@ -17,7 +17,9 @@ import (
type IdentifierType string type IdentifierType string
const ( const (
IP IdentifierType = "ip" // IP is the ACME ip identifier type
IP IdentifierType = "ip"
// DNS is the ACME dns identifier type
DNS IdentifierType = "dns" DNS IdentifierType = "dns"
) )
@ -288,6 +290,9 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
// MUST appear either in the commonName portion of the requested subject // MUST appear either in the commonName portion of the requested subject
// name or in an extensionRequest attribute [RFC2985] requesting a // name or in an extensionRequest attribute [RFC2985] requesting a
// subjectAltName extension, or both. // subjectAltName extension, or both.
// TODO(hs): we might want to check if the CommonName is in fact a DNS (and cannot
// be parsed as IP). This is related to https://github.com/smallstep/cli/pull/576
// (ACME IP SANS)
if csr.Subject.CommonName != "" { if csr.Subject.CommonName != "" {
// nolint:gocritic // nolint:gocritic
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)

View file

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -286,6 +287,14 @@ func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface
return m.ret1.(provisioner.Interface), m.err return m.ret1.(provisioner.Interface), m.err
} }
func (m *mockSignAuth) IsRevoked(sn string) (bool, error) {
return false, nil
}
func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error {
return nil
}
func TestOrder_Finalize(t *testing.T) { func TestOrder_Finalize(t *testing.T) {
type test struct { type test struct {
o *Order o *Order

View file

@ -588,6 +588,19 @@ func (a *Authority) CloseForReload() {
} }
} }
// IsRevoked returns whether or not a certificate has been
// revoked before.
func (a *Authority) IsRevoked(sn string) (bool, error) {
// Check the passive revocation table.
if lca, ok := a.adminDB.(interface {
IsRevoked(string) (bool, error)
}); ok {
return lca.IsRevoked(sn)
}
return a.db.IsRevoked(sn)
}
// requiresDecrypter returns whether the Authority // requiresDecrypter returns whether the Authority
// requires a KMS that provides a crypto.Decrypter // requires a KMS that provides a crypto.Decrypter
// Currently this is only required when SCEP is // Currently this is only required when SCEP is

View file

@ -274,19 +274,9 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
// //
// TODO(mariano): should we authorize by default? // TODO(mariano): should we authorize by default?
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
var err error
var isRevoked bool
var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())}
// Check the passive revocation table.
serial := cert.SerialNumber.String() serial := cert.SerialNumber.String()
if lca, ok := a.adminDB.(interface { var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
IsRevoked(string) (bool, error) isRevoked, err := a.IsRevoked(serial)
}); ok {
isRevoked, err = lca.IsRevoked(serial)
} else {
isRevoked, err = a.db.IsRevoked(serial)
}
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
} }

View file

@ -99,6 +99,15 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
}, nil }, nil
} }
// AuthorizeRevoke is called just before the certificate is to be revoked by
// the CA. It can be used to authorize revocation of a certificate. It
// currently is a no-op.
// TODO(hs): add configuration option that toggles revocation? Or change function signature to make it more useful?
// Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx.
func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
return nil
}
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
// NOTE: This method does not actually validate the certificate or check it's // NOTE: This method does not actually validate the certificate or check it's
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the

View file

@ -184,7 +184,6 @@ func TestUnimplementedMethods(t *testing.T) {
{"x5c/sshRenew", &X5C{}, SSHRenewMethod}, {"x5c/sshRenew", &X5C{}, SSHRenewMethod},
{"x5c/sshRekey", &X5C{}, SSHRekeyMethod}, {"x5c/sshRekey", &X5C{}, SSHRekeyMethod},
{"x5c/sshRevoke", &X5C{}, SSHRekeyMethod}, {"x5c/sshRevoke", &X5C{}, SSHRekeyMethod},
{"acme/revoke", &ACME{}, RevokeMethod},
{"acme/sshSign", &ACME{}, SSHSignMethod}, {"acme/sshSign", &ACME{}, SSHSignMethod},
{"acme/sshRekey", &ACME{}, SSHRekeyMethod}, {"acme/sshRekey", &ACME{}, SSHRekeyMethod},
{"acme/sshRenew", &ACME{}, SSHRenewMethod}, {"acme/sshRenew", &ACME{}, SSHRenewMethod},

View file

@ -343,6 +343,7 @@ type RevokeOptions struct {
ReasonCode int ReasonCode int
PassiveOnly bool PassiveOnly bool
MTLS bool MTLS bool
ACME bool
Crt *x509.Certificate Crt *x509.Certificate
OTT string OTT string
} }
@ -360,9 +361,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
errs.WithKeyVal("reason", revokeOpts.Reason), errs.WithKeyVal("reason", revokeOpts.Reason),
errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly), errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly),
errs.WithKeyVal("MTLS", revokeOpts.MTLS), errs.WithKeyVal("MTLS", revokeOpts.MTLS),
errs.WithKeyVal("ACME", revokeOpts.ACME),
errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()), errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()),
} }
if revokeOpts.MTLS { if revokeOpts.MTLS || revokeOpts.ACME {
opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw))) opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw)))
} else { } else {
opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT)) opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT))
@ -373,6 +375,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
ReasonCode: revokeOpts.ReasonCode, ReasonCode: revokeOpts.ReasonCode,
Reason: revokeOpts.Reason, Reason: revokeOpts.Reason,
MTLS: revokeOpts.MTLS, MTLS: revokeOpts.MTLS,
ACME: revokeOpts.ACME,
RevokedAt: time.Now().UTC(), RevokedAt: time.Now().UTC(),
} }
@ -380,8 +383,8 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
p provisioner.Interface p provisioner.Interface
err error err error
) )
// If not mTLS then get the TokenID of the token. // If not mTLS nor ACME, then get the TokenID of the token.
if !revokeOpts.MTLS { if !(revokeOpts.MTLS || revokeOpts.ACME) {
token, err := jose.ParseSigned(revokeOpts.OTT) token, err := jose.ParseSigned(revokeOpts.OTT)
if err != nil { if err != nil {
return errs.Wrap(http.StatusUnauthorized, err, return errs.Wrap(http.StatusUnauthorized, err,

View file

@ -1267,6 +1267,23 @@ func TestAuthority_Revoke(t *testing.T) {
}, },
} }
}, },
"ok/ACME": func() test {
_a := testAuthority(t, WithDatabase(&db.MockAuthDB{}))
crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt")
assert.FatalError(t, err)
return test{
auth: _a,
opts: &RevokeOptions{
Crt: crt,
Serial: "102012593071130646873265215610956555026",
ReasonCode: reasonCode,
Reason: reason,
ACME: true,
},
}
},
} }
for name, f := range tests { for name, f := range tests {
tc := f() tc := f()

View file

@ -442,7 +442,7 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
return tlsConfig, nil return tlsConfig, nil
} }
// shouldMountSCEPEndpoints returns if the CA should be // shouldServeSCEPEndpoints returns if the CA should be
// configured with endpoints for SCEP. This is assumed to be // configured with endpoints for SCEP. This is assumed to be
// true if a SCEPService exists, which is true in case a // true if a SCEPService exists, which is true in case a
// SCEP provisioner was configured. // SCEP provisioner was configured.

View file

@ -104,6 +104,7 @@ type RevokedCertificateInfo struct {
RevokedAt time.Time RevokedAt time.Time
TokenID string TokenID string
MTLS bool MTLS bool
ACME bool
} }
// IsRevoked returns whether or not a certificate with the given identifier // IsRevoked returns whether or not a certificate with the given identifier

1
go.mod
View file

@ -18,6 +18,7 @@ require (
github.com/go-kit/kit v0.10.0 // indirect github.com/go-kit/kit v0.10.0 // indirect
github.com/go-piv/piv-go v1.7.0 github.com/go-piv/piv-go v1.7.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/googleapis/gax-go/v2 v2.0.5 github.com/googleapis/gax-go/v2 v2.0.5
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect