Merge pull request #752 from smallstep/errors-bad-request

Bad request errors
This commit is contained in:
Mariano Cano 2021-11-22 13:16:04 -08:00 committed by GitHub
commit 4f84cef0cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 165 additions and 124 deletions

View file

@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
if v := q.Get("limit"); len(v) > 0 { if v := q.Get("limit"); len(v) > 0 {
limit, err = strconv.Atoi(v) limit, err = strconv.Atoi(v)
if err != nil { if err != nil {
return "", 0, errors.Wrapf(err, "error converting %s to integer", v) return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
} }
} }
return return

View file

@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expectedError400 := errs.BadRequest("force") expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
expectedError400Bytes, err := json.Marshal(expectedError400) expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err) assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force") expectedError500 := errs.InternalServer("force")

View file

@ -18,7 +18,7 @@ func (s *RekeyRequest) Validate() error {
return errs.BadRequest("missing csr") return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr") return errs.BadRequestErr(err, "invalid csr")
} }
return nil return nil
@ -26,15 +26,14 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate")) WriteError(w, errs.BadRequest("missing client certificate"))
return return
} }
var body RekeyRequest var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -10,7 +10,7 @@ import (
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate")) WriteError(w, errs.BadRequest("missing client certificate"))
return return
} }

View file

@ -49,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or peer certificate")) WriteError(w, errs.BadRequest("missing ott or client certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.

View file

@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
return errs.BadRequest("missing csr") return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr") return errs.BadRequestErr(err, "invalid csr")
} }
if s.OTT == "" { if s.OTT == "" {
return errs.BadRequest("missing ott") return errs.BadRequest("missing ott")
@ -50,7 +50,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -49,16 +49,16 @@ type SSHSignRequest struct {
func (s *SSHSignRequest) Validate() error { func (s *SSHSignRequest) Validate() error {
switch { switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType) return errs.BadRequest("invalid certType '%s'", s.CertType)
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey") return errs.BadRequest("missing or empty publicKey")
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
// Validate identity signature if provided // Validate identity signature if provided
if s.IdentityCSR.CertificateRequest != nil { if s.IdentityCSR.CertificateRequest != nil {
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
return errors.Wrap(err, "invalid identityCSR") return errs.BadRequestErr(err, "invalid identityCSR")
} }
} }
return nil return nil
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
case provisioner.SSHUserCert, provisioner.SSHHostCert: case provisioner.SSHUserCert, provisioner.SSHHostCert:
return nil return nil
default: default:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("invalid type '%s'", r.Type)
} }
} }
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
func (r *SSHCheckPrincipalRequest) Validate() error { func (r *SSHCheckPrincipalRequest) Validate() error {
switch { switch {
case r.Type != provisioner.SSHHostCert: case r.Type != provisioner.SSHHostCert:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("unsupported type '%s'", r.Type)
case r.Principal == "": case r.Principal == "":
return errors.New("missing or empty principal") return errs.BadRequest("missing or empty principal")
default: default:
return nil return nil
} }
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
// Validate checks the values of the SSHBastionRequest. // Validate checks the values of the SSHBastionRequest.
func (r *SSHBastionRequest) Validate() error { func (r *SSHBastionRequest) Validate() error {
if r.Hostname == "" { if r.Hostname == "" {
return errors.New("missing or empty hostname") return errs.BadRequest("missing or empty hostname")
} }
return nil return nil
} }
@ -250,19 +250,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -270,7 +270,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -394,11 +394,11 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -426,11 +426,11 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -465,11 +465,11 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
func (s *SSHRekeyRequest) Validate() error { func (s *SSHRekeyRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty public key") return errs.BadRequest("missing or empty public key")
default: default:
return nil return nil
} }
@ -40,19 +39,19 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }

View file

@ -19,7 +19,7 @@ type SSHRenewRequest struct {
func (s *SSHRenewRequest) Validate() error { func (s *SSHRenewRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
return nil return nil
} }
@ -37,13 +37,13 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -93,7 +93,7 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
// pointed by v. // pointed by v.
func ReadJSON(r io.Reader, v interface{}) error { func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil { if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error decoding json") return errs.BadRequestErr(err, "error decoding json")
} }
return nil return nil
} }
@ -103,7 +103,7 @@ func ReadJSON(r io.Reader, v interface{}) error {
func ReadProtoJSON(r io.Reader, m proto.Message) error { func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error reading request body") return errs.BadRequestErr(err, "error reading request body")
} }
return protojson.Unmarshal(data, m) return protojson.Unmarshal(data, m)
} }

View file

@ -228,7 +228,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token. // Use options in the token.
if opts.CertType != "" { if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "jwk.AuthorizeSSHSign") return nil, errs.BadRequestErr(err, err.Error())
} }
} }
if opts.KeyID != "" { if opts.KeyID != "" {

View file

@ -8,9 +8,7 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"time" "time"
@ -372,17 +370,6 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
return &validityValidator{min: min, max: max} return &validityValidator{min: min, max: max}
} }
// TODO(mariano): refactor errs package to allow sending real errors to the
// user.
func badRequest(format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &errs.Error{
Status: http.StatusBadRequest,
Msg: msg,
Err: errors.New(msg),
}
}
// Valid validates the certificate validity settings (notBefore/notAfter) and // Valid validates the certificate validity settings (notBefore/notAfter) and
// total duration. // total duration.
func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
@ -395,20 +382,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
d := na.Sub(nb) d := na.Sub(nb)
if na.Before(now) { if na.Before(now) {
return badRequest("notAfter cannot be in the past; na=%v", na) return errs.BadRequest("notAfter cannot be in the past; na=%v", na)
} }
if na.Before(nb) { if na.Before(nb) {
return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min) return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate) return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -55,7 +56,7 @@ type SignSSHOptions struct {
// Validate validates the given SignSSHOptions. // Validate validates the given SignSSHOptions.
func (o SignSSHOptions) Validate() error { func (o SignSSHOptions) Validate() error {
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
return errors.Errorf("unknown certType %s", o.CertType) return errs.BadRequest("unknown certificate type '%s'", o.CertType)
} }
return nil return nil
} }
@ -335,11 +336,11 @@ type sshCertValidityValidator struct {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return badRequest("ssh certificate validAfter cannot be 0") return errs.BadRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return badRequest("ssh certificate validBefore cannot be in the past") return errs.BadRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return badRequest("ssh certificate validBefore cannot be before validAfter") return errs.BadRequest("ssh certificate validBefore cannot be before validAfter")
} }
var min, max time.Duration var min, max time.Duration
@ -351,9 +352,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
min = v.MinHostSSHCertDuration() min = v.MinHostSSHCertDuration()
max = v.MaxHostSSHCertDuration() max = v.MaxHostSSHCertDuration()
case 0: case 0:
return badRequest("ssh certificate type has not been set") return errs.BadRequest("ssh certificate type has not been set")
default: default:
return badRequest("unknown ssh certificate type %d", cert.CertType) return errs.BadRequest("unknown ssh certificate type %d", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -362,9 +363,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min) return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate) return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }

View file

@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number")
"must be equivalent to sshpop certificate serial number")
} }
return nil return nil
} }
@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, nil return claims.sshCert, nil
@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, []SignOption{ return claims.sshCert, []SignOption{
// Validate public key // Validate public key

View file

@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {

View file

@ -271,7 +271,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token. // Use options in the token.
if opts.CertType != "" { if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "x5c.AuthorizeSSHSign") return nil, errs.BadRequestErr(err, err.Error())
} }
} }
if opts.KeyID != "" { if opts.KeyID != "" {

View file

@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
ts = a.templates.SSH.Host ts = a.templates.SSH.Host
} }
default: default:
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) return nil, errs.BadRequest("invalid certificate type '%s'", typ)
} }
// Merge user and default data // Merge user and default data
@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Check for required variables. // Check for required variables.
if err := t.ValidateRequiredData(data); err != nil { if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err)) return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
} }
o, err := t.Output(mergedData) o, err := t.Output(mergedData)
@ -151,7 +151,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Validate given options. // Validate given options.
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") return nil, err
} }
// Set backdate with the configured value // Set backdate with the configured value
@ -194,8 +194,8 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
certificate, err := sshutil.NewCertificate(cr, certOptions...) certificate, err := sshutil.NewCertificate(cr, certOptions...)
if err != nil { if err != nil {
if _, ok := err.(*sshutil.TemplateError); ok { if _, ok := err.(*sshutil.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err, return nil, errs.ApplyOptions(
errs.WithMessage(err.Error()), errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
) )
} }
@ -208,7 +208,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Use SignSSHOptions to modify the certificate validity. It will be later // Use SignSSHOptions to modify the certificate validity. It will be later
// checked or set if not defined. // checked or set if not defined.
if err := opts.ModifyValidity(certTpl); err != nil { if err := opts.ModifyValidity(certTpl); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") return nil, errs.BadRequestErr(err, err.Error())
} }
// Use provisioner modifiers. // Use provisioner modifiers.
@ -258,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") return nil, errs.BadRequest("cannot renew a certificate without validity period")
} }
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -329,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") return nil, errs.BadRequest("cannot rekey a certificate without validity period")
} }
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -369,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType)
} }
var err error var err error

View file

@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{}, cert: &ssh.Certificate{},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), err: errors.New("unexpected certificate type '0'"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },

View file

