Add logic for Account authorizations and improve tests
This commit is contained in:
parent
bae1d256ee
commit
06bb97c91e
7 changed files with 845 additions and 280 deletions
|
@ -1,6 +1,8 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -66,35 +68,36 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
serial := certToBeRevoked.SerialNumber.String()
|
||||
existingCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||
dbCert, err := h.db.GetCertificateBySerial(ctx, serial)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.Equal(dbCert.Leaf.Raw, certToBeRevoked.Raw) {
|
||||
// this should never happen
|
||||
api.WriteError(w, acme.NewErrorISE("certificate raw bytes are not equal"))
|
||||
return
|
||||
}
|
||||
|
||||
if shouldCheckAccountFrom(jws) {
|
||||
account, err := accountFromContext(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
if !account.IsValid() {
|
||||
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil))
|
||||
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
|
||||
if acmeErr != nil {
|
||||
api.WriteError(w, acmeErr)
|
||||
return
|
||||
}
|
||||
if existingCert.AccountID != account.ID { // TODO(hs): combine this check with the one below; ony one of the two has to be true
|
||||
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, fmt.Sprintf("account '%s' does not own certificate '%s'", account.ID, existingCert.ID), nil))
|
||||
return
|
||||
}
|
||||
// TODO(hs): check and implement "an account that holds authorizations for all of the identifiers in the certificate."
|
||||
// In that case the certificate may not have been created by this account, but another account that was authorized before.
|
||||
} else {
|
||||
// if account doesn't need to be checked, the JWS should be verified to be signed by the
|
||||
// private key that belongs to the public key in the certificate to be revoked.
|
||||
_, err := jws.Verify(certToBeRevoked.PublicKey)
|
||||
if err != nil {
|
||||
// TODO(hs): possible to determine an error vs. unauthorized and thus provide an ISE vs. Unauthorized?
|
||||
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, "verification of jws using certificate public key failed", err))
|
||||
api.WriteError(w, wrapUnauthorizedError(certToBeRevoked, nil, "verification of jws using certificate public key failed", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -137,6 +140,107 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
|
|||
w.Write(nil)
|
||||
}
|
||||
|
||||
// isAccountAuthorized checks if an ACME account that was retrieved earlier is authorized
|
||||
// to revoke the certificate. An Account must always be valid in order to revoke a certificate.
|
||||
// In case the certificate retrieved from the database belongs to the Account, the Account is
|
||||
// authorized. If the certificate retrieved from the database doesn't belong to the Account,
|
||||
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations
|
||||
// that are stored for the ACME Account. If these sets match, the Account is considered authorized
|
||||
// to revoke the certificate. If this check fails, the client will receive an unauthorized error.
|
||||
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
|
||||
if !account.IsValid() {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
|
||||
}
|
||||
certificateBelongsToAccount := dbCert.AccountID == account.ID
|
||||
if certificateBelongsToAccount {
|
||||
return nil // return early; skip relatively expensive database check
|
||||
}
|
||||
requiredIdentifiers := extractIdentifiers(certToBeRevoked)
|
||||
if len(requiredIdentifiers) == 0 {
|
||||
return wrapUnauthorizedError(certToBeRevoked, nil, "cannot authorize revocation without providing identifiers to authorize", nil)
|
||||
}
|
||||
authzs, err := h.db.GetAuthorizationsByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
return acme.WrapErrorISE(err, "error retrieving authorizations for Account %s", account.ID)
|
||||
}
|
||||
authorizedIdentifiers := map[string]acme.Identifier{}
|
||||
for _, authz := range authzs {
|
||||
// Only valid Authorizations are included
|
||||
if authz.Status != acme.StatusValid {
|
||||
continue
|
||||
}
|
||||
authorizedIdentifiers[identifierKey(authz.Identifier)] = authz.Identifier
|
||||
}
|
||||
if len(authorizedIdentifiers) == 0 {
|
||||
unauthorizedIdentifiers := []acme.Identifier{}
|
||||
for _, identifier := range requiredIdentifiers {
|
||||
unauthorizedIdentifiers = append(unauthorizedIdentifiers, identifier)
|
||||
}
|
||||
return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have valid authorizations", account.ID), nil)
|
||||
}
|
||||
unauthorizedIdentifiers := []acme.Identifier{}
|
||||
for key := range requiredIdentifiers {
|
||||
_, ok := authorizedIdentifiers[key]
|
||||
if !ok {
|
||||
unauthorizedIdentifiers = append(unauthorizedIdentifiers, requiredIdentifiers[key])
|
||||
}
|
||||
}
|
||||
if len(unauthorizedIdentifiers) != 0 {
|
||||
return wrapUnauthorizedError(certToBeRevoked, unauthorizedIdentifiers, fmt.Sprintf("account '%s' does not have authorizations for all identifiers", account.ID), nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// identifierKey creates a unique key for an ACME identifier using
|
||||
// the following format: ip|127.0.0.1; dns|*.example.com
|
||||
func identifierKey(identifier acme.Identifier) string {
|
||||
if identifier.Type == acme.IP {
|
||||
return "ip|" + identifier.Value
|
||||
}
|
||||
if identifier.Type == acme.DNS {
|
||||
return "dns|" + identifier.Value
|
||||
}
|
||||
return "unsupported|" + identifier.Value
|
||||
}
|
||||
|
||||
// extractIdentifiers extracts ACME identifiers from an x509 certificate and
|
||||
// creates a map from them. The map ensures that double SANs are deduplicated.
|
||||
// The Subject CommonName is included, because RFC8555 7.4 states that DNS
|
||||
// identifiers can come from either the CommonName or a DNS SAN or both. When
|
||||
// authorizing issuance, the DNS identifier must be in the request and will be
|
||||
// included in the validation (see Order.sans()) as of now. This means that the
|
||||
// CommonName will in fact have an authorization available.
|
||||
func extractIdentifiers(cert *x509.Certificate) map[string]acme.Identifier {
|
||||
result := map[string]acme.Identifier{}
|
||||
for _, name := range cert.DNSNames {
|
||||
identifier := acme.Identifier{
|
||||
Type: acme.DNS,
|
||||
Value: name,
|
||||
}
|
||||
result[identifierKey(identifier)] = identifier
|
||||
}
|
||||
for _, ip := range cert.IPAddresses {
|
||||
identifier := acme.Identifier{
|
||||
Type: acme.IP,
|
||||
Value: ip.String(),
|
||||
}
|
||||
result[identifierKey(identifier)] = identifier
|
||||
}
|
||||
// TODO(hs): should we include the CommonName or not?
|
||||
if cert.Subject.CommonName != "" {
|
||||
identifier := acme.Identifier{
|
||||
// assuming only DNS can be in Common Name (RFC8555, 7.4); RFC8738
|
||||
// IP Identifier Validation Extension does not state anything about this.
|
||||
// This logic is in accordance with the logic in order.canonicalize()
|
||||
Type: acme.DNS,
|
||||
Value: cert.Subject.CommonName,
|
||||
}
|
||||
result[identifierKey(identifier)] = identifier
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// wrapRevokeErr is a best effort implementation to transform an error during
|
||||
// revocation into an ACME error, so that clients can understand the error.
|
||||
func wrapRevokeErr(err error) *acme.Error {
|
||||
|
@ -149,15 +253,24 @@ func wrapRevokeErr(err error) *acme.Error {
|
|||
|
||||
// unauthorizedError returns an ACME error indicating the request was
|
||||
// not authorized to revoke the certificate.
|
||||
func wrapUnauthorizedError(cert *x509.Certificate, msg string, err error) *acme.Error {
|
||||
func wrapUnauthorizedError(cert *x509.Certificate, unauthorizedIdentifiers []acme.Identifier, msg string, err error) *acme.Error {
|
||||
var acmeErr *acme.Error
|
||||
if err == nil {
|
||||
acmeErr = acme.NewError(acme.ErrorUnauthorizedType, msg)
|
||||
} else {
|
||||
acmeErr = acme.WrapError(acme.ErrorUnauthorizedType, err, msg)
|
||||
}
|
||||
acmeErr.Status = http.StatusForbidden
|
||||
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.String()) // TODO(hs): what about other SANs? When no Subject is in the certificate?
|
||||
acmeErr.Status = http.StatusForbidden // RFC8555 7.6 shows example with 403
|
||||
|
||||
switch {
|
||||
case len(unauthorizedIdentifiers) > 0:
|
||||
identifier := unauthorizedIdentifiers[0] // picking the first; compound may be an option too?
|
||||
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", identifier.Value)
|
||||
case cert.Subject.String() != "":
|
||||
acmeErr.Detail = fmt.Sprintf("No authorization provided for name %s", cert.Subject.CommonName)
|
||||
default:
|
||||
acmeErr.Detail = "No authorization provided"
|
||||
}
|
||||
|
||||
return acmeErr
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
12
acme/db.go
12
acme/db.go
|
@ -25,6 +25,7 @@ type DB interface {
|
|||
CreateAuthorization(ctx context.Context, az *Authorization) error
|
||||
GetAuthorization(ctx context.Context, id string) (*Authorization, error)
|
||||
UpdateAuthorization(ctx context.Context, az *Authorization) error
|
||||
GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||
|
||||
CreateCertificate(ctx context.Context, cert *Certificate) error
|
||||
GetCertificate(ctx context.Context, id string) (*Certificate, error)
|
||||
|
@ -54,6 +55,7 @@ type MockDB struct {
|
|||
MockCreateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
MockGetAuthorization func(ctx context.Context, id string) (*Authorization, error)
|
||||
MockUpdateAuthorization func(ctx context.Context, az *Authorization) error
|
||||
MockGetAuthorizationsByAccountID func(ctx context.Context, accountID string) ([]*Authorization, error)
|
||||
|
||||
MockCreateCertificate func(ctx context.Context, cert *Certificate) error
|
||||
MockGetCertificate func(ctx context.Context, id string) (*Certificate, error)
|
||||
|
@ -162,6 +164,16 @@ func (m *MockDB) UpdateAuthorization(ctx context.Context, az *Authorization) err
|
|||
return m.MockError
|
||||
}
|
||||
|
||||
// GetAuthorizationsByAccountID mock
|
||||
func (m *MockDB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*Authorization, error) {
|
||||
if m.MockGetAuthorizationsByAccountID != nil {
|
||||
return m.MockGetAuthorizationsByAccountID(ctx, accountID)
|
||||
} else if m.MockError != nil {
|
||||
return nil, m.MockError
|
||||
}
|
||||
return nil, m.MockError
|
||||
}
|
||||
|
||||
// CreateCertificate mock
|
||||
func (m *MockDB) CreateCertificate(ctx context.Context, cert *Certificate) error {
|
||||
if m.MockCreateCertificate != nil {
|
||||
|
|
|
@ -116,3 +116,37 @@ func (db *DB) UpdateAuthorization(ctx context.Context, az *acme.Authorization) e
|
|||
nu.Error = az.Error
|
||||
return db.save(ctx, old.ID, nu, old, "authz", authzTable)
|
||||
}
|
||||
|
||||
// GetAuthorizationsByAccountID retrieves and unmarshals ACME authz types from the database.
|
||||
func (db *DB) GetAuthorizationsByAccountID(ctx context.Context, accountID string) ([]*acme.Authorization, error) {
|
||||
entries, err := db.db.List(authzTable)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error listing authz")
|
||||
}
|
||||
authzs := []*acme.Authorization{}
|
||||
for _, entry := range entries {
|
||||
dbaz := new(dbAuthz)
|
||||
if err = json.Unmarshal(entry.Value, dbaz); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling dbAuthz key '%s' into dbAuthz struct", string(entry.Key))
|
||||
}
|
||||
// Filter out all dbAuthzs that don't belong to the accountID. This
|
||||
// could be made more efficient with additional data structures mapping the
|
||||
// Account ID to authorizations. Not trivial to do, though.
|
||||
if dbaz.AccountID != accountID {
|
||||
continue
|
||||
}
|
||||
authzs = append(authzs, &acme.Authorization{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil, // challenges not required for current use case
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Token: dbaz.Token,
|
||||
Error: dbaz.Error,
|
||||
})
|
||||
}
|
||||
|
||||
return authzs, nil
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package nosql
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
|
@ -614,3 +616,154 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB_GetAuthorizationsByAccountID(t *testing.T) {
|
||||
azID := "azID"
|
||||
accountID := "accountID"
|
||||
type test struct {
|
||||
db nosql.DB
|
||||
err error
|
||||
acmeErr *acme.Error
|
||||
authzs []*acme.Authorization
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"fail/db.List-error": func(t *testing.T) test {
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
err: errors.New("error listing authz: force"),
|
||||
}
|
||||
},
|
||||
"fail/unmarshal": func(t *testing.T) test {
|
||||
b := []byte(`{malformed}`)
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: nil,
|
||||
err: fmt.Errorf("error unmarshaling dbAuthz key '%s' into dbAuthz struct", azID),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: accountID,
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusValid,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: []*acme.Authorization{
|
||||
{
|
||||
ID: dbaz.ID,
|
||||
AccountID: dbaz.AccountID,
|
||||
Token: dbaz.Token,
|
||||
Identifier: dbaz.Identifier,
|
||||
Status: dbaz.Status,
|
||||
Challenges: nil,
|
||||
Wildcard: dbaz.Wildcard,
|
||||
ExpiresAt: dbaz.ExpiresAt,
|
||||
Error: dbaz.Error,
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/skip-different-account": func(t *testing.T) test {
|
||||
now := clock.Now()
|
||||
dbaz := &dbAuthz{
|
||||
ID: azID,
|
||||
AccountID: "differentAccountID",
|
||||
Identifier: acme.Identifier{
|
||||
Type: "dns",
|
||||
Value: "test.ca.smallstep.com",
|
||||
},
|
||||
Status: acme.StatusValid,
|
||||
Token: "token",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(5 * time.Minute),
|
||||
ChallengeIDs: []string{"foo", "bar"},
|
||||
Wildcard: true,
|
||||
}
|
||||
b, err := json.Marshal(dbaz)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
return test{
|
||||
db: &db.MockNoSQLDB{
|
||||
MList: func(bucket []byte) ([]*nosqldb.Entry, error) {
|
||||
assert.Equals(t, bucket, authzTable)
|
||||
return []*nosqldb.Entry{
|
||||
{
|
||||
Bucket: bucket,
|
||||
Key: []byte(azID),
|
||||
Value: b,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authzs: []*acme.Authorization{},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, run := range tests {
|
||||
tc := run(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
d := DB{db: tc.db}
|
||||
if azs, err := d.GetAuthorizationsByAccountID(context.Background(), accountID); err != nil {
|
||||
switch k := err.(type) {
|
||||
case *acme.Error:
|
||||
if assert.NotNil(t, tc.acmeErr) {
|
||||
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
||||
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
||||
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
||||
}
|
||||
default:
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||
}
|
||||
}
|
||||
} else if assert.Nil(t, tc.err) {
|
||||
if !cmp.Equal(azs, tc.authzs) {
|
||||
t.Errorf("db.GetAuthorizationsByAccountID() diff =\n%s", cmp.Diff(azs, tc.authzs))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -290,6 +290,9 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
|
|||
// MUST appear either in the commonName portion of the requested subject
|
||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||
// subjectAltName extension, or both.
|
||||
// TODO(hs): we might want to check if the CommonName is in fact a DNS (and cannot
|
||||
// be parsed as IP). This is related to https://github.com/smallstep/cli/pull/576
|
||||
// (ACME IP SANS)
|
||||
if csr.Subject.CommonName != "" {
|
||||
// nolint:gocritic
|
||||
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||
|
|
1
go.mod
1
go.mod
|
@ -18,6 +18,7 @@ require (
|
|||
github.com/go-kit/kit v0.10.0 // indirect
|
||||
github.com/go-piv/piv-go v1.7.0
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.5.6
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/googleapis/gax-go/v2 v2.0.5
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect
|
||||
|
|
Loading…
Reference in a new issue