Add logic for Account authorizations and improve tests

This commit is contained in:
Herman Slatman 2021-12-02 16:25:35 +01:00
parent bae1d256ee
commit 06bb97c91e
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
7 changed files with 845 additions and 280 deletions

View file

@ -1,6 +1,8 @@
package api package api
import ( import (
"bytes"
"context"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -66,35 +68,36 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
existingCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
} }
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
// this should never happen
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
return
}
if shouldCheckAccountFrom(jws) { if shouldCheckAccountFrom(jws) {
account, err := accountFromContext(ctx) account, err := accountFromContext(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if !account.IsValid() { acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)) if acmeErr != nil {
api.WriteError(w, acmeErr)
return return
} }
if existingCert.AccountID != account.ID { // TODO(hs): combine this check with the one below; ony one of the two has to be true
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil))
return
}
// TODO(hs): check and implement "an account that holds authorizations for all of the identifiers in the certificate."
// In that case the certificate may not have been created by this account, but another account that was authorized before.
} else { } else {
// if account doesn't need to be checked, the JWS should be verified to be signed by the // if account doesn't need to be checked, the JWS should be verified to be signed by the
// private key that belongs to the public key in the certificate to be revoked. // private key that belongs to the public key in the certificate to be revoked.
_, err := jws.Verify(certToBeRevoked.PublicKey) _, err := jws.Verify(certToBeRevoked.PublicKey)
if err != nil { if err != nil {
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized? // TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, "verification of jws using certificate public key failed", err)) api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
return return
} }
} }
@ -137,6 +140,107 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
w.Write(nil) w.Write(nil)
} }
// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized
// to revoke the certificate. An Account must always be valid in order to revoke a certificate.
// In case the certificate retrieved from the database belongs to the Account, the Account is
// authorized. If the certificate retrieved from the database doesn't belong to the Account,
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
}
certificateBelongsToAccount := dbCert.AccountID == account.ID
if certificateBelongsToAccount {
return nil // return early; skip relatively expensive database check
}
requiredIdentifiers := extractIdentifiers(certToBeRevoked)
if len(requiredIdentifiers) == 0 {
return wrapUnauthorizedError(certToBeRevoked, nil, "cannot authorize revocation without providing identifiers to authorize", nil)
}
authzs, err := h.db.GetAuthorizationsByAccountID(ctx, account.ID)
if err != nil {
return acme.WrapErrorISE(err, "error retrieving authorizations for Account %s", account.ID)
}
authorizedIdentifiers := map[string]acme.Identifier{}
for _, authz := range authzs {
// Only valid Authorizations are included
if authz.Status != acme.StatusValid {
continue
}
authorizedIdentifiers[identifierKey(authz.Identifier)] = authz.Identifier
}
if len(authorizedIdentifiers) == 0 {
unauthorizedIdentifiers := []acme.Identifier{}
for _, identifier := range requiredIdentifiers {
unauthorizedIdentifiers = append(unauthorizedIdentifiers, identifier)
}
return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have valid authorizations", account.ID), nil)
}
unauthorizedIdentifiers := []acme.Identifier{}
for key := range requiredIdentifiers {
_, ok := authorizedIdentifiers[key]
if !ok {
unauthorizedIdentifiers = append(unauthorizedIdentifiers, requiredIdentifiers[key])
}
}
if len(unauthorizedIdentifiers) != 0 {
return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have authorizations for all identifiers", account.ID), nil)
}
return nil
}
// identifierKey creates a unique key for an ACME identifier using
// the following format: ip|127.0.0.1; dns|*.example.com
func identifierKey(identifier acme.Identifier) string {
if identifier.Type == acme.IP {
return "ip|" + identifier.Value
}
if identifier.Type == acme.DNS {
return "dns|" + identifier.Value
}
return "unsupported|" + identifier.Value
}
// extractIdentifiers extracts ACME identifiers from an x509 certificate and
// creates a map from them. The map ensures that double SANs are deduplicated.
// The Subject CommonName is included, because RFC8555 7.4 states that DNS
// identifiers can come from either the CommonName or a DNS SAN or both. When
// authorizing issuance, the DNS identifier must be in the request and will be
// included in the validation (see Order.sans()) as of now. This means that the
// CommonName will in fact have an authorization available.
func extractIdentifiers(cert *x509.Certificate) map[string]acme.Identifier {
result := map[string]acme.Identifier{}
for _, name := range cert.DNSNames {
identifier := acme.Identifier{
Type: acme.DNS,
Value: name,
}
result[identifierKey(identifier)] = identifier
}
for _, ip := range cert.IPAddresses {
identifier := acme.Identifier{
Type: acme.IP,
Value: ip.String(),
}
result[identifierKey(identifier)] = identifier
}
// TODO(hs): should we include the CommonName or not?
if cert.Subject.CommonName != "" {
identifier := acme.Identifier{
// assuming only DNS can be in Common Name (RFC8555, 7.4); RFC8738
// IP Identifier Validation Extension does not state anything about this.
// This logic is in accordance with the logic in order.canonicalize()
Type: acme.DNS,
Value: cert.Subject.CommonName,
}
result[identifierKey(identifier)] = identifier
}
return result
}
// wrapRevokeErr is a best effort implementation to transform an error during // wrapRevokeErr is a best effort implementation to transform an error during
// revocation into an ACME error, so that clients can understand the error. // revocation into an ACME error, so that clients can understand the error.
func wrapRevokeErr(err error) *acme.Error { func wrapRevokeErr(err error) *acme.Error {
@ -149,15 +253,24 @@ func wrapRevokeErr(err error) *acme.Error {
// unauthorizedError returns an ACME error indicating the request was // unauthorizedError returns an ACME error indicating the request was
// not authorized to revoke the certificate. // not authorized to revoke the certificate.
func wrapUnauthorizedError(cert *x509.Certificate, msg string, err error) *acme.Error { func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
var acmeErr *acme.Error var acmeErr *acme.Error
if err == nil { if err == nil {
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg) acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
} else { } else {
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg) acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
} }
acmeErr.Status = http.StatusForbidden acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO(hs): what about other SANs? When no Subject is in the certificate?
switch {
case len(unauthorizedIdentifiers) > 0:
identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too?
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value)
case cert.Subject.String() != "":
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName)
default:
acmeErr.Detail = "No authorization provided"
}
return acmeErr return acmeErr
} }

File diff suppressed because it is too large Load diff

View file

@ -25,6 +25,7 @@ type DB interface {
CreateAuthorization(ctx context.Context, az *Authorization) error CreateAuthorization(ctx context.Context, az *Authorization) error
GetAuthorization(ctx context.Context, id string) (*Authorization, error) GetAuthorization(ctx context.Context, id string) (*Authorization, error)
UpdateAuthorization(ctx context.Context, az *Authorization) error UpdateAuthorization(ctx context.Context, az *Authorization) error
GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error)
CreateCertificate(ctx context.Context, cert *Certificate) error CreateCertificate(ctx context.Context, cert *Certificate) error
GetCertificate(ctx context.Context, id string) (*Certificate, error) GetCertificate(ctx context.Context, id string) (*Certificate, error)
@ -54,6 +55,7 @@ type MockDB struct {
MockCreateAuthorization func(ctx context.Context, az *Authorization) error MockCreateAuthorization func(ctx context.Context, az *Authorization) error
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error) MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error)
MockCreateCertificate func(ctx context.Context, cert *Certificate) error MockCreateCertificate func(ctx context.Context, cert *Certificate) error
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error) MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
@ -162,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err
return m.MockError return m.MockError
} }
// GetAuthorizationsByAccountID mock
func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) {
if m.MockGetAuthorizationsByAccountID != nil {
return m.MockGetAuthorizationsByAccountID(ctx, accountID)
} else if m.MockError != nil {
return nil, m.MockError
}
return nil, m.MockError
}
// CreateCertificate mock // CreateCertificate mock
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error { func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
if m.MockCreateCertificate != nil { if m.MockCreateCertificate != nil {

View file

@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
nu.Error = az.Error nu.Error = az.Error
return db.save(ctx, old.ID, nu, old, "authz", authzTable) return db.save(ctx, old.ID, nu, old, "authz", authzTable)
} }
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
entries, err := db.db.List(authzTable)
if err != nil {
return nil, errors.Wrapf(err, "error listing authz")
}
authzs := []*acme.Authorization{}
for _, entry := range entries {
dbaz := new(dbAuthz)
if err = json.Unmarshal(entry.Value, dbaz); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key))
}
// Filter out all dbAuthzs that don't belong to the accountID. This
// could be made more efficient with additional data structures mapping the
// Account ID to authorizations. Not trivial to do, though.
if dbaz.AccountID != accountID {
continue
}
authzs = append(authzs, &acme.Authorization{
ID: dbaz.ID,
AccountID: dbaz.AccountID,
Identifier: dbaz.Identifier,
Status: dbaz.Status,
Challenges: nil, // challenges not required for current use case
Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt,
Token: dbaz.Token,
Error: dbaz.Error,
})
}
return authzs, nil
}