@ -76,7 +76,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
if err := csr.CheckSignature(); err != nil { if err := csr.CheckSignature(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...) return nil, errs.ApplyOptions(
errs.BadRequestErr(err, "invalid certificate request"),
opts...,
)
} }
// Set backdate with the configured value // Set backdate with the configured value
@ -114,8 +117,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
cert, err := x509util.NewCertificate(csr, certOptions...) cert, err := x509util.NewCertificate(csr, certOptions...)
if err != nil { if err != nil {
if _, ok := err.(*x509util.TemplateError); ok { if _, ok := err.(*x509util.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err, return nil, errs.ApplyOptions(
errs.WithMessage(err.Error()), errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("csr", csr), errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
) )
@ -433,8 +436,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
case db.ErrNotImplemented: case db.ErrNotImplemented:
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists: case db.ErrAlreadyExists:
return errs.BadRequest("authority.Revoke; certificate with serial "+ return errs.ApplyOptions(
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial),
opts...,
)
default: default:
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
} }

View file

@ -256,7 +256,7 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: errors.New("authority.Sign; invalid certificate request"), err: errors.New("invalid certificate request"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) assert.Equals(t, err.Details["token"], raw)

View file

@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
ca: ca, ca: ca,
body: "invalid json", body: "invalid json",
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail invalid-csr-sig": func(t *testing.T) *signTest { "fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail unauthorized-ott": func(t *testing.T) *signTest { "fail unauthorized-ott": func(t *testing.T) *signTest {
@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: nil, tlsConnState: nil,
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"request-missing-peer-certificate": func(t *testing.T) *renewTest { "request-missing-peer-certificate": func(t *testing.T) *renewTest {
@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"success": func(t *testing.T) *renewTest { "success": func(t *testing.T) *renewTest {

View file

@ -662,7 +662,7 @@ retry:
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") return nil, errs.BadRequest("root certificate fingerprint does not match")
} }
return &root, nil return &root, nil
} }
@ -1108,8 +1108,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body)
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {

View file

@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got) t.Errorf("Client.Revoke() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) {
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) {
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response) t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) {
}{ }{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {

View file

@ -25,7 +25,7 @@ type Option func(e *Error) error
// message only if it is empty. // message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option { func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error { return func(e *Error) error {
if len(e.Msg) > 0 { if e.Msg != "" {
return e return e
} }
e.Msg = fmt.Sprintf(format, args...) e.Msg = fmt.Sprintf(format, args...)
@ -164,7 +164,8 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error { func StatusCodeError(code int, e error, opts ...Option) error {
switch code { switch code {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequestErr(e, opts...) opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, e, opts...)
case http.StatusUnauthorized: case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
@ -194,6 +195,21 @@ var (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
) )
var (
// BadRequestPrefix is the prefix added to the bad request messages that are
// directly sent to the cli.
BadRequestPrefix = "The request could not be completed: "
)
func formatMessage(status int, msg string) string {
switch status {
case http.StatusBadRequest:
return BadRequestPrefix + msg + "."
default:
return msg
}
}
// splitOptionArgs splits the variadic length args into string formatting args // splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error. // and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@ -218,6 +234,32 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
return args[:indexOptionStart], opts return args[:indexOptionStart], opts
} }
// New creates a new http error with the given status and message.
func New(status int, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: errors.New(msg),
}
}
// NewError creates a new http error with the given error and message.
func NewError(status int, err error, format string, args ...interface{}) error {
if _, ok := err.(*Error); ok {
return err
}
msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok {
err = errors.Wrap(err, msg)
}
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: err,
}
}
// NewErr returns a new Error. If the given error implements the StatusCoder // NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status. // interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error { func NewErr(status int, err error, opts ...Option) error {
@ -254,6 +296,18 @@ func Errorf(code int, format string, args ...interface{}) error {
return e return e
} }
// ApplyOptions applies the given options to the error if is the type *Error.
// TODO(mariano): try to get rid of this.
func ApplyOptions(err error, opts ...interface{}) error {
if e, ok := err.(*Error); ok {
_, o := splitOptionArgs(opts)
for _, fn := range o {
fn(e)
}
}
return err
}
// InternalServer creates a 500 error with the given format and arguments. // InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error { func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
@ -280,14 +334,12 @@ func NotImplementedErr(err error, opts ...Option) error {
// BadRequest creates a 400 error with the given format and arguments. // BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error { func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg)) return New(http.StatusBadRequest, format, args...)
return Errorf(http.StatusBadRequest, format, args...)
} }
// BadRequestErr returns an 400 error with the given error. // BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error { func BadRequestErr(err error, format string, args ...interface{}) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) return NewError(http.StatusBadRequest, err, format, args...)
return NewErr(http.StatusBadRequest, err, opts...)
} }
// Unauthorized creates a 401 error with the given format and arguments. // Unauthorized creates a 401 error with the given format and arguments.