Add tests for GetCertificateBySerial

This commit is contained in:
Herman Slatman 2021-11-28 21:20:57 +01:00
parent 4d01cf8135
commit a7fbbc4748
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
4 changed files with 152 additions and 25 deletions

View file

@ -82,11 +82,11 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)) api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil))
return return
} }
if existingCert.AccountID != account.ID { // TODO: combine with the below; ony one of the two has to be true if existingCert.AccountID != account.ID { // TODO(hs): combine this check with the one below; ony one of the two has to be true
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil)) api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil))
return return
} }
// TODO: check and implement "an account that holds authorizations for all of the identifiers in the certificate." // TODO(hs): check and implement "an account that holds authorizations for all of the identifiers in the certificate."
// In that case the certificate may not have been created by this account, but another account that was authorized before. // In that case the certificate may not have been created by this account, but another account that was authorized before.
} else { } else {
// if account doesn't need to be checked, the JWS should be verified to be signed by the // if account doesn't need to be checked, the JWS should be verified to be signed by the
@ -157,7 +157,7 @@ func wrapUnauthorizedError(cert *x509.Certificate, msg string, err error) *acme.
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
} }
acmeErr.Status = http.StatusForbidden acmeErr.Status = http.StatusForbidden
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO: what about other SANs? When no Subject is in the certificate? acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO(hs): what about other SANs? When no Subject is in the certificate?
return acmeErr return acmeErr
} }

View file

