From 2fc5a7f22e1f0c95ad8afcd07e795cba39ff643c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Sat, 27 Feb 2021 00:34:50 +0100 Subject: [PATCH] Improve SCEP API logic and error handling --- api/errors.go | 4 + scep/api/api.go | 215 +++++++++++++++++++++++++----------------------- scep/common.go | 11 ++- scep/errors.go | 19 +++++ 4 files changed, 146 insertions(+), 103 deletions(-) create mode 100644 scep/errors.go diff --git a/api/errors.go b/api/errors.go index 93057ed2..3e5dec47 100644 --- a/api/errors.go +++ b/api/errors.go @@ -10,6 +10,7 @@ import ( "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/scep" ) // WriteError writes to w a JSON representation of the given error. @@ -18,6 +19,9 @@ func WriteError(w http.ResponseWriter, err error) { case *acme.Error: w.Header().Set("Content-Type", "application/problem+json") err = k.ToACME() + case *scep.Error: + // TODO: check if this is correct; and should we do some more processing? + w.Header().Set("Content-Type", "text/plain") default: w.Header().Set("Content-Type", "application/json") } diff --git a/scep/api/api.go b/scep/api/api.go index 16278075..f2d11fb7 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -11,7 +11,6 @@ import ( "net/http" "strings" - "github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/scep" @@ -27,6 +26,30 @@ const ( // TODO: add other (more optional) operations and handling ) +const maxPayloadSize = 2 << 20 + +type nextHTTP = func(http.ResponseWriter, *http.Request) + +var ( + // TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2 + // TODO: move capabilities to Authority or Provisioner, so that they can be configured? + defaultCapabilities = []string{ + "Renewal", + "SHA-1", + "SHA-256", + "AES", + "DES3", + "SCEPStandard", + "POSTPKIOperation", + } +) + +const ( + certChainHeader = "application/x-x509-ca-ra-cert" + leafHeader = "application/x-x509-ca-cert" + pkiOpHeader = "application/x-pki-message" +) + // SCEPRequest is a SCEP server request. type SCEPRequest struct { Operation string @@ -35,10 +58,10 @@ type SCEPRequest struct { // SCEPResponse is a SCEP server response. type SCEPResponse struct { - Operation string - CACertNum int - Data []byte - Err error + Operation string + CACertNum int + Data []byte + Certificate *x509.Certificate } // Handler is the SCEP request handler. @@ -64,60 +87,65 @@ func (h *Handler) Route(r api.Router) { } +// Get handles all SCEP GET requests func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { - scepRequest, err := decodeSCEPRequest(r) + request, err := decodeSCEPRequest(r) if err != nil { - fmt.Println(err) - fmt.Println("not a scep get request") - w.WriteHeader(500) + writeError(w, fmt.Errorf("not a scep get request: %w", err)) + return } - scepResponse := SCEPResponse{Operation: scepRequest.Operation} + ctx := r.Context() + var response SCEPResponse - switch scepRequest.Operation { + switch request.Operation { case opnGetCACert: - err := h.GetCACert(w, r, scepResponse) - if err != nil { - fmt.Println(err) - } - + response, err = h.GetCACert(ctx) case opnGetCACaps: - err := h.GetCACaps(w, r, scepResponse) - if err != nil { - fmt.Println(err) - } + response, err = h.GetCACaps(ctx) case opnPKIOperation: - + // TODO: implement the GET for PKI operation default: - + err = fmt.Errorf("unknown operation: %s", request.Operation) } + + if err != nil { + writeError(w, fmt.Errorf("get request failed: %w", err)) + return + } + + writeSCEPResponse(w, response) } +// Post handles all SCEP POST requests func (h *Handler) Post(w http.ResponseWriter, r *http.Request) { - scepRequest, err := decodeSCEPRequest(r) + request, err := decodeSCEPRequest(r) if err != nil { - fmt.Println(err) - fmt.Println("not a scep post request") - w.WriteHeader(500) + writeError(w, fmt.Errorf("not a scep post request: %w", err)) + return } - scepResponse := SCEPResponse{Operation: scepRequest.Operation} + ctx := r.Context() + var response SCEPResponse - switch scepRequest.Operation { + switch request.Operation { case opnPKIOperation: - err := h.PKIOperation(w, r, scepRequest, scepResponse) - if err != nil { - fmt.Println(err) - } + response, err = h.PKIOperation(ctx, request) default: - + err = fmt.Errorf("unknown operation: %s", request.Operation) } -} + if err != nil { + writeError(w, fmt.Errorf("post request failed: %w", err)) + return + } -const maxPayloadSize = 2 << 20 + api.LogCertificate(w, response.Certificate) + + writeSCEPResponse(w, response) +} func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { @@ -169,8 +197,6 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { } } -type nextHTTP = func(http.ResponseWriter, *http.Request) - // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { @@ -189,67 +215,69 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { p, err := h.Auth.LoadProvisionerByID("scep/" + provisionerID) if err != nil { - api.WriteError(w, err) + writeError(w, err) return } - scepProvisioner, ok := p.(*provisioner.SCEP) + provisioner, ok := p.(*provisioner.SCEP) if !ok { - api.WriteError(w, errors.New("provisioner must be of type SCEP")) + writeError(w, errors.New("provisioner must be of type SCEP")) return } ctx := r.Context() - ctx = context.WithValue(ctx, acme.ProvisionerContextKey, scep.Provisioner(scepProvisioner)) + ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner)) next(w, r.WithContext(ctx)) } } -func (h *Handler) GetCACert(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error { +// GetCACert returns the CA certificates in a SCEP response +func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) { certs, err := h.Auth.GetCACertificates() if err != nil { - return err + return SCEPResponse{}, err } if len(certs) == 0 { - scepResponse.CACertNum = 0 - scepResponse.Err = errors.New("missing CA Cert") - } else if len(certs) == 1 { - scepResponse.Data = certs[0].Raw - scepResponse.CACertNum = 1 - } else { - data, err := microscep.DegenerateCertificates(certs) - scepResponse.CACertNum = len(certs) - scepResponse.Data = data - scepResponse.Err = err + return SCEPResponse{}, errors.New("missing CA cert") } - return writeSCEPResponse(w, scepResponse) + response := SCEPResponse{Operation: opnGetCACert} + response.CACertNum = len(certs) + + if len(certs) == 1 { + response.Data = certs[0].Raw + } else { + data, err := microscep.DegenerateCertificates(certs) + if err != nil { + return SCEPResponse{}, err + } + response.Data = data + } + + return response, nil } -func (h *Handler) GetCACaps(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error { +// GetCACaps returns the CA capabilities in a SCEP response +func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) { - //ctx := r.Context() - - // _, err := ProvisionerFromContext(ctx) - // if err != nil { - // return err - // } + response := SCEPResponse{Operation: opnGetCACaps} // TODO: get the actual capabilities from provisioner config - scepResponse.Data = formatCapabilities(defaultCapabilities) + response.Data = formatCapabilities(defaultCapabilities) - return writeSCEPResponse(w, scepResponse) + return response, nil } -func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepRequest SCEPRequest, scepResponse SCEPResponse) error { +// PKIOperation performs PKI operations and returns a SCEP response +func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) { - ctx := r.Context() + response := SCEPResponse{Operation: opnPKIOperation} - microMsg, err := microscep.ParsePKIMessage(scepRequest.Message) + microMsg, err := microscep.ParsePKIMessage(request.Message) if err != nil { - return err + return SCEPResponse{}, err } msg := &scep.PKIMessage{ @@ -260,7 +288,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque } if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil { - return err + return SCEPResponse{}, err } if msg.MessageType == microscep.PKCSReq { @@ -271,7 +299,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque certRep, err := h.Auth.SignCSR(ctx, csr, msg) if err != nil { - return err + return SCEPResponse{}, err } // //cert := certRep.CertRepMessage.Certificate @@ -280,11 +308,10 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque // // TODO: check if CN already exists, if renewal is allowed and if existing should be revoked; fail if not // // TODO: store the new cert for CN locally; should go into the DB - scepResponse.Data = certRep.Raw + response.Data = certRep.Raw + response.Certificate = certRep.Certificate - api.LogCertificate(w, certRep.Certificate) - - return writeSCEPResponse(w, scepResponse) + return response, nil } func certName(cert *x509.Certificate) string { @@ -299,40 +326,26 @@ func formatCapabilities(caps []string) []byte { } // writeSCEPResponse writes a SCEP response back to the SCEP client. -func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) error { - if response.Err != nil { - http.Error(w, response.Err.Error(), http.StatusInternalServerError) - return nil +func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { + w.Header().Set("Content-Type", contentHeader(response)) + _, err := w.Write(response.Data) + if err != nil { + writeError(w, fmt.Errorf("error when writing scep response: %w", err)) // This could end up as an error again } - w.Header().Set("Content-Type", contentHeader(response.Operation, response.CACertNum)) - w.Write(response.Data) - return nil } -var ( - // TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2 - // TODO: move capabilities to Authority or Provisioner, so that they can be configured? - defaultCapabilities = []string{ - "Renewal", - "SHA-1", - "SHA-256", - "AES", - "DES3", - "SCEPStandard", - "POSTPKIOperation", +func writeError(w http.ResponseWriter, err error) { + scepError := &scep.Error{ + Err: fmt.Errorf("post request failed: %w", err), + Status: http.StatusInternalServerError, // TODO: make this a param? } -) + api.WriteError(w, scepError) +} -const ( - certChainHeader = "application/x-x509-ca-ra-cert" - leafHeader = "application/x-x509-ca-cert" - pkiOpHeader = "application/x-pki-message" -) - -func contentHeader(operation string, certNum int) string { - switch operation { +func contentHeader(r SCEPResponse) string { + switch r.Operation { case opnGetCACert: - if certNum > 1 { + if r.CACertNum > 1 { return certChainHeader } return leafHeader diff --git a/scep/common.go b/scep/common.go index 123c8d82..ca87841f 100644 --- a/scep/common.go +++ b/scep/common.go @@ -3,14 +3,21 @@ package scep import ( "context" "errors" +) - "github.com/smallstep/certificates/acme" +// ContextKey is the key type for storing and searching for SCEP request +// essentials in the context of a request. +type ContextKey string + +const ( + // ProvisionerContextKey provisioner key + ProvisionerContextKey = ContextKey("provisioner") ) // ProvisionerFromContext searches the context for a SCEP provisioner. // Returns the provisioner or an error. func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { - val := ctx.Value(acme.ProvisionerContextKey) + val := ctx.Value(ProvisionerContextKey) if val == nil { return nil, errors.New("provisioner expected in request context") } diff --git a/scep/errors.go b/scep/errors.go new file mode 100644 index 00000000..52fff8ae --- /dev/null +++ b/scep/errors.go @@ -0,0 +1,19 @@ +package scep + +// Error is an SCEP error type +type Error struct { + // Type ProbType + // Detail string + Err error + Status int + // Sub []*Error + // Identifier *Identifier +} + +// Error implements the error interface. +func (e *Error) Error() string { + // if e.Err == nil { + // return e.Detail + // } + return e.Err.Error() +}