Modify errs.BadRequestErr() to always return an error to the client.

This commit is contained in:
Mariano Cano 2021-11-18 18:17:36 -08:00
parent 8ce807a6cb
commit 8c8db0d4b7
14 changed files with 65 additions and 46 deletions

View file

@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r)
if err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
if v := q.Get("limit"); len(v) > 0 {
limit, err = strconv.Atoi(v)
if err != nil {
return "", 0, errors.Wrapf(err, "error converting %s to integer", v)
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
}
}
return

View file

@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) {
fields fields
err error
}{
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")},
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")},
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")},
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")},
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err)
}
expectedError400 := errs.BadRequestErr(errors.New("force"))
expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force")

View file

@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) {
tests := map[string]test{
"error/missing serial": {
rr: &RevokeRequest{},
err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest},
err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
},
"error/bad reasonCode": {
rr: &RevokeRequest{
@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) {
ReasonCode: 15,
Passive: true,
},
err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest},
err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
},
"error/non-passive not implemented": {
rr: &RevokeRequest{

View file

@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
return errs.BadRequest("missing csr")
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
return errs.BadRequestErr(err, "invalid csr")
}
if s.OTT == "" {
return errs.BadRequest("missing ott")
@ -50,7 +50,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return
}

View file

@ -49,16 +49,16 @@ type SSHSignRequest struct {
func (s *SSHSignRequest) Validate() error {
switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType)
return errs.BadRequest("invalid certType '%s'", s.CertType)
case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey")
return errs.BadRequest("missing or empty publicKey")
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
default:
// Validate identity signature if provided
if s.IdentityCSR.CertificateRequest != nil {
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
return errors.Wrap(err, "invalid identityCSR")
return errs.BadRequestErr(err, "invalid identityCSR")
}
}
return nil
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
case provisioner.SSHUserCert, provisioner.SSHHostCert:
return nil
default:
return errors.Errorf("unsupported type %s", r.Type)
return errs.BadRequest("invalid type '%s'", r.Type)
}
}
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
func (r *SSHCheckPrincipalRequest) Validate() error {
switch {
case r.Type != provisioner.SSHHostCert:
return errors.Errorf("unsupported type %s", r.Type)
return errs.BadRequest("unsupported type '%s'", r.Type)
case r.Principal == "":
return errors.New("missing or empty principal")
return errs.BadRequest("missing or empty principal")
default:
return nil
}
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
// Validate checks the values of the SSHBastionRequest.
func (r *SSHBastionRequest) Validate() error {
if r.Hostname == "" {
return errors.New("missing or empty hostname")
return errs.BadRequest("missing or empty hostname")
}
return nil
}
@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}
@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return
}
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}

View file

@ -4,7 +4,6 @@ import (
"net/http"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
func (s *SSHRekeyRequest) Validate() error {
switch {
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
case len(s.PublicKey) == 0:
return errors.New("missing or empty public key")
return errs.BadRequest("missing or empty public key")
default:
return nil
}
@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}

View file

@ -19,7 +19,7 @@ type SSHRenewRequest struct {
func (s *SSHRenewRequest) Validate() error {
switch {
case s.OTT == "":
return errors.New("missing or empty ott")
return errs.BadRequest("missing or empty ott")
default:
return nil
}
@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err))
WriteError(w, err)
return
}

View file

@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."),
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
}
},
"ok": func(t *testing.T) test {
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
err: errors.New("sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
p: p,
token: tok,
code: http.StatusBadRequest,
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
err: errors.New("sshpop certificate must be a host ssh certificate"),
}
},
"ok": func(t *testing.T) test {

View file

@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Check for required variables.
if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err))
return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
}
o, err := t.Output(mergedData)

View file

@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest,
}
},
@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest,
}
},
@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub,
signOpts: []provisioner.SignOption{},
err: errors.New("The request could not be completed: unexpected certificate type '0'."),
err: errors.New("unexpected certificate type '0'"),
code: http.StatusBadRequest,
}
},

View file

@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
Reason: reason,
OTT: raw,
},
err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"),
err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw)

View file

@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
ca: ca,
body: "invalid json",
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
ca: ca,
body: string(body),
status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg,
errMsg: errs.BadRequestPrefix,
}
},
"fail unauthorized-ott": func(t *testing.T) *signTest {

View file

@ -1108,8 +1108,7 @@ retry:
retried = true
goto retry
}
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
return nil, readError(resp.Body)
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {

View file

@ -25,7 +25,7 @@ type Option func(e *Error) error
// message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
if len(e.Msg) > 0 {
if e.Msg != "" {
return e
}
e.Msg = fmt.Sprintf(format, args...)
@ -164,7 +164,8 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error {
switch code {
case http.StatusBadRequest:
return BadRequestErr(e, opts...)
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, e, opts...)
case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...)
case http.StatusForbidden:
@ -200,6 +201,15 @@ var (
BadRequestPrefix = "The request could not be completed: "
)
func formatMessage(status int, msg string) string {
switch status {
case http.StatusBadRequest:
return BadRequestPrefix + msg + "."
default:
return msg
}
}
// splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &Error{
Status: status,
Msg: msg,
Msg: formatMessage(status, msg),
Err: errors.New(msg),
}
}
// NewError creates a new http error with the given error and message.
func NewError(status int, err error, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok {
err = errors.Wrap(err, msg)
}
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: err,
}
}
// NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error {
@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error {
// BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error {
format = BadRequestPrefix + format + "."
return New(http.StatusBadRequest, format, args...)
}
// BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, err, opts...)
func BadRequestErr(err error, format string, args ...interface{}) error {
return NewError(http.StatusBadRequest, err, format, args...)
}
// Unauthorized creates a 401 error with the given format and arguments.