@ -616,10 +616,10 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/wrong-certificate-encoding": func(t *testing.T) test { "fail/wrong-certificate-encoding": func(t *testing.T) test {
rp := &revokePayload{ wrongPayload := &revokePayload{
Certificate: base64.StdEncoding.EncodeToString(cert.Raw), Certificate: base64.StdEncoding.EncodeToString(cert.Raw),
} }
wronglyEncodedPayloadBytes, err := json.Marshal(rp) wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
jws := &jose.JSONWebSignature{ jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{ Signatures: []jose.Signature{
@ -648,10 +648,10 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/no-certificate-encoded": func(t *testing.T) test { "fail/no-certificate-encoded": func(t *testing.T) test {
rp := &revokePayload{ emptyPayload := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString([]byte{}), Certificate: base64.RawURLEncoding.EncodeToString([]byte{}),
} }
wrongPayloadBytes, err := json.Marshal(rp) wrongPayloadBytes, err := json.Marshal(emptyPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
jws := &jose.JSONWebSignature{ jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{ Signatures: []jose.Signature{
@ -856,15 +856,15 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/unauthorized-certificate-key": func(t *testing.T) test { "fail/unauthorized-certificate-key": func(t *testing.T) test {
_, unauthorizedKey, err := generateCertKeyPair() _, unauthorizedKey, err := generateCertKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
rp := &revokePayload{ jwsPayload := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw),
ReasonCode: v(1), ReasonCode: v(2),
} }
jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL) jwsBytes, err := jwsEncodeJSON(rp, unauthorizedKey, "", "nonce", revokeURL)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes)) jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err) assert.FatalError(t, err)
unauthorizedPayloadBytes, err := json.Marshal(rp) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
@ -981,11 +981,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/invalid-reasoncode": func(t *testing.T) test { "fail/invalid-reasoncode": func(t *testing.T) test {
rp := &revokePayload{ invalidReasonPayload := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw), Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw),
ReasonCode: v(7), ReasonCode: v(7),
} }
wrongReasonCodePayloadBytes, err := json.Marshal(rp) wrongReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
jws := &jose.JSONWebSignature{ jws := &jose.JSONWebSignature{
Signatures: []jose.Signature{ Signatures: []jose.Signature{
@ -1205,16 +1205,10 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"ok/using-certificate-key": func(t *testing.T) test { "ok/using-certificate-key": func(t *testing.T) test {
rp := &revokePayload{
Certificate: base64.RawURLEncoding.EncodeToString(cert.Raw),
ReasonCode: v(1),
}
jwsBytes, err := jwsEncodeJSON(rp, key, "", "nonce", revokeURL) jwsBytes, err := jwsEncodeJSON(rp, key, "", "nonce", revokeURL)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes)) jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err) assert.FatalError(t, err)
payloadBytes, err := json.Marshal(rp)
assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)

View file

@ -1,10 +1,12 @@
package nosql package nosql
import ( import (
"bytes"
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt"
"testing" "testing"
"time" "time"
@ -14,7 +16,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database" nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -31,7 +32,6 @@ func TestDB_CreateCertificate(t *testing.T) {
err error err error
_id *string _id *string
} }
countOfCmpAndSwapCalls := 0
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test { "fail/cmpAndSwap-error": func(t *testing.T) test {
cert := &acme.Certificate{ cert := &acme.Certificate{
@ -76,7 +76,10 @@ func TestDB_CreateCertificate(t *testing.T) {
return test{ return test{
db: &db.MockNoSQLDB{ db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
if countOfCmpAndSwapCalls == 0 { if !bytes.Equal(bucket, certTable) && !bytes.Equal(bucket, certBySerialTable) {
t.Fail()
}
if bytes.Equal(bucket, certTable) {
*idPtr = string(key) *idPtr = string(key)
assert.Equals(t, bucket, certTable) assert.Equals(t, bucket, certTable)
assert.Equals(t, key, []byte(cert.ID)) assert.Equals(t, key, []byte(cert.ID))
@ -90,7 +93,7 @@ func TestDB_CreateCertificate(t *testing.T) {
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
} }
if countOfCmpAndSwapCalls == 1 { if bytes.Equal(bucket, certBySerialTable) {
assert.Equals(t, bucket, certBySerialTable) assert.Equals(t, bucket, certBySerialTable)
assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String())) assert.Equals(t, key, []byte(cert.Leaf.SerialNumber.String()))
assert.Equals(t, old, nil) assert.Equals(t, old, nil)
@ -103,8 +106,6 @@ func TestDB_CreateCertificate(t *testing.T) {
*idPtr = cert.ID *idPtr = cert.ID
} }
countOfCmpAndSwapCalls++
return nil, true, nil return nil, true, nil
}, },
}, },
@ -335,3 +336,135 @@ func Test_parseBundle(t *testing.T) {
}) })
} }
} }
func TestDB_GetCertificateBySerial(t *testing.T) {
leaf, err := pemutil.ReadCertificate("../../../authority/testdata/certs/foo.crt")
assert.FatalError(t, err)
inter, err := pemutil.ReadCertificate("../../../authority/testdata/certs/intermediate_ca.crt")
assert.FatalError(t, err)
root, err := pemutil.ReadCertificate("../../../authority/testdata/certs/root_ca.crt")
assert.FatalError(t, err)
certID := "certID"
serial := ""
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, nosqldb.ErrNotFound
}
return nil, errors.New("wrong table")
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate with serial %s not found", serial),
}
},
"fail/db-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return nil, errors.New("force")
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error loading certificate ID for serial %s", serial),
}
},
"fail/unmarshal-dbSerial": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
return []byte(`{"serial":malformed!}`), nil
}
return nil, errors.New("wrong table")
},
},
err: fmt.Errorf("error unmarshaling certificate with serial %s", serial),
}
},
"ok": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if bytes.Equal(bucket, certBySerialTable) {
certSerial := dbSerial{
Serial: serial,
CertificateID: certID,
}
b, err := json.Marshal(certSerial)
assert.FatalError(t, err)
return b, nil
}
if bytes.Equal(bucket, certTable) {
cert := dbCert{
ID: certID,
AccountID: "accountID",
OrderID: "orderID",
Leaf: pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: leaf.Raw,
}),
Intermediates: append(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: inter.Raw,
}), pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: root.Raw,
})...),
CreatedAt: clock.Now(),
}
b, err := json.Marshal(cert)
assert.FatalError(t, err)
return b, nil
}
return nil, errors.New("wrong table")
},
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
cert, err := d.GetCertificateBySerial(context.Background(), serial)
if err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
})
}
}

View file

@ -102,7 +102,7 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// AuthorizeRevoke is called just before the certificate is to be revoked by // AuthorizeRevoke is called just before the certificate is to be revoked by
// the CA. It can be used to authorize revocation of a certificate. It // the CA. It can be used to authorize revocation of a certificate. It
// currently is a no-op. // currently is a no-op.
// TODO: add configuration option that toggles revocation? Or change function signature to make it more useful? // TODO(hs): add configuration option that toggles revocation? Or change function signature to make it more useful?
// Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx. // Or move certain logic out of the Revoke API to here? Would likely involve some more stuff in the ctx.
func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error { func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
return nil return nil