143 lines
3.6 KiB
Go
143 lines
3.6 KiB
Go
|
package api
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
"github.com/smallstep/ca-component/logging"
|
||
|
)
|
||
|
|
||
|
// StatusCoder interface is used by errors that returns the HTTP response code.
|
||
|
type StatusCoder interface {
|
||
|
StatusCode() int
|
||
|
}
|
||
|
|
||
|
// StackTracer must be by those errors that return an stack trace.
|
||
|
type StackTracer interface {
|
||
|
StackTrace() errors.StackTrace
|
||
|
}
|
||
|
|
||
|
// Error represents the CA API errors.
|
||
|
type Error struct {
|
||
|
Status int
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
// ErrorResponse represents an error in JSON format.
|
||
|
type ErrorResponse struct {
|
||
|
Status int `json:"status"`
|
||
|
Message string `json:"message"`
|
||
|
}
|
||
|
|
||
|
// Cause implements the errors.Causer interface and returns the original error.
|
||
|
func (e *Error) Cause() error {
|
||
|
return e.Err
|
||
|
}
|
||
|
|
||
|
// Error implements the error interface and returns the error string.
|
||
|
func (e *Error) Error() string {
|
||
|
return e.Err.Error()
|
||
|
}
|
||
|
|
||
|
// StatusCode implements the StatusCoder interface and returns the HTTP response
|
||
|
// code.
|
||
|
func (e *Error) StatusCode() int {
|
||
|
return e.Status
|
||
|
}
|
||
|
|
||
|
// MarshalJSON implements json.Marshaller interface for the Error struct.
|
||
|
func (e *Error) MarshalJSON() ([]byte, error) {
|
||
|
return json.Marshal(&ErrorResponse{Status: e.Status, Message: http.StatusText(e.Status)})
|
||
|
}
|
||
|
|
||
|
// UnmarshalJSON implements json.Unmarshaler interface for the Error struct.
|
||
|
func (e *Error) UnmarshalJSON(data []byte) error {
|
||
|
var er ErrorResponse
|
||
|
if err := json.Unmarshal(data, &er); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
e.Status = er.Status
|
||
|
e.Err = fmt.Errorf(er.Message)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// NewError returns a new Error. If the given error implements the StatusCoder
|
||
|
// interface we will ignore the given status.
|
||
|
func NewError(status int, err error) error {
|
||
|
if sc, ok := err.(StatusCoder); ok {
|
||
|
return &Error{Status: sc.StatusCode(), Err: err}
|
||
|
}
|
||
|
cause := errors.Cause(err)
|
||
|
if sc, ok := cause.(StatusCoder); ok {
|
||
|
return &Error{Status: sc.StatusCode(), Err: err}
|
||
|
}
|
||
|
return &Error{Status: status, Err: err}
|
||
|
}
|
||
|
|
||
|
// InternalServerError returns a 500 error with the given error.
|
||
|
func InternalServerError(err error) error {
|
||
|
return NewError(http.StatusInternalServerError, err)
|
||
|
}
|
||
|
|
||
|
// BadRequest returns an 400 error with the given error.
|
||
|
func BadRequest(err error) error {
|
||
|
return NewError(http.StatusBadRequest, err)
|
||
|
}
|
||
|
|
||
|
// Unauthorized returns an 401 error with the given error.
|
||
|
func Unauthorized(err error) error {
|
||
|
return NewError(http.StatusUnauthorized, err)
|
||
|
}
|
||
|
|
||
|
// Forbidden returns an 403 error with the given error.
|
||
|
func Forbidden(err error) error {
|
||
|
return NewError(http.StatusForbidden, err)
|
||
|
}
|
||
|
|
||
|
// NotFound returns an 404 error with the given error.
|
||
|
func NotFound(err error) error {
|
||
|
return NewError(http.StatusNotFound, err)
|
||
|
}
|
||
|
|
||
|
// WriteError writes to w a JSON representation of the given error.
|
||
|
func WriteError(w http.ResponseWriter, err error) {
|
||
|
w.Header().Set("Content-Type", "application/json")
|
||
|
cause := errors.Cause(err)
|
||
|
if sc, ok := err.(StatusCoder); ok {
|
||
|
w.WriteHeader(sc.StatusCode())
|
||
|
} else {
|
||
|
if sc, ok := cause.(StatusCoder); ok {
|
||
|
w.WriteHeader(sc.StatusCode())
|
||
|
} else {
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Write errors in the response writer
|
||
|
if rl, ok := w.(logging.ResponseLogger); ok {
|
||
|
rl.WithFields(map[string]interface{}{
|
||
|
"error": err,
|
||
|
})
|
||
|
if os.Getenv("STEPDEBUG") == "1" {
|
||
|
if e, ok := err.(StackTracer); ok {
|
||
|
rl.WithFields(map[string]interface{}{
|
||
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||
|
})
|
||
|
} else {
|
||
|
if e, ok := cause.(StackTracer); ok {
|
||
|
rl.WithFields(map[string]interface{}{
|
||
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||
|
LogError(w, err)
|
||
|
}
|
||
|
}
|