From 8c8db0d4b7ab17e764375aa5f2419592628415db Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 18 Nov 2021 18:17:36 -0800 Subject: [PATCH] Modify errs.BadRequestErr() to always return an error to the client. --- api/api.go | 4 ++-- api/api_test.go | 6 ++--- api/revoke_test.go | 4 ++-- api/sign.go | 4 ++-- api/ssh.go | 24 +++++++++---------- api/sshRekey.go | 7 +++--- api/sshRenew.go | 4 ++-- authority/provisioner/sshpop_test.go | 6 ++--- authority/ssh.go | 2 +- authority/ssh_test.go | 6 ++--- authority/tls_test.go | 2 +- ca/ca_test.go | 4 ++-- ca/client.go | 3 +-- errs/error.go | 35 ++++++++++++++++++++++------ 14 files changed, 65 insertions(+), 46 deletions(-) diff --git a/api/api.go b/api/api.go index 30ba03f9..e057caaa 100644 --- a/api/api.go +++ b/api/api.go @@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) { if v := q.Get("limit"); len(v) > 0 { limit, err = strconv.Atoi(v) if err != nil { - return "", 0, errors.Wrapf(err, "error converting %s to integer", v) + return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v) } } return diff --git a/api/api_test.go b/api/api_test.go index 0fab1a5b..5cbce8b3 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) { fields fields err error }{ - {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")}, + {"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")}, {"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")}, - {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")}, + {"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) { t.Fatal(err) } - expectedError400 := errs.BadRequestErr(errors.New("force")) + expectedError400 := errs.BadRequest("limit 'abc' is not an integer") expectedError400Bytes, err := json.Marshal(expectedError400) assert.FatalError(t, err) expectedError500 := errs.InternalServer("force") diff --git a/api/revoke_test.go b/api/revoke_test.go index b6ba30fb..4ed4e3fe 100644 --- a/api/revoke_test.go +++ b/api/revoke_test.go @@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) { tests := map[string]test{ "error/missing serial": { rr: &RevokeRequest{}, - err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest}, }, "error/bad reasonCode": { rr: &RevokeRequest{ @@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) { ReasonCode: 15, Passive: true, }, - err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest}, + err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest}, }, "error/non-passive not implemented": { rr: &RevokeRequest{ diff --git a/api/sign.go b/api/sign.go index d6fd2bc6..a1e5b998 100644 --- a/api/sign.go +++ b/api/sign.go @@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error { return errs.BadRequest("missing csr") } if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { - return errs.Wrap(http.StatusBadRequest, err, "invalid csr") + return errs.BadRequestErr(err, "invalid csr") } if s.OTT == "" { return errs.BadRequest("missing ott") @@ -50,7 +50,7 @@ type SignResponse struct { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := ReadJSON(r.Body, &body); err != nil { - WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) + WriteError(w, errs.BadRequestErr(err, "error reading request body")) return } diff --git a/api/ssh.go b/api/ssh.go index 7c7a5acd..315b3e83 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -49,16 +49,16 @@ type SSHSignRequest struct { func (s *SSHSignRequest) Validate() error { switch { case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: - return errors.Errorf("unknown certType %s", s.CertType) + return errs.BadRequest("invalid certType '%s'", s.CertType) case len(s.PublicKey) == 0: - return errors.New("missing or empty publicKey") + return errs.BadRequest("missing or empty publicKey") case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: // Validate identity signature if provided if s.IdentityCSR.CertificateRequest != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { - return errors.Wrap(err, "invalid identityCSR") + return errs.BadRequestErr(err, "invalid identityCSR") } } return nil @@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error { case provisioner.SSHUserCert, provisioner.SSHHostCert: return nil default: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("invalid type '%s'", r.Type) } } @@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct { func (r *SSHCheckPrincipalRequest) Validate() error { switch { case r.Type != provisioner.SSHHostCert: - return errors.Errorf("unsupported type %s", r.Type) + return errs.BadRequest("unsupported type '%s'", r.Type) case r.Principal == "": - return errors.New("missing or empty principal") + return errs.BadRequest("missing or empty principal") default: return nil } @@ -232,7 +232,7 @@ type SSHBastionRequest struct { // Validate checks the values of the SSHBastionRequest. func (r *SSHBastionRequest) Validate() error { if r.Hostname == "" { - return errors.New("missing or empty hostname") + return errs.BadRequest("missing or empty hostname") } return nil } @@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } @@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRekey.go b/api/sshRekey.go index 9d9e17cf..4e29b043 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -4,7 +4,6 @@ import ( "net/http" "time" - "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/errs" "golang.org/x/crypto/ssh" @@ -20,9 +19,9 @@ type SSHRekeyRequest struct { func (s *SSHRekeyRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") case len(s.PublicKey) == 0: - return errors.New("missing or empty public key") + return errs.BadRequest("missing or empty public key") default: return nil } @@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/api/sshRenew.go b/api/sshRenew.go index d0633ecf..d28b57b5 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -19,7 +19,7 @@ type SSHRenewRequest struct { func (s *SSHRenewRequest) Validate() error { switch { case s.OTT == "": - return errors.New("missing or empty ott") + return errs.BadRequest("missing or empty ott") default: return nil } @@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { logOtt(w, body.OTT) if err := body.Validate(); err != nil { - WriteError(w, errs.BadRequestErr(err)) + WriteError(w, err) return } diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 850a698d..da036864 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."), + err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"), } }, "ok": func(t *testing.T) test { @@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { @@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { p: p, token: tok, code: http.StatusBadRequest, - err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."), + err: errors.New("sshpop certificate must be a host ssh certificate"), } }, "ok": func(t *testing.T) test { diff --git a/authority/ssh.go b/authority/ssh.go index eba48297..5e03ee9e 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin // Check for required variables. if err := t.ValidateRequiredData(data); err != nil { - return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set ` flag", err)) + return nil, errs.BadRequestErr(err, "%v, please use `--set ` flag", err) } o, err := t.Output(mergedData) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index a62c9e54..b0907a79 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."), + err: errors.New("cannot rekey a certificate without validity period"), code: http.StatusBadRequest, } }, @@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) { cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, key: pub, signOpts: []provisioner.SignOption{}, - err: errors.New("The request could not be completed: unexpected certificate type '0'."), + err: errors.New("unexpected certificate type '0'"), code: http.StatusBadRequest, } }, diff --git a/authority/tls_test.go b/authority/tls_test.go index 1796c4a3..409c0582 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) { Reason: reason, OTT: raw, }, - err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"), + err: errors.New("certificate with serial number 'sn' is already revoked"), code: http.StatusBadRequest, checkErrDetails: func(err *errs.Error) { assert.Equals(t, err.Details["token"], raw) diff --git a/ca/ca_test.go b/ca/ca_test.go index 64371ac3..1271659a 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -115,7 +115,7 @@ func TestCASign(t *testing.T) { ca: ca, body: "invalid json", status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail invalid-csr-sig": func(t *testing.T) *signTest { @@ -153,7 +153,7 @@ ZEp7knvU2psWRw== ca: ca, body: string(body), status: http.StatusBadRequest, - errMsg: errs.BadRequestDefaultMsg, + errMsg: errs.BadRequestPrefix, } }, "fail unauthorized-ott": func(t *testing.T) *signTest { diff --git a/ca/client.go b/ca/client.go index 74a3b7df..6bc48a42 100644 --- a/ca/client.go +++ b/ca/client.go @@ -1108,8 +1108,7 @@ retry: retried = true goto retry } - - return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body)) + return nil, readError(resp.Body) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { diff --git a/errs/error.go b/errs/error.go index ab488af1..3e40b3f3 100644 --- a/errs/error.go +++ b/errs/error.go @@ -25,7 +25,7 @@ type Option func(e *Error) error // message only if it is empty. func withDefaultMessage(format string, args ...interface{}) Option { return func(e *Error) error { - if len(e.Msg) > 0 { + if e.Msg != "" { return e } e.Msg = fmt.Sprintf(format, args...) @@ -164,7 +164,8 @@ type Messenger interface { func StatusCodeError(code int, e error, opts ...Option) error { switch code { case http.StatusBadRequest: - return BadRequestErr(e, opts...) + opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) + return NewErr(http.StatusBadRequest, e, opts...) case http.StatusUnauthorized: return UnauthorizedErr(e, opts...) case http.StatusForbidden: @@ -200,6 +201,15 @@ var ( BadRequestPrefix = "The request could not be completed: " ) +func formatMessage(status int, msg string) string { + switch status { + case http.StatusBadRequest: + return BadRequestPrefix + msg + "." + default: + return msg + } +} + // splitOptionArgs splits the variadic length args into string formatting args // and Option(s) to apply to an Error. func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { @@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error { msg := fmt.Sprintf(format, args...) return &Error{ Status: status, - Msg: msg, + Msg: formatMessage(status, msg), Err: errors.New(msg), } } +// NewError creates a new http error with the given error and message. +func NewError(status int, err error, format string, args ...interface{}) error { + msg := fmt.Sprintf(format, args...) + if _, ok := err.(StackTracer); !ok { + err = errors.Wrap(err, msg) + } + return &Error{ + Status: status, + Msg: formatMessage(status, msg), + Err: err, + } +} + // NewErr returns a new Error. If the given error implements the StatusCoder // interface we will ignore the given status. func NewErr(status int, err error, opts ...Option) error { @@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error { // BadRequest creates a 400 error with the given format and arguments. func BadRequest(format string, args ...interface{}) error { - format = BadRequestPrefix + format + "." return New(http.StatusBadRequest, format, args...) } // BadRequestErr returns an 400 error with the given error. -func BadRequestErr(err error, opts ...Option) error { - opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) - return NewErr(http.StatusBadRequest, err, opts...) +func BadRequestErr(err error, format string, args ...interface{}) error { + return NewError(http.StatusBadRequest, err, format, args...) } // Unauthorized creates a 401 error with the given format and arguments.