From bca74cb6a734794284439a63219d6ece301b9a82 Mon Sep 17 00:00:00 2001 From: Panagiotis Siatras Date: Thu, 24 Mar 2022 14:58:50 +0200 Subject: [PATCH] scep: minor cleanup (#867) * api, scep: removed scep.Error * scep/api: replaced nextHTTP with http.HandlerFunc * scep/api: renamed writeSCEPResponse to writeResponse * scep/api: renamed decodeSCEPRequest to decodeRequest * scep/api: renamed writeError to fail * scep/api: replaced pkg/errors with errors * scep/api: formatted imports * scep/api: do not export SCEPRequest & SCEPResponse * scep/api: do not export Handler * api: flush errors better --- api/errors.go | 24 +++----- scep/api/api.go | 158 +++++++++++++++++++++++------------------------- scep/errors.go | 12 ---- 3 files changed, 86 insertions(+), 108 deletions(-) delete mode 100644 scep/errors.go diff --git a/api/errors.go b/api/errors.go index 49efd486..680e6578 100644 --- a/api/errors.go +++ b/api/errors.go @@ -13,7 +13,6 @@ import ( "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/scep" ) // WriteError writes to w a JSON representation of the given error. @@ -25,22 +24,9 @@ func WriteError(w http.ResponseWriter, err error) { case *admin.Error: admin.WriteError(w, k) return - case *scep.Error: - w.Header().Set("Content-Type", "text/plain") - default: - w.Header().Set("Content-Type", "application/json") } cause := errors.Cause(err) - if sc, ok := err.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - if sc, ok := cause.(errs.StatusCoder); ok { - w.WriteHeader(sc.StatusCode()) - } else { - w.WriteHeader(http.StatusInternalServerError) - } - } // Write errors in the response writer if rl, ok := w.(logging.ResponseLogger); ok { @@ -60,6 +46,16 @@ func WriteError(w http.ResponseWriter, err error) { } } + code := http.StatusInternalServerError + if sc, ok := err.(errs.StatusCoder); ok { + code = sc.StatusCode() + } else if sc, ok := cause.(errs.StatusCoder); ok { + code = sc.StatusCode() + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + if err := json.NewEncoder(w).Encode(err); err != nil { log.Error(w, err) } diff --git a/scep/api/api.go b/scep/api/api.go index a326ea92..91d337fe 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -4,6 +4,8 @@ import ( "context" "crypto/x509" "encoding/base64" + "errors" + "fmt" "io" "net/http" "net/url" @@ -11,7 +13,6 @@ import ( "github.com/go-chi/chi" microscep "github.com/micromdm/scep/v2/scep" - "github.com/pkg/errors" "go.mozilla.org/pkcs7" "github.com/smallstep/certificates/api" @@ -30,22 +31,20 @@ const ( const maxPayloadSize = 2 << 20 -type nextHTTP = func(http.ResponseWriter, *http.Request) - const ( certChainHeader = "application/x-x509-ca-ra-cert" leafHeader = "application/x-x509-ca-cert" pkiOperationHeader = "application/x-pki-message" ) -// SCEPRequest is a SCEP server request. -type SCEPRequest struct { +// request is a SCEP server request. +type request struct { Operation string Message []byte } -// SCEPResponse is a SCEP server response. -type SCEPResponse struct { +// response is a SCEP server response. +type response struct { Operation string CACertNum int Data []byte @@ -53,18 +52,18 @@ type SCEPResponse struct { Error error } -// Handler is the SCEP request handler. -type Handler struct { +// handler is the SCEP request handler. +type handler struct { Auth scep.Interface } // New returns a new SCEP API router. func New(scepAuth scep.Interface) api.RouterHandler { - return &Handler{scepAuth} + return &handler{scepAuth} } // Route traffic and implement the Router interface. -func (h *Handler) Route(r api.Router) { +func (h *handler) Route(r api.Router) { getLink := h.Auth.GetLinkExplicit r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) @@ -73,64 +72,64 @@ func (h *Handler) Route(r api.Router) { } // Get handles all SCEP GET requests -func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { +func (h *handler) Get(w http.ResponseWriter, r *http.Request) { - request, err := decodeSCEPRequest(r) + req, err := decodeRequest(r) if err != nil { - writeError(w, errors.Wrap(err, "invalid scep get request")) + fail(w, fmt.Errorf("invalid scep get request: %w", err)) return } ctx := r.Context() - var response SCEPResponse + var res response - switch request.Operation { + switch req.Operation { case opnGetCACert: - response, err = h.GetCACert(ctx) + res, err = h.GetCACert(ctx) case opnGetCACaps: - response, err = h.GetCACaps(ctx) + res, err = h.GetCACaps(ctx) case opnPKIOperation: // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though default: - err = errors.Errorf("unknown operation: %s", request.Operation) + err = fmt.Errorf("unknown operation: %s", req.Operation) } if err != nil { - writeError(w, errors.Wrap(err, "scep get request failed")) + fail(w, fmt.Errorf("scep get request failed: %w", err)) return } - writeSCEPResponse(w, response) + writeResponse(w, res) } // Post handles all SCEP POST requests -func (h *Handler) Post(w http.ResponseWriter, r *http.Request) { +func (h *handler) Post(w http.ResponseWriter, r *http.Request) { - request, err := decodeSCEPRequest(r) + req, err := decodeRequest(r) if err != nil { - writeError(w, errors.Wrap(err, "invalid scep post request")) + fail(w, fmt.Errorf("invalid scep post request: %w", err)) return } ctx := r.Context() - var response SCEPResponse + var res response - switch request.Operation { + switch req.Operation { case opnPKIOperation: - response, err = h.PKIOperation(ctx, request) + res, err = h.PKIOperation(ctx, req) default: - err = errors.Errorf("unknown operation: %s", request.Operation) + err = fmt.Errorf("unknown operation: %s", req.Operation) } if err != nil { - writeError(w, errors.Wrap(err, "scep post request failed")) + fail(w, fmt.Errorf("scep post request failed: %w", err)) return } - writeSCEPResponse(w, response) + writeResponse(w, res) } -func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { +func decodeRequest(r *http.Request) (request, error) { defer r.Body.Close() @@ -146,7 +145,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { case http.MethodGet: switch operation { case opnGetCACert, opnGetCACaps: - return SCEPRequest{ + return request{ Operation: operation, Message: []byte{}, }, nil @@ -158,50 +157,50 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { // TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding decodedMessage, err := base64.URLEncoding.DecodeString(message) if err != nil { - return SCEPRequest{}, err + return request{}, err } - return SCEPRequest{ + return request{ Operation: operation, Message: decodedMessage, }, nil default: - return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) + return request{}, fmt.Errorf("unsupported operation: %s", operation) } case http.MethodPost: body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) if err != nil { - return SCEPRequest{}, err + return request{}, err } - return SCEPRequest{ + return request{ Operation: operation, Message: body, }, nil default: - return SCEPRequest{}, errors.Errorf("unsupported method: %s", method) + return request{}, fmt.Errorf("unsupported method: %s", method) } } // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. -func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { +func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "provisionerName") provisionerName, err := url.PathUnescape(name) if err != nil { - api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name)) + fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name)) return } p, err := h.Auth.LoadProvisionerByName(provisionerName) if err != nil { - api.WriteError(w, err) + fail(w, err) return } prov, ok := p.(*provisioner.SCEP) if !ok { - api.WriteError(w, errors.New("provisioner must be of type SCEP")) + fail(w, errors.New("provisioner must be of type SCEP")) return } @@ -212,59 +211,59 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { } // GetCACert returns the CA certificates in a SCEP response -func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) { +func (h *handler) GetCACert(ctx context.Context) (response, error) { certs, err := h.Auth.GetCACertificates(ctx) if err != nil { - return SCEPResponse{}, err + return response{}, err } if len(certs) == 0 { - return SCEPResponse{}, errors.New("missing CA cert") + return response{}, errors.New("missing CA cert") } - response := SCEPResponse{ + res := response{ Operation: opnGetCACert, CACertNum: len(certs), } if len(certs) == 1 { - response.Data = certs[0].Raw + res.Data = certs[0].Raw } else { // create degenerate pkcs7 certificate structure, according to // https://tools.ietf.org/html/rfc8894#section-4.2.1.2, because // not signed or encrypted data has to be returned. data, err := microscep.DegenerateCertificates(certs) if err != nil { - return SCEPResponse{}, err + return response{}, err } - response.Data = data + res.Data = data } - return response, nil + return res, nil } // GetCACaps returns the CA capabilities in a SCEP response -func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) { +func (h *handler) GetCACaps(ctx context.Context) (response, error) { caps := h.Auth.GetCACaps(ctx) - response := SCEPResponse{ + res := response{ Operation: opnGetCACaps, Data: formatCapabilities(caps), } - return response, nil + return res, nil } // PKIOperation performs PKI operations and returns a SCEP response -func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) { +func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) { // parse the message using microscep implementation - microMsg, err := microscep.ParsePKIMessage(request.Message) + microMsg, err := microscep.ParsePKIMessage(req.Message) if err != nil { // return the error, because we can't use the msg for creating a CertRep - return SCEPResponse{}, err + return response{}, err } // this is essentially doing the same as microscep.ParsePKIMessage, but @@ -272,7 +271,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe // wrapper for the microscep implementation. p7, err := pkcs7.Parse(microMsg.Raw) if err != nil { - return SCEPResponse{}, err + return response{}, err } // copy over properties to our internal PKIMessage @@ -285,7 +284,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe } if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil { - return SCEPResponse{}, err + return response{}, err } // NOTE: at this point we have sufficient information for returning nicely signed CertReps @@ -317,61 +316,56 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe certRep, err := h.Auth.SignCSR(ctx, csr, msg) if err != nil { - return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.Wrap(err, "error when signing new certificate")) + return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err)) } - response := SCEPResponse{ + res := response{ Operation: opnPKIOperation, Data: certRep.Raw, Certificate: certRep.Certificate, } - return response, nil + return res, nil } func formatCapabilities(caps []string) []byte { return []byte(strings.Join(caps, "\r\n")) } -// writeSCEPResponse writes a SCEP response back to the SCEP client. -func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { +// writeResponse writes a SCEP response back to the SCEP client. +func writeResponse(w http.ResponseWriter, res response) { - if response.Error != nil { - log.Error(w, response.Error) + if res.Error != nil { + log.Error(w, res.Error) } - if response.Certificate != nil { - api.LogCertificate(w, response.Certificate) + if res.Certificate != nil { + api.LogCertificate(w, res.Certificate) } - w.Header().Set("Content-Type", contentHeader(response)) - _, err := w.Write(response.Data) - if err != nil { - writeError(w, errors.Wrap(err, "error when writing scep response")) // This could end up as an error again - } + w.Header().Set("Content-Type", contentHeader(res)) + _, _ = w.Write(res.Data) } -func writeError(w http.ResponseWriter, err error) { - scepError := &scep.Error{ - Message: err.Error(), - Status: http.StatusInternalServerError, // TODO: make this a param? - } - api.WriteError(w, scepError) +func fail(w http.ResponseWriter, err error) { + log.Error(w, err) + + http.Error(w, err.Error(), http.StatusInternalServerError) } -func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) { +func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) { certRepMsg, err := h.Auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) if err != nil { - return SCEPResponse{}, err + return response{}, err } - return SCEPResponse{ + return response{ Operation: opnPKIOperation, Data: certRepMsg.Raw, Error: failError, }, nil } -func contentHeader(r SCEPResponse) string { +func contentHeader(r response) string { switch r.Operation { case opnGetCACert: if r.CACertNum > 1 { diff --git a/scep/errors.go b/scep/errors.go deleted file mode 100644 index 4287403b..00000000 --- a/scep/errors.go +++ /dev/null @@ -1,12 +0,0 @@ -package scep - -// Error is an SCEP error type -type Error struct { - Message string `json:"message"` - Status int `json:"-"` -} - -// Error implements the error interface. -func (e *Error) Error() string { - return e.Message -}