Merge branch 'master' into hs/acme-revocation

This commit is contained in:
Herman Slatman 2021-12-09 09:36:52 +01:00
commit 3bc3957b06
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
18 changed files with 151 additions and 118 deletions

View file

@ -42,6 +42,8 @@ To get up and running quickly, or as an alternative to running your own `step-ca
[![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers) [![GitHub stars](https://img.shields.io/github/stars/smallstep/certificates.svg?style=social)](https://github.com/smallstep/certificates/stargazers)
[![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs) [![Twitter followers](https://img.shields.io/twitter/follow/smallsteplabs.svg?label=Follow&style=social)](https://twitter.com/intent/follow?screen_name=smallsteplabs)
![star us](https://github.com/smallstep/certificates/raw/master/docs/images/star.gif)
## Features ## Features
### 🦾 A fast, stable, flexible private CA ### 🦾 A fast, stable, flexible private CA

View file

@ -348,7 +348,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := h.Authority.GetRoots()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error getting roots"))
return return
} }
@ -366,7 +366,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := h.Authority.GetFederation()
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error getting federated roots"))
return return
} }

View file

@ -96,7 +96,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }

View file

@ -74,7 +74,7 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error signing certificate"))
return return
} }
certChainPEM := certChainToPEM(certChain) certChainPEM := certChainToPEM(certChain)

View file

@ -293,7 +293,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
@ -301,7 +301,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
} }
addUserCertificate = &SSHCertificate{addUserCert} addUserCertificate = &SSHCertificate{addUserCert}
@ -326,7 +326,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
} }
identityCertificate = certChainToPEM(certChain) identityCertificate = certChainToPEM(certChain)

View file

@ -68,7 +68,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
} }
@ -78,7 +78,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }

View file

@ -60,7 +60,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
newCert, err := h.Authority.RenewSSH(ctx, oldCert) newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
} }
@ -70,7 +70,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := h.renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
} }

View file

@ -75,7 +75,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }

View file

@ -9,12 +9,14 @@ import (
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -83,19 +85,19 @@ type emailOnlyIdentity string
func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error { func (e emailOnlyIdentity) Valid(req *x509.CertificateRequest) error {
switch { switch {
case len(req.DNSNames) > 0: case len(req.DNSNames) > 0:
return errors.New("certificate request cannot contain DNS names") return errs.Forbidden("certificate request cannot contain DNS names")
case len(req.IPAddresses) > 0: case len(req.IPAddresses) > 0:
return errors.New("certificate request cannot contain IP addresses") return errs.Forbidden("certificate request cannot contain IP addresses")
case len(req.URIs) > 0: case len(req.URIs) > 0:
return errors.New("certificate request cannot contain URIs") return errs.Forbidden("certificate request cannot contain URIs")
case len(req.EmailAddresses) == 0: case len(req.EmailAddresses) == 0:
return errors.New("certificate request does not contain any email address") return errs.Forbidden("certificate request does not contain any email address")
case len(req.EmailAddresses) > 1: case len(req.EmailAddresses) > 1:
return errors.New("certificate request contains too many email addresses") return errs.Forbidden("certificate request contains too many email addresses")
case req.EmailAddresses[0] == "": case req.EmailAddresses[0] == "":
return errors.New("certificate request cannot contain an empty email address") return errs.Forbidden("certificate request cannot contain an empty email address")
case req.EmailAddresses[0] != string(e): case req.EmailAddresses[0] != string(e):
return errors.Errorf("certificate request does not contain the valid email address, got %s, want %s", req.EmailAddresses[0], e) return errs.Forbidden("certificate request does not contain the valid email address - got %s, want %s", req.EmailAddresses[0], e)
default: default:
return nil return nil
} }
@ -108,12 +110,13 @@ type defaultPublicKeyValidator struct{}
func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error { func (v defaultPublicKeyValidator) Valid(req *x509.CertificateRequest) error {
switch k := req.PublicKey.(type) { switch k := req.PublicKey.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
if k.Size() < 256 { if k.Size() < keyutil.MinRSAKeyBytes {
return errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)") return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
} }
case *ecdsa.PublicKey, ed25519.PublicKey: case *ecdsa.PublicKey, ed25519.PublicKey:
default: default:
return errors.Errorf("unrecognized public key of type '%T' in CSR", k) return errs.BadRequest("certificate request key of type '%T' is not supported", k)
} }
return nil return nil
} }
@ -139,11 +142,12 @@ func (v publicKeyMinimumLengthValidator) Valid(req *x509.CertificateRequest) err
case *rsa.PublicKey: case *rsa.PublicKey:
minimumLengthInBytes := v.length / 8 minimumLengthInBytes := v.length / 8
if k.Size() < minimumLengthInBytes { if k.Size() < minimumLengthInBytes {
return errors.Errorf("rsa key in CSR must be at least %d bits (%d bytes)", v.length, minimumLengthInBytes) return errs.Forbidden("certificate request RSA key must be at least %d bits (%d bytes)",
v.length, minimumLengthInBytes)
} }
case *ecdsa.PublicKey, ed25519.PublicKey: case *ecdsa.PublicKey, ed25519.PublicKey:
default: default:
return errors.Errorf("unrecognized public key of type '%T' in CSR", k) return errs.BadRequest("certificate request key of type '%T' is not supported", k)
} }
return nil return nil
} }
@ -158,7 +162,7 @@ func (v commonNameValidator) Valid(req *x509.CertificateRequest) error {
return nil return nil
} }
if req.Subject.CommonName != string(v) { if req.Subject.CommonName != string(v) {
return errors.Errorf("certificate request does not contain the valid common name; requested common name = %s, token subject = %s", req.Subject.CommonName, v) return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
} }
return nil return nil
} }
@ -176,7 +180,7 @@ func (v commonNameSliceValidator) Valid(req *x509.CertificateRequest) error {
return nil return nil
} }
} }
return errors.Errorf("certificate request does not contain the valid common name, got %s, want %s", req.Subject.CommonName, v) return errs.Forbidden("certificate request does not contain the valid common name - got %s, want %s", req.Subject.CommonName, v)
} }
// dnsNamesValidator validates the DNS names SAN of a certificate request. // dnsNamesValidator validates the DNS names SAN of a certificate request.
@ -197,7 +201,7 @@ func (v dnsNamesValidator) Valid(req *x509.CertificateRequest) error {
got[s] = true got[s] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errors.Errorf("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v) return errs.Forbidden("certificate request does not contain the valid DNS names - got %v, want %v", req.DNSNames, v)
} }
return nil return nil
} }
@ -220,7 +224,7 @@ func (v ipAddressesValidator) Valid(req *x509.CertificateRequest) error {
got[ip.String()] = true got[ip.String()] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errors.Errorf("IP Addresses claim failed - got %v, want %v", req.IPAddresses, v) return errs.Forbidden("certificate request does not contain the valid IP addresses - got %v, want %v", req.IPAddresses, v)
} }
return nil return nil
} }
@ -243,7 +247,7 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error {
got[s] = true got[s] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errors.Errorf("certificate request does not contain the valid Email Addresses - got %v, want %v", req.EmailAddresses, v) return errs.Forbidden("certificate request does not contain the valid email addresses - got %v, want %v", req.EmailAddresses, v)
} }
return nil return nil
} }
@ -266,7 +270,7 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error {
got[u.String()] = true got[u.String()] = true
} }
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
return errors.Errorf("URIs claim failed - got %v, want %v", req.URIs, v) return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v)
} }
return nil return nil
} }
@ -334,15 +338,15 @@ func (v profileLimitDuration) Modify(cert *x509.Certificate, so SignOptions) err
backdate = -1 * so.Backdate backdate = -1 * so.Backdate
} }
if notBefore.Before(v.notBefore) { if notBefore.Before(v.notBefore) {
return errors.Errorf("requested certificate notBefore (%s) is before "+ return errs.Forbidden(
"the active validity window of the provisioning credential (%s)", "requested certificate notBefore (%s) is before the active validity window of the provisioning credential (%s)",
notBefore, v.notBefore) notBefore, v.notBefore)
} }
notAfter := so.NotAfter.RelativeTime(notBefore) notAfter := so.NotAfter.RelativeTime(notBefore)
if notAfter.After(v.notAfter) { if notAfter.After(v.notAfter) {
return errors.Errorf("requested certificate notAfter (%s) is after "+ return errs.Forbidden(
"the expiration of the provisioning credential (%s)", "requested certificate notAfter (%s) is after the expiration of the provisioning credential (%s)",
notAfter, v.notAfter) notAfter, v.notAfter)
} }
if notAfter.IsZero() { if notAfter.IsZero() {
@ -388,14 +392,14 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) return errs.Forbidden("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) return errs.Forbidden("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
} }
return nil return nil
} }
@ -422,16 +426,15 @@ func newForceCNOption(forceCN bool) *forceCNOption {
func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error { func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
if !o.ForceCN { if !o.ForceCN {
// Forcing CN is disabled, do nothing to certificate
return nil return nil
} }
// Force the common name to be the first DNS if not provided.
if cert.Subject.CommonName == "" { if cert.Subject.CommonName == "" {
if len(cert.DNSNames) > 0 { if len(cert.DNSNames) == 0 {
cert.Subject.CommonName = cert.DNSNames[0] return errs.BadRequest("cannot force common name, DNS names is empty")
} else {
return errors.New("Cannot force CN, DNSNames is empty")
} }
cert.Subject.CommonName = cert.DNSNames[0]
} }
return nil return nil
@ -456,7 +459,7 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValue
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
if err != nil { if err != nil {
return err return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
} }
// Prepend the provisioner extension. In the auth.Sign code we will // Prepend the provisioner extension. In the auth.Sign code we will
// force the resulting certificate to only have one extension, the // force the resulting certificate to only have one extension, the
@ -477,7 +480,7 @@ func createProvisionerExtension(typ int, name, credentialID string, keyValuePair
KeyValuePairs: keyValuePairs, KeyValuePairs: keyValuePairs,
}) })
if err != nil { if err != nil {
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension") return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
} }
return pkix.Extension{ return pkix.Extension{
Id: stepOIDProvisioner, Id: stepOIDProvisioner,

View file

@ -77,12 +77,12 @@ func Test_defaultPublicKeyValidator_Valid(t *testing.T) {
{ {
"fail/unrecognized-key-type", "fail/unrecognized-key-type",
&x509.CertificateRequest{PublicKey: "foo"}, &x509.CertificateRequest{PublicKey: "foo"},
errors.New("unrecognized public key of type 'string' in CSR"), errors.New("certificate request key of type 'string' is not supported"),
}, },
{ {
"fail/rsa/too-short", "fail/rsa/too-short",
shortRSA, shortRSA,
errors.New("rsa key in CSR must be at least 2048 bits (256 bytes)"), errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"),
}, },
{ {
"ok/rsa", "ok/rsa",
@ -303,14 +303,14 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
return test{ return test{
csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}}, csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}},
expectedSANs: []string{"dcow@fx.com"}, expectedSANs: []string{"dcow@fx.com"},
err: errors.New("certificate request does not contain the valid Email Addresses"), err: errors.New("certificate request does not contain the valid email addresses"),
} }
}, },
"fail/ipAddressesValidator": func() test { "fail/ipAddressesValidator": func() test {
return test{ return test{
csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}}, csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}},
expectedSANs: []string{"127.0.0.1"}, expectedSANs: []string{"127.0.0.1"},
err: errors.New("IP Addresses claim failed"), err: errors.New("certificate request does not contain the valid IP addresses"),
} }
}, },
"fail/urisValidator": func() test { "fail/urisValidator": func() test {
@ -321,7 +321,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) {
return test{ return test{
csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}},
expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"},
err: errors.New("URIs claim failed"), err: errors.New("certificate request does not contain the valid URIs"),
} }
}, },
"ok": func() test { "ok": func() test {
@ -512,7 +512,7 @@ func Test_forceCN_Option(t *testing.T) {
Subject: pkix.Name{}, Subject: pkix.Name{},
DNSNames: []string{}, DNSNames: []string{},
}, },
err: errors.New("Cannot force CN, DNSNames is empty"), err: errors.New("cannot force common name, DNS names is empty"),
} }
}, },
} }

View file

@ -56,7 +56,12 @@ type SignSSHOptions struct {
// Validate validates the given SignSSHOptions. // Validate validates the given SignSSHOptions.
func (o SignSSHOptions) Validate() error { func (o SignSSHOptions) Validate() error {
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
return errs.BadRequest("unknown certificate type '%s'", o.CertType) return errs.BadRequest("certType '%s' is not valid", o.CertType)
}
for _, p := range o.Principals {
if p == "" {
return errs.BadRequest("principals cannot contain empty values")
}
} }
return nil return nil
} }
@ -75,7 +80,7 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
case SSHHostCert: case SSHHostCert:
cert.CertType = ssh.HostCert cert.CertType = ssh.HostCert
default: default:
return errors.Errorf("ssh certificate has an unknown type - %s", o.CertType) return errs.BadRequest("ssh certificate has an unknown type '%s'", o.CertType)
} }
cert.KeyId = o.KeyID cert.KeyId = o.KeyID
@ -95,7 +100,7 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix()) cert.ValidBefore = uint64(o.ValidBefore.RelativeTime(t).Unix())
} }
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before") return errs.BadRequest("ssh certificate validAfter cannot be greater than validBefore")
} }
return nil return nil
} }
@ -104,16 +109,16 @@ func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
// ignores zero values. // ignores zero values.
func (o SignSSHOptions) match(got SignSSHOptions) error { func (o SignSSHOptions) match(got SignSSHOptions) error {
if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType { if o.CertType != "" && got.CertType != "" && o.CertType != got.CertType {
return errors.Errorf("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType) return errs.Forbidden("ssh certificate type does not match - got %v, want %v", got.CertType, o.CertType)
} }
if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) { if len(o.Principals) > 0 && len(got.Principals) > 0 && !containsAllMembers(o.Principals, got.Principals) {
return errors.Errorf("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals) return errs.Forbidden("ssh certificate principals does not match - got %v, want %v", got.Principals, o.Principals)
} }
if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) { if !o.ValidAfter.IsZero() && !got.ValidAfter.IsZero() && !o.ValidAfter.Equal(&got.ValidAfter) {
return errors.Errorf("ssh certificate valid after does not match - got %v, want %v", got.ValidAfter, o.ValidAfter) return errs.Forbidden("ssh certificate validAfter does not match - got %v, want %v", got.ValidAfter, o.ValidAfter)
} }
if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) { if !o.ValidBefore.IsZero() && !got.ValidBefore.IsZero() && !o.ValidBefore.Equal(&got.ValidBefore) {
return errors.Errorf("ssh certificate valid before does not match - got %v, want %v", got.ValidBefore, o.ValidBefore) return errs.Forbidden("ssh certificate validBefore does not match - got %v, want %v", got.ValidBefore, o.ValidBefore)
} }
return nil return nil
} }
@ -206,7 +211,7 @@ func (m *sshDefaultExtensionModifier) Modify(cert *ssh.Certificate, _ SignSSHOpt
cert.Extensions["permit-user-rc"] = "" cert.Extensions["permit-user-rc"] = ""
return nil return nil
default: default:
return errors.New("ssh certificate type has not been set or is invalid") return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
} }
} }
@ -272,7 +277,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
certValidAfter := time.Unix(int64(cert.ValidAfter), 0) certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) { if certValidAfter.After(m.NotAfter) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter) m.NotAfter, certValidAfter)
} }
@ -285,7 +290,7 @@ func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error
} else { } else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0) certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) { if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", return errs.Forbidden("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore) m.NotAfter, certValidBefore)
} }
} }
@ -319,11 +324,11 @@ type sshCertOptionsRequireValidator struct {
func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error { func (v *sshCertOptionsRequireValidator) Valid(got SignSSHOptions) error {
switch { switch {
case v.CertType && got.CertType == "": case v.CertType && got.CertType == "":
return errors.New("ssh certificate certType cannot be empty") return errs.BadRequest("ssh certificate certType cannot be empty")
case v.KeyID && got.KeyID == "": case v.KeyID && got.KeyID == "":
return errors.New("ssh certificate keyID cannot be empty") return errs.BadRequest("ssh certificate keyID cannot be empty")
case v.Principals && len(got.Principals) == 0: case v.Principals && len(got.Principals) == 0:
return errors.New("ssh certificate principals cannot be empty") return errs.BadRequest("ssh certificate principals cannot be empty")
default: default:
return nil return nil
} }
@ -354,7 +359,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
case 0: case 0:
return errs.BadRequest("ssh certificate type has not been set") return errs.BadRequest("ssh certificate type has not been set")
default: default:
return errs.BadRequest("unknown ssh certificate type %d", cert.CertType) return errs.BadRequest("ssh certificate has an unknown type '%d'", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -363,9 +368,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) return errs.Forbidden("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) return errs.Forbidden("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }
@ -381,25 +386,25 @@ type sshCertDefaultValidator struct{}
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty") return errs.Forbidden("ssh certificate nonce cannot be empty")
case cert.Key == nil: case cert.Key == nil:
return errors.New("ssh certificate key cannot be nil") return errs.Forbidden("ssh certificate key cannot be nil")
case cert.Serial == 0: case cert.Serial == 0:
return errors.New("ssh certificate serial cannot be 0") return errs.Forbidden("ssh certificate serial cannot be 0")
case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert: case cert.CertType != ssh.UserCert && cert.CertType != ssh.HostCert:
return errors.Errorf("ssh certificate has an unknown type: %d", cert.CertType) return errs.Forbidden("ssh certificate has an unknown type '%d'", cert.CertType)
case cert.KeyId == "": case cert.KeyId == "":
return errors.New("ssh certificate key id cannot be empty") return errs.Forbidden("ssh certificate key id cannot be empty")
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errs.Forbidden("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate validBefore cannot be in the past") return errs.Forbidden("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter") return errs.Forbidden("ssh certificate validBefore cannot be before validAfter")
case cert.SignatureKey == nil: case cert.SignatureKey == nil:
return errors.New("ssh certificate signature key cannot be nil") return errs.Forbidden("ssh certificate signature key cannot be nil")
case cert.Signature == nil: case cert.Signature == nil:
return errors.New("ssh certificate signature cannot be nil") return errs.Forbidden("ssh certificate signature cannot be nil")
default: default:
return nil return nil
} }
@ -409,27 +414,31 @@ func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SignSSHOptions)
type sshDefaultPublicKeyValidator struct{} type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured. // Valid checks that certificate request common name matches the one configured.
//
// TODO: this is the only validator that checks the key type. We should execute
// this before the signing. We should add a new validations interface or extend
// SSHCertOptionsValidator with the key.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error { func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOptions) error {
if cert.Key == nil { if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil") return errs.BadRequest("ssh certificate key cannot be nil")
} }
switch cert.Key.Type() { switch cert.Key.Type() {
case ssh.KeyAlgoRSA: case ssh.KeyAlgoRSA:
_, in, ok := sshParseString(cert.Key.Marshal()) _, in, ok := sshParseString(cert.Key.Marshal())
if !ok { if !ok {
return errors.New("ssh certificate key is invalid") return errs.BadRequest("ssh certificate key is invalid")
} }
key, err := sshParseRSAPublicKey(in) key, err := sshParseRSAPublicKey(in)
if err != nil { if err != nil {
return err return errs.BadRequestErr(err, "error parsing public key")
} }
if key.Size() < keyutil.MinRSAKeyBytes { if key.Size() < keyutil.MinRSAKeyBytes {
return errors.Errorf("ssh certificate key must be at least %d bits (%d bytes)", return errs.Forbidden("ssh certificate key must be at least %d bits (%d bytes)",
8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes) 8*keyutil.MinRSAKeyBytes, keyutil.MinRSAKeyBytes)
} }
return nil return nil
case ssh.KeyAlgoDSA: case ssh.KeyAlgoDSA:
return errors.New("ssh certificate key algorithm (DSA) is not supported") return errs.BadRequest("ssh certificate key algorithm (DSA) is not supported")
default: default:
return nil return nil
} }

View file

@ -49,14 +49,14 @@ func TestSSHOptions_Modify(t *testing.T) {
return test{ return test{
so: SignSSHOptions{CertType: "foo"}, so: SignSSHOptions{CertType: "foo"},
cert: new(ssh.Certificate), cert: new(ssh.Certificate),
err: errors.Errorf("ssh certificate has an unknown type - foo"), err: errors.Errorf("ssh certificate has an unknown type 'foo'"),
} }
}, },
"fail/validAfter-greater-validBefore": func() test { "fail/validAfter-greater-validBefore": func() test {
return test{ return test{
so: SignSSHOptions{CertType: "user"}, so: SignSSHOptions{CertType: "user"},
cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)}, cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
err: errors.Errorf("ssh certificate valid after cannot be greater than valid before"), err: errors.Errorf("ssh certificate validAfter cannot be greater than validBefore"),
} }
}, },
"ok/user-cert": func() test { "ok/user-cert": func() test {
@ -136,14 +136,14 @@ func TestSSHOptions_Match(t *testing.T) {
return test{ return test{
so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))}, so: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))}, cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid after does not match"), err: errors.Errorf("ssh certificate validAfter does not match"),
} }
}, },
"fail/validBefore": func() test { "fail/validBefore": func() test {
return test{ return test{
so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))}, so: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))}, cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
err: errors.Errorf("ssh certificate valid before does not match"), err: errors.Errorf("ssh certificate validBefore does not match"),
} }
}, },
"ok/original-empty": func() test { "ok/original-empty": func() test {
@ -394,7 +394,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
return test{ return test{
modifier: sshDefaultExtensionModifier{}, modifier: sshDefaultExtensionModifier{},
cert: cert, cert: cert,
err: errors.New("ssh certificate type has not been set or is invalid"), err: errors.New("ssh certificate has an unknown type '3'"),
} }
}, },
"ok/host": func() test { "ok/host": func() test {
@ -518,7 +518,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
"fail/unexpected-cert-type", "fail/unexpected-cert-type",
// UserCert = 1, HostCert = 2 // UserCert = 1, HostCert = 2
&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1}, &ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
errors.New("ssh certificate has an unknown type: 3"), errors.New("ssh certificate has an unknown type '3'"),
}, },
{ {
"fail/empty-cert-key-id", "fail/empty-cert-key-id",
@ -725,7 +725,7 @@ func Test_sshCertValidityValidator(t *testing.T) {
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
}, },
SignSSHOptions{}, SignSSHOptions{},
errors.New("unknown ssh certificate type 3"), errors.New("ssh certificate has an unknown type '3'"),
}, },
{ {
"fail/duration<min", "fail/duration<min",

View file

@ -9,7 +9,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -174,7 +173,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// validate the given SSHOptions // validate the given SSHOptions
case provisioner.SSHCertOptionsValidator: case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") return nil, errs.BadRequestErr(err, "error validating ssh certificate options")
} }
default: default:
@ -214,7 +213,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Use provisioner modifiers. // Use provisioner modifiers.
for _, m := range mods { for _, m := range mods {
if err := m.Modify(certTpl, opts); err != nil { if err := m.Modify(certTpl, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") return nil, errs.ForbiddenErr(err, "error creating ssh certificate")
} }
} }
@ -244,7 +243,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// User provisioners validators. // User provisioners validators.
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert, opts); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") return nil, errs.ForbiddenErr(err, "error validating ssh certificate")
} }
} }
@ -382,7 +381,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
// Apply validators from provisioner. // Apply validators from provisioner.
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert, provisioner.SignSSHOptions{Backdate: backdate}); err != nil { if err := v.Valid(cert, provisioner.SignSSHOptions{Backdate: backdate}); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH") return nil, errs.ForbiddenErr(err, "error validating ssh certificate")
} }
} }
@ -407,12 +406,12 @@ func (a *Authority) storeSSHCertificate(cert *ssh.Certificate) error {
// the given certificate. // the given certificate.
func IsValidForAddUser(cert *ssh.Certificate) error { func IsValidForAddUser(cert *ssh.Certificate) error {
if cert.CertType != ssh.UserCert { if cert.CertType != ssh.UserCert {
return errors.New("certificate is not a user certificate") return errs.Forbidden("certificate is not a user certificate")
} }
switch len(cert.ValidPrincipals) { switch len(cert.ValidPrincipals) {
case 0: case 0:
return errors.New("certificate does not have any principals") return errs.Forbidden("certificate does not have any principals")
case 1: case 1:
return nil return nil
case 2: case 2:
@ -421,9 +420,9 @@ func IsValidForAddUser(cert *ssh.Certificate) error {
if strings.Index(cert.ValidPrincipals[1], "@") > 0 { if strings.Index(cert.ValidPrincipals[1], "@") > 0 {
return nil return nil
} }
return errors.New("certificate does not have only one principal") return errs.Forbidden("certificate does not have only one principal")
default: default:
return errors.New("certificate does not have only one principal") return errs.Forbidden("certificate does not have only one principal")
} }
} }
@ -433,7 +432,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
} }
if err := IsValidForAddUser(subject); err != nil { if err := IsValidForAddUser(subject); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSHAddUser") return nil, err
} }
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)

View file

@ -94,7 +94,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
// Validate the given certificate request. // Validate the given certificate request.
case provisioner.CertificateRequestValidator: case provisioner.CertificateRequestValidator:
if err := k.Valid(csr); err != nil { if err := k.Valid(csr); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error validating certificate"),
opts...,
)
} }
// Validates the unsigned certificate template. // Validates the unsigned certificate template.
@ -131,26 +134,38 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
// Set default subject // Set default subject
if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil { if err := withDefaultASN1DN(a.config.AuthorityConfig.Template).Modify(leaf, signOpts); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
)
} }
for _, m := range certModifiers { for _, m := range certModifiers {
if err := m.Modify(leaf, signOpts); err != nil { if err := m.Modify(leaf, signOpts); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
)
} }
} }
// Certificate validation. // Certificate validation.
for _, v := range certValidators { for _, v := range certValidators {
if err := v.Valid(leaf, signOpts); err != nil { if err := v.Valid(leaf, signOpts); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error validating certificate"),
opts...,
)
} }
} }
// Certificate modifiers after validation // Certificate modifiers after validation
for _, m := range certEnforcers { for _, m := range certEnforcers {
if err := m.Enforce(leaf); err != nil { if err := m.Enforce(leaf); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.Sign", opts...) return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
)
} }
} }

View file

@ -281,8 +281,8 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: errors.New("authority.Sign: default ASN1DN template cannot be nil"), err: errors.New("default ASN1DN template cannot be nil"),
code: http.StatusUnauthorized, code: http.StatusForbidden,
} }
}, },
"fail create cert": func(t *testing.T) *signTest { "fail create cert": func(t *testing.T) *signTest {
@ -309,8 +309,8 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), err: errors.New("requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusBadRequest, code: http.StatusForbidden,
} }
}, },
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -322,8 +322,8 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: errors.New("authority.Sign: certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"), err: errors.New("certificate request does not contain the valid DNS names - got [test.smallstep.com smallstep test], want [test.smallstep.com]"),
code: http.StatusUnauthorized, code: http.StatusForbidden,
} }
}, },
"fail rsa key too short": func(t *testing.T) *signTest { "fail rsa key too short": func(t *testing.T) *signTest {
@ -348,8 +348,8 @@ ZYtQ9Ot36qc=
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: errors.New("authority.Sign: rsa key in CSR must be at least 2048 bits (256 bytes)"), err: errors.New("certificate request RSA key must be at least 2048 bits (256 bytes)"),
code: http.StatusUnauthorized, code: http.StatusForbidden,
} }
}, },
"fail store cert in db": func(t *testing.T) *signTest { "fail store cert in db": func(t *testing.T) *signTest {

View file

@ -200,8 +200,8 @@ ZEp7knvU2psWRw==
return &signTest{ return &signTest{
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusUnauthorized, status: http.StatusForbidden,
errMsg: errs.UnauthorizedDefaultMsg, errMsg: errs.ForbiddenPrefix,
} }
}, },
"ok": func(t *testing.T) *signTest { "ok": func(t *testing.T) *signTest {

BIN
docs/images/star.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

View file

@ -169,7 +169,8 @@ func StatusCodeError(code int, e error, opts ...Option) error {
case http.StatusUnauthorized: case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
return ForbiddenErr(e, opts...) opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg))
return NewErr(http.StatusForbidden, e, opts...)
case http.StatusInternalServerError: case http.StatusInternalServerError:
return InternalServerErr(e, opts...) return InternalServerErr(e, opts...)
case http.StatusNotImplemented: case http.StatusNotImplemented:
@ -199,12 +200,18 @@ var (
// BadRequestPrefix is the prefix added to the bad request messages that are // BadRequestPrefix is the prefix added to the bad request messages that are
// directly sent to the cli. // directly sent to the cli.
BadRequestPrefix = "The request could not be completed: " BadRequestPrefix = "The request could not be completed: "
// ForbiddenPrefix is the prefix added to the forbidden messates that are
// sent to the cli.
ForbiddenPrefix = "The request was forbidden by the certificate authority: "
) )
func formatMessage(status int, msg string) string { func formatMessage(status int, msg string) string {
switch status { switch status {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequestPrefix + msg + "." return BadRequestPrefix + msg + "."
case http.StatusForbidden:
return ForbiddenPrefix + msg + "."
default: default:
return msg return msg
} }
@ -356,14 +363,12 @@ func UnauthorizedErr(err error, opts ...Option) error {
// Forbidden creates a 403 error with the given format and arguments. // Forbidden creates a 403 error with the given format and arguments.
func Forbidden(format string, args ...interface{}) error { func Forbidden(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(ForbiddenDefaultMsg)) return New(http.StatusForbidden, format, args...)
return Errorf(http.StatusForbidden, format, args...)
} }
// ForbiddenErr returns an 403 error with the given error. // ForbiddenErr returns an 403 error with the given error.
func ForbiddenErr(err error, opts ...Option) error { func ForbiddenErr(err error, format string, args ...interface{}) error {
opts = append(opts, withDefaultMessage(ForbiddenDefaultMsg)) return NewError(http.StatusForbidden, err, format, args...)
return NewErr(http.StatusForbidden, err, opts...)
} }
// NotFound creates a 404 error with the given format and arguments. // NotFound creates a 404 error with the given format and arguments.