Merge pull request #752 from smallstep/errors-bad-request

Bad request errors
This commit is contained in:
Mariano Cano 2021-11-22 13:16:04 -08:00 committed by GitHub
commit 4f84cef0cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 165 additions and 124 deletions

View file

@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r)
if err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
if v := q.Get("limit"); len(v) > 0 {
limit, err = strconv.Atoi(v)
if err != nil {
return "", 0, errors.Wrapf(err, "error converting %s to integer", v)
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
}
}
return

View file

@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err)
}
expectedError400 := errs.BadRequest("force")
expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force")

View file

@ -18,7 +18,7 @@ func (s *RekeyRequest) Validate() error {
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
return errs.BadRequestErr(err, "invalid csr")
}
return nil
@ -26,15 +26,14 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate"))
WriteError(w, errs.BadRequest("missing client certificate"))
return
}
var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}

View file

@ -10,7 +10,7 @@ import (
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate"))
WriteError(w, errs.BadRequest("missing client certificate"))
return
}

View file

@ -49,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or peer certificate"))
WriteError(w, errs.BadRequest("missing ott or client certificate"))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body"))
WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
return
}
// TODO: should probably be checking if the certificate was revoked here.

View file

@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
return errs.BadRequestErr(err, "invalid csr")
}
if s.OTT == "" {
return errs.BadRequest("missing ott")
@ -50,7 +50,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}

View file

@ -49,16 +49,16 @@ type SSHSignRequest struct {
func (s *SSHSignRequest) Validate() error {
switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType)
return errs.BadRequest("invalid certType '%s'", s.CertType)
case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey")
return errs.BadRequest("missing or empty publicKey")
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
default:
// Validate identity signature if provided
if s.IdentityCSR.CertificateRequest != nil {
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
return errors.Wrap(err, "invalid identityCSR")
return errs.BadRequestErr(err, "invalid identityCSR")
}
}
return nil
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
case provisioner.SSHUserCert, provisioner.SSHHostCert:
return nil
default:
return errors.Errorf("unsupported type %s", r.Type)
return errs.BadRequest("invalid type '%s'", r.Type)
}
}
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
func (r *SSHCheckPrincipalRequest) Validate() error {
switch {
case r.Type != provisioner.SSHHostCert:
return errors.Errorf("unsupported type %s", r.Type)
return errs.BadRequest("unsupported type '%s'", r.Type)
case r.Principal == "":
return errors.New("missing or empty principal")
return errs.BadRequest("missing or empty principal")
default:
return nil
}
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
// Validate checks the values of the SSHBastionRequest.
func (r *SSHBastionRequest) Validate() error {
if r.Hostname == "" {
return errors.New("missing or empty hostname")
return errs.BadRequest("missing or empty hostname")
}
return nil
}
@ -250,19 +250,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}
@ -270,7 +270,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey"))
WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return
}
}
@ -394,11 +394,11 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -426,11 +426,11 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -465,11 +465,11 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}

View file

@ -4,7 +4,6 @@ import (
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
func (s *SSHRekeyRequest) Validate() error {
switch {
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
case len(s.PublicKey) == 0:
return errors.New("missing or empty public key")
return errs.BadRequest("missing or empty public key")
default:
return nil
}
@ -40,19 +39,19 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey"))
WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return
}

View file

@ -19,7 +19,7 @@ type SSHRenewRequest struct {
func (s *SSHRenewRequest) Validate() error {
switch {
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
default:
return nil
}
@ -37,13 +37,13 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}

View file

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}

View file

@ -93,7 +93,7 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error decoding json")
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
@ -103,7 +103,7 @@ func ReadJSON(r io.Reader, v interface{}) error {
func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error reading request body")
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

View file

@ -228,7 +228,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token.
if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "jwk.AuthorizeSSHSign")
return nil, errs.BadRequestErr(err, err.Error())
}
}
if opts.KeyID != "" {

View file

@ -8,9 +8,7 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"reflect"
"time"
@ -372,17 +370,6 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
return &validityValidator{min: min, max: max}
}
// TODO(mariano): refactor errs package to allow sending real errors to the
// user.
func badRequest(format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &errs.Error{
Status: http.StatusBadRequest,
Msg: msg,
Err: errors.New(msg),
}
}
// Valid validates the certificate validity settings (notBefore/notAfter) and
// total duration.
func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
@ -395,20 +382,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
d := na.Sub(nb)
if na.Before(now) {
return badRequest("notAfter cannot be in the past; na=%v", na)
return errs.BadRequest("notAfter cannot be in the past; na=%v", na)
}
if na.Before(nb) {
return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
}
if d < v.min {
return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
}
// NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough.
if d > v.max+o.Backdate {
return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
}
return nil
}

View file

@ -9,6 +9,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh"
)
@ -55,7 +56,7 @@ type SignSSHOptions struct {
// Validate validates the given SignSSHOptions.
func (o SignSSHOptions) Validate() error {
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
return errors.Errorf("unknown certType %s", o.CertType)
return errs.BadRequest("unknown certificate type '%s'", o.CertType)
}
return nil
}
@ -335,11 +336,11 @@ type sshCertValidityValidator struct {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error {
switch {
case cert.ValidAfter == 0:
return badRequest("ssh certificate validAfter cannot be 0")
return errs.BadRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()):
return badRequest("ssh certificate validBefore cannot be in the past")
return errs.BadRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter:
return badRequest("ssh certificate validBefore cannot be before validAfter")
return errs.BadRequest("ssh certificate validBefore cannot be before validAfter")
}
var min, max time.Duration
@ -351,9 +352,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
min = v.MinHostSSHCertDuration()
max = v.MaxHostSSHCertDuration()
case 0:
return badRequest("ssh certificate type has not been set")
return errs.BadRequest("ssh certificate type has not been set")
default:
return badRequest("unknown ssh certificate type %d", cert.CertType)
return errs.BadRequest("unknown ssh certificate type %d", cert.CertType)
}
// To not take into account the backdate, time.Now() will be used to
@ -362,9 +363,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch {
case dur < min:
return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate:
return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default:
return nil
}

View file

@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
}
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " +
"must be equivalent to sshpop certificate serial number")
return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number")
}
return nil
}
@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate")
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, nil
@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
}
if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate")
return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
}
return claims.sshCert, []SignOption{
// Validate public key

View file

@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"),
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
}
},
"ok": func(t *testing.T) test {
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"),
err: errors.New("sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"),
err: errors.New("sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {

View file

@ -271,7 +271,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token.
if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "x5c.AuthorizeSSHSign")
return nil, errs.BadRequestErr(err, err.Error())
}
}
if opts.KeyID != "" {

View file

@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
ts = a.templates.SSH.Host
}
default:
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ)
return nil, errs.BadRequest("invalid certificate type '%s'", typ)
}
// Merge user and default data
@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Check for required variables.
if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err))
return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
}
o, err := t.Output(mergedData)
@ -151,7 +151,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Validate given options.
if err := opts.Validate(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
return nil, err
}
// Set backdate with the configured value
@ -194,8 +194,8 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
certificate, err := sshutil.NewCertificate(cr, certOptions...)
if err != nil {
if _, ok := err.(*sshutil.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err,
errs.WithMessage(err.Error()),
return nil, errs.ApplyOptions(
errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
)
}
@ -208,7 +208,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Use SignSSHOptions to modify the certificate validity. It will be later
// checked or set if not defined.
if err := opts.ModifyValidity(certTpl); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
return nil, errs.BadRequestErr(err, err.Error())
}
// Use provisioner modifiers.
@ -258,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period")
return nil, errs.BadRequest("cannot renew a certificate without validity period")
}
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -329,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
}
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period")
return nil, errs.BadRequest("cannot rekey a certificate without validity period")
}
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -369,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
}
signer = a.sshCAHostCertSignKey
default:
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType)
return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType)
}
var err error

View file

@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest,
}
},
@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"),
err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest,
}
},
@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"),
err: errors.New("unexpected certificate type '0'"),
code: http.StatusBadRequest,
}
},

View file

@ -76,7 +76,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
if err := csr.CheckSignature(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...)
return nil, errs.ApplyOptions(
errs.BadRequestErr(err, "invalid certificate request"),
opts...,
)
}
// Set backdate with the configured value
@ -114,8 +117,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
cert, err := x509util.NewCertificate(csr, certOptions...)
if err != nil {
if _, ok := err.(*x509util.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err,
errs.WithMessage(err.Error()),
return nil, errs.ApplyOptions(
errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts),
)
@ -433,8 +436,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
case db.ErrNotImplemented:
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists:
return errs.BadRequest("authority.Revoke; certificate with serial "+
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...)
return errs.ApplyOptions(
errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial),
opts...,
)
default:
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
}

View file

@ -256,7 +256,7 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr,
extraOpts: extraOpts,
signOpts: signOpts,
err: errors.New("authority.Sign; invalid certificate request"),
err: errors.New("invalid certificate request"),
code: http.StatusBadRequest,
}
},
@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
Reason: reason,
OTT: raw,
},
err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"),
err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)

View file

@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
ca: ca,
body: "invalid json",
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
ca: ca,
body: string(body),
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"fail unauthorized-ott": func(t *testing.T) *signTest {
@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) {
ca: ca,
tlsConnState: nil,
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"request-missing-peer-certificate": func(t *testing.T) *renewTest {
@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) {
ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"success": func(t *testing.T) *renewTest {

View file

@ -662,7 +662,7 @@ retry:
// verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw)
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
return nil, errs.BadRequest("root certificate fingerprint does not match")
}
return &root, nil
}
@ -1108,8 +1108,7 @@ retry:
retried = true
goto retry
}
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
return nil, readError(resp.Body)
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {

View file

@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) {
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
}
srv := httptest.NewServer(nil)
@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) {
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got)
}
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error())
assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) {
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) {
}{
{"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) {
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) {
}{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) {
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error())
assert.HasPrefix(t, err.Error(), tt.err.Error())
}
default:
if !reflect.DeepEqual(got, tt.response) {

View file

@ -25,7 +25,7 @@ type Option func(e *Error) error
// message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
if len(e.Msg) > 0 {
if e.Msg != "" {
return e
}
e.Msg = fmt.Sprintf(format, args...)
@ -164,7 +164,8 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error {
switch code {
case http.StatusBadRequest:
return BadRequestErr(e, opts...)
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, e, opts...)
case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...)
case http.StatusForbidden:
@ -194,6 +195,21 @@ var (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
)
var (
// BadRequestPrefix is the prefix added to the bad request messages that are
// directly sent to the cli.
BadRequestPrefix = "The request could not be completed: "
)
func formatMessage(status int, msg string) string {
switch status {
case http.StatusBadRequest:
return BadRequestPrefix + msg + "."
default:
return msg
}
}
// splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@ -218,6 +234,32 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
return args[:indexOptionStart], opts
}
// New creates a new http error with the given status and message.
func New(status int, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: errors.New(msg),
}
}
// NewError creates a new http error with the given error and message.
func NewError(status int, err error, format string, args ...interface{}) error {
if _, ok := err.(*Error); ok {
return err
}
msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok {
err = errors.Wrap(err, msg)
}
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: err,
}
}
// NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error {
@ -254,6 +296,18 @@ func Errorf(code int, format string, args ...interface{}) error {
return e
}
// ApplyOptions applies the given options to the error if is the type *Error.
// TODO(mariano): try to get rid of this.
func ApplyOptions(err error, opts ...interface{}) error {
if e, ok := err.(*Error); ok {
_, o := splitOptionArgs(opts)
for _, fn := range o {
fn(e)
}
}
return err
}
// InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
@ -280,14 +334,12 @@ func NotImplementedErr(err error, opts ...Option) error {
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg))
return Errorf(http.StatusBadRequest, format, args...)
return New(http.StatusBadRequest, format, args...)
}
// BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, err, opts...)
func BadRequestErr(err error, format string, args ...interface{}) error {
return NewError(http.StatusBadRequest, err, format, args...)
}
// Unauthorized creates a 401 error with the given format and arguments.