diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e9d0bc4..b41a90e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased - 0.18.1] - DATE ### Added +- Support for ACME revocation. ### Changed ### Deprecated ### Removed diff --git a/README.md b/README.md index 65116b38..5c29ccdf 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ To get up and running quickly, or as an alternative to running your own `step-ca [![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers) [![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs) +![star us](https://github.com/smallstep/certificates/raw/master/docs/images/star.gif) + ## Features ### 🦾 A fast, stable, flexible private CA diff --git a/acme/api/handler.go b/acme/api/handler.go index 394986e1..09ca03a3 100644 --- a/acme/api/handler.go +++ b/acme/api/handler.go @@ -100,11 +100,17 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.GetDirectory))) + validatingMiddleware := func(next nextHTTP) nextHTTP { + return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next))))))) + } extractPayloadByJWK := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.extractJWK(h.verifyAndExtractJWSPayload(next))))))))) + return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) } extractPayloadByKid := func(next nextHTTP) nextHTTP { - return h.baseURLFromRequest(h.lookupProvisioner(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(h.lookupJWK(h.verifyAndExtractJWSPayload(next))))))))) + return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) + } + extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { + return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) } r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) @@ -117,6 +123,7 @@ func (h *Handler) Route(r api.Router) { r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) + r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) } // GetNonce just sets the right header since a Nonce is added to each response diff --git a/acme/api/middleware.go b/acme/api/middleware.go index be531ca8..d701f240 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -262,11 +262,11 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP { // Store the JWK in the context. ctx = context.WithValue(ctx, jwkContextKey, jwk) - // Get Account or continue to generate a new one. + // Get Account OR continue to generate a new one OR continue Revoke with certificate private key acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) switch { case errors.Is(err, acme.ErrNotFound): - // For NewAccount requests ... + // For NewAccount and Revoke requests ... break case err != nil: api.WriteError(w, err) @@ -352,6 +352,42 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { } } +// extractOrLookupJWK forwards handling to either extractJWK or +// lookupJWK based on the presence of a JWK or a KID, respectively. +func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + jws, err := jwsFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + + // at this point the JWS has already been verified (if correctly configured in middleware), + // and it can be used to check if a JWK exists. This flow is used when the ACME client + // signed the payload with a certificate private key. + if canExtractJWKFrom(jws) { + h.extractJWK(next)(w, r) + return + } + + // default to looking up the JWK based on KeyID. This flow is used when the ACME client + // signed the payload with an account private key. + h.lookupJWK(next)(w, r) + } +} + +// canExtractJWKFrom checks if the JWS has a JWK that can be extracted +func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { + if jws == nil { + return false + } + if len(jws.Signatures) == 0 { + return false + } + return jws.Signatures[0].Protected.JSONWebKey != nil +} + // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // Make sure to parse and validate the JWS before running this middleware. func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index 9b36d316..050b46a5 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1472,3 +1472,187 @@ func TestHandler_validateJWS(t *testing.T) { }) } } + +func Test_canExtractJWKFrom(t *testing.T) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + type args struct { + jws *jose.JSONWebSignature + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "no-jws", + args: args{ + jws: nil, + }, + want: false, + }, + { + name: "no-signatures", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{}, + }, + }, + want: false, + }, + { + name: "no-jwk", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{}, + }, + }, + }, + }, + want: false, + }, + { + name: "ok", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + JSONWebKey: jwk, + }, + }, + }, + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := canExtractJWKFrom(tt.args.jws); got != tt.want { + t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandler_extractOrLookupJWK(t *testing.T) { + u := "https://ca.smallstep.com/acme/account" + type test struct { + db acme.DB + linker Linker + statusCode int + ctx context.Context + err *acme.Error + next func(w http.ResponseWriter, r *http.Request) + } + var tests = map[string]func(t *testing.T) test{ + "ok/extract": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + kid, err := jwk.Thumbprint(crypto.SHA256) + assert.FatalError(t, err) + pub := jwk.Public() + pub.KeyID = base64.RawURLEncoding.EncodeToString(kid) + so := new(jose.SignerOptions) + so.WithHeader("jwk", pub) // JWK for certificate private key flow + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + signed, err := signer.Sign([]byte("foo")) + assert.FatalError(t, err) + raw, err := signed.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + return test{ + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) + return nil, acme.ErrNotFound + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS), + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + } + }, + "ok/lookup": func(t *testing.T) test { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + accID := "accID" + prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName) + so := new(jose.SignerOptions) + so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) // KID for account private key flow + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + linker: NewLinker("test.ca.smallstep.com", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { + assert.Equals(t, accID, acc.ID) + return acc, nil + }, + }, + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: tc.db, linker: tc.linker} + req := httptest.NewRequest("GET", u, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.extractOrLookupJWK(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +} diff --git a/acme/api/revoke.go b/acme/api/revoke.go new file mode 100644 index 00000000..d01e401c --- /dev/null +++ b/acme/api/revoke.go @@ -0,0 +1,287 @@ +package api + +import ( + "bytes" + "context" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/logging" + "go.step.sm/crypto/jose" + "golang.org/x/crypto/ocsp" +) + +type revokePayload struct { + Certificate string `json:"certificate"` + ReasonCode *int `json:"reason,omitempty"` +} + +// RevokeCert attempts to revoke a certificate. +func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { + + ctx := r.Context() + jws, err := jwsFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + + prov, err := provisionerFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + + payload, err := payloadFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + + var p revokePayload + err = json.Unmarshal(payload.value, &p) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error unmarshaling payload")) + return + } + + certBytes, err := base64.RawURLEncoding.DecodeString(p.Certificate) + if err != nil { + // in this case the most likely cause is a client that didn't properly encode the certificate + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding payload certificate property")) + return + } + + certToBeRevoked, err := x509.ParseCertificate(certBytes) + if err != nil { + // in this case a client may have encoded something different than a certificate + api.WriteError(w, acme.WrapError(acme.ErrorMalformedType, err, "error parsing certificate")) + return + } + + serial := certToBeRevoked.SerialNumber.String() + dbCert, err := h.db.GetCertificateBySerial(ctx, serial) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) + return + } + + if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { + // this should never happen + api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + return + } + + if shouldCheckAccountFrom(jws) { + account, err := accountFromContext(ctx) + if err != nil { + api.WriteError(w, err) + return + } + acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + if acmeErr != nil { + api.WriteError(w, acmeErr) + return + } + } else { + // if account doesn't need to be checked, the JWS should be verified to be signed by the + // private key that belongs to the public key in the certificate to be revoked. + _, err := jws.Verify(certToBeRevoked.PublicKey) + if err != nil { + // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? + api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) + return + } + } + + hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) + return + } + + if hasBeenRevokedBefore { + api.WriteError(w, acme.NewError(acme.ErrorAlreadyRevokedType, "certificate was already revoked")) + return + } + + reasonCode := p.ReasonCode + acmeErr := validateReasonCode(reasonCode) + if acmeErr != nil { + api.WriteError(w, acmeErr) + return + } + + // Authorize revocation by ACME provisioner + ctx = provisioner.NewContextWithMethod(ctx, provisioner.RevokeMethod) + err = prov.AuthorizeRevoke(ctx, "") + if err != nil { + api.WriteError(w, acme.WrapErrorISE(err, "error authorizing revocation on provisioner")) + return + } + + options := revokeOptions(serial, certToBeRevoked, reasonCode) + err = h.ca.Revoke(ctx, options) + if err != nil { + api.WriteError(w, wrapRevokeErr(err)) + return + } + + logRevoke(w, options) + w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) + w.Write(nil) +} + +// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized +// to revoke the certificate. An Account must always be valid in order to revoke a certificate. +// In case the certificate retrieved from the database belongs to the Account, the Account is +// authorized. If the certificate retrieved from the database doesn't belong to the Account, +// the identifiers in the certificate are extracted and compared against the (valid) Authorizations +// that are stored for the ACME Account. If these sets match, the Account is considered authorized +// to revoke the certificate. If this check fails, the client will receive an unauthorized error. +func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { + if !account.IsValid() { + return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) + } + certificateBelongsToAccount := dbCert.AccountID == account.ID + if certificateBelongsToAccount { + return nil // return early + } + + // TODO(hs): according to RFC8555: 7.6, a server MUST consider the following accounts authorized + // to revoke a certificate: + // + // o the account that issued the certificate. + // o an account that holds authorizations for all of the identifiers in the certificate. + // + // We currently only support the first case. The second might result in step going OOM when + // large numbers of Authorizations are involved when the current nosql interface is in use. + // We want to protect users from this failure scenario, so that's why it hasn't been added yet. + // This issue is tracked in https://github.com/smallstep/certificates/issues/767 + + // not authorized; fail closed. + return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' is not authorized", account.ID), nil) +} + +// wrapRevokeErr is a best effort implementation to transform an error during +// revocation into an ACME error, so that clients can understand the error. +func wrapRevokeErr(err error) *acme.Error { + t := err.Error() + if strings.Contains(t, "is already revoked") { + return acme.NewError(acme.ErrorAlreadyRevokedType, t) + } + return acme.WrapErrorISE(err, "error when revoking certificate") +} + +// unauthorizedError returns an ACME error indicating the request was +// not authorized to revoke the certificate. +func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error { + var acmeErr *acme.Error + if err == nil { + acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg) + } else { + acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) + } + acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403 + + switch { + case len(unauthorizedIdentifiers) > 0: + identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too? + acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value) + case cert.Subject.String() != "": + acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName) + default: + acmeErr.Detail = "No authorization provided" + } + + return acmeErr +} + +// logRevoke logs successful revocation of certificate +func logRevoke(w http.ResponseWriter, ri *authority.RevokeOptions) { + if rl, ok := w.(logging.ResponseLogger); ok { + rl.WithFields(map[string]interface{}{ + "serial": ri.Serial, + "reasonCode": ri.ReasonCode, + "reason": ri.Reason, + "passiveOnly": ri.PassiveOnly, + "ACME": ri.ACME, + }) + } +} + +// validateReasonCode validates the revocation reason +func validateReasonCode(reasonCode *int) *acme.Error { + if reasonCode != nil && ((*reasonCode < ocsp.Unspecified || *reasonCode > ocsp.AACompromise) || *reasonCode == 7) { + return acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds") + } + // NOTE: it's possible to add additional requirements to the reason code: + // The server MAY disallow a subset of reasonCodes from being + // used by the user. If a request contains a disallowed reasonCode, + // then the server MUST reject it with the error type + // "urn:ietf:params:acme:error:badRevocationReason" + // No additional checks have been implemented so far. + return nil +} + +// revokeOptions determines the RevokeOptions for the Authority to use in revocation +func revokeOptions(serial string, certToBeRevoked *x509.Certificate, reasonCode *int) *authority.RevokeOptions { + opts := &authority.RevokeOptions{ + Serial: serial, + ACME: true, + Crt: certToBeRevoked, + } + if reasonCode != nil { // NOTE: when implementing CRL and/or OCSP, and reason code is missing, CRL entry extension should be omitted + opts.Reason = reason(*reasonCode) + opts.ReasonCode = *reasonCode + } + return opts +} + +// reason transforms an integer reason code to a +// textual description of the revocation reason. +func reason(reasonCode int) string { + switch reasonCode { + case ocsp.Unspecified: + return "unspecified reason" + case ocsp.KeyCompromise: + return "key compromised" + case ocsp.CACompromise: + return "ca compromised" + case ocsp.AffiliationChanged: + return "affiliation changed" + case ocsp.Superseded: + return "superseded" + case ocsp.CessationOfOperation: + return "cessation of operation" + case ocsp.CertificateHold: + return "certificate hold" + case ocsp.RemoveFromCRL: + return "remove from crl" + case ocsp.PrivilegeWithdrawn: + return "privilege withdrawn" + case ocsp.AACompromise: + return "aa compromised" + default: + return "unspecified reason" + } +} + +// shouldCheckAccountFrom indicates whether an account should be +// retrieved from the context, so that it can be used for +// additional checks. This should only be done when no JWK +// can be extracted from the request, as that would indicate +// that the revocation request was signed with a certificate +// key pair (and not an account key pair). Looking up such +// a JWK would result in no Account being found. +func shouldCheckAccountFrom(jws *jose.JSONWebSignature) bool { + return !canExtractJWKFrom(jws) +} diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go new file mode 100644 index 00000000..4ff54405 --- /dev/null +++ b/acme/api/revoke_test.go @@ -0,0 +1,1316 @@ +package api + +import ( + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" + "github.com/smallstep/assert" + "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/authority/provisioner" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ocsp" +) + +// v is a utility function to return the pointer to an integer +func v(v int) *int { + return &v +} + +func generateSerial() (*big.Int, error) { + return rand.Int(rand.Reader, big.NewInt(1000000000000000000)) +} + +// generateCertKeyPair generates fresh x509 certificate/key pairs for testing +func generateCertKeyPair() (*x509.Certificate, crypto.Signer, error) { + + pub, priv, err := keyutil.GenerateKeyPair("EC", "P-256", 0) + if err != nil { + return nil, nil, err + } + + serial, err := generateSerial() + if err != nil { + return nil, nil, err + } + + now := time.Now() + template := &x509.Certificate{ + Subject: pkix.Name{CommonName: "127.0.0.1"}, + Issuer: pkix.Name{CommonName: "Test ACME Revoke Certificate"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + IsCA: false, + MaxPathLen: 0, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + NotBefore: now, + NotAfter: now.Add(time.Hour), + SerialNumber: serial, + } + + signer, ok := priv.(crypto.Signer) + if !ok { + return nil, nil, errors.Errorf("result is not a crypto.Signer: type %T", priv) + } + + cert, err := x509util.CreateCertificate(template, template, pub, signer) + + return cert, signer, err +} + +var errUnsupportedKey = fmt.Errorf("unknown key type; only RSA and ECDSA are supported") + +// keyID is the account identity provided by a CA during registration. +type keyID string + +// noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID. +// See jwsEncodeJSON for details. +const noKeyID = keyID("") + +// jwsEncodeJSON signs claimset using provided key and a nonce. +// The result is serialized in JSON format containing either kid or jwk +// fields based on the provided keyID value. +// +// If kid is non-empty, its quoted value is inserted in the protected head +// as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted +// as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive. +// +// See https://tools.ietf.org/html/rfc7515#section-7. +// +// If nonce is empty, it will not be encoded into the header. +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid keyID, nonce, u string) ([]byte, error) { + alg, sha := jwsHasher(key.Public()) + if alg == "" || !sha.Available() { + return nil, errUnsupportedKey + } + + phead, err := jwsHead(alg, nonce, u, kid, key) + if err != nil { + return nil, err + } + + var payload string + if claimset != nil { + cs, err := json.Marshal(claimset) + if err != nil { + return nil, err + } + payload = base64.RawURLEncoding.EncodeToString(cs) + } + + payloadToSign := []byte(phead + "." + payload) + hash := sha.New() + _, _ = hash.Write(payloadToSign) + digest := hash.Sum(nil) + + sig, err := jwsSign(key, sha, digest) + if err != nil { + return nil, err + } + + return jwsFinal(sha, sig, phead, payload) +} + +// jwsHasher indicates suitable JWS algorithm name and a hash function +// to use for signing a digest with the provided key. +// It returns ("", 0) if the key is not supported. +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) { + switch pub := pub.(type) { + case *rsa.PublicKey: + return "RS256", crypto.SHA256 + case *ecdsa.PublicKey: + switch pub.Params().Name { + case "P-256": + return "ES256", crypto.SHA256 + case "P-384": + return "ES384", crypto.SHA384 + case "P-521": + return "ES512", crypto.SHA512 + } + } + return "", 0 +} + +// jwsSign signs the digest using the given key. +// The hash is unused for ECDSA keys. +// +// Note: non-stdlib crypto.Signer implementations are expected to return +// the signature in the format as specified in RFC7518. +// See https://tools.ietf.org/html/rfc7518 for more details. +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { + if key, ok := key.(*ecdsa.PrivateKey); ok { + // The key.Sign method of ecdsa returns ASN1-encoded signature. + // So, we use the package Sign function instead + // to get R and S values directly and format the result accordingly. + r, s, err := ecdsa.Sign(rand.Reader, key, digest) + if err != nil { + return nil, err + } + rb, sb := r.Bytes(), s.Bytes() + size := key.Params().BitSize / 8 + if size%8 > 0 { + size++ + } + sig := make([]byte, size*2) + copy(sig[size-len(rb):], rb) + copy(sig[size*2-len(sb):], sb) + return sig, nil + } + return key.Sign(rand.Reader, digest, hash) +} + +// jwsHead constructs the protected JWS header for the given fields. +// Since jwk and kid are mutually-exclusive, the jwk will be encoded +// only if kid is empty. If nonce is empty, it will not be encoded. +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwsHead(alg, nonce, u string, kid keyID, key crypto.Signer) (string, error) { + phead := fmt.Sprintf(`{"alg":%q`, alg) + if kid == noKeyID { + jwk, err := jwkEncode(key.Public()) + if err != nil { + return "", err + } + phead += fmt.Sprintf(`,"jwk":%s`, jwk) + } else { + phead += fmt.Sprintf(`,"kid":%q`, kid) + } + if nonce != "" { + phead += fmt.Sprintf(`,"nonce":%q`, nonce) + } + phead += fmt.Sprintf(`,"url":%q}`, u) + phead = base64.RawURLEncoding.EncodeToString([]byte(phead)) + return phead, nil +} + +// jwkEncode encodes public part of an RSA or ECDSA key into a JWK. +// The result is also suitable for creating a JWK thumbprint. +// https://tools.ietf.org/html/rfc7517 +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwkEncode(pub crypto.PublicKey) (string, error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.3.1 + n := pub.N + e := big.NewInt(int64(pub.E)) + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`, + base64.RawURLEncoding.EncodeToString(e.Bytes()), + base64.RawURLEncoding.EncodeToString(n.Bytes()), + ), nil + case *ecdsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.2.1 + p := pub.Curve.Params() + n := p.BitSize / 8 + if p.BitSize%8 != 0 { + n++ + } + x := pub.X.Bytes() + if n > len(x) { + x = append(make([]byte, n-len(x)), x...) + } + y := pub.Y.Bytes() + if n > len(y) { + y = append(make([]byte, n-len(y)), y...) + } + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`, + p.Name, + base64.RawURLEncoding.EncodeToString(x), + base64.RawURLEncoding.EncodeToString(y), + ), nil + } + return "", errUnsupportedKey +} + +// jwsFinal constructs the final JWS object. +// Implementation taken from github.com/mholt/acmez, which seems to be based on +// https://github.com/golang/crypto/blob/master/acme/jws.go. +func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error) { + enc := struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Sig string `json:"signature"` + }{ + Protected: phead, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(sig), + } + result, err := json.Marshal(&enc) + if err != nil { + return nil, err + } + return result, nil +} + +type mockCA struct { + MockIsRevoked func(sn string) (bool, error) + MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error +} + +func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + return nil, nil +} + +func (m *mockCA) IsRevoked(sn string) (bool, error) { + if m.MockIsRevoked != nil { + return m.MockIsRevoked(sn) + } + return false, nil +} + +func (m *mockCA) Revoke(ctx context.Context, opts *authority.RevokeOptions) error { + if m.MockRevoke != nil { + return m.MockRevoke(ctx, opts) + } + return nil +} + +func (m *mockCA) LoadProvisionerByName(string) (provisioner.Interface, error) { + return nil, nil +} + +func Test_validateReasonCode(t *testing.T) { + tests := []struct { + name string + reasonCode *int + want *acme.Error + }{ + { + name: "ok", + reasonCode: v(ocsp.Unspecified), + want: nil, + }, + { + name: "fail/too-low", + reasonCode: v(-1), + want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), + }, + { + name: "fail/too-high", + reasonCode: v(11), + want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), + }, + { + name: "fail/missing-7", + reasonCode: v(7), + + want: acme.NewError(acme.ErrorBadRevocationReasonType, "reasonCode out of bounds"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateReasonCode(tt.reasonCode) + if (err != nil) != (tt.want != nil) { + t.Errorf("validateReasonCode() = %v, want %v", err, tt.want) + } + if err != nil { + assert.Equals(t, err.Type, tt.want.Type) + assert.Equals(t, err.Detail, tt.want.Detail) + assert.Equals(t, err.Status, tt.want.Status) + assert.Equals(t, err.Err.Error(), tt.want.Err.Error()) + assert.Equals(t, err.Detail, tt.want.Detail) + } + }) + } +} + +func Test_reason(t *testing.T) { + tests := []struct { + name string + reasonCode int + want string + }{ + { + name: "unspecified reason", + reasonCode: ocsp.Unspecified, + want: "unspecified reason", + }, + { + name: "key compromised", + reasonCode: ocsp.KeyCompromise, + want: "key compromised", + }, + { + name: "ca compromised", + reasonCode: ocsp.CACompromise, + want: "ca compromised", + }, + { + name: "affiliation changed", + reasonCode: ocsp.AffiliationChanged, + want: "affiliation changed", + }, + { + name: "superseded", + reasonCode: ocsp.Superseded, + want: "superseded", + }, + { + name: "cessation of operation", + reasonCode: ocsp.CessationOfOperation, + want: "cessation of operation", + }, + { + name: "certificate hold", + reasonCode: ocsp.CertificateHold, + want: "certificate hold", + }, + { + name: "remove from crl", + reasonCode: ocsp.RemoveFromCRL, + want: "remove from crl", + }, + { + name: "privilege withdrawn", + reasonCode: ocsp.PrivilegeWithdrawn, + want: "privilege withdrawn", + }, + { + name: "aa compromised", + reasonCode: ocsp.AACompromise, + want: "aa compromised", + }, + { + name: "default", + reasonCode: -1, + want: "unspecified reason", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := reason(tt.reasonCode); got != tt.want { + t.Errorf("reason() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_revokeOptions(t *testing.T) { + cert, _, err := generateCertKeyPair() + assert.FatalError(t, err) + type args struct { + serial string + certToBeRevoked *x509.Certificate + reasonCode *int + } + tests := []struct { + name string + args args + want *authority.RevokeOptions + }{ + { + name: "ok/no-reasoncode", + args: args{ + serial: "1234", + certToBeRevoked: cert, + }, + want: &authority.RevokeOptions{ + Serial: "1234", + Crt: cert, + ACME: true, + }, + }, + { + name: "ok/including-reasoncode", + args: args{ + serial: "1234", + certToBeRevoked: cert, + reasonCode: v(ocsp.KeyCompromise), + }, + want: &authority.RevokeOptions{ + Serial: "1234", + Crt: cert, + ACME: true, + ReasonCode: 1, + Reason: "key compromised", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := revokeOptions(tt.args.serial, tt.args.certToBeRevoked, tt.args.reasonCode); !cmp.Equal(got, tt.want) { + t.Errorf("revokeOptions() diff =\n%s", cmp.Diff(got, tt.want)) + } + }) + } +} + +func TestHandler_RevokeCert(t *testing.T) { + prov := &provisioner.ACME{ + Type: "ACME", + Name: "testprov", + } + escProvName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + + chiCtx := chi.NewRouteContext() + revokeURL := fmt.Sprintf("%s/acme/%s/revoke-cert", baseURL.String(), escProvName) + + cert, key, err := generateCertKeyPair() + assert.FatalError(t, err) + rp := &revokePayload{ + Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), + } + payloadBytes, err := json.Marshal(rp) + assert.FatalError(t, err) + + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + KeyID: "bar", + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": revokeURL, + }, + }, + }, + }, + } + + type test struct { + db acme.DB + ca acme.CertificateAuthority + ctx context.Context + statusCode int + err *acme.Error + } + + var tests = map[string]func(t *testing.T) test{ + "fail/no-jws": func(t *testing.T) test { + ctx := context.Background() + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("jws expected in request context"), + } + }, + "fail/nil-jws": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("jws expected in request context"), + } + }, + "fail/no-provisioner": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, jws) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/nil-provisioner": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, jws) + ctx = context.WithValue(ctx, provisionerContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("provisioner does not exist"), + } + }, + "fail/no-payload": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, jws) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("payload does not exist"), + } + }, + "fail/nil-payload": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), jwsContextKey, jws) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, nil) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("payload does not exist"), + } + }, + "fail/unmarshal-payload": func(t *testing.T) test { + malformedPayload := []byte(`{"payload":malformed?}`) + ctx := context.WithValue(context.Background(), jwsContextKey, jws) + ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) + return test{ + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("error unmarshaling payload"), + } + }, + "fail/wrong-certificate-encoding": func(t *testing.T) test { + wrongPayload := &revokePayload{ + Certificate: base64.StdEncoding.EncodeToString(cert.Raw), + } + wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + return test{ + ctx: ctx, + statusCode: 400, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:malformed", + Status: 400, + Detail: "The request message was malformed", + }, + } + }, + "fail/no-certificate-encoded": func(t *testing.T) test { + emptyPayload := &revokePayload{ + Certificate: base64.RawURLEncoding.EncodeToString([]byte{}), + } + emptyPayloadBytes, err := json.Marshal(emptyPayload) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + return test{ + ctx: ctx, + statusCode: 400, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:malformed", + Status: 400, + Detail: "The request message was malformed", + }, + } + }, + "fail/db.GetCertificateBySerial": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + return nil, errors.New("force") + }, + } + return test{ + db: db, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("error retrieving certificate by serial"), + } + }, + "fail/different-certificate-contents": func(t *testing.T) test { + aDifferentCert, _, err := generateCertKeyPair() + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + Leaf: aDifferentCert, + }, nil + }, + } + return test{ + db: db, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("certificate raw bytes are not equal"), + } + }, + "fail/no-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + Leaf: cert, + }, nil + }, + } + return test{ + db: db, + ctx: ctx, + statusCode: 400, + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context"), + } + }, + "fail/nil-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, accContextKey, nil) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + Leaf: cert, + }, nil + }, + } + return test{ + db: db, + ctx: ctx, + statusCode: 400, + err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context"), + } + }, + "fail/account-not-valid": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{} + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 403, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Detail: "No authorization provided for name 127.0.0.1", + Status: 403, + }, + } + }, + "fail/account-not-authorized": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "differentAccountID", + Leaf: cert, + }, nil + }, + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: "accountID", + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.1.0", + }, + }, + }, nil + }, + } + ca := &mockCA{} + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 403, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Detail: "No authorization provided for name 127.0.0.1", + Status: 403, + }, + } + }, + "fail/unauthorized-certificate-key": func(t *testing.T) test { + _, unauthorizedKey, err := generateCertKeyPair() + assert.FatalError(t, err) + jwsPayload := &revokePayload{ + Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), + ReasonCode: v(2), + } + jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL) + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(string(jwsBytes)) + assert.FatalError(t, err) + unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{} + acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "verification of jws using certificate public key failed") + acmeErr.Detail = "No authorization provided for name 127.0.0.1" + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 403, + err: acmeErr, + } + }, + "fail/certificate-revoked-check-fails": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockIsRevoked: func(sn string) (bool, error) { + return false, errors.New("force") + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 500, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Detail: "The server experienced an internal error", + Status: 500, + }, + } + }, + "fail/certificate-already-revoked": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockIsRevoked: func(sn string) (bool, error) { + return true, nil + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 400, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:alreadyRevoked", + Detail: "Certificate already revoked", + Status: 400, + }, + } + }, + "fail/invalid-reasoncode": func(t *testing.T) test { + invalidReasonPayload := &revokePayload{ + Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), + ReasonCode: v(7), + } + invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) + assert.FatalError(t, err) + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockIsRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 400, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:badRevocationReason", + Detail: "The revocation reason provided is not allowed by the server", + Status: 400, + }, + } + }, + "fail/prov.AuthorizeRevoke": func(t *testing.T) test { + assert.FatalError(t, err) + mockACMEProv := &acme.MockProvisioner{ + MauthorizeRevoke: func(ctx context.Context, token string) error { + return errors.New("force") + }, + } + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockIsRevoked: func(sn string) (bool, error) { + return false, nil + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 500, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Detail: "The server experienced an internal error", + Status: 500, + }, + } + }, + "fail/ca.Revoke": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockRevoke: func(ctx context.Context, opts *authority.RevokeOptions) error { + return errors.New("force") + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 500, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:serverInternal", + Detail: "The server experienced an internal error", + Status: 500, + }, + } + }, + "fail/ca.Revoke-already-revoked": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{ + MockIsRevoked: func(sn string) (bool, error) { + return false, nil + }, + MockRevoke: func(ctx context.Context, opts *authority.RevokeOptions) error { + return fmt.Errorf("certificate with serial number '%s' is already revoked", cert.SerialNumber.String()) + }, + } + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 400, + err: acme.NewError(acme.ErrorAlreadyRevokedType, "certificate with serial number '%s' is already revoked", cert.SerialNumber.String()), + } + }, + "ok/using-account-key": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "accountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{} + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 200, + } + }, + "ok/using-certificate-key": func(t *testing.T) test { + jwsBytes, err := jwsEncodeJSON(rp, key, "", "nonce", revokeURL) + assert.FatalError(t, err) + jws, err := jose.ParseJWS(string(jwsBytes)) + assert.FatalError(t, err) + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "someDifferentAccountID", + Leaf: cert, + }, nil + }, + } + ca := &mockCA{} + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 200, + } + }, + } + for name, setup := range tests { + tc := setup(t) + t.Run(name, func(t *testing.T) { + h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} + req := httptest.NewRequest("POST", revokeURL, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.RevokeCert(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := io.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.True(t, bytes.Equal(bytes.TrimSpace(body), []byte{})) + assert.Equals(t, int64(0), req.ContentLength) + assert.Equals(t, []string{fmt.Sprintf("<%s/acme/%s/directory>;rel=\"index\"", baseURL.String(), escProvName)}, res.Header["Link"]) + } + }) + } +} + +func TestHandler_isAccountAuthorized(t *testing.T) { + type test struct { + db acme.DB + ctx context.Context + existingCert *acme.Certificate + certToBeRevoked *x509.Certificate + account *acme.Account + err *acme.Error + } + accountID := "accountID" + var tests = map[string]func(t *testing.T) test{ + "fail/account-invalid": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusInvalid, + } + certToBeRevoked := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + } + return test{ + ctx: context.TODO(), + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided for name 127.0.0.1", + Err: errors.New("account 'accountID' has status 'invalid'"), + }, + } + }, + "fail/different-account": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: accountID, + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.0.1", + }, + }, + }, nil + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided", + Err: errors.New("account 'accountID' is not authorized"), + }, + } + }, + "ok": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + existingCert := &acme.Certificate{ + AccountID: "accountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: accountID, + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.0.1", + }, + }, + }, nil + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: nil, + } + }, + } + for name, setup := range tests { + tc := setup(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: tc.db} + acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) + + expectError := tc.err != nil + gotError := acmeErr != nil + if expectError != gotError { + t.Errorf("expected: %t, got: %t", expectError, gotError) + return + } + + if !gotError { + return // nothing to check; return early + } + + assert.Equals(t, acmeErr.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, acmeErr.Type, tc.err.Type) + assert.Equals(t, acmeErr.Status, tc.err.Status) + assert.Equals(t, acmeErr.Detail, tc.err.Detail) + assert.Equals(t, acmeErr.Identifier, tc.err.Identifier) + assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems) + + }) + } +} + +func Test_wrapUnauthorizedError(t *testing.T) { + type test struct { + cert *x509.Certificate + unauthorizedIdentifiers []acme.Identifier + msg string + err error + want *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "unauthorizedIdentifiers": func(t *testing.T) test { + acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") + acmeErr.Status = http.StatusForbidden + acmeErr.Detail = "No authorization provided for name 127.0.0.1" + return test{ + err: nil, + cert: nil, + unauthorizedIdentifiers: []acme.Identifier{ + { + Type: acme.IP, + Value: "127.0.0.1", + }, + }, + msg: "account 'accountID' is not authorized", + want: acmeErr, + } + }, + "subject": func(t *testing.T) test { + acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") + acmeErr.Status = http.StatusForbidden + acmeErr.Detail = "No authorization provided for name test.example.com" + cert := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test.example.com", + }, + } + return test{ + err: nil, + cert: cert, + unauthorizedIdentifiers: []acme.Identifier{}, + msg: "account 'accountID' is not authorized", + want: acmeErr, + } + }, + "wrap-subject": func(t *testing.T) test { + acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "verification of jws using certificate public key failed: square/go-jose: error in cryptographic primitive") + acmeErr.Status = http.StatusForbidden + acmeErr.Detail = "No authorization provided for name test.example.com" + cert := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test.example.com", + }, + } + return test{ + err: errors.New("square/go-jose: error in cryptographic primitive"), + cert: cert, + unauthorizedIdentifiers: []acme.Identifier{}, + msg: "verification of jws using certificate public key failed", + want: acmeErr, + } + }, + "default": func(t *testing.T) test { + acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "account 'accountID' is not authorized") + acmeErr.Status = http.StatusForbidden + acmeErr.Detail = "No authorization provided" + cert := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "", + }, + } + return test{ + err: nil, + cert: cert, + unauthorizedIdentifiers: []acme.Identifier{}, + msg: "account 'accountID' is not authorized", + want: acmeErr, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + acmeErr := wrapUnauthorizedError(tc.cert, tc.unauthorizedIdentifiers, tc.msg, tc.err) + assert.Equals(t, acmeErr.Err.Error(), tc.want.Err.Error()) + assert.Equals(t, acmeErr.Type, tc.want.Type) + assert.Equals(t, acmeErr.Status, tc.want.Status) + assert.Equals(t, acmeErr.Detail, tc.want.Detail) + assert.Equals(t, acmeErr.Identifier, tc.want.Identifier) + assert.Equals(t, acmeErr.Subproblems, tc.want.Subproblems) + }) + } +} diff --git a/acme/challenge.go b/acme/challenge.go index bfe1937d..0e1994e4 100644 --- a/acme/challenge.go +++ b/acme/challenge.go @@ -26,8 +26,11 @@ import ( type ChallengeType string const ( - HTTP01 ChallengeType = "http-01" - DNS01 ChallengeType = "dns-01" + // HTTP01 is the http-01 ACME challenge type + HTTP01 ChallengeType = "http-01" + // DNS01 is the dns-01 ACME challenge type + DNS01 ChallengeType = "dns-01" + // TLSALPN01 is the tls-alpn-01 ACME challenge type TLSALPN01 ChallengeType = "tls-alpn-01" ) diff --git a/acme/common.go b/acme/common.go index f18907fe..0c9e83dc 100644 --- a/acme/common.go +++ b/acme/common.go @@ -5,12 +5,15 @@ import ( "crypto/x509" "time" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" ) // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + IsRevoked(sn string) (bool, error) + Revoke(context.Context, *authority.RevokeOptions) error LoadProvisionerByName(string) (provisioner.Interface, error) } @@ -28,6 +31,7 @@ var clock Clock // only those methods required by the ACME api/authority. type Provisioner interface { AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) + AuthorizeRevoke(ctx context.Context, token string) error GetID() string GetName() string DefaultTLSCertDuration() time.Duration @@ -41,6 +45,7 @@ type MockProvisioner struct { MgetID func() string MgetName func() string MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) + MauthorizeRevoke func(ctx context.Context, token string) error MdefaultTLSCertDuration func() time.Duration MgetOptions func() *provisioner.Options } @@ -61,6 +66,14 @@ func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]prov return m.Mret1.([]provisioner.SignOption), m.Merr } +// AuthorizeRevoke mock +func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { + if m.MauthorizeRevoke != nil { + return m.MauthorizeRevoke(ctx, token) + } + return m.Merr +} + // DefaultTLSCertDuration mock func (m *MockProvisioner) DefaultTLSCertDuration() time.Duration { if m.MdefaultTLSCertDuration != nil { diff --git a/acme/db.go b/acme/db.go index d678fef4..1675c7e7 100644 --- a/acme/db.go +++ b/acme/db.go @@ -25,9 +25,11 @@ type DB interface { CreateAuthorization(ctx context.Context, az *Authorization) error GetAuthorization(ctx context.Context, id string) (*Authorization, error) UpdateAuthorization(ctx context.Context, az *Authorization) error + GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) CreateCertificate(ctx context.Context, cert *Certificate) error GetCertificate(ctx context.Context, id string) (*Certificate, error) + GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) CreateChallenge(ctx context.Context, ch *Challenge) error GetChallenge(ctx context.Context, id, authzID string) (*Challenge, error) @@ -50,12 +52,14 @@ type MockDB struct { MockCreateNonce func(ctx context.Context) (Nonce, error) MockDeleteNonce func(ctx context.Context, nonce Nonce) error - MockCreateAuthorization func(ctx context.Context, az *Authorization) error - MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) - MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + MockCreateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) + MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error) - MockCreateCertificate func(ctx context.Context, cert *Certificate) error - MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) + MockCreateCertificate func(ctx context.Context, cert *Certificate) error + MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) + MockGetCertificateBySerial func(ctx context.Context, serial string) (*Certificate, error) MockCreateChallenge func(ctx context.Context, ch *Challenge) error MockGetChallenge func(ctx context.Context, id, authzID string) (*Challenge, error) @@ -160,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err return m.MockError } +// GetAuthorizationsByAccountID mock +func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) { + if m.MockGetAuthorizationsByAccountID != nil { + return m.MockGetAuthorizationsByAccountID(ctx, accountID) + } else if m.MockError != nil { + return nil, m.MockError + } + return nil, m.MockError +} + // CreateCertificate mock func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { if m.MockCreateCertificate != nil { @@ -180,6 +194,16 @@ func (m *MockDB) GetCertificate(ctx context.Context, id string) (*Certificate, e return m.MockRet1.(*Certificate), m.MockError } +// GetCertificateBySerial mock +func (m *MockDB) GetCertificateBySerial(ctx context.Context, serial string) (*Certificate, error) { + if m.MockGetCertificateBySerial != nil { + return m.MockGetCertificateBySerial(ctx, serial) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Certificate), m.MockError +} + // CreateChallenge mock func (m *MockDB) CreateChallenge(ctx context.Context, ch *Challenge) error { if m.MockCreateChallenge != nil { diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 6decbe4f..01cb7fed 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e nu.Error = az.Error return db.save(ctx, old.ID, nu, old, "authz", authzTable) } + +// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database. +func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + entries, err := db.db.List(authzTable) + if err != nil { + return nil, errors.Wrapf(err, "error listing authz") + } + authzs := []*acme.Authorization{} + for _, entry := range entries { + dbaz := new(dbAuthz) + if err = json.Unmarshal(entry.Value, dbaz); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key)) + } + // Filter out all dbAuthzs that don't belong to the accountID. This + // could be made more efficient with additional data structures mapping the + // Account ID to authorizations. Not trivial to do, though. + if dbaz.AccountID != accountID { + continue + } + authzs = append(authzs, &acme.Authorization{ + ID: dbaz.ID, + AccountID: dbaz.AccountID, + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: nil, // challenges not required for current use case + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Token: dbaz.Token, + Error: dbaz.Error, + }) + } + + return authzs, nil +} diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go index 01c255dc..2e5dd3ea 100644 --- a/acme/db/nosql/authz_test.go +++ b/acme/db/nosql/authz_test.go @@ -3,9 +3,11 @@ package nosql import ( "context" "encoding/json" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" @@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) { }) } } + +func TestDB_GetAuthorizationsByAccountID(t *testing.T) { + azID := "azID" + accountID := "accountID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + authzs []*acme.Authorization + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.List-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return nil, errors.New("force") + }, + }, + err: errors.New("error listing authz: force"), + } + }, + "fail/unmarshal": func(t *testing.T) test { + b := []byte(`{malformed}`) + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: nil, + err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: accountID, + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusValid, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: []*acme.Authorization{ + { + ID: dbaz.ID, + AccountID: dbaz.AccountID, + Token: dbaz.Token, + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: nil, + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Error: dbaz.Error, + }, + }, + } + }, + "ok/skip-different-account": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "differentAccountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusValid, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: []*acme.Authorization{}, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + if !cmp.Equal(azs, tc.authzs) { + t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs)) + } + } + }) + } +} diff --git a/acme/db/nosql/certificate.go b/acme/db/nosql/certificate.go index d3e15833..ee37c570 100644 --- a/acme/db/nosql/certificate.go +++ b/acme/db/nosql/certificate.go @@ -21,6 +21,11 @@ type dbCert struct { Intermediates []byte `json:"intermediates"` } +type dbSerial struct { + Serial string `json:"serial"` + CertificateID string `json:"certificateID"` +} + // CreateCertificate creates and stores an ACME certificate type. func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) error { var err error @@ -49,7 +54,17 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err Intermediates: intermediates, CreatedAt: time.Now().UTC(), } - return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) + err = db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) + if err != nil { + return err + } + + serial := cert.Leaf.SerialNumber.String() + dbSerial := &dbSerial{ + Serial: serial, + CertificateID: cert.ID, + } + return db.save(ctx, serial, dbSerial, nil, "serial", certBySerialTable) } // GetCertificate retrieves and unmarshals an ACME certificate type from the @@ -80,6 +95,24 @@ func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, }, nil } +// GetCertificateBySerial retrieves and unmarshals an ACME certificate type from the +// datastore based on a certificate serial number. +func (db *DB) GetCertificateBySerial(ctx context.Context, serial string) (*acme.Certificate, error) { + b, err := db.db.Get(certBySerialTable, []byte(serial)) + if nosql.IsErrNotFound(err) { + return nil, acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial) + } else if err != nil { + return nil, errors.Wrapf(err, "error loading certificate ID for serial %s", serial) + } + + dbSerial := new(dbSerial) + if err := json.Unmarshal(b, dbSerial); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling certificate with serial %s", serial) + } + + return db.GetCertificate(ctx, dbSerial.CertificateID) +} + func parseBundle(b []byte) ([]*x509.Certificate, error) { var ( err error diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go index 37a61352..d64b3015 100644 --- a/acme/db/nosql/certificate_test.go +++ b/acme/db/nosql/certificate_test.go @@ -1,10 +1,12 @@ package nosql import ( + "bytes" "context" "crypto/x509" "encoding/json" "encoding/pem" + "fmt" "testing" "time" @@ -14,7 +16,6 @@ import ( "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" - "go.step.sm/crypto/pemutil" ) @@ -75,18 +76,36 @@ func TestDB_CreateCertificate(t *testing.T) { return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - *idPtr = string(key) - assert.Equals(t, bucket, certTable) - assert.Equals(t, key, []byte(cert.ID)) - assert.Equals(t, old, nil) + if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) { + t.Fail() + } + if bytes.Equal(bucket, certTable) { + *idPtr = string(key) + assert.Equals(t, bucket, certTable) + assert.Equals(t, key, []byte(cert.ID)) + assert.Equals(t, old, nil) + + dbc := new(dbCert) + assert.FatalError(t, json.Unmarshal(nu, dbc)) + assert.Equals(t, dbc.ID, string(key)) + assert.Equals(t, dbc.ID, cert.ID) + assert.Equals(t, dbc.AccountID, cert.AccountID) + assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) + assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) + } + if bytes.Equal(bucket, certBySerialTable) { + assert.Equals(t, bucket, certBySerialTable) + assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String())) + assert.Equals(t, old, nil) + + dbs := new(dbSerial) + assert.FatalError(t, json.Unmarshal(nu, dbs)) + assert.Equals(t, dbs.Serial, string(key)) + assert.Equals(t, dbs.CertificateID, cert.ID) + + *idPtr = cert.ID + } - dbc := new(dbCert) - assert.FatalError(t, json.Unmarshal(nu, dbc)) - assert.Equals(t, dbc.ID, string(key)) - assert.Equals(t, dbc.ID, cert.ID) - assert.Equals(t, dbc.AccountID, cert.AccountID) - assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) - assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) return nil, true, nil }, }, @@ -317,3 +336,135 @@ func Test_parseBundle(t *testing.T) { }) } } + +func TestDB_GetCertificateBySerial(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + certID := "certID" + serial := "" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return nil, nosqldb.ErrNotFound + } + return nil, errors.New("wrong table") + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial), + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return nil, errors.New("force") + } + return nil, errors.New("wrong table") + }, + }, + err: fmt.Errorf("error loading certificate ID for serial %s", serial), + } + }, + "fail/unmarshal-dbSerial": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return []byte(`{"serial":malformed!}`), nil + } + return nil, errors.New("wrong table") + }, + }, + err: fmt.Errorf("error unmarshaling certificate with serial %s", serial), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + + if bytes.Equal(bucket, certBySerialTable) { + certSerial := dbSerial{ + Serial: serial, + CertificateID: certID, + } + + b, err := json.Marshal(certSerial) + assert.FatalError(t, err) + + return b, nil + } + + if bytes.Equal(bucket, certTable) { + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), + Intermediates: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: inter.Raw, + }), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + })...), + CreatedAt: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + } + return nil, errors.New("wrong table") + }, + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + cert, err := d.GetCertificateBySerial(context.Background(), serial) + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, certID) + assert.Equals(t, cert.AccountID, "accountID") + assert.Equals(t, cert.OrderID, "orderID") + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) + } + }) + } +} diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index b1547373..34932361 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -19,6 +19,7 @@ var ( orderTable = []byte("acme_orders") ordersByAccountIDTable = []byte("acme_account_orders_index") certTable = []byte("acme_certs") + certBySerialTable = []byte("acme_serial_certs_index") ) // DB is a struct that implements the AcmeDB interface. @@ -29,7 +30,7 @@ type DB struct { // New configures and returns a new ACME DB backend implemented using a nosql DB. func New(db nosqlDB.DB) (*DB, error) { tables := [][]byte{accountTable, accountByKeyIDTable, authzTable, - challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable} + challengeTable, nonceTable, orderTable, ordersByAccountIDTable, certTable, certBySerialTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", diff --git a/acme/errors.go b/acme/errors.go index 6ecf0912..a5c820ba 100644 --- a/acme/errors.go +++ b/acme/errors.go @@ -147,7 +147,7 @@ var ( }, ErrorAlreadyRevokedType: { typ: officialACMEPrefix + ErrorAlreadyRevokedType.String(), - details: "Certificate already Revoked", + details: "Certificate already revoked", status: 400, }, ErrorBadCSRType: { diff --git a/acme/order.go b/acme/order.go index 7e65b5d7..366d1a5e 100644 --- a/acme/order.go +++ b/acme/order.go @@ -17,7 +17,9 @@ import ( type IdentifierType string const ( - IP IdentifierType = "ip" + // IP is the ACME ip identifier type + IP IdentifierType = "ip" + // DNS is the ACME dns identifier type DNS IdentifierType = "dns" ) diff --git a/acme/order_test.go b/acme/order_test.go index dee828f7..73f72065 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -13,6 +13,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "go.step.sm/crypto/x509util" ) @@ -287,6 +288,14 @@ func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface return m.ret1.(provisioner.Interface), m.err } +func (m *mockSignAuth) IsRevoked(sn string) (bool, error) { + return false, nil +} + +func (m *mockSignAuth) Revoke(context.Context, *authority.RevokeOptions) error { + return nil +} + func TestOrder_Finalize(t *testing.T) { type test struct { o *Order diff --git a/api/api.go b/api/api.go index e057caaa..16e24bb2 100644 --- a/api/api.go +++ b/api/api.go @@ -348,7 +348,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { roots, err := h.Authority.GetRoots() if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error getting roots")) return } @@ -366,7 +366,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { federated, err := h.Authority.GetFederation() if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error getting federated roots")) return } diff --git a/api/revoke.go b/api/revoke.go index 44d52cb9..25520e3e 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -96,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { } if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/sign.go b/api/sign.go index a1e5b998..93c5f599 100644 --- a/api/sign.go +++ b/api/sign.go @@ -74,7 +74,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error signing certificate")) return } certChainPEM := certChainToPEM(certChain) diff --git a/api/ssh.go b/api/ssh.go index 43ee6b98..c9be1527 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -293,7 +293,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } @@ -301,7 +301,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return } addUserCertificate = &SSHCertificate{addUserCert} @@ -326,7 +326,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate")) return } identityCertificate = certChainToPEM(certChain) diff --git a/api/sshRekey.go b/api/sshRekey.go index 8d2ba5ee..8670f0bd 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -68,7 +68,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return } @@ -78,7 +78,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index 5dfd5983..57b6f432 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -60,7 +60,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return } @@ -70,7 +70,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) if err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index cfc25f04..60f44f2a 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -75,7 +75,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { opts.OTT = body.OTT if err := h.Authority.Revoke(ctx, opts); err != nil { - WriteError(w, errs.ForbiddenErr(err)) + WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return } diff --git a/authority/authority.go b/authority/authority.go index aa8698d7..b6fcdf23 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -588,6 +588,19 @@ func (a *Authority) CloseForReload() { } } +// IsRevoked returns whether or not a certificate has been +// revoked before. +func (a *Authority) IsRevoked(sn string) (bool, error) { + // Check the passive revocation table. + if lca, ok := a.adminDB.(interface { + IsRevoked(string) (bool, error) + }); ok { + return lca.IsRevoked(sn) + } + + return a.db.IsRevoked(sn) +} + // requiresDecrypter returns whether the Authority // requires a KMS that provides a crypto.Decrypter // Currently this is only required when SCEP is diff --git a/authority/authorize.go b/authority/authorize.go index a4e7e591..5108f567 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -274,19 +274,9 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error { // // TODO(mariano): should we authorize by default? func (a *Authority) authorizeRenew(cert *x509.Certificate) error { - var err error - var isRevoked bool - var opts = []interface{}{errs.WithKeyVal("serialNumber", cert.SerialNumber.String())} - - // Check the passive revocation table. serial := cert.SerialNumber.String() - if lca, ok := a.adminDB.(interface { - IsRevoked(string) (bool, error) - }); ok { - isRevoked, err = lca.IsRevoked(serial) - } else { - isRevoked, err = a.db.IsRevoked(serial) - } + var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} + isRevoked, err := a.IsRevoked(serial) if err != nil { return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) } diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index d81b0231..c8950568 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -99,6 +99,15 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e }, nil } +// AuthorizeRevoke is called just before the certificate is to be revoked by +// the CA. It can be used to authorize revocation of a certificate. It +// currently is a no-op. +// TODO(hs): add configuration option that toggles revocation? Or change function signature to make it more useful? +// Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx. +func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { + return nil +} + // AuthorizeRenew returns an error if the renewal is disabled. // NOTE: This method does not actually validate the certificate or check it's // revocation status. Just confirms that the provisioner that created the diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 3d895277..330d1b57 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -184,7 +184,6 @@ func TestUnimplementedMethods(t *testing.T) { {"x5c/sshRenew", &X5C{}, SSHRenewMethod}, {"x5c/sshRekey", &X5C{}, SSHRekeyMethod}, {"x5c/sshRevoke", &X5C{}, SSHRekeyMethod}, - {"acme/revoke", &ACME{}, RevokeMethod}, {"acme/sshSign", &ACME{}, SSHSignMethod}, {"acme/sshRekey", &ACME{}, SSHRekeyMethod}, {"acme/sshRenew", &ACME{}, SSHRenewMethod}, diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index c4779ea3..34b2e99b 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -9,12 +9,14 @@ import ( "encoding/asn1" "encoding/json" "net" + "net/http" "net/url" "reflect" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/errs" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/x509util" ) @@ -83,19 +85,19 @@ type emailOnlyIdentity string func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { switch { case len(req.DNSNames) > 0: - return errors.New("certificate request cannot contain DNS names") + return errs.Forbidden("certificate request cannot contain DNS names") case len(req.IPAddresses) > 0: - return errors.New("certificate request cannot contain IP addresses") + return errs.Forbidden("certificate request cannot contain IP addresses") case len(req.URIs) > 0: - return errors.New("certificate request cannot contain URIs") + return errs.Forbidden("certificate request cannot contain URIs") case len(req.EmailAddresses) == 0: - return errors.New("certificate request does not contain any email address") + return errs.Forbidden("certificate request does not contain any email address") case len(req.EmailAddresses) > 1: - return errors.New("certificate request contains too many email addresses") + return errs.Forbidden("certificate request contains too many email addresses") case req.EmailAddresses[0] == "": - return errors.New("certificate request cannot contain an empty email address") + return errs.Forbidden("certificate request cannot contain an empty email address") case req.EmailAddresses[0] != string(e): - return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e) + return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e) default: return nil } @@ -108,12 +110,13 @@ type defaultPublicKeyValidator struct{} func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error { switch k := req.PublicKey.(type) { case *rsa.PublicKey: - if k.Size() < 256 { - return errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)") + if k.Size() < keyutil.MinRSAKeyBytes { + return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", + 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: - return errors.Errorf("unrecognized public key of type '%T' in CSR", k) + return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } @@ -139,11 +142,12 @@ func (v publicKeyMinimumLengthValidator) Valid(req *x509.CertificateRequest) err case *rsa.PublicKey: minimumLengthInBytes := v.length / 8 if k.Size() < minimumLengthInBytes { - return errors.Errorf("rsa key in CSR must be at least %d bits (%d bytes)", v.length, minimumLengthInBytes) + return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)", + v.length, minimumLengthInBytes) } case *ecdsa.PublicKey, ed25519.PublicKey: default: - return errors.Errorf("unrecognized public key of type '%T' in CSR", k) + return errs.BadRequest("certificate request key of type '%T' is not supported", k) } return nil } @@ -158,7 +162,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error { return nil } if req.Subject.CommonName != string(v) { - return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v) + return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } return nil } @@ -176,7 +180,7 @@ func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error { return nil } } - return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) + return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) } // dnsNamesValidator validates the DNS names SAN of a certificate request. @@ -197,7 +201,7 @@ func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error { got[s] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) + return errs.Forbidden("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) } return nil } @@ -220,7 +224,7 @@ func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error { got[ip.String()] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v) + return errs.Forbidden("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v) } return nil } @@ -243,7 +247,7 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error { got[s] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("certificate request does not contain the valid Email Addresses - got %v, want %v", req.EmailAddresses, v) + return errs.Forbidden("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v) } return nil } @@ -266,7 +270,7 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error { got[u.String()] = true } if !reflect.DeepEqual(want, got) { - return errors.Errorf("URIs claim failed - got %v, want %v", req.URIs, v) + return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v) } return nil } @@ -334,15 +338,15 @@ func (v profileLimitDuration) Modify(cert *x509.Certificate, so SignOptions) err backdate = -1 * so.Backdate } if notBefore.Before(v.notBefore) { - return errors.Errorf("requested certificate notBefore (%s) is before "+ - "the active validity window of the provisioning credential (%s)", + return errs.Forbidden( + "requested certificate notBefore (%s) is before the active validity window of the provisioning credential (%s)", notBefore, v.notBefore) } notAfter := so.NotAfter.RelativeTime(notBefore) if notAfter.After(v.notAfter) { - return errors.Errorf("requested certificate notAfter (%s) is after "+ - "the expiration of the provisioning credential (%s)", + return errs.Forbidden( + "requested certificate notAfter (%s) is after the expiration of the provisioning credential (%s)", notAfter, v.notAfter) } if notAfter.IsZero() { @@ -388,14 +392,14 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) } if d < v.min { - return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) + return errs.Forbidden("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) } // NOTE: this check is not "technically correct". We're allowing the max // duration of a cert to be "max + backdate" and not all certificates will // be backdated (e.g. if a user passes the NotBefore value then we do not // apply a backdate). This is good enough. if d > v.max+o.Backdate { - return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) + return errs.Forbidden("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) } return nil } @@ -422,16 +426,15 @@ func newForceCNOption(forceCN bool) *forceCNOption { func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { if !o.ForceCN { - // Forcing CN is disabled, do nothing to certificate return nil } + // Force the common name to be the first DNS if not provided. if cert.Subject.CommonName == "" { - if len(cert.DNSNames) > 0 { - cert.Subject.CommonName = cert.DNSNames[0] - } else { - return errors.New("Cannot force CN, DNSNames is empty") + if len(cert.DNSNames) == 0 { + return errs.BadRequest("cannot force common name, DNS names is empty") } + cert.Subject.CommonName = cert.DNSNames[0] } return nil @@ -456,7 +459,7 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValue func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) if err != nil { - return err + return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") } // Prepend the provisioner extension. In the auth.Sign code we will // force the resulting certificate to only have one extension, the @@ -477,7 +480,7 @@ func createProvisionerExtension(typ int, name, credentialID string, keyValuePair KeyValuePairs: keyValuePairs, }) if err != nil { - return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension") + return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension") } return pkix.Extension{ Id: stepOIDProvisioner, diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index cf8f7a54..32b8e3c6 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -77,12 +77,12 @@ func Test_defaultPublicKeyValidator_Valid(t *testing.T) { { "fail/unrecognized-key-type", &x509.CertificateRequest{PublicKey: "foo"}, - errors.New("unrecognized public key of type 'string' in CSR"), + errors.New("certificate request key of type 'string' is not supported"), }, { "fail/rsa/too-short", shortRSA, - errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)"), + errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"), }, { "ok/rsa", @@ -303,14 +303,14 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { return test{ csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}}, expectedSANs: []string{"dcow@fx.com"}, - err: errors.New("certificate request does not contain the valid Email Addresses"), + err: errors.New("certificate request does not contain the valid email addresses"), } }, "fail/ipAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}}, expectedSANs: []string{"127.0.0.1"}, - err: errors.New("IP Addresses claim failed"), + err: errors.New("certificate request does not contain the valid IP addresses"), } }, "fail/urisValidator": func() test { @@ -321,7 +321,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { return test{ csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, - err: errors.New("URIs claim failed"), + err: errors.New("certificate request does not contain the valid URIs"), } }, "ok": func() test { @@ -512,7 +512,7 @@ func Test_forceCN_Option(t *testing.T) { Subject: pkix.Name{}, DNSNames: []string{}, }, - err: errors.New("Cannot force CN, DNSNames is empty"), + err: errors.New("cannot force common name, DNS names is empty"), } }, } diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 6cd38c59..a2ca78b1 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -56,7 +56,12 @@ type SignSSHOptions struct { // Validate validates the given SignSSHOptions. func (o SignSSHOptions) Validate() error { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { - return errs.BadRequest("unknown certificate type '%s'", o.CertType) + return errs.BadRequest("certType '%s' is not valid", o.CertType) + } + for _, p := range o.Principals { + if p == "" { + return errs.BadRequest("principals cannot contain empty values") + } } return nil } @@ -75,7 +80,7 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { case SSHHostCert: cert.CertType = ssh.HostCert default: - return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType) + return errs.BadRequest("ssh certificate has an unknown type '%s'", o.CertType) } cert.KeyId = o.KeyID @@ -95,7 +100,7 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error { cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix()) } if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { - return errors.New("ssh certificate valid after cannot be greater than valid before") + return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore") } return nil } @@ -104,16 +109,16 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error { // ignores zero values. func (o SignSSHOptions) match(got SignSSHOptions) error { if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { - return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) + return errs.Forbidden("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) } if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { - return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) + return errs.Forbidden("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) } if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { - return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) + return errs.Forbidden("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) } if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { - return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) + return errs.Forbidden("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) } return nil } @@ -206,7 +211,7 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOpt cert.Extensions["permit-user-rc"] = "" return nil default: - return errors.New("ssh certificate type has not been set or is invalid") + return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType) } } @@ -272,7 +277,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error certValidAfter := time.Unix(int64(cert.ValidAfter), 0) if certValidAfter.After(m.NotAfter) { - return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", + return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", m.NotAfter, certValidAfter) } @@ -285,7 +290,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error } else { certValidBefore := time.Unix(int64(cert.ValidBefore), 0) if m.NotAfter.Before(certValidBefore) { - return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", + return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", m.NotAfter, certValidBefore) } } @@ -319,11 +324,11 @@ type sshCertOptionsRequireValidator struct { func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error { switch { case v.CertType && got.CertType == "": - return errors.New("ssh certificate certType cannot be empty") + return errs.BadRequest("ssh certificate certType cannot be empty") case v.KeyID && got.KeyID == "": - return errors.New("ssh certificate keyID cannot be empty") + return errs.BadRequest("ssh certificate keyID cannot be empty") case v.Principals && len(got.Principals) == 0: - return errors.New("ssh certificate principals cannot be empty") + return errs.BadRequest("ssh certificate principals cannot be empty") default: return nil } @@ -354,7 +359,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti case 0: return errs.BadRequest("ssh certificate type has not been set") default: - return errs.BadRequest("unknown ssh certificate type %d", cert.CertType) + return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType) } // To not take into account the backdate, time.Now() will be used to @@ -363,9 +368,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti switch { case dur < min: - return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) + return errs.Forbidden("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) case dur > max+opts.Backdate: - return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) + return errs.Forbidden("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) default: return nil } @@ -381,25 +386,25 @@ type sshCertDefaultValidator struct{} func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { switch { case len(cert.Nonce) == 0: - return errors.New("ssh certificate nonce cannot be empty") + return errs.Forbidden("ssh certificate nonce cannot be empty") case cert.Key == nil: - return errors.New("ssh certificate key cannot be nil") + return errs.Forbidden("ssh certificate key cannot be nil") case cert.Serial == 0: - return errors.New("ssh certificate serial cannot be 0") + return errs.Forbidden("ssh certificate serial cannot be 0") case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: - return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType) + return errs.Forbidden("ssh certificate has an unknown type '%d'", cert.CertType) case cert.KeyId == "": - return errors.New("ssh certificate key id cannot be empty") + return errs.Forbidden("ssh certificate key id cannot be empty") case cert.ValidAfter == 0: - return errors.New("ssh certificate validAfter cannot be 0") + return errs.Forbidden("ssh certificate validAfter cannot be 0") case cert.ValidBefore < uint64(now().Unix()): - return errors.New("ssh certificate validBefore cannot be in the past") + return errs.Forbidden("ssh certificate validBefore cannot be in the past") case cert.ValidBefore < cert.ValidAfter: - return errors.New("ssh certificate validBefore cannot be before validAfter") + return errs.Forbidden("ssh certificate validBefore cannot be before validAfter") case cert.SignatureKey == nil: - return errors.New("ssh certificate signature key cannot be nil") + return errs.Forbidden("ssh certificate signature key cannot be nil") case cert.Signature == nil: - return errors.New("ssh certificate signature cannot be nil") + return errs.Forbidden("ssh certificate signature cannot be nil") default: return nil } @@ -409,27 +414,31 @@ func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) type sshDefaultPublicKeyValidator struct{} // Valid checks that certificate request common name matches the one configured. +// +// TODO: this is the only validator that checks the key type. We should execute +// this before the signing. We should add a new validations interface or extend +// SSHCertOptionsValidator with the key. func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { if cert.Key == nil { - return errors.New("ssh certificate key cannot be nil") + return errs.BadRequest("ssh certificate key cannot be nil") } switch cert.Key.Type() { case ssh.KeyAlgoRSA: _, in, ok := sshParseString(cert.Key.Marshal()) if !ok { - return errors.New("ssh certificate key is invalid") + return errs.BadRequest("ssh certificate key is invalid") } key, err := sshParseRSAPublicKey(in) if err != nil { - return err + return errs.BadRequestErr(err, "error parsing public key") } if key.Size() < keyutil.MinRSAKeyBytes { - return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)", + return errs.Forbidden("ssh certificate key must be at least %d bits (%d bytes)", 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) } return nil case ssh.KeyAlgoDSA: - return errors.New("ssh certificate key algorithm (DSA) is not supported") + return errs.BadRequest("ssh certificate key algorithm (DSA) is not supported") default: return nil } diff --git a/authority/provisioner/sign_ssh_options_test.go b/authority/provisioner/sign_ssh_options_test.go index 3a1ff324..b59d6945 100644 --- a/authority/provisioner/sign_ssh_options_test.go +++ b/authority/provisioner/sign_ssh_options_test.go @@ -49,14 +49,14 @@ func TestSSHOptions_Modify(t *testing.T) { return test{ so: SignSSHOptions{CertType: "foo"}, cert: new(ssh.Certificate), - err: errors.Errorf("ssh certificate has an unknown type - foo"), + err: errors.Errorf("ssh certificate has an unknown type 'foo'"), } }, "fail/validAfter-greater-validBefore": func() test { return test{ so: SignSSHOptions{CertType: "user"}, cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)}, - err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"), + err: errors.Errorf("ssh certificate validAfter cannot be greater than validBefore"), } }, "ok/user-cert": func() test { @@ -136,14 +136,14 @@ func TestSSHOptions_Match(t *testing.T) { return test{ so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))}, - err: errors.Errorf("ssh certificate valid after does not match"), + err: errors.Errorf("ssh certificate validAfter does not match"), } }, "fail/validBefore": func() test { return test{ so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))}, cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))}, - err: errors.Errorf("ssh certificate valid before does not match"), + err: errors.Errorf("ssh certificate validBefore does not match"), } }, "ok/original-empty": func() test { @@ -394,7 +394,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) { return test{ modifier: sshDefaultExtensionModifier{}, cert: cert, - err: errors.New("ssh certificate type has not been set or is invalid"), + err: errors.New("ssh certificate has an unknown type '3'"), } }, "ok/host": func() test { @@ -518,7 +518,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) { "fail/unexpected-cert-type", // UserCert = 1, HostCert = 2 &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1}, - errors.New("ssh certificate has an unknown type: 3"), + errors.New("ssh certificate has an unknown type '3'"), }, { "fail/empty-cert-key-id", @@ -725,7 +725,7 @@ func Test_sshCertValidityValidator(t *testing.T) { ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), }, SignSSHOptions{}, - errors.New("unknown ssh certificate type 3"), + errors.New("ssh certificate has an unknown type '3'"), }, { "fail/duration 0 { return nil } - return errors.New("certificate does not have only one principal") + return errs.Forbidden("certificate does not have only one principal") default: - return errors.New("certificate does not have only one principal") + return errs.Forbidden("certificate does not have only one principal") } } @@ -433,7 +432,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") } if err := IsValidForAddUser(subject); err != nil { - return nil, errs.Wrap(http.StatusForbidden, err, "signSSHAddUser") + return nil, err } nonce, err := randutil.ASCII(32) diff --git a/authority/tls.go b/authority/tls.go index 716d8956..dfa88ac3 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -94,7 +94,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign // Validate the given certificate request. case provisioner.CertificateRequestValidator: if err := k.Valid(csr); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error validating certificate"), + opts..., + ) } // Validates the unsigned certificate template. @@ -131,26 +134,38 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign // Set default subject if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error creating certificate"), + opts..., + ) } for _, m := range certModifiers { if err := m.Modify(leaf, signOpts); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error creating certificate"), + opts..., + ) } } // Certificate validation. for _, v := range certValidators { if err := v.Valid(leaf, signOpts); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error validating certificate"), + opts..., + ) } } // Certificate modifiers after validation for _, m := range certEnforcers { if err := m.Enforce(leaf); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) + return nil, errs.ApplyOptions( + errs.ForbiddenErr(err, "error creating certificate"), + opts..., + ) } } @@ -328,6 +343,7 @@ type RevokeOptions struct { ReasonCode int PassiveOnly bool MTLS bool + ACME bool Crt *x509.Certificate OTT string } @@ -345,9 +361,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error errs.WithKeyVal("reason", revokeOpts.Reason), errs.WithKeyVal("passiveOnly", revokeOpts.PassiveOnly), errs.WithKeyVal("MTLS", revokeOpts.MTLS), + errs.WithKeyVal("ACME", revokeOpts.ACME), errs.WithKeyVal("context", provisioner.MethodFromContext(ctx).String()), } - if revokeOpts.MTLS { + if revokeOpts.MTLS || revokeOpts.ACME { opts = append(opts, errs.WithKeyVal("certificate", base64.StdEncoding.EncodeToString(revokeOpts.Crt.Raw))) } else { opts = append(opts, errs.WithKeyVal("token", revokeOpts.OTT)) @@ -358,6 +375,7 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error ReasonCode: revokeOpts.ReasonCode, Reason: revokeOpts.Reason, MTLS: revokeOpts.MTLS, + ACME: revokeOpts.ACME, RevokedAt: time.Now().UTC(), } @@ -365,8 +383,8 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error p provisioner.Interface err error ) - // If not mTLS then get the TokenID of the token. - if !revokeOpts.MTLS { + // If not mTLS nor ACME, then get the TokenID of the token. + if !(revokeOpts.MTLS || revokeOpts.ACME) { token, err := jose.ParseSigned(revokeOpts.OTT) if err != nil { return errs.Wrap(http.StatusUnauthorized, err, diff --git a/authority/tls_test.go b/authority/tls_test.go index 03beb5c1..3acf05f5 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -281,8 +281,8 @@ func TestAuthority_Sign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: errors.New("authority.Sign: default ASN1DN template cannot be nil"), - code: http.StatusUnauthorized, + err: errors.New("default ASN1DN template cannot be nil"), + code: http.StatusForbidden, } }, "fail create cert": func(t *testing.T) *signTest { @@ -309,8 +309,8 @@ func TestAuthority_Sign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: _signOpts, - err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), - code: http.StatusBadRequest, + err: errors.New("requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), + code: http.StatusForbidden, } }, "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { @@ -322,8 +322,8 @@ func TestAuthority_Sign(t *testing.T) { csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), - code: http.StatusUnauthorized, + err: errors.New("certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), + code: http.StatusForbidden, } }, "fail rsa key too short": func(t *testing.T) *signTest { @@ -348,8 +348,8 @@ ZYtQ9Ot36qc= csr: csr, extraOpts: extraOpts, signOpts: signOpts, - err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), - code: http.StatusUnauthorized, + err: errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"), + code: http.StatusForbidden, } }, "fail store cert in db": func(t *testing.T) *signTest { @@ -1267,6 +1267,23 @@ func TestAuthority_Revoke(t *testing.T) { }, } }, + "ok/ACME": func() test { + _a := testAuthority(t, WithDatabase(&db.MockAuthDB{})) + + crt, err := pemutil.ReadCertificate("./testdata/certs/foo.crt") + assert.FatalError(t, err) + + return test{ + auth: _a, + opts: &RevokeOptions{ + Crt: crt, + Serial: "102012593071130646873265215610956555026", + ReasonCode: reasonCode, + Reason: reason, + ACME: true, + }, + } + }, } for name, f := range tests { tc := f() diff --git a/ca/ca.go b/ca/ca.go index c76e8c0a..da0fb874 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -442,7 +442,7 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { return tlsConfig, nil } -// shouldMountSCEPEndpoints returns if the CA should be +// shouldServeSCEPEndpoints returns if the CA should be // configured with endpoints for SCEP. This is assumed to be // true if a SCEPService exists, which is true in case a // SCEP provisioner was configured. diff --git a/ca/ca_test.go b/ca/ca_test.go index 1271659a..e4c35a90 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -200,8 +200,8 @@ ZEp7knvU2psWRw== return &signTest{ ca: ca, body: string(body), - status: http.StatusUnauthorized, - errMsg: errs.UnauthorizedDefaultMsg, + status: http.StatusForbidden, + errMsg: errs.ForbiddenPrefix, } }, "ok": func(t *testing.T) *signTest { diff --git a/db/db.go b/db/db.go index 2643e577..6d48723f 100644 --- a/db/db.go +++ b/db/db.go @@ -104,6 +104,7 @@ type RevokedCertificateInfo struct { RevokedAt time.Time TokenID string MTLS bool + ACME bool } // IsRevoked returns whether or not a certificate with the given identifier diff --git a/docs/images/star.gif b/docs/images/star.gif new file mode 100644 index 00000000..eef9d707 Binary files /dev/null and b/docs/images/star.gif differ diff --git a/errs/error.go b/errs/error.go index 60312313..2c1fe6a9 100644 --- a/errs/error.go +++ b/errs/error.go @@ -169,7 +169,8 @@ func StatusCodeError(code int, e error, opts ...Option) error { case http.StatusUnauthorized: return UnauthorizedErr(e, opts...) case http.StatusForbidden: - return ForbiddenErr(e, opts...) + opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) + return NewErr(http.StatusForbidden, e, opts...) case http.StatusInternalServerError: return InternalServerErr(e, opts...) case http.StatusNotImplemented: @@ -199,12 +200,18 @@ var ( // BadRequestPrefix is the prefix added to the bad request messages that are // directly sent to the cli. BadRequestPrefix = "The request could not be completed: " + + // ForbiddenPrefix is the prefix added to the forbidden messates that are + // sent to the cli. + ForbiddenPrefix = "The request was forbidden by the certificate authority: " ) func formatMessage(status int, msg string) string { switch status { case http.StatusBadRequest: return BadRequestPrefix + msg + "." + case http.StatusForbidden: + return ForbiddenPrefix + msg + "." default: return msg } @@ -356,14 +363,12 @@ func UnauthorizedErr(err error, opts ...Option) error { // Forbidden creates a 403 error with the given format and arguments. func Forbidden(format string, args ...interface{}) error { - args = append(args, withDefaultMessage(ForbiddenDefaultMsg)) - return Errorf(http.StatusForbidden, format, args...) + return New(http.StatusForbidden, format, args...) } // ForbiddenErr returns an 403 error with the given error. -func ForbiddenErr(err error, opts ...Option) error { - opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) - return NewErr(http.StatusForbidden, err, opts...) +func ForbiddenErr(err error, format string, args ...interface{}) error { + return NewError(http.StatusForbidden, err, format, args...) } // NotFound creates a 404 error with the given format and arguments. diff --git a/go.mod b/go.mod index 394eb1a4..830cab82 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect