From a7fbbc47483986899bb420942b0c5de515520371 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Sun, 28 Nov 2021 21:20:57 +0100 Subject: [PATCH] Add tests for GetCertificateBySerial --- acme/api/revoke.go | 6 +- acme/api/revoke_test.go | 24 ++--- acme/db/nosql/certificate_test.go | 145 ++++++++++++++++++++++++++++-- authority/provisioner/acme.go | 2 +- 4 files changed, 152 insertions(+), 25 deletions(-) diff --git a/acme/api/revoke.go b/acme/api/revoke.go index 7ae93152..1c664dde 100644 --- a/acme/api/revoke.go +++ b/acme/api/revoke.go @@ -82,11 +82,11 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)) return } - if existingCert.AccountID != account.ID { // TODO: combine with the below; ony one of the two has to be true + if existingCert.AccountID != account.ID { // TODO(hs): combine this check with the one below; ony one of the two has to be true api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil)) return } - // TODO: check and implement "an account that holds authorizations for all of the identifiers in the certificate." + // TODO(hs): check and implement "an account that holds authorizations for all of the identifiers in the certificate." // In that case the certificate may not have been created by this account, but another account that was authorized before. } else { // if account doesn't need to be checked, the JWS should be verified to be signed by the @@ -157,7 +157,7 @@ func wrapUnauthorizedError(cert *x509.Certificate, msg string, err error) *acme. acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) } acmeErr.Status = http.StatusForbidden - acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO: what about other SANs? When no Subject is in the certificate? + acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO(hs): what about other SANs? When no Subject is in the certificate? return acmeErr } diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 2feae989..05952240 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -616,10 +616,10 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/wrong-certificate-encoding": func(t *testing.T) test { - rp := &revokePayload{ + wrongPayload := &revokePayload{ Certificate: base64.StdEncoding.EncodeToString(cert.Raw), } - wronglyEncodedPayloadBytes, err := json.Marshal(rp) + wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) assert.FatalError(t, err) jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ @@ -648,10 +648,10 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/no-certificate-encoded": func(t *testing.T) test { - rp := &revokePayload{ + emptyPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString([]byte{}), } - wrongPayloadBytes, err := json.Marshal(rp) + wrongPayloadBytes, err := json.Marshal(emptyPayload) assert.FatalError(t, err) jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ @@ -856,15 +856,15 @@ func TestHandler_RevokeCert(t *testing.T) { "fail/unauthorized-certificate-key": func(t *testing.T) test { _, unauthorizedKey, err := generateCertKeyPair() assert.FatalError(t, err) - rp := &revokePayload{ + jwsPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), - ReasonCode: v(1), + ReasonCode: v(2), } jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - unauthorizedPayloadBytes, err := json.Marshal(rp) + unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) @@ -981,11 +981,11 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "fail/invalid-reasoncode": func(t *testing.T) test { - rp := &revokePayload{ + invalidReasonPayload := &revokePayload{ Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), ReasonCode: v(7), } - wrongReasonCodePayloadBytes, err := json.Marshal(rp) + wrongReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) assert.FatalError(t, err) jws := &jose.JSONWebSignature{ Signatures: []jose.Signature{ @@ -1205,16 +1205,10 @@ func TestHandler_RevokeCert(t *testing.T) { } }, "ok/using-certificate-key": func(t *testing.T) test { - rp := &revokePayload{ - Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), - ReasonCode: v(1), - } jwsBytes, err := jwsEncodeJSON(rp, key, "", "nonce", revokeURL) assert.FatalError(t, err) jws, err := jose.ParseJWS(string(jwsBytes)) assert.FatalError(t, err) - payloadBytes, err := json.Marshal(rp) - assert.FatalError(t, err) ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, jwsContextKey, jws) diff --git a/acme/db/nosql/certificate_test.go b/acme/db/nosql/certificate_test.go index 8b6b6ef3..d64b3015 100644 --- a/acme/db/nosql/certificate_test.go +++ b/acme/db/nosql/certificate_test.go @@ -1,10 +1,12 @@ package nosql import ( + "bytes" "context" "crypto/x509" "encoding/json" "encoding/pem" + "fmt" "testing" "time" @@ -14,7 +16,6 @@ import ( "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" nosqldb "github.com/smallstep/nosql/database" - "go.step.sm/crypto/pemutil" ) @@ -31,7 +32,6 @@ func TestDB_CreateCertificate(t *testing.T) { err error _id *string } - countOfCmpAndSwapCalls := 0 var tests = map[string]func(t *testing.T) test{ "fail/cmpAndSwap-error": func(t *testing.T) test { cert := &acme.Certificate{ @@ -76,7 +76,10 @@ func TestDB_CreateCertificate(t *testing.T) { return test{ db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - if countOfCmpAndSwapCalls == 0 { + if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) { + t.Fail() + } + if bytes.Equal(bucket, certTable) { *idPtr = string(key) assert.Equals(t, bucket, certTable) assert.Equals(t, key, []byte(cert.ID)) @@ -90,7 +93,7 @@ func TestDB_CreateCertificate(t *testing.T) { assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) } - if countOfCmpAndSwapCalls == 1 { + if bytes.Equal(bucket, certBySerialTable) { assert.Equals(t, bucket, certBySerialTable) assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String())) assert.Equals(t, old, nil) @@ -103,8 +106,6 @@ func TestDB_CreateCertificate(t *testing.T) { *idPtr = cert.ID } - countOfCmpAndSwapCalls++ - return nil, true, nil }, }, @@ -335,3 +336,135 @@ func Test_parseBundle(t *testing.T) { }) } } + +func TestDB_GetCertificateBySerial(t *testing.T) { + leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt") + assert.FatalError(t, err) + inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt") + assert.FatalError(t, err) + root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt") + assert.FatalError(t, err) + + certID := "certID" + serial := "" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return nil, nosqldb.ErrNotFound + } + return nil, errors.New("wrong table") + }, + }, + acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial), + } + }, + "fail/db-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return nil, errors.New("force") + } + return nil, errors.New("wrong table") + }, + }, + err: fmt.Errorf("error loading certificate ID for serial %s", serial), + } + }, + "fail/unmarshal-dbSerial": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + if bytes.Equal(bucket, certBySerialTable) { + return []byte(`{"serial":malformed!}`), nil + } + return nil, errors.New("wrong table") + }, + }, + err: fmt.Errorf("error unmarshaling certificate with serial %s", serial), + } + }, + "ok": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + + if bytes.Equal(bucket, certBySerialTable) { + certSerial := dbSerial{ + Serial: serial, + CertificateID: certID, + } + + b, err := json.Marshal(certSerial) + assert.FatalError(t, err) + + return b, nil + } + + if bytes.Equal(bucket, certTable) { + cert := dbCert{ + ID: certID, + AccountID: "accountID", + OrderID: "orderID", + Leaf: pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: leaf.Raw, + }), + Intermediates: append(pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: inter.Raw, + }), pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: root.Raw, + })...), + CreatedAt: clock.Now(), + } + b, err := json.Marshal(cert) + assert.FatalError(t, err) + + return b, nil + } + return nil, errors.New("wrong table") + }, + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + cert, err := d.GetCertificateBySerial(context.Background(), serial) + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) { + assert.Equals(t, cert.ID, certID) + assert.Equals(t, cert.AccountID, "accountID") + assert.Equals(t, cert.OrderID, "orderID") + assert.Equals(t, cert.Leaf, leaf) + assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root}) + } + }) + } +} diff --git a/authority/provisioner/acme.go b/authority/provisioner/acme.go index 25821051..c8950568 100644 --- a/authority/provisioner/acme.go +++ b/authority/provisioner/acme.go @@ -102,7 +102,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e // AuthorizeRevoke is called just before the certificate is to be revoked by // the CA. It can be used to authorize revocation of a certificate. It // currently is a no-op. -// TODO: add configuration option that toggles revocation? Or change function signature to make it more useful? +// TODO(hs): add configuration option that toggles revocation? Or change function signature to make it more useful? // Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx. func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { return nil