View file

@ -3,9 +3,11 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) {
}) })
} }
} }
func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
azID := "azID"
accountID := "accountID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
authzs []*acme.Authorization
}
var tests = map[string]func(t *testing.T) test{
"fail/db.List-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return nil, errors.New("force")
},
},
err: errors.New("error listing authz: force"),
}
},
"fail/unmarshal": func(t *testing.T) test {
b := []byte(`{malformed}`)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: nil,
err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID),
}
},
"ok": func(t *testing.T) test {
now := clock.Now()
dbaz := &dbAuthz{
ID: azID,
AccountID: accountID,
Identifier: acme.Identifier{
Type: "dns",
Value: "test.ca.smallstep.com",
},
Status: acme.StatusValid,
Token: "token",
CreatedAt: now,
ExpiresAt: now.Add(5 * time.Minute),
ChallengeIDs: []string{"foo", "bar"},
Wildcard: true,
}
b, err := json.Marshal(dbaz)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: []*acme.Authorization{
{
ID: dbaz.ID,
AccountID: dbaz.AccountID,
Token: dbaz.Token,
Identifier: dbaz.Identifier,
Status: dbaz.Status,
Challenges: nil,
Wildcard: dbaz.Wildcard,
ExpiresAt: dbaz.ExpiresAt,
Error: dbaz.Error,
},
},
}
},
"ok/skip-different-account": func(t *testing.T) test {
now := clock.Now()
dbaz := &dbAuthz{
ID: azID,
AccountID: "differentAccountID",
Identifier: acme.Identifier{
Type: "dns",
Value: "test.ca.smallstep.com",
},
Status: acme.StatusValid,
Token: "token",
CreatedAt: now,
ExpiresAt: now.Add(5 * time.Minute),
ChallengeIDs: []string{"foo", "bar"},
Wildcard: true,
}
b, err := json.Marshal(dbaz)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
assert.Equals(t, bucket, authzTable)
return []*nosqldb.Entry{
{
Bucket: bucket,
Key: []byte(azID),
Value: b,
},
}, nil
},
},
authzs: []*acme.Authorization{},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) {
if !cmp.Equal(azs, tc.authzs) {
t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs))
}
}
})
}
}

View file

@ -290,6 +290,9 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
// MUST appear either in the commonName portion of the requested subject // MUST appear either in the commonName portion of the requested subject
// name or in an extensionRequest attribute [RFC2985] requesting a // name or in an extensionRequest attribute [RFC2985] requesting a
// subjectAltName extension, or both. // subjectAltName extension, or both.
// TODO(hs): we might want to check if the CommonName is in fact a DNS (and cannot
// be parsed as IP). This is related to https://github.com/smallstep/cli/pull/576
// (ACME IP SANS)
if csr.Subject.CommonName != "" { if csr.Subject.CommonName != "" {
// nolint:gocritic // nolint:gocritic
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)

1
go.mod
View file

@ -18,6 +18,7 @@ require (
github.com/go-kit/kit v0.10.0 // indirect github.com/go-kit/kit v0.10.0 // indirect
github.com/go-piv/piv-go v1.7.0 github.com/go-piv/piv-go v1.7.0
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.6
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/googleapis/gax-go/v2 v2.0.5 github.com/googleapis/gax-go/v2 v2.0.5
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect