Change some bad requests to forbidded.

Change in the sign options bad requests to forbidded if is the
provisioner the one adding a restriction, e.g. list of dns names,
validity, ...
This commit is contained in:
Mariano Cano 2021-11-24 11:32:35 -08:00
parent ff04873a2a
commit c3f98fd04d
2 changed files with 22 additions and 22 deletions

View file

@ -85,19 +85,19 @@ type emailOnlyIdentity string
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
switch { switch {
case len(req.DNSNames) > 0: case len(req.DNSNames) > 0:
return errs.BadRequest("certificate request cannot contain DNS names") return errs.Forbidden("certificate request cannot contain DNS names")
case len(req.IPAddresses) > 0: case len(req.IPAddresses) > 0:
return errs.BadRequest("certificate request cannot contain IP addresses") return errs.Forbidden("certificate request cannot contain IP addresses")
case len(req.URIs) > 0: case len(req.URIs) > 0:
return errs.BadRequest("certificate request cannot contain URIs") return errs.Forbidden("certificate request cannot contain URIs")
case len(req.EmailAddresses) == 0: case len(req.EmailAddresses) == 0:
return errs.BadRequest("certificate request does not contain any email address") return errs.Forbidden("certificate request does not contain any email address")
case len(req.EmailAddresses) > 1: case len(req.EmailAddresses) > 1:
return errs.BadRequest("certificate request contains too many email addresses") return errs.Forbidden("certificate request contains too many email addresses")
case req.EmailAddresses[0] == "": case req.EmailAddresses[0] == "":
return errs.BadRequest("certificate request cannot contain an empty email address") return errs.Forbidden("certificate request cannot contain an empty email address")
case req.EmailAddresses[0] != string(e): case req.EmailAddresses[0] != string(e):
return errs.BadRequest("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e) return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
default: default:
return nil return nil
} }
@ -162,7 +162,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
return nil return nil
} }
if req.Subject.CommonName != string(v) { if req.Subject.CommonName != string(v) {
return errs.BadRequest("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
} }
return nil return nil
} }
@ -180,7 +180,7 @@ func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error {
return nil return nil
} }
} }
return errs.BadRequest("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v) return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
} }
// dnsNamesValidator validates the DNS names SAN of a certificate request. // dnsNamesValidator validates the DNS names SAN of a certificate request.
@ -201,7 +201,7 @@ func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
got[s] = true got[s] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errs.BadRequest("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) return errs.Forbidden("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
} }
return nil return nil
} }
@ -224,7 +224,7 @@ func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
got[ip.String()] = true got[ip.String()] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errs.BadRequest("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v) return errs.Forbidden("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v)
} }
return nil return nil
} }
@ -247,7 +247,7 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error {
got[s] = true got[s] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errs.BadRequest("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v) return errs.Forbidden("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v)
} }
return nil return nil
} }
@ -270,7 +270,7 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error {
got[u.String()] = true got[u.String()] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errs.BadRequest("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v) return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v)
} }
return nil return nil
} }
@ -392,14 +392,14 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) return errs.Forbidden("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) return errs.Forbidden("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
} }
return nil return nil
} }
@ -432,7 +432,7 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
// Force the common name to be the first DNS if not provided. // Force the common name to be the first DNS if not provided.
if cert.Subject.CommonName == "" { if cert.Subject.CommonName == "" {
if len(cert.DNSNames) == 0 { if len(cert.DNSNames) == 0 {
return errs.Forbidden("cannot force common name, DNS names is empty") return errs.BadRequest("cannot force common name, DNS names is empty")
} }
cert.Subject.CommonName = cert.DNSNames[0] cert.Subject.CommonName = cert.DNSNames[0]
} }

View file

@ -109,16 +109,16 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
// ignores zero values. // ignores zero values.
func (o SignSSHOptions) match(got SignSSHOptions) error { func (o SignSSHOptions) match(got SignSSHOptions) error {
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
return errs.BadRequest("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) return errs.Forbidden("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
} }
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
return errs.BadRequest("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) return errs.Forbidden("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
} }
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
return errs.BadRequest("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) return errs.Forbidden("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
} }
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
return errs.BadRequest("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) return errs.Forbidden("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
} }
return nil return nil
} }
@ -368,9 +368,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) return errs.Forbidden("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) return errs.Forbidden("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }