Move api errors to their own package and modify the typedef

This commit is contained in:
max furman 2019-12-15 23:54:25 -08:00
parent f033422ffa
commit b9f6aacb0f
13 changed files with 350 additions and 190 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"github.com/smallstep/cli/crypto/tlsutil"
)
@ -233,13 +234,13 @@ type ProvisionerKeyResponse struct {
// or an error if something is wrong.
func (s *SignRequest) Validate() error {
if s.CsrPEM.CertificateRequest == nil {
return BadRequest(errors.New("missing csr"))
return errs.BadRequest(errors.New("missing csr"))
}
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return BadRequest(errors.Wrap(err, "invalid csr"))
return errs.BadRequest(errors.Wrap(err, "invalid csr"))
}
if s.OTT == "" {
return BadRequest(errors.New("missing ott"))
return errs.BadRequest(errors.New("missing ott"))
}
return nil
@ -328,7 +329,7 @@ func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
// Load root certificate with the
cert, err := h.Authority.Root(sum)
if err != nil {
WriteError(w, NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
WriteError(w, errs.NotFound(errors.Wrapf(err, "%s was not found", r.RequestURI)))
return
}
@ -349,7 +350,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
@ -366,13 +367,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
signOpts, err := h.Authority.AuthorizeSign(body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
@ -393,13 +394,13 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
// new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, BadRequest(errors.New("missing peer certificate")))
WriteError(w, errs.BadRequest(errors.New("missing peer certificate")))
return
}
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
certChainPEM := certChainToPEM(certChain)
@ -421,13 +422,13 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r)
if err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
p, next, err := h.Authority.GetProvisioners(cursor, limit)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
JSON(w, &ProvisionersResponse{
@ -441,7 +442,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid)
if err != nil {
WriteError(w, NotFound(err))
WriteError(w, errs.NotFound(err))
return
}
JSON(w, &ProvisionerKeyResponse{key})
@ -451,7 +452,7 @@ func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots()
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
@ -469,7 +470,7 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation()
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}

View file

@ -8,106 +8,10 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
// StatusCoder interface is used by errors that returns the HTTP response code.
type StatusCoder interface {
StatusCode() int
}
// StackTracer must be by those errors that return an stack trace.
type StackTracer interface {
StackTrace() errors.StackTrace
}
// Error represents the CA API errors.
type Error struct {
Status int
Err error
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *Error) Cause() error {
return e.Err
}
// Error implements the error interface and returns the error string.
func (e *Error) Error() string {
return e.Err.Error()
}
// StatusCode implements the StatusCoder interface and returns the HTTP response
// code.
func (e *Error) StatusCode() int {
return e.Status
}
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *Error) MarshalJSON() ([]byte, error) {
return json.Marshal(&ErrorResponse{Status: e.Status, Message: http.StatusText(e.Status)})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *Error) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.Status = er.Status
e.Err = fmt.Errorf(er.Message)
return nil
}
// NewError returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func NewError(status int, err error) error {
if sc, ok := err.(StatusCoder); ok {
return &Error{Status: sc.StatusCode(), Err: err}
}
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
return &Error{Status: sc.StatusCode(), Err: err}
}
return &Error{Status: status, Err: err}
}
// InternalServerError returns a 500 error with the given error.
func InternalServerError(err error) error {
return NewError(http.StatusInternalServerError, err)
}
// NotImplemented returns a 500 error with the given error.
func NotImplemented(err error) error {
return NewError(http.StatusNotImplemented, err)
}
// BadRequest returns an 400 error with the given error.
func BadRequest(err error) error {
return NewError(http.StatusBadRequest, err)
}
// Unauthorized returns an 401 error with the given error.
func Unauthorized(err error) error {
return NewError(http.StatusUnauthorized, err)
}
// Forbidden returns an 403 error with the given error.
func Forbidden(err error) error {
return NewError(http.StatusForbidden, err)
}
// NotFound returns an 404 error with the given error.
func NotFound(err error) error {
return NewError(http.StatusNotFound, err)
}
// WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err error) {
switch k := err.(type) {
@ -118,10 +22,10 @@ func WriteError(w http.ResponseWriter, err error) {
w.Header().Set("Content-Type", "application/json")
}
cause := errors.Cause(err)
if sc, ok := err.(StatusCoder); ok {
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(StatusCoder); ok {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
@ -134,12 +38,12 @@ func WriteError(w http.ResponseWriter, err error) {
"error": err,
})
if os.Getenv("STEPDEBUG") == "1" {
if e, ok := err.(StackTracer); ok {
if e, ok := err.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})
} else {
if e, ok := cause.(StackTracer); ok {
if e, ok := cause.(errs.StackTracer); ok {
rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e),
})

View file

@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
)
@ -29,13 +30,13 @@ type RevokeRequest struct {
// or an error if something is wrong.
func (r *RevokeRequest) Validate() (err error) {
if r.Serial == "" {
return BadRequest(errors.New("missing serial"))
return errs.BadRequest(errors.New("missing serial"))
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest(errors.New("reasonCode out of bounds"))
}
if !r.Passive {
return NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
}
return
@ -49,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
@ -71,7 +72,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
if len(body.OTT) > 0 {
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
opts.OTT = body.OTT
@ -80,12 +81,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number
// being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, BadRequest(errors.New("missing ott or peer certificate")))
WriteError(w, errs.BadRequest(errors.New("missing ott or peer certificate")))
return
}
opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
WriteError(w, errs.BadRequest(errors.New("revoke: serial number in mtls certificate different than body")))
return
}
// TODO: should probably be checking if the certificate was revoked here.
@ -96,7 +97,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
}
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}

View file

@ -11,6 +11,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
@ -248,19 +249,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
return
}
@ -268,7 +269,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing addUserPublicKey")))
return
}
}
@ -284,13 +285,13 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignSSHMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
@ -298,7 +299,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
addUserCertificate = &SSHCertificate{addUserCert}
@ -319,12 +320,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
certChain, err := h.Authority.Sign(cr, opts, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}
identityCertificate = certChainToPEM(certChain)
@ -342,12 +343,12 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots()
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound(errors.New("no keys found")))
return
}
@ -367,12 +368,12 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation()
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
if len(keys.HostKeys) == 0 && len(keys.UserKeys) == 0 {
WriteError(w, NotFound(errors.New("no keys found")))
WriteError(w, errs.NotFound(errors.New("no keys found")))
return
}
@ -392,17 +393,17 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
@ -413,7 +414,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
case provisioner.SSHHostCert:
config.HostTemplates = ts
default:
WriteError(w, InternalServerError(errors.New("it should hot get here")))
WriteError(w, errs.InternalServerError(errors.New("it should hot get here")))
return
}
@ -424,17 +425,17 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body"))
return
}
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
JSON(w, &SSHCheckPrincipalResponse{
@ -451,7 +452,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
hosts, err := h.Authority.GetSSHHosts(cert)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}
JSON(w, &SSHGetHostsResponse{
@ -463,17 +464,17 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
return
}

View file

@ -6,6 +6,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
@ -38,36 +39,36 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error parsing publicKey")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error parsing publicKey")))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RekeySSHMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
}
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}

View file

@ -6,6 +6,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
)
// SSHRenewRequest is the request body of an SSH certificate request.
@ -34,30 +35,30 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
logOtt(w, body.OTT)
if err := body.Validate(); err != nil {
WriteError(w, BadRequest(err))
WriteError(w, errs.BadRequest(err))
return
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RenewSSHMethod)
_, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
oldCert, err := provisioner.ExtractSSHPOPCert(body.OTT)
if err != nil {
WriteError(w, InternalServerError(err))
WriteError(w, errs.InternalServerError(err))
}
newCert, err := h.Authority.RenewSSH(oldCert)
if err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}

