Improve error creation and testing for core policy engine

This commit is contained in:
Herman Slatman 2022-04-26 01:47:07 +02:00
parent 20f5d12b99
commit 76112c2da1
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
9 changed files with 974 additions and 409 deletions

View file

@ -274,7 +274,7 @@ func isAllowed(engine authPolicy.X509Policy, sans []string) error {
if allowed, err = engine.AreSANsAllowed(sans); err != nil {
var policyErr *policy.NamePolicyError
isNamePolicyError := errors.As(err, &policyErr)
if isNamePolicyError && policyErr.Reason == policy.NotAuthorizedForThisName {
if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
return &PolicyError{
Typ: AdminLockOut,
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),

View file

@ -58,7 +58,7 @@ func TestAuthority_checkPolicy(t *testing.T) {
},
err: &PolicyError{
Typ: EvaluationFailure,
Err: errors.New("cannot parse domain: dns \"*\" cannot be converted to ASCII"),
Err: errors.New("cannot parse dns domain \"*\""),
},
}
},
@ -105,7 +105,7 @@ func TestAuthority_checkPolicy(t *testing.T) {
},
err: &PolicyError{
Typ: EvaluationFailure,
Err: errors.New("cannot parse domain: dns \"**\" cannot be converted to ASCII"),
Err: errors.New("cannot parse dns domain \"**\""),
},
}
},

View file

@ -6,6 +6,7 @@ import (
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"net/http"
"strings"
"time"
@ -256,10 +257,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
allowed, err := a.sshUserPolicy.IsSSHCertificateAllowed(certTpl)
if err != nil {
var pe *policy.NamePolicyError
if errors.As(err, &pe) && pe.Reason == policy.NotAuthorizedForThisName {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: %s", err.Error()),
)
if errors.As(err, &pe) && pe.Reason == policy.NotAllowed {
return nil, &errs.Error{
// NOTE: custom forbidden error, so that denied name is sent to client
// as well as shown in the logs.
Status: http.StatusForbidden,
Err: fmt.Errorf("authority not allowed to sign: %w", err),
Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()),
}
}
return nil, errs.InternalServerErr(err,
errs.WithMessage("authority.SignSSH: error creating ssh user certificate"),
@ -279,11 +284,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
allowed, err := a.sshHostPolicy.IsSSHCertificateAllowed(certTpl)
if err != nil {
var pe *policy.NamePolicyError
if errors.As(err, &pe) && pe.Reason == policy.NotAuthorizedForThisName {
return nil, errs.ApplyOptions(
// TODO: show which names were not allowed; they are in the err
errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: %s", err.Error()),
)
if errors.As(err, &pe) && pe.Reason == policy.NotAllowed {
return nil, &errs.Error{
// NOTE: custom forbidden error, so that denied name is sent to client
// as well as shown in the logs.
Status: http.StatusForbidden,
Err: fmt.Errorf("authority not allowed to sign: %w", err),
Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()),
}
}
return nil, errs.InternalServerErr(err,
errs.WithMessage("authority.SignSSH: error creating ssh host certificate"),

View file

@ -203,11 +203,14 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
var allowedToSign bool
if allowedToSign, err = a.isAllowedToSign(leaf); err != nil {
var pe *policy.NamePolicyError
if errors.As(err, &pe) && pe.Reason == policy.NotAuthorizedForThisName {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(errors.New("authority not allowed to sign"), err.Error()),
opts...,
)
if errors.As(err, &pe) && pe.Reason == policy.NotAllowed {
return nil, errs.ApplyOptions(&errs.Error{
// NOTE: custom forbidden error, so that denied name is sent to client
// as well as shown in the logs.
Status: http.StatusForbidden,
Err: fmt.Errorf("authority not allowed to sign: %w", err),
Msg: fmt.Sprintf("The request was forbidden by the certificate authority: %s", err.Error()),
}, opts...)
}
return nil, errs.InternalServerErr(err,
errs.WithKeyVal("csr", csr),

View file

@ -13,15 +13,17 @@ import (
"time"
"github.com/pkg/errors"
adminAPI "github.com/smallstep/certificates/authority/admin/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"google.golang.org/protobuf/encoding/protojson"
"go.step.sm/cli-utils/token"
"go.step.sm/cli-utils/token/provision"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/randutil"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
adminAPI "github.com/smallstep/certificates/authority/admin/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
const (
@ -818,7 +820,7 @@ retry:
func (c *AdminClient) GetProvisionerPolicy(provisionerName string) (*linkedca.Policy, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioner", provisionerName, "policy")})
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
@ -853,7 +855,7 @@ func (c *AdminClient) CreateProvisionerPolicy(provisionerName string, p *linkedc
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioner", provisionerName, "policy")})
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
@ -888,7 +890,7 @@ func (c *AdminClient) UpdateProvisionerPolicy(provisionerName string, p *linkedc
if err != nil {
return nil, fmt.Errorf("error marshaling request: %w", err)
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioner", provisionerName, "policy")})
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return nil, fmt.Errorf("error generating admin token: %w", err)
@ -919,7 +921,7 @@ retry:
func (c *AdminClient) RemoveProvisionerPolicy(provisionerName string) error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioner", provisionerName, "policy")})
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", provisionerName, "policy")})
tok, err := c.generateAdminToken(u)
if err != nil {
return fmt.Errorf("error generating admin token: %w", err)

View file

@ -15,11 +15,11 @@ import (
type NamePolicyReason int
const (
// NotAuthorizedForThisName results when an instance of
// NamePolicyEngine determines that there's a constraint which
// doesn't permit a DNS or another type of SAN to be signed
// (or otherwise used).
NotAuthorizedForThisName NamePolicyReason = iota
_ NamePolicyReason = iota
// NotAllowed results when an instance of NamePolicyEngine
// determines that there's a constraint which doesn't permit
// a DNS or another type of SAN to be signed (or otherwise used).
NotAllowed
// CannotParseDomain is returned when an error occurs
// when parsing the domain part of SAN or subject.
CannotParseDomain
@ -31,26 +31,42 @@ const (
CannotMatchNameToConstraint
)
type NameType string
const (
DNSNameType NameType = "dns"
IPNameType NameType = "ip"
EmailNameType NameType = "email"
URINameType NameType = "uri"
PrincipalNameType NameType = "principal"
)
type NamePolicyError struct {
Reason NamePolicyReason
Detail string
NameType NameType
Name string
detail string
}
func (e *NamePolicyError) Error() string {
switch e.Reason {
case NotAuthorizedForThisName:
return "not authorized to sign for this name: " + e.Detail
case NotAllowed:
return fmt.Sprintf("%s name %q not allowed", e.NameType, e.Name)
case CannotParseDomain:
return "cannot parse domain: " + e.Detail
return fmt.Sprintf("cannot parse %s domain %q", e.NameType, e.Name)
case CannotParseRFC822Name:
return "cannot parse rfc822Name: " + e.Detail
return fmt.Sprintf("cannot parse %s rfc822Name %q", e.NameType, e.Name)
case CannotMatchNameToConstraint:
return "error matching name to constraint: " + e.Detail
return fmt.Sprintf("error matching %s name %q to constraint", e.NameType, e.Name)
default:
return "unknown error: " + e.Detail
return fmt.Sprintf("unknown error reason (%d): %s", e.Reason, e.detail)
}
}
func (e *NamePolicyError) Detail() string {
return e.detail
}
// NamePolicyEngine can be used to check that a CSR or Certificate meets all allowed and
// denied names before a CA creates and/or signs the Certificate.
// TODO(hs): the X509 RFC also defines name checks on directory name; support that?
@ -98,13 +114,13 @@ func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) {
}
e.permittedDNSDomains = removeDuplicates(e.permittedDNSDomains)
e.permittedIPRanges = removeDuplicateIPRanges(e.permittedIPRanges)
e.permittedIPRanges = removeDuplicateIPNets(e.permittedIPRanges)
e.permittedEmailAddresses = removeDuplicates(e.permittedEmailAddresses)
e.permittedURIDomains = removeDuplicates(e.permittedURIDomains)
e.permittedPrincipals = removeDuplicates(e.permittedPrincipals)
e.excludedDNSDomains = removeDuplicates(e.excludedDNSDomains)
e.excludedIPRanges = removeDuplicateIPRanges(e.excludedIPRanges)
e.excludedIPRanges = removeDuplicateIPNets(e.excludedIPRanges)
e.excludedEmailAddresses = removeDuplicates(e.excludedEmailAddresses)
e.excludedURIDomains = removeDuplicates(e.excludedURIDomains)
e.excludedPrincipals = removeDuplicates(e.excludedPrincipals)
@ -126,35 +142,59 @@ func New(opts ...NamePolicyOption) (*NamePolicyEngine, error) {
return e, nil
}
func removeDuplicates(strSlice []string) []string {
if len(strSlice) == 0 {
return nil
}
keys := make(map[string]bool)
result := []string{}
for _, item := range strSlice {
if _, value := keys[item]; !value && item != "" { // skip empty constraints
keys[item] = true
result = append(result, item)
}
}
return result
// removeDuplicates returns a new slice of strings with
// duplicate values removed. It retains the order of elements
// in the source slice.
func removeDuplicates(items []string) (ret []string) {
// no need to remove dupes; return original
if len(items) <= 1 {
return items
}
func removeDuplicateIPRanges(ipRanges []*net.IPNet) []*net.IPNet {
if len(ipRanges) == 0 {
return nil
keys := make(map[string]struct{}, len(items))
ret = make([]string, 0, len(items))
for _, item := range items {
if _, ok := keys[item]; ok {
continue
}
keys := make(map[string]bool)
result := []*net.IPNet{}
for _, item := range ipRanges {
key := item.String()
if _, value := keys[key]; !value {
keys[key] = true
result = append(result, item)
keys[item] = struct{}{}
ret = append(ret, item)
}
return
}
return result
// removeDuplicateIPNets returns a new slice of net.IPNets with
// duplicate values removed. It retains the order of elements in
// the source slice. An IPNet is considered duplicate if its CIDR
// notation exists multiple times in the slice.
func removeDuplicateIPNets(items []*net.IPNet) (ret []*net.IPNet) {
// no need to remove dupes; return original
if len(items) <= 1 {
return items
}
keys := make(map[string]struct{}, len(items))
ret = make([]*net.IPNet, 0, len(items))
for _, item := range items {
key := item.String() // use CIDR notation as key
if _, ok := keys[key]; ok {
continue
}
keys[key] = struct{}{}
ret = append(ret, item)
}
// TODO(hs): implement filter of fully overlapping ranges,
// so that the smaller ones are automatically removed?
return
}
// IsX509CertificateAllowed verifies that all SANs in a Certificate are allowed.

File diff suppressed because it is too large Load diff

View file

@ -5,8 +5,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/smallstep/assert"
"github.com/stretchr/testify/assert"
)
func Test_normalizeAndValidateDNSDomainConstraint(t *testing.T) {
@ -368,9 +367,9 @@ func TestNew(t *testing.T) {
},
"ok/with-permitted-ip-ranges": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedIPRanges(nw1, nw2),
}
@ -389,9 +388,9 @@ func TestNew(t *testing.T) {
},
"ok/with-excluded-ip-ranges": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedIPRanges(nw1, nw2),
}
@ -410,9 +409,9 @@ func TestNew(t *testing.T) {
},
"ok/with-permitted-cidrs": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedCIDRs("127.0.0.1/24", "192.168.0.1/24"),
}
@ -431,9 +430,9 @@ func TestNew(t *testing.T) {
},
"ok/with-excluded-cidrs": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedCIDRs("127.0.0.1/24", "192.168.0.1/24"),
}
@ -452,11 +451,11 @@ func TestNew(t *testing.T) {
},
"ok/with-permitted-ipsOrCIDRs-cidr": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.31/32")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithPermittedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
}
@ -475,11 +474,11 @@ func TestNew(t *testing.T) {
},
"ok/with-excluded-ipsOrCIDRs-cidr": func(t *testing.T) test {
_, nw1, err := net.ParseCIDR("127.0.0.1/24")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw2, err := net.ParseCIDR("192.168.0.31/32")
assert.FatalError(t, err)
assert.NoError(t, err)
_, nw3, err := net.ParseCIDR("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")
assert.FatalError(t, err)
assert.NoError(t, err)
options := []NamePolicyOption{
WithExcludedIPsOrCIDRs("127.0.0.1/24", "192.168.0.31", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"),
}

View file

@ -25,8 +25,6 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
return nil
}
// TODO: implement check that requires at least a single name in all of the SANs + subject?
// TODO: set limit on total of all names validated? In x509 there's a limit on the number of comparisons
// that protects the CA from a DoS (i.e. many heavy comparisons). The x509 implementation takes
// this number as a total of all checks and keeps a (pointer to a) counter of the number of checks
@ -40,29 +38,37 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
// (other) excluded constraints, we'll allow a DNS (implicit allow; currently).
if e.numberOfDNSDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("dns %q is not explicitly permitted by any constraint", dns),
Reason: NotAllowed,
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("dns %q is not explicitly permitted by any constraint", dns),
}
}
didCutWildcard := false
if strings.HasPrefix(dns, "*.") {
dns = dns[1:]
parsedDNS := dns
if strings.HasPrefix(parsedDNS, "*.") {
parsedDNS = parsedDNS[1:]
didCutWildcard = true
}
parsedDNS, err := idna.Lookup.ToASCII(dns)
// TODO(hs): fix this above; we need separate rule for Subject Common Name?
parsedDNS, err := idna.Lookup.ToASCII(parsedDNS)
if err != nil {
return &NamePolicyError{
Reason: CannotParseDomain,
Detail: fmt.Sprintf("dns %q cannot be converted to ASCII", dns),
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("dns %q cannot be converted to ASCII", dns),
}
}
if didCutWildcard {
parsedDNS = "*" + parsedDNS
}
if _, ok := domainToReverseLabels(parsedDNS); !ok {
if _, ok := domainToReverseLabels(parsedDNS); !ok { // TODO(hs): this also fails with spaces
return &NamePolicyError{
Reason: CannotParseDomain,
Detail: fmt.Sprintf("cannot parse dns %q", dns),
NameType: DNSNameType,
Name: dns,
detail: fmt.Sprintf("cannot parse dns %q", dns),
}
}
if err := checkNameConstraints("dns", dns, parsedDNS,
@ -76,8 +82,10 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
for _, ip := range ips {
if e.numberOfIPRangeConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("ip %q is not explicitly permitted by any constraint", ip.String()),
Reason: NotAllowed,
NameType: IPNameType,
Name: ip.String(),
detail: fmt.Sprintf("ip %q is not explicitly permitted by any constraint", ip.String()),
}
}
if err := checkNameConstraints("ip", ip.String(), ip,
@ -91,15 +99,19 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
for _, email := range emailAddresses {
if e.numberOfEmailAddressConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("email %q is not explicitly permitted by any constraint", email),
Reason: NotAllowed,
NameType: EmailNameType,
Name: email,
detail: fmt.Sprintf("email %q is not explicitly permitted by any constraint", email),
}
}
mailbox, ok := parseRFC2821Mailbox(email)
if !ok {
return &NamePolicyError{
Reason: CannotParseRFC822Name,
Detail: fmt.Sprintf("invalid rfc822Name %q", mailbox),
NameType: EmailNameType,
Name: email,
detail: fmt.Sprintf("invalid rfc822Name %q", mailbox),
}
}
// According to RFC 5280, section 7.5, emails are considered to match if the local part is
@ -109,7 +121,9 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
if err != nil {
return &NamePolicyError{
Reason: CannotParseDomain,
Detail: fmt.Sprintf("cannot parse email domain %q", email),
NameType: EmailNameType,
Name: email,
detail: fmt.Sprintf("cannot parse email domain %q", email),
}
}
mailbox.domain = domainASCII
@ -126,10 +140,14 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
for _, uri := range uris {
if e.numberOfURIDomainConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("uri %q is not explicitly permitted by any constraint", uri.String()),
Reason: NotAllowed,
NameType: URINameType,
Name: uri.String(),
detail: fmt.Sprintf("uri %q is not explicitly permitted by any constraint", uri.String()),
}
}
// TODO(hs): ideally we'd like the uri.String() to be the original contents; now
// it's transformed into ASCII. Prevent that here?
if err := checkNameConstraints("uri", uri.String(), uri,
func(parsedName, constraint interface{}) (bool, error) {
return e.matchURIConstraint(parsedName.(*url.URL), constraint.(string))
@ -141,8 +159,10 @@ func (e *NamePolicyEngine) validateNames(dnsNames []string, ips []net.IP, emailA
for _, principal := range principals {
if e.numberOfPrincipalConstraints == 0 && e.totalNumberOfPermittedConstraints > 0 {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("username principal %q is not explicitly permitted by any constraint", principal),
Reason: NotAllowed,
NameType: PrincipalNameType,
Name: principal,
detail: fmt.Sprintf("username principal %q is not explicitly permitted by any constraint", principal),
}
}
// TODO: some validation? I.e. allowed characters?
@ -176,14 +196,18 @@ func checkNameConstraints(
if err != nil {
return &NamePolicyError{
Reason: CannotMatchNameToConstraint,
Detail: err.Error(),
NameType: NameType(nameType),
Name: name,
detail: err.Error(),
}
}
if match {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint),
Reason: NotAllowed,
NameType: NameType(nameType),
Name: name,
detail: fmt.Sprintf("%s %q is excluded by constraint %q", nameType, name, constraint),
}
}
}
@ -197,7 +221,9 @@ func checkNameConstraints(
if ok, err = match(parsedName, constraint); err != nil {
return &NamePolicyError{
Reason: CannotMatchNameToConstraint,
Detail: err.Error(),
NameType: NameType(nameType),
Name: name,
detail: err.Error(),
}
}
@ -208,8 +234,10 @@ func checkNameConstraints(
if !ok {
return &NamePolicyError{
Reason: NotAuthorizedForThisName,
Detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name),
Reason: NotAllowed,
NameType: NameType(nameType),
Name: name,
detail: fmt.Sprintf("%s %q is not permitted by any constraint", nameType, name),
}
}