From 06bb97c91e3d94c615b2b07010d765a91c990a95 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 2 Dec 2021 16:25:35 +0100 Subject: [PATCH] Add logic for Account authorizations and improve tests --- acme/api/revoke.go | 139 ++++++- acme/api/revoke_test.go | 777 ++++++++++++++++++++++++------------ acme/db.go | 18 +- acme/db/nosql/authz.go | 34 ++ acme/db/nosql/authz_test.go | 153 +++++++ acme/order.go | 3 + go.mod | 1 + 7 files changed, 845 insertions(+), 280 deletions(-) diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 1c664dde..209bc204 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -1,6 +1,8 @@ package api import ( + "bytes" + "context" "crypto/x509" "encoding/base64" "encoding/json" @@ -66,35 +68,36 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { } serial := certToBeRevoked.SerialNumber.String() - existingCert, err := h.db.GetCertificateBySerial(ctx, serial) + dbCert, err := h.db.GetCertificateBySerial(ctx, serial) if err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) return } + if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) { + // this should never happen + api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal")) + return + } + if shouldCheckAccountFrom(jws) { account, err := accountFromContext(ctx) if err != nil { api.WriteError(w, err) return } - if !account.IsValid() { - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)) + acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) + if acmeErr != nil { + api.WriteError(w, acmeErr) return } - if existingCert.AccountID != account.ID { // TODO(hs): combine this check with the one below; ony one of the two has to be true - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil)) - return - } - // TODO(hs): check and implement "an account that holds authorizations for all of the identifiers in the certificate." - // In that case the certificate may not have been created by this account, but another account that was authorized before. } else { // if account doesn't need to be checked, the JWS should be verified to be signed by the // private key that belongs to the public key in the certificate to be revoked. _, err := jws.Verify(certToBeRevoked.PublicKey) if err != nil { // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? - api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, "verification of jws using certificate public key failed", err)) + api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err)) return } } @@ -137,6 +140,107 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { w.Write(nil) } +// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized +// to revoke the certificate. An Account must always be valid in order to revoke a certificate. +// In case the certificate retrieved from the database belongs to the Account, the Account is +// authorized. If the certificate retrieved from the database doesn't belong to the Account, +// the identifiers in the certificate are extracted and compared against the (valid) Authorizations +// that are stored for the ACME Account. If these sets match, the Account is considered authorized +// to revoke the certificate. If this check fails, the client will receive an unauthorized error. +func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { + if !account.IsValid() { + return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) + } + certificateBelongsToAccount := dbCert.AccountID == account.ID + if certificateBelongsToAccount { + return nil // return early; skip relatively expensive database check + } + requiredIdentifiers := extractIdentifiers(certToBeRevoked) + if len(requiredIdentifiers) == 0 { + return wrapUnauthorizedError(certToBeRevoked, nil, "cannot authorize revocation without providing identifiers to authorize", nil) + } + authzs, err := h.db.GetAuthorizationsByAccountID(ctx, account.ID) + if err != nil { + return acme.WrapErrorISE(err, "error retrieving authorizations for Account %s", account.ID) + } + authorizedIdentifiers := map[string]acme.Identifier{} + for _, authz := range authzs { + // Only valid Authorizations are included + if authz.Status != acme.StatusValid { + continue + } + authorizedIdentifiers[identifierKey(authz.Identifier)] = authz.Identifier + } + if len(authorizedIdentifiers) == 0 { + unauthorizedIdentifiers := []acme.Identifier{} + for _, identifier := range requiredIdentifiers { + unauthorizedIdentifiers = append(unauthorizedIdentifiers, identifier) + } + return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have valid authorizations", account.ID), nil) + } + unauthorizedIdentifiers := []acme.Identifier{} + for key := range requiredIdentifiers { + _, ok := authorizedIdentifiers[key] + if !ok { + unauthorizedIdentifiers = append(unauthorizedIdentifiers, requiredIdentifiers[key]) + } + } + if len(unauthorizedIdentifiers) != 0 { + return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have authorizations for all identifiers", account.ID), nil) + } + + return nil +} + +// identifierKey creates a unique key for an ACME identifier using +// the following format: ip|127.0.0.1; dns|*.example.com +func identifierKey(identifier acme.Identifier) string { + if identifier.Type == acme.IP { + return "ip|" + identifier.Value + } + if identifier.Type == acme.DNS { + return "dns|" + identifier.Value + } + return "unsupported|" + identifier.Value +} + +// extractIdentifiers extracts ACME identifiers from an x509 certificate and +// creates a map from them. The map ensures that double SANs are deduplicated. +// The Subject CommonName is included, because RFC8555 7.4 states that DNS +// identifiers can come from either the CommonName or a DNS SAN or both. When +// authorizing issuance, the DNS identifier must be in the request and will be +// included in the validation (see Order.sans()) as of now. This means that the +// CommonName will in fact have an authorization available. +func extractIdentifiers(cert *x509.Certificate) map[string]acme.Identifier { + result := map[string]acme.Identifier{} + for _, name := range cert.DNSNames { + identifier := acme.Identifier{ + Type: acme.DNS, + Value: name, + } + result[identifierKey(identifier)] = identifier + } + for _, ip := range cert.IPAddresses { + identifier := acme.Identifier{ + Type: acme.IP, + Value: ip.String(), + } + result[identifierKey(identifier)] = identifier + } + // TODO(hs): should we include the CommonName or not? + if cert.Subject.CommonName != "" { + identifier := acme.Identifier{ + // assuming only DNS can be in Common Name (RFC8555, 7.4); RFC8738 + // IP Identifier Validation Extension does not state anything about this. + // This logic is in accordance with the logic in order.canonicalize() + Type: acme.DNS, + Value: cert.Subject.CommonName, + } + result[identifierKey(identifier)] = identifier + } + return result +} + // wrapRevokeErr is a best effort implementation to transform an error during // revocation into an ACME error, so that clients can understand the error. func wrapRevokeErr(err error) *acme.Error { @@ -149,15 +253,24 @@ func wrapRevokeErr(err error) *acme.Error { // unauthorizedError returns an ACME error indicating the request was // not authorized to revoke the certificate. -func wrapUnauthorizedError(cert *x509.Certificate, msg string, err error) *acme.Error { +func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error { var acmeErr *acme.Error if err == nil { acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg) } else { acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) } - acmeErr.Status = http.StatusForbidden - acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO(hs): what about other SANs? When no Subject is in the certificate? + acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403 + + switch { + case len(unauthorizedIdentifiers) > 0: + identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too? + acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value) + case cert.Subject.String() != "": + acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName) + default: + acmeErr.Detail = "No authorization provided" + } return acmeErr } diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 05952240..cf036abb 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -14,6 +14,8 @@ import ( "fmt" "io" "math/big" + "net" + "net/http" "net/http/httptest" "net/url" "testing" @@ -37,6 +39,10 @@ func v(v int) *int { return &v } +func generateSerial() (*big.Int, error) { + return rand.Int(rand.Reader, big.NewInt(1000000000000000000)) +} + // generateCertKeyPair generates fresh x509 certificate/key pairs for testing func generateCertKeyPair() (*x509.Certificate, crypto.Signer, error) { @@ -45,15 +51,16 @@ func generateCertKeyPair() (*x509.Certificate, crypto.Signer, error) { return nil, nil, err } - serial, err := rand.Int(rand.Reader, big.NewInt(1000000000000000000)) + serial, err := generateSerial() if err != nil { return nil, nil, err } now := time.Now() template := &x509.Certificate{ - Subject: pkix.Name{CommonName: "Test ACME Revoke Certificate"}, + Subject: pkix.Name{CommonName: "127.0.0.1"}, Issuer: pkix.Name{CommonName: "Test ACME Revoke Certificate"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, IsCA: false, MaxPathLen: 0, KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, @@ -453,7 +460,7 @@ func Test_revokeOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := revokeOptions(tt.args.serial, tt.args.certToBeRevoked, tt.args.reasonCode); !cmp.Equal(got, tt.want) { - t.Errorf("revokeOptions() diff = %s", cmp.Diff(got, tt.want)) + t.Errorf("revokeOptions() diff =\n%s", cmp.Diff(got, tt.want)) } }) } @@ -478,6 +485,20 @@ func TestHandler_RevokeCert(t *testing.T) { payloadBytes, err := json.Marshal(rp) assert.FatalError(t, err) + jws := &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + Algorithm: jose.ES256, + KeyID: "bar", + ExtraHeaders: map[jose.HeaderKey]interface{}{ + "url": revokeURL, + }, + }, + }, + }, + } + type test struct { db acme.DB ca acme.CertificateAuthority @@ -504,19 +525,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-provisioner": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), jwsContextKey, jws) return test{ ctx: ctx, @@ -525,19 +533,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-provisioner": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = context.WithValue(ctx, provisionerContextKey, nil) return test{ @@ -547,19 +542,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-payload": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = context.WithValue(ctx, provisionerContextKey, prov) return test{ @@ -569,19 +551,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-payload": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, nil) @@ -592,19 +561,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/unmarshal-payload": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } malformedPayload := []byte(`{"payload":malformed?}`) ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx = context.WithValue(ctx, provisionerContextKey, prov) @@ -621,19 +577,6 @@ func TestHandler_RevokeCert(t *testing.T) { } wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -651,23 +594,10 @@ func TestHandler_RevokeCert(t *testing.T) { emptyPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString([]byte{}), } - wrongPayloadBytes, err := json.Marshal(emptyPayload) + emptyPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wrongPayloadBytes}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) return test{ ctx: ctx, @@ -680,19 +610,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/db.GetCertificateBySerial": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -708,27 +625,37 @@ func TestHandler_RevokeCert(t *testing.T) { err: acme.NewErrorISE("error retrieving certificate by serial"), } }, - "fail/no-account": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } + "fail/different-certificate-contents": func(t *testing.T) test { + aDifferentCert, _, err := generateCertKeyPair() + assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) - return &acme.Certificate{}, nil + return &acme.Certificate{ + Leaf: aDifferentCert, + }, nil + }, + } + return test{ + db: db, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("certificate raw bytes are not equal"), + } + }, + "fail/no-account": func(t *testing.T) test { + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + Leaf: cert, + }, nil }, } return test{ @@ -739,19 +666,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/nil-account": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) @@ -759,7 +673,9 @@ func TestHandler_RevokeCert(t *testing.T) { db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) - return &acme.Certificate{}, nil + return &acme.Certificate{ + Leaf: cert, + }, nil }, } return test{ @@ -770,19 +686,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/account-not-valid": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -795,6 +698,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -806,25 +710,12 @@ func TestHandler_RevokeCert(t *testing.T) { statusCode: 403, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", - Detail: fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()), + Detail: "No authorization provided for name 127.0.0.1", Status: 403, }, } }, - "fail/account-not-authorized": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } + "fail/db.GetAuthorizationsByAccountID-error": func(t *testing.T) test { acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -837,6 +728,49 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "differentAccountID", + Leaf: cert, + }, nil + }, + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + return nil, errors.New("force") + }, + } + ca := &mockCA{} + return test{ + db: db, + ca: ca, + ctx: ctx, + statusCode: 500, + err: acme.NewErrorISE("error retrieving authorizations for Account %s", "accountID"), + } + }, + "fail/account-not-authorized": func(t *testing.T) test { + acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, accContextKey, acc) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) + ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) + db := &acme.MockDB{ + MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { + assert.Equals(t, cert.SerialNumber.String(), serial) + return &acme.Certificate{ + AccountID: "differentAccountID", + Leaf: cert, + }, nil + }, + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: "accountID", + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.1.0", + }, + }, }, nil }, } @@ -848,7 +782,7 @@ func TestHandler_RevokeCert(t *testing.T) { statusCode: 403, err: &acme.Error{ Type: "urn:ietf:params:acme:error:unauthorized", - Detail: fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()), + Detail: "No authorization provided for name 127.0.0.1", Status: 403, }, } @@ -862,13 +796,13 @@ func TestHandler_RevokeCert(t *testing.T) { } jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL) assert.FatalError(t, err) - jws, err := jose.ParseJWS(string(jwsBytes)) + parsedJWS, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) - ctx = context.WithValue(ctx, jwsContextKey, jws) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) db := &acme.MockDB{ @@ -876,12 +810,13 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } ca := &mockCA{} acmeErr := acme.NewError(acme.ErrorUnauthorizedType, "verification of jws using certificate public key failed") - acmeErr.Detail = "No authorization provided for name CN=Test ACME Revoke Certificate" + acmeErr.Detail = "No authorization provided for name 127.0.0.1" return test{ db: db, ca: ca, @@ -891,19 +826,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/certificate-revoked-check-fails": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -916,6 +838,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -937,19 +860,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/certificate-already-revoked": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -960,6 +870,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -985,31 +896,19 @@ func TestHandler_RevokeCert(t *testing.T) { Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), ReasonCode: v(7), } - wrongReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) + invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) - ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wrongReasonCodePayloadBytes}) + ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) db := &acme.MockDB{ MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -1032,19 +931,6 @@ func TestHandler_RevokeCert(t *testing.T) { }, "fail/prov.AuthorizeRevoke": func(t *testing.T) test { assert.FatalError(t, err) - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } mockACMEProv := &acme.MockProvisioner{ MauthorizeRevoke: func(ctx context.Context, token string) error { return errors.New("force") @@ -1060,6 +946,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -1081,19 +968,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/ca.Revoke": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -1104,6 +978,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -1125,19 +1000,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/ca.Revoke-already-revoked": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -1148,6 +1010,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -1168,19 +1031,6 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "ok/using-account-key": func(t *testing.T) test { - jws := &jose.JSONWebSignature{ - Signatures: []jose.Signature{ - { - Protected: jose.Header{ - Algorithm: jose.ES256, - KeyID: "bar", - ExtraHeaders: map[jose.HeaderKey]interface{}{ - "url": revokeURL, - }, - }, - }, - }, - } acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, accContextKey, acc) @@ -1193,6 +1043,7 @@ func TestHandler_RevokeCert(t *testing.T) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ AccountID: "accountID", + Leaf: cert, }, nil }, } @@ -1218,7 +1069,8 @@ func TestHandler_RevokeCert(t *testing.T) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { assert.Equals(t, cert.SerialNumber.String(), serial) return &acme.Certificate{ - AccountID: "accountID", + AccountID: "someDifferentAccountID", + Leaf: cert, }, nil }, } @@ -1264,3 +1116,400 @@ func TestHandler_RevokeCert(t *testing.T) { }) } } + +func TestHandler_isAccountAuthorized(t *testing.T) { + type test struct { + db acme.DB + ctx context.Context + existingCert *acme.Certificate + certToBeRevoked *x509.Certificate + account *acme.Account + err *acme.Error + } + accountID := "accountID" + var tests = map[string]func(t *testing.T) test{ + "fail/account-invalid": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusInvalid, + } + certToBeRevoked := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + } + return test{ + ctx: context.TODO(), + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided for name 127.0.0.1", + Err: errors.New("account 'accountID' has status 'invalid'"), + }, + } + }, + "fail/no-certificate-identifiers": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{} + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided", + Err: errors.New("cannot authorize revocation without providing identifiers to authorize"), + }, + } + }, + "fail/db.GetAuthorizationsByAccountID-error": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + } + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return nil, errors.New("force") + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: acme.NewErrorISE("error retrieving authorizations for Account %s: force", accountID), + } + }, + "fail/no-valid-authorizations": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + } + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: accountID, + Status: acme.StatusInvalid, + }, + }, nil + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided for name 127.0.0.1", + Err: errors.New("account 'accountID' does not have valid authorizations"), + }, + } + }, + "fail/authorizations-do-not-match-identifiers": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + } + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: accountID, + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.0.2", + }, + }, + }, nil + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: &acme.Error{ + Type: "urn:ietf:params:acme:error:unauthorized", + Status: http.StatusForbidden, + Detail: "No authorization provided for name 127.0.0.1", + Err: errors.New("account 'accountID' does not have authorizations for all identifiers"), + }, + } + }, + "ok": func(t *testing.T) test { + account := &acme.Account{ + ID: accountID, + Status: acme.StatusValid, + } + certToBeRevoked := &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + existingCert := &acme.Certificate{ + AccountID: "differentAccountID", + } + return test{ + db: &acme.MockDB{ + MockGetAuthorizationsByAccountID: func(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + assert.Equals(t, "accountID", accountID) + return []*acme.Authorization{ + { + AccountID: accountID, + Status: acme.StatusValid, + Identifier: acme.Identifier{ + Type: acme.IP, + Value: "127.0.0.1", + }, + }, + }, nil + }, + }, + ctx: context.TODO(), + existingCert: existingCert, + certToBeRevoked: certToBeRevoked, + account: account, + err: nil, + } + }, + } + for name, setup := range tests { + tc := setup(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: tc.db} + acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) + + expectError := tc.err != nil + gotError := acmeErr != nil + if expectError != gotError { + t.Errorf("expected: %t, got: %t", expectError, gotError) + return + } + + if !gotError { + return // nothing to check; return early + } + + assert.Equals(t, acmeErr.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, acmeErr.Type, tc.err.Type) + assert.Equals(t, acmeErr.Status, tc.err.Status) + assert.Equals(t, acmeErr.Detail, tc.err.Detail) + assert.Equals(t, acmeErr.Identifier, tc.err.Identifier) + assert.Equals(t, acmeErr.Subproblems, tc.err.Subproblems) + + }) + } +} + +func Test_identifierKey(t *testing.T) { + tests := []struct { + name string + identifier acme.Identifier + want string + }{ + { + name: "ip", + identifier: acme.Identifier{ + Type: acme.IP, + Value: "10.0.0.1", + }, + want: "ip|10.0.0.1", + }, + { + name: "dns", + identifier: acme.Identifier{ + Type: acme.DNS, + Value: "*.example.com", + }, + want: "dns|*.example.com", + }, + { + name: "unknown", + identifier: acme.Identifier{ + Type: "InvalidType", + Value: "<<>>", + }, + want: "unsupported|<<>>", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := identifierKey(tt.identifier) + if !cmp.Equal(tt.want, got) { + t.Errorf("identifierKey() diff = \n%s", cmp.Diff(tt.want, got)) + } + }) + } +} + +func Test_extractIdentifiers(t *testing.T) { + tests := []struct { + name string + cert *x509.Certificate + want map[string]acme.Identifier + }{ + { + name: "ip", + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + }, + want: map[string]acme.Identifier{ + "ip|127.0.0.1": { + Type: acme.IP, + Value: "127.0.0.1", + }, + }, + }, + { + name: "dns", + cert: &x509.Certificate{ + DNSNames: []string{"*.example.com"}, + }, + want: map[string]acme.Identifier{ + "dns|*.example.com": { + Type: acme.DNS, + Value: "*.example.com", + }, + }, + }, + { + name: "dns-subject", + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "www.example.com", + }, + }, + want: map[string]acme.Identifier{ + "dns|www.example.com": { + Type: acme.DNS, + Value: "www.example.com", + }, + }, + }, + { + name: "ip-subject", + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + }, + want: map[string]acme.Identifier{ + "dns|127.0.0.1": { // this is the currently expected behavior + Type: acme.DNS, + Value: "127.0.0.1", + }, + }, + }, + { + name: "combined", + cert: &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "127.0.0.1", + }, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.2")}, + DNSNames: []string{"*.example.com", "www.example.com"}, + }, + want: map[string]acme.Identifier{ + "ip|127.0.0.1": { + Type: acme.IP, + Value: "127.0.0.1", + }, + "ip|127.0.0.2": { + Type: acme.IP, + Value: "127.0.0.2", + }, + "dns|*.example.com": { + Type: acme.DNS, + Value: "*.example.com", + }, + "dns|www.example.com": { + Type: acme.DNS, + Value: "www.example.com", + }, + "dns|127.0.0.1": { // this is the currently expected behavior + Type: acme.DNS, + Value: "127.0.0.1", + }, + }, + }, + { + name: "ip-duplicates", + cert: &x509.Certificate{ + IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.1"), net.ParseIP("127.0.0.2")}, + }, + want: map[string]acme.Identifier{ + "ip|127.0.0.1": { + Type: acme.IP, + Value: "127.0.0.1", + }, + "ip|127.0.0.2": { + Type: acme.IP, + Value: "127.0.0.2", + }, + }, + }, + { + name: "dns-duplicates", + cert: &x509.Certificate{ + DNSNames: []string{"*.example.com", "www.example.com", "www.example.com"}, + }, + want: map[string]acme.Identifier{ + "dns|*.example.com": { + Type: acme.DNS, + Value: "*.example.com", + }, + "dns|www.example.com": { + Type: acme.DNS, + Value: "www.example.com", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractIdentifiers(tt.cert) + if !cmp.Equal(tt.want, got) { + t.Errorf("extractIdentifiers() diff=\n%s", cmp.Diff(tt.want, got)) + } + }) + } +} diff --git a/acme/db.go b/acme/db.go index 67053269..1675c7e7 100644 --- a/acme/db.go +++ b/acme/db.go @@ -25,6 +25,7 @@ type DB interface { CreateAuthorization(ctx context.Context, az *Authorization) error GetAuthorization(ctx context.Context, id string) (*Authorization, error) UpdateAuthorization(ctx context.Context, az *Authorization) error + GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) CreateCertificate(ctx context.Context, cert *Certificate) error GetCertificate(ctx context.Context, id string) (*Certificate, error) @@ -51,9 +52,10 @@ type MockDB struct { MockCreateNonce func(ctx context.Context) (Nonce, error) MockDeleteNonce func(ctx context.Context, nonce Nonce) error - MockCreateAuthorization func(ctx context.Context, az *Authorization) error - MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) - MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + MockCreateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) + MockUpdateAuthorization func(ctx context.Context, az *Authorization) error + MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error) MockCreateCertificate func(ctx context.Context, cert *Certificate) error MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) @@ -162,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err return m.MockError } +// GetAuthorizationsByAccountID mock +func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) { + if m.MockGetAuthorizationsByAccountID != nil { + return m.MockGetAuthorizationsByAccountID(ctx, accountID) + } else if m.MockError != nil { + return nil, m.MockError + } + return nil, m.MockError +} + // CreateCertificate mock func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { if m.MockCreateCertificate != nil { diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 6decbe4f..01cb7fed 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e nu.Error = az.Error return db.save(ctx, old.ID, nu, old, "authz", authzTable) } + +// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database. +func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) { + entries, err := db.db.List(authzTable) + if err != nil { + return nil, errors.Wrapf(err, "error listing authz") + } + authzs := []*acme.Authorization{} + for _, entry := range entries { + dbaz := new(dbAuthz) + if err = json.Unmarshal(entry.Value, dbaz); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key)) + } + // Filter out all dbAuthzs that don't belong to the accountID. This + // could be made more efficient with additional data structures mapping the + // Account ID to authorizations. Not trivial to do, though. + if dbaz.AccountID != accountID { + continue + } + authzs = append(authzs, &acme.Authorization{ + ID: dbaz.ID, + AccountID: dbaz.AccountID, + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: nil, // challenges not required for current use case + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Token: dbaz.Token, + Error: dbaz.Error, + }) + } + + return authzs, nil +} diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go index 01c255dc..2e5dd3ea 100644 --- a/acme/db/nosql/authz_test.go +++ b/acme/db/nosql/authz_test.go @@ -3,9 +3,11 @@ package nosql import ( "context" "encoding/json" + "fmt" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/acme" @@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) { }) } } + +func TestDB_GetAuthorizationsByAccountID(t *testing.T) { + azID := "azID" + accountID := "accountID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + authzs []*acme.Authorization + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.List-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return nil, errors.New("force") + }, + }, + err: errors.New("error listing authz: force"), + } + }, + "fail/unmarshal": func(t *testing.T) test { + b := []byte(`{malformed}`) + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: nil, + err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID), + } + }, + "ok": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: accountID, + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusValid, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: []*acme.Authorization{ + { + ID: dbaz.ID, + AccountID: dbaz.AccountID, + Token: dbaz.Token, + Identifier: dbaz.Identifier, + Status: dbaz.Status, + Challenges: nil, + Wildcard: dbaz.Wildcard, + ExpiresAt: dbaz.ExpiresAt, + Error: dbaz.Error, + }, + }, + } + }, + "ok/skip-different-account": func(t *testing.T) test { + now := clock.Now() + dbaz := &dbAuthz{ + ID: azID, + AccountID: "differentAccountID", + Identifier: acme.Identifier{ + Type: "dns", + Value: "test.ca.smallstep.com", + }, + Status: acme.StatusValid, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, + } + b, err := json.Marshal(dbaz) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MList: func(bucket []byte) ([]*nosqldb.Entry, error) { + assert.Equals(t, bucket, authzTable) + return []*nosqldb.Entry{ + { + Bucket: bucket, + Key: []byte(azID), + Value: b, + }, + }, nil + }, + }, + authzs: []*acme.Authorization{}, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + if !cmp.Equal(azs, tc.authzs) { + t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs)) + } + } + }) + } +} diff --git a/acme/order.go b/acme/order.go index bd820da1..d4a4c300 100644 --- a/acme/order.go +++ b/acme/order.go @@ -290,6 +290,9 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate // MUST appear either in the commonName portion of the requested subject // name or in an extensionRequest attribute [RFC2985] requesting a // subjectAltName extension, or both. + // TODO(hs): we might want to check if the CommonName is in fact a DNS (and cannot + // be parsed as IP). This is related to https://github.com/smallstep/cli/pull/576 + // (ACME IP SANS) if csr.Subject.CommonName != "" { // nolint:gocritic canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) diff --git a/go.mod b/go.mod index 394eb1a4..830cab82 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/go-kit/kit v0.10.0 // indirect github.com/go-piv/piv-go v1.7.0 github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.0.5 github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect