Modify errs.BadRequestErr() to always return an error to the client.
This commit is contained in:
parent
8ce807a6cb
commit
8c8db0d4b7
14 changed files with 65 additions and 46 deletions
|
@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
|
|||
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
|
||||
cursor, limit, err := ParseCursor(r)
|
||||
if err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
|
|||
if v := q.Get("limit"); len(v) > 0 {
|
||||
limit, err = strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrapf(err, "error converting %s to integer", v)
|
||||
return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
|
|
@ -403,9 +403,9 @@ func TestSignRequest_Validate(t *testing.T) {
|
|||
fields fields
|
||||
err error
|
||||
}{
|
||||
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing csr.")},
|
||||
{"missing csr", fields{CertificateRequest{}, "foobarzar", time.Time{}, time.Time{}}, errors.New("missing csr")},
|
||||
{"invalid csr", fields{CertificateRequest{bad}, "foobarzar", time.Time{}, time.Time{}}, errors.New("invalid csr")},
|
||||
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("The request could not be completed: missing ott.")},
|
||||
{"missing ott", fields{CertificateRequest{csr}, "", time.Time{}, time.Time{}}, errors.New("missing ott")},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedError400 := errs.BadRequestErr(errors.New("force"))
|
||||
expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
|
||||
expectedError400Bytes, err := json.Marshal(expectedError400)
|
||||
assert.FatalError(t, err)
|
||||
expectedError500 := errs.InternalServer("force")
|
||||
|
|
|
@ -28,7 +28,7 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
tests := map[string]test{
|
||||
"error/missing serial": {
|
||||
rr: &RevokeRequest{},
|
||||
err: &errs.Error{Err: errors.New("The request could not be completed: missing serial."), Status: http.StatusBadRequest},
|
||||
err: &errs.Error{Err: errors.New("missing serial"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/bad reasonCode": {
|
||||
rr: &RevokeRequest{
|
||||
|
@ -36,7 +36,7 @@ func TestRevokeRequestValidate(t *testing.T) {
|
|||
ReasonCode: 15,
|
||||
Passive: true,
|
||||
},
|
||||
err: &errs.Error{Err: errors.New("The request could not be completed: reasonCode out of bounds."), Status: http.StatusBadRequest},
|
||||
err: &errs.Error{Err: errors.New("reasonCode out of bounds"), Status: http.StatusBadRequest},
|
||||
},
|
||||
"error/non-passive not implemented": {
|
||||
rr: &RevokeRequest{
|
||||
|
|
|
@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
|
|||
return errs.BadRequest("missing csr")
|
||||
}
|
||||
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errs.Wrap(http.StatusBadRequest, err, "invalid csr")
|
||||
return errs.BadRequestErr(err, "invalid csr")
|
||||
}
|
||||
if s.OTT == "" {
|
||||
return errs.BadRequest("missing ott")
|
||||
|
@ -50,7 +50,7 @@ type SignResponse struct {
|
|||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
||||
|
|
24
api/ssh.go
24
api/ssh.go
|
@ -49,16 +49,16 @@ type SSHSignRequest struct {
|
|||
func (s *SSHSignRequest) Validate() error {
|
||||
switch {
|
||||
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
|
||||
return errors.Errorf("unknown certType %s", s.CertType)
|
||||
return errs.BadRequest("invalid certType '%s'", s.CertType)
|
||||
case len(s.PublicKey) == 0:
|
||||
return errors.New("missing or empty publicKey")
|
||||
return errs.BadRequest("missing or empty publicKey")
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
default:
|
||||
// Validate identity signature if provided
|
||||
if s.IdentityCSR.CertificateRequest != nil {
|
||||
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
|
||||
return errors.Wrap(err, "invalid identityCSR")
|
||||
return errs.BadRequestErr(err, "invalid identityCSR")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
|
|||
case provisioner.SSHUserCert, provisioner.SSHHostCert:
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("unsupported type %s", r.Type)
|
||||
return errs.BadRequest("invalid type '%s'", r.Type)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
|
|||
func (r *SSHCheckPrincipalRequest) Validate() error {
|
||||
switch {
|
||||
case r.Type != provisioner.SSHHostCert:
|
||||
return errors.Errorf("unsupported type %s", r.Type)
|
||||
return errs.BadRequest("unsupported type '%s'", r.Type)
|
||||
case r.Principal == "":
|
||||
return errors.New("missing or empty principal")
|
||||
return errs.BadRequest("missing or empty principal")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
|
|||
// Validate checks the values of the SSHBastionRequest.
|
||||
func (r *SSHBastionRequest) Validate() error {
|
||||
if r.Hostname == "" {
|
||||
return errors.New("missing or empty hostname")
|
||||
return errs.BadRequest("missing or empty hostname")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -256,7 +256,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -398,7 +398,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -430,7 +430,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -469,7 +469,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
|
|||
func (s *SSHRekeyRequest) Validate() error {
|
||||
switch {
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
case len(s.PublicKey) == 0:
|
||||
return errors.New("missing or empty public key")
|
||||
return errs.BadRequest("missing or empty public key")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -46,7 +45,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ type SSHRenewRequest struct {
|
|||
func (s *SSHRenewRequest) Validate() error {
|
||||
switch {
|
||||
case s.OTT == "":
|
||||
return errors.New("missing or empty ott")
|
||||
return errs.BadRequest("missing or empty ott")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
logOtt(w, body.OTT)
|
||||
if err := body.Validate(); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err))
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
|
|||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("The request could not be completed: sshpop token subject must be equivalent to sshpop certificate serial number."),
|
||||
err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
|
|||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
|
||||
err: errors.New("sshpop certificate must be a host ssh certificate"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
|||
p: p,
|
||||
token: tok,
|
||||
code: http.StatusBadRequest,
|
||||
err: errors.New("The request could not be completed: sshpop certificate must be a host ssh certificate."),
|
||||
err: errors.New("sshpop certificate must be a host ssh certificate"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
|
|
@ -94,7 +94,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
|
|||
|
||||
// Check for required variables.
|
||||
if err := t.ValidateRequiredData(data); err != nil {
|
||||
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err))
|
||||
return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
|
||||
}
|
||||
|
||||
o, err := t.Output(mergedData)
|
||||
|
|
|
@ -912,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
|||
cert: &ssh.Certificate{},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
|
||||
err: errors.New("cannot rekey a certificate without validity period"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
|
@ -923,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
|||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("The request could not be completed: cannot rekey a certificate without validity period."),
|
||||
err: errors.New("cannot rekey a certificate without validity period"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
|
@ -956,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
|||
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
|
||||
key: pub,
|
||||
signOpts: []provisioner.SignOption{},
|
||||
err: errors.New("The request could not be completed: unexpected certificate type '0'."),
|
||||
err: errors.New("unexpected certificate type '0'"),
|
||||
code: http.StatusBadRequest,
|
||||
}
|
||||
},
|
||||
|
|
|
@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
|
|||
Reason: reason,
|
||||
OTT: raw,
|
||||
},
|
||||
err: errors.New("The request could not be completed: certificate with serial number 'sn' is already revoked"),
|
||||
err: errors.New("certificate with serial number 'sn' is already revoked"),
|
||||
code: http.StatusBadRequest,
|
||||
checkErrDetails: func(err *errs.Error) {
|
||||
assert.Equals(t, err.Details["token"], raw)
|
||||
|
|
|
@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
|
|||
ca: ca,
|
||||
body: "invalid json",
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
errMsg: errs.BadRequestPrefix,
|
||||
}
|
||||
},
|
||||
"fail invalid-csr-sig": func(t *testing.T) *signTest {
|
||||
|
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
|
|||
ca: ca,
|
||||
body: string(body),
|
||||
status: http.StatusBadRequest,
|
||||
errMsg: errs.BadRequestDefaultMsg,
|
||||
errMsg: errs.BadRequestPrefix,
|
||||
}
|
||||
},
|
||||
"fail unauthorized-ott": func(t *testing.T) *signTest {
|
||||
|
|
|
@ -1108,8 +1108,7 @@ retry:
|
|||
retried = true
|
||||
goto retry
|
||||
}
|
||||
|
||||
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var check api.SSHCheckPrincipalResponse
|
||||
if err := readJSON(resp.Body, &check); err != nil {
|
||||
|
|
|
@ -25,7 +25,7 @@ type Option func(e *Error) error
|
|||
// message only if it is empty.
|
||||
func withDefaultMessage(format string, args ...interface{}) Option {
|
||||
return func(e *Error) error {
|
||||
if len(e.Msg) > 0 {
|
||||
if e.Msg != "" {
|
||||
return e
|
||||
}
|
||||
e.Msg = fmt.Sprintf(format, args...)
|
||||
|
@ -164,7 +164,8 @@ type Messenger interface {
|
|||
func StatusCodeError(code int, e error, opts ...Option) error {
|
||||
switch code {
|
||||
case http.StatusBadRequest:
|
||||
return BadRequestErr(e, opts...)
|
||||
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return NewErr(http.StatusBadRequest, e, opts...)
|
||||
case http.StatusUnauthorized:
|
||||
return UnauthorizedErr(e, opts...)
|
||||
case http.StatusForbidden:
|
||||
|
@ -200,6 +201,15 @@ var (
|
|||
BadRequestPrefix = "The request could not be completed: "
|
||||
)
|
||||
|
||||
func formatMessage(status int, msg string) string {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return BadRequestPrefix + msg + "."
|
||||
default:
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// splitOptionArgs splits the variadic length args into string formatting args
|
||||
// and Option(s) to apply to an Error.
|
||||
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
|
||||
|
@ -229,11 +239,24 @@ func New(status int, format string, args ...interface{}) error {
|
|||
msg := fmt.Sprintf(format, args...)
|
||||
return &Error{
|
||||
Status: status,
|
||||
Msg: msg,
|
||||
Msg: formatMessage(status, msg),
|
||||
Err: errors.New(msg),
|
||||
}
|
||||
}
|
||||
|
||||
// NewError creates a new http error with the given error and message.
|
||||
func NewError(status int, err error, format string, args ...interface{}) error {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
if _, ok := err.(StackTracer); !ok {
|
||||
err = errors.Wrap(err, msg)
|
||||
}
|
||||
return &Error{
|
||||
Status: status,
|
||||
Msg: formatMessage(status, msg),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErr returns a new Error. If the given error implements the StatusCoder
|
||||
// interface we will ignore the given status.
|
||||
func NewErr(status int, err error, opts ...Option) error {
|
||||
|
@ -308,14 +331,12 @@ func NotImplementedErr(err error, opts ...Option) error {
|
|||
|
||||
// BadRequest creates a 400 error with the given format and arguments.
|
||||
func BadRequest(format string, args ...interface{}) error {
|
||||
format = BadRequestPrefix + format + "."
|
||||
return New(http.StatusBadRequest, format, args...)
|
||||
}
|
||||
|
||||
// BadRequestErr returns an 400 error with the given error.
|
||||
func BadRequestErr(err error, opts ...Option) error {
|
||||
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
|
||||
return NewErr(http.StatusBadRequest, err, opts...)
|
||||
func BadRequestErr(err error, format string, args ...interface{}) error {
|
||||
return NewError(http.StatusBadRequest, err, format, args...)
|
||||
}
|
||||
|
||||
// Unauthorized creates a 401 error with the given format and arguments.
|
||||
|
|
Loading…
Reference in a new issue