View file

@ -7,6 +7,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
)
@ -29,16 +30,16 @@ type SSHRevokeRequest struct {
// or an error if something is wrong.
func (r *SSHRevokeRequest) Validate() (err error) {
if r.Serial == "" {
return BadRequest(errors.New("missing serial"))
return errs.BadRequest(errors.New("missing serial"))
}
if r.ReasonCode < ocsp.Unspecified || r.ReasonCode > ocsp.AACompromise {
return BadRequest(errors.New("reasonCode out of bounds"))
return errs.BadRequest(errors.New("reasonCode out of bounds"))
}
if !r.Passive {
return NotImplemented(errors.New("non-passive revocation not implemented"))
return errs.NotImplemented(errors.New("non-passive revocation not implemented"))
}
if len(r.OTT) == 0 {
return BadRequest(errors.New("missing ott"))
return errs.BadRequest(errors.New("missing ott"))
}
return
}
@ -49,7 +50,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, BadRequest(errors.Wrap(err, "error reading request body")))
WriteError(w, errs.BadRequest(errors.Wrap(err, "error reading request body")))
return
}
@ -70,13 +71,13 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
// otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
WriteError(w, Unauthorized(err))
WriteError(w, errs.Unauthorized(err))
return
}
opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil {
WriteError(w, Forbidden(err))
WriteError(w, errs.Forbidden(err))
return
}

View file

@ -7,6 +7,7 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
)
@ -68,7 +69,7 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return BadRequest(errors.Wrap(err, "error decoding json"))
return errs.BadRequest(errors.Wrap(err, "error decoding json"))
}
return nil
}

View file

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/sshutil"
"github.com/smallstep/certificates/templates"
"github.com/smallstep/cli/crypto/randutil"
@ -660,25 +661,19 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
if a.sshCheckHostFunc != nil {
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
if err != nil {
return false, &apiError{
err: errors.Wrap(err, "checkSSHHost: error from injected checkSSHHost func"),
code: http.StatusInternalServerError,
}
return false, errs.Wrap(http.StatusInternalServerError, err,
"checkSSHHost: error from injected checkSSHHost func")
}
return exists, nil
}
exists, err := a.db.IsSSHHost(principal)
if err != nil {
if err == db.ErrNotImplemented {
return false, &apiError{
err: errors.Wrap(err, "checkSSHHost: isSSHHost is not implemented"),
code: http.StatusNotImplemented,
}
}
return false, &apiError{
err: errors.Wrap(err, "checkSSHHost: error checking if hosts exists"),
code: http.StatusInternalServerError,
return false, errs.Wrap(http.StatusNotImplemented, err,
"checkSSHHost: isSSHHost is not implemented")
}
return false, errs.Wrap(http.StatusInternalServerError, err,
"checkSSHHost: error checking if hosts exists")
}
return exists, nil

View file

@ -26,6 +26,7 @@ import (
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/keys"
"github.com/smallstep/cli/crypto/pemutil"
@ -134,7 +135,7 @@ func (o *clientOptions) applyDefaultIdentity() error {
}
crt, err := i.TLSCertificate()
if err != nil {
return nil
return err
}
o.certificate = crt
return nil
@ -472,11 +473,6 @@ func (c *Client) GetRootCAs() *x509.CertPool {
}
}
// GetTransport returns the transport of the internal HTTP client.
func (c *Client) GetTransport() http.RoundTripper {
return c.client.GetTransport()
}
// SetTransport updates the transport of the internal HTTP client.
func (c *Client) SetTransport(tr http.RoundTripper) {
c.client.SetTransport(tr)
@ -958,24 +954,27 @@ func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrin
Token: token,
})
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
return nil, errs.Wrap(http.StatusInternalServerError, err,
"error marshaling check-host request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client POST %s failed", u,
errs.WithMessage("Failed to perform POST request to %s", u))
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
}
var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
return nil, errs.Wrapf(http.StatusInternalServerError, err, "error reading %s response", u)
}
return &check, nil
}
@ -1174,7 +1173,7 @@ func readJSON(r io.ReadCloser, v interface{}) error {
func readError(r io.ReadCloser) error {
defer r.Close()
apiErr := new(api.Error)
apiErr := new(errs.Error)
if err := json.NewDecoder(r).Decode(apiErr); err != nil {
return err
}

250
errs/error.go Normal file
View file

@ -0,0 +1,250 @@
package errs
import (
"encoding/json"
"fmt"
"net/http"
"github.com/pkg/errors"
)
// StatusCoder interface is used by errors that returns the HTTP response code.
type StatusCoder interface {
StatusCode() int
}
// StackTracer must be by those errors that return an stack trace.
type StackTracer interface {
StackTrace() errors.StackTrace
}
// Option modifies the Error type.
type Option func(e *Error) error
// WithMessage returns an Option that modifies the error by overwriting the
// message only if it is empty.
func WithMessage(format string, args ...interface{}) Option {
return func(e *Error) error {
if len(e.Msg) > 0 {
return e
}
e.Msg = fmt.Sprintf(format, args...)
return e
}
}
// Error represents the CA API errors.
type Error struct {
Status int
Err error
Msg string
}
// New returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status.
func New(status int, err error, opts ...Option) error {
var e *Error
if sc, ok := err.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
cause := errors.Cause(err)
if sc, ok := cause.(StatusCoder); ok {
e = &Error{Status: sc.StatusCode(), Err: err}
} else {
e = &Error{Status: status, Err: err}
}
}
for _, o := range opts {
o(e)
}
return e
}
// ErrorResponse represents an error in JSON format.
type ErrorResponse struct {
Status int `json:"status"`
Message string `json:"message"`
}
// Cause implements the errors.Causer interface and returns the original error.
func (e *Error) Cause() error {
return e.Err
}
// Error implements the error interface and returns the error string.
func (e *Error) Error() string {
return e.Err.Error()
}
// StatusCode implements the StatusCoder interface and returns the HTTP response
// code.
func (e *Error) StatusCode() int {
return e.Status
}
// Message returns a user friendly error, if one is set.
func (e *Error) Message() string {
if len(e.Msg) > 0 {
return e.Msg
}
return e.Err.Error()
}
// Wrap returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil.
func Wrap(status int, e error, m string, opts ...Option) error {
if e == nil {
return nil
}
if err, ok := e.(*Error); ok {
err.Err = errors.Wrap(err.Err, m)
e = err
} else {
e = errors.Wrap(e, m)
}
return StatusCodeError(status, e, opts...)
}
// Wrapf returns an error annotating err with a stack trace at the point Wrap is
// called, and the supplied message. If err is nil, Wrap returns nil.
func Wrapf(status int, e error, format string, args ...interface{}) error {
if e == nil {
return nil
}
var opts []Option
for i, arg := range args {
// Once we find the first Option, assume that all further arguments are Options.
if _, ok := arg.(Option); ok {
for _, a := range args[i:] {
// Ignore any arguments after the first Option that are not Options.
if opt, ok := a.(Option); ok {
opts = append(opts, opt)
}
}
args = args[:i]
break
}
}
if err, ok := e.(*Error); ok {
err.Err = errors.Wrapf(err.Err, format, args...)
e = err
} else {
e = errors.Wrapf(e, format, args...)
}
return StatusCodeError(status, e, opts...)
}
// MarshalJSON implements json.Marshaller interface for the Error struct.
func (e *Error) MarshalJSON() ([]byte, error) {
var msg string
if len(e.Msg) > 0 {
msg = e.Msg
} else {
msg = http.StatusText(e.Status)
}
return json.Marshal(&ErrorResponse{Status: e.Status, Message: msg})
}
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
func (e *Error) UnmarshalJSON(data []byte) error {
var er ErrorResponse
if err := json.Unmarshal(data, &er); err != nil {
return err
}
e.Status = er.Status
e.Err = fmt.Errorf(er.Message)
return nil
}
// Format implements the fmt.Formatter interface.
func (e *Error) Format(f fmt.State, c rune) {
if err, ok := e.Err.(fmt.Formatter); ok {
err.Format(f, c)
return
}
fmt.Fprint(f, e.Err.Error())
}
// Messenger is a friendly message interface that errors can implement.
type Messenger interface {
Message() string
}
// StatusCodeError selects the proper error based on the status code.
func StatusCodeError(code int, e error, opts ...Option) error {
switch code {
case http.StatusBadRequest:
return BadRequest(e, opts...)
case http.StatusUnauthorized:
return Unauthorized(e, opts...)
case http.StatusForbidden:
return Forbidden(e, opts...)
case http.StatusInternalServerError:
return InternalServerError(e, opts...)
case http.StatusNotImplemented:
return NotImplemented(e, opts...)
default:
return UnexpectedError(code, e, opts...)
}
}
var seeLogs = "Please see the certificate authority logs for more info."
// InternalServerError returns a 500 error with the given error.
func InternalServerError(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The certificate authority encountered an Internal Server Error. "+seeLogs))
}
return New(http.StatusInternalServerError, err, opts...)
}
// NotImplemented returns a 501 error with the given error.
func NotImplemented(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The requested method is not implemented by the certificate authority. "+seeLogs))
}
return New(http.StatusNotImplemented, err, opts...)
}
// BadRequest returns an 400 error with the given error.
func BadRequest(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request could not be completed due to being poorly formatted or "+
"missing critical data. "+seeLogs))
}
return New(http.StatusBadRequest, err, opts...)
}
// Unauthorized returns an 401 error with the given error.
func Unauthorized(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request lacked necessary authorization to be completed. "+seeLogs))
}
return New(http.StatusUnauthorized, err, opts...)
}
// Forbidden returns an 403 error with the given error.
func Forbidden(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The request was Forbidden by the certificate authority. "+seeLogs))
}
return New(http.StatusForbidden, err, opts...)
}
// NotFound returns an 404 error with the given error.
func NotFound(err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The requested resource could not be found. "+seeLogs))
}
return New(http.StatusNotFound, err, opts...)
}
// UnexpectedError will be used when the certificate authority makes an outgoing
// request and receives an unhandled status code.
func UnexpectedError(code int, err error, opts ...Option) error {
if len(opts) == 0 {
opts = append(opts, WithMessage("The certificate authority received an "+
"unexpected HTTP status code - '%d'. "+seeLogs, code))
}
return New(code, err, opts...)
}

3
go.mod
View file

@ -5,6 +5,7 @@ go 1.13
require (
github.com/Masterminds/sprig/v3 v3.0.0
github.com/go-chi/chi v4.0.2+incompatible
github.com/juju/ansiterm v0.0.0-20180109212912-720a0952cc2a // indirect
github.com/newrelic/go-agent v2.15.0+incompatible
github.com/pkg/errors v0.8.1
github.com/rs/xid v1.2.1
@ -18,4 +19,4 @@ require (
gopkg.in/square/go-jose.v2 v2.4.0
)
//replace github.com/smallstep/cli => ../cli
replace github.com/smallstep/cli => ../cli

4
go.sum
View file

@ -79,6 +79,8 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.10 h1:qxFzApOv4WsAL965uUPIsXzAKCZxN2p9UqdhFS4ZW10=
github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84=
github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
@ -177,6 +179,8 @@ golang.org/x/sys v0.0.0-20190424175732-18eb32c0e2f0/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191008105621-543471e840be h1:QAcqgptGM8IQBC9K/RC4o+O9YmqEm0diQn9QmZw/0mU=
golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c=