Modify errs.BadRequestErr() to always return an error to the client.

This commit is contained in:
Mariano Cano 2021-11-18 18:17:36 -08:00
parent 8ce807a6cb
commit 8c8db0d4b7
14 changed files with 65 additions and 46 deletions

View file

@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
if v := q.Get("limit"); len(v) > 0 { if v := q.Get("limit"); len(v) > 0 {
limit, err = strconv.Atoi(v) limit, err = strconv.Atoi(v)
if err != nil { if err != nil {
return "", 0, errors.Wrapf(err, "error converting %s to integer", v) return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
} }
} }
return return

View file

@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) {
fields fields fields fields
err error err error
}{ }{
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")}, {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")},
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")},
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")}, {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expectedError400 := errs.BadRequestErr(errors.New("force")) expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
expectedError400Bytes, err := json.Marshal(expectedError400) expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err) assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force") expectedError500 := errs.InternalServer("force")

View file

@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) {
tests := map[string]test{ tests := map[string]test{
"error/missing serial": { "error/missing serial": {
rr: &RevokeRequest{}, rr: &RevokeRequest{},
err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest}, err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
}, },
"error/bad reasonCode": { "error/bad reasonCode": {
rr: &RevokeRequest{ rr: &RevokeRequest{
@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) {
ReasonCode: 15, ReasonCode: 15,
Passive: true, Passive: true,
}, },
err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest}, err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
}, },
"error/non-passive not implemented": { "error/non-passive not implemented": {
rr: &RevokeRequest{ rr: &RevokeRequest{

View file

@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
return errs.BadRequest("missing csr") return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr") return errs.BadRequestErr(err, "invalid csr")
} }
if s.OTT == "" { if s.OTT == "" {
return errs.BadRequest("missing ott") return errs.BadRequest("missing ott")
@ -50,7 +50,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -49,16 +49,16 @@ type SSHSignRequest struct {
func (s *SSHSignRequest) Validate() error { func (s *SSHSignRequest) Validate() error {
switch { switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType) return errs.BadRequest("invalid certType '%s'", s.CertType)
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey") return errs.BadRequest("missing or empty publicKey")
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
// Validate identity signature if provided // Validate identity signature if provided
if s.IdentityCSR.CertificateRequest != nil { if s.IdentityCSR.CertificateRequest != nil {
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
return errors.Wrap(err, "invalid identityCSR") return errs.BadRequestErr(err, "invalid identityCSR")
} }
} }
return nil return nil
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
case provisioner.SSHUserCert, provisioner.SSHHostCert: case provisioner.SSHUserCert, provisioner.SSHHostCert:
return nil return nil
default: default:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("invalid type '%s'", r.Type)
} }
} }
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
func (r *SSHCheckPrincipalRequest) Validate() error { func (r *SSHCheckPrincipalRequest) Validate() error {
switch { switch {
case r.Type != provisioner.SSHHostCert: case r.Type != provisioner.SSHHostCert:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("unsupported type '%s'", r.Type)
case r.Principal == "": case r.Principal == "":
return errors.New("missing or empty principal") return errs.BadRequest("missing or empty principal")
default: default:
return nil return nil
} }
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
// Validate checks the values of the SSHBastionRequest. // Validate checks the values of the SSHBastionRequest.
func (r *SSHBastionRequest) Validate() error { func (r *SSHBastionRequest) Validate() error {
if r.Hostname == "" { if r.Hostname == "" {
return errors.New("missing or empty hostname") return errs.BadRequest("missing or empty hostname")
} }
return nil return nil
} }
@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
func (s *SSHRekeyRequest) Validate() error { func (s *SSHRekeyRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty public key") return errs.BadRequest("missing or empty public key")
default: default:
return nil return nil
} }
@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -19,7 +19,7 @@ type SSHRenewRequest struct {
func (s *SSHRenewRequest) Validate() error { func (s *SSHRenewRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
return nil return nil
} }
@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."), err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {

View file

@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Check for required variables. // Check for required variables.
if err := t.ValidateRequiredData(data); err != nil { if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err)) return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
} }
o, err := t.Output(mergedData) o, err := t.Output(mergedData)

View file

@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{}, cert: &ssh.Certificate{},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: unexpected certificate type '0'."), err: errors.New("unexpected certificate type '0'"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },

View file

@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"), err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) assert.Equals(t, err.Details["token"], raw)

View file

@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
ca: ca, ca: ca,
body: "invalid json", body: "invalid json",
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail invalid-csr-sig": func(t *testing.T) *signTest { "fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail unauthorized-ott": func(t *testing.T) *signTest { "fail unauthorized-ott": func(t *testing.T) *signTest {

View file

@ -1108,8 +1108,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body)
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {

View file

@ -25,7 +25,7 @@ type Option func(e *Error) error
// message only if it is empty. // message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option { func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error { return func(e *Error) error {
if len(e.Msg) > 0 { if e.Msg != "" {
return e return e
} }
e.Msg = fmt.Sprintf(format, args...) e.Msg = fmt.Sprintf(format, args...)
@ -164,7 +164,8 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error { func StatusCodeError(code int, e error, opts ...Option) error {
switch code { switch code {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequestErr(e, opts...) opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, e, opts...)
case http.StatusUnauthorized: case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
@ -200,6 +201,15 @@ var (
BadRequestPrefix = "The request could not be completed: " BadRequestPrefix = "The request could not be completed: "
) )
func formatMessage(status int, msg string) string {
switch status {
case http.StatusBadRequest:
return BadRequestPrefix + msg + "."
default:
return msg
}
}
// splitOptionArgs splits the variadic length args into string formatting args // splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error. // and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
return &Error{ return &Error{
Status: status, Status: status,
Msg: msg, Msg: formatMessage(status, msg),
Err: errors.New(msg), Err: errors.New(msg),
} }
} }
// NewError creates a new http error with the given error and message.
func NewError(status int, err error, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok {
err = errors.Wrap(err, msg)
}
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: err,
}
}
// NewErr returns a new Error. If the given error implements the StatusCoder // NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status. // interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error { func NewErr(status int, err error, opts ...Option) error {
@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error {
// BadRequest creates a 400 error with the given format and arguments. // BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error { func BadRequest(format string, args ...interface{}) error {
format = BadRequestPrefix + format + "."
return New(http.StatusBadRequest, format, args...) return New(http.StatusBadRequest, format, args...)
} }
// BadRequestErr returns an 400 error with the given error. // BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error { func BadRequestErr(err error, format string, args ...interface{}) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) return NewError(http.StatusBadRequest, err, format, args...)
return NewErr(http.StatusBadRequest, err, opts...)
} }
// Unauthorized creates a 401 error with the given format and arguments. // Unauthorized creates a 401 error with the given format and arguments.