From a191319da9c8f36249bf343e6e048a0bec88f78c Mon Sep 17 00:00:00 2001
From: Herman Slatman <hermanslatman@hotmail.com>
Date: Sat, 27 Feb 2021 00:34:50 +0100
Subject: [PATCH] Improve SCEP API logic and error handling

---
 api/errors.go   |   3 +
 scep/api/api.go | 215 +++++++++++++++++++++++++-----------------------
 scep/common.go  |  11 ++-
 scep/errors.go  |  19 +++++
 4 files changed, 145 insertions(+), 103 deletions(-)
 create mode 100644 scep/errors.go

diff --git a/api/errors.go b/api/errors.go
index 085d05cf..f9bcb199 100644
--- a/api/errors.go
+++ b/api/errors.go
@@ -11,6 +11,7 @@ import (
 	"github.com/smallstep/certificates/authority/mgmt"
 	"github.com/smallstep/certificates/errs"
 	"github.com/smallstep/certificates/logging"
+	"github.com/smallstep/certificates/scep"
 )
 
 // WriteError writes to w a JSON representation of the given error.
@@ -22,6 +23,8 @@ func WriteError(w http.ResponseWriter, err error) {
 	case *mgmt.Error:
 		mgmt.WriteError(w, k)
 		return
+	case *scep.Error:
+		w.Header().Set("Content-Type", "text/plain")
 	default:
 		w.Header().Set("Content-Type", "application/json")
 	}
diff --git a/scep/api/api.go b/scep/api/api.go
index 16278075..f2d11fb7 100644
--- a/scep/api/api.go
+++ b/scep/api/api.go
@@ -11,7 +11,6 @@ import (
 	"net/http"
 	"strings"
 
-	"github.com/smallstep/certificates/acme"
 	"github.com/smallstep/certificates/api"
 	"github.com/smallstep/certificates/authority/provisioner"
 	"github.com/smallstep/certificates/scep"
@@ -27,6 +26,30 @@ const (
 	// TODO: add other (more optional) operations and handling
 )
 
+const maxPayloadSize = 2 << 20
+
+type nextHTTP = func(http.ResponseWriter, *http.Request)
+
+var (
+	// TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2
+	// TODO: move capabilities to Authority or Provisioner, so that they can be configured?
+	defaultCapabilities = []string{
+		"Renewal",
+		"SHA-1",
+		"SHA-256",
+		"AES",
+		"DES3",
+		"SCEPStandard",
+		"POSTPKIOperation",
+	}
+)
+
+const (
+	certChainHeader = "application/x-x509-ca-ra-cert"
+	leafHeader      = "application/x-x509-ca-cert"
+	pkiOpHeader     = "application/x-pki-message"
+)
+
 // SCEPRequest is a SCEP server request.
 type SCEPRequest struct {
 	Operation string
@@ -35,10 +58,10 @@ type SCEPRequest struct {
 
 // SCEPResponse is a SCEP server response.
 type SCEPResponse struct {
-	Operation string
-	CACertNum int
-	Data      []byte
-	Err       error
+	Operation   string
+	CACertNum   int
+	Data        []byte
+	Certificate *x509.Certificate
 }
 
 // Handler is the SCEP request handler.
@@ -64,60 +87,65 @@ func (h *Handler) Route(r api.Router) {
 
 }
 
+// Get handles all SCEP GET requests
 func (h *Handler) Get(w http.ResponseWriter, r *http.Request) {
 
-	scepRequest, err := decodeSCEPRequest(r)
+	request, err := decodeSCEPRequest(r)
 	if err != nil {
-		fmt.Println(err)
-		fmt.Println("not a scep get request")
-		w.WriteHeader(500)
+		writeError(w, fmt.Errorf("not a scep get request: %w", err))
+		return
 	}
 
-	scepResponse := SCEPResponse{Operation: scepRequest.Operation}
+	ctx := r.Context()
+	var response SCEPResponse
 
-	switch scepRequest.Operation {
+	switch request.Operation {
 	case opnGetCACert:
-		err := h.GetCACert(w, r, scepResponse)
-		if err != nil {
-			fmt.Println(err)
-		}
-
+		response, err = h.GetCACert(ctx)
 	case opnGetCACaps:
-		err := h.GetCACaps(w, r, scepResponse)
-		if err != nil {
-			fmt.Println(err)
-		}
+		response, err = h.GetCACaps(ctx)
 	case opnPKIOperation:
-
+		// TODO: implement the GET for PKI operation
 	default:
-
+		err = fmt.Errorf("unknown operation: %s", request.Operation)
 	}
+
+	if err != nil {
+		writeError(w, fmt.Errorf("get request failed: %w", err))
+		return
+	}
+
+	writeSCEPResponse(w, response)
 }
 
+// Post handles all SCEP POST requests
 func (h *Handler) Post(w http.ResponseWriter, r *http.Request) {
 
-	scepRequest, err := decodeSCEPRequest(r)
+	request, err := decodeSCEPRequest(r)
 	if err != nil {
-		fmt.Println(err)
-		fmt.Println("not a scep post request")
-		w.WriteHeader(500)
+		writeError(w, fmt.Errorf("not a scep post request: %w", err))
+		return
 	}
 
-	scepResponse := SCEPResponse{Operation: scepRequest.Operation}
+	ctx := r.Context()
+	var response SCEPResponse
 
-	switch scepRequest.Operation {
+	switch request.Operation {
 	case opnPKIOperation:
-		err := h.PKIOperation(w, r, scepRequest, scepResponse)
-		if err != nil {
-			fmt.Println(err)
-		}
+		response, err = h.PKIOperation(ctx, request)
 	default:
-
+		err = fmt.Errorf("unknown operation: %s", request.Operation)
 	}
 
-}
+	if err != nil {
+		writeError(w, fmt.Errorf("post request failed: %w", err))
+		return
+	}
 
-const maxPayloadSize = 2 << 20
+	api.LogCertificate(w, response.Certificate)
+
+	writeSCEPResponse(w, response)
+}
 
 func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
 
@@ -169,8 +197,6 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
 	}
 }
 
-type nextHTTP = func(http.ResponseWriter, *http.Request)
-
 // lookupProvisioner loads the provisioner associated with the request.
 // Responds 404 if the provisioner does not exist.
 func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
@@ -189,67 +215,69 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
 
 		p, err := h.Auth.LoadProvisionerByID("scep/" + provisionerID)
 		if err != nil {
-			api.WriteError(w, err)
+			writeError(w, err)
 			return
 		}
 
-		scepProvisioner, ok := p.(*provisioner.SCEP)
+		provisioner, ok := p.(*provisioner.SCEP)
 		if !ok {
-			api.WriteError(w, errors.New("provisioner must be of type SCEP"))
+			writeError(w, errors.New("provisioner must be of type SCEP"))
 			return
 		}
 
 		ctx := r.Context()
-		ctx = context.WithValue(ctx, acme.ProvisionerContextKey, scep.Provisioner(scepProvisioner))
+		ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner))
 		next(w, r.WithContext(ctx))
 	}
 }
 
-func (h *Handler) GetCACert(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error {
+// GetCACert returns the CA certificates in a SCEP response
+func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) {
 
 	certs, err := h.Auth.GetCACertificates()
 	if err != nil {
-		return err
+		return SCEPResponse{}, err
 	}
 
 	if len(certs) == 0 {
-		scepResponse.CACertNum = 0
-		scepResponse.Err = errors.New("missing CA Cert")
-	} else if len(certs) == 1 {
-		scepResponse.Data = certs[0].Raw
-		scepResponse.CACertNum = 1
-	} else {
-		data, err := microscep.DegenerateCertificates(certs)
-		scepResponse.CACertNum = len(certs)
-		scepResponse.Data = data
-		scepResponse.Err = err
+		return SCEPResponse{}, errors.New("missing CA cert")
 	}
 
-	return writeSCEPResponse(w, scepResponse)
+	response := SCEPResponse{Operation: opnGetCACert}
+	response.CACertNum = len(certs)
+
+	if len(certs) == 1 {
+		response.Data = certs[0].Raw
+	} else {
+		data, err := microscep.DegenerateCertificates(certs)
+		if err != nil {
+			return SCEPResponse{}, err
+		}
+		response.Data = data
+	}
+
+	return response, nil
 }
 
-func (h *Handler) GetCACaps(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error {
+// GetCACaps returns the CA capabilities in a SCEP response
+func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) {
 
-	//ctx := r.Context()
-
-	// _, err := ProvisionerFromContext(ctx)
-	// if err != nil {
-	// 	return err
-	// }
+	response := SCEPResponse{Operation: opnGetCACaps}
 
 	// TODO: get the actual capabilities from provisioner config
-	scepResponse.Data = formatCapabilities(defaultCapabilities)
+	response.Data = formatCapabilities(defaultCapabilities)
 
-	return writeSCEPResponse(w, scepResponse)
+	return response, nil
 }
 
-func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepRequest SCEPRequest, scepResponse SCEPResponse) error {
+// PKIOperation performs PKI operations and returns a SCEP response
+func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) {
 
-	ctx := r.Context()
+	response := SCEPResponse{Operation: opnPKIOperation}
 
-	microMsg, err := microscep.ParsePKIMessage(scepRequest.Message)
+	microMsg, err := microscep.ParsePKIMessage(request.Message)
 	if err != nil {
-		return err
+		return SCEPResponse{}, err
 	}
 
 	msg := &scep.PKIMessage{
@@ -260,7 +288,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
 	}
 
 	if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil {
-		return err
+		return SCEPResponse{}, err
 	}
 
 	if msg.MessageType == microscep.PKCSReq {
@@ -271,7 +299,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
 
 	certRep, err := h.Auth.SignCSR(ctx, csr, msg)
 	if err != nil {
-		return err
+		return SCEPResponse{}, err
 	}
 
 	// //cert := certRep.CertRepMessage.Certificate
@@ -280,11 +308,10 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
 	// // TODO: check if CN already exists, if renewal is allowed and if existing should be revoked; fail if not
 	// // TODO: store the new cert for CN locally; should go into the DB
 
-	scepResponse.Data = certRep.Raw
+	response.Data = certRep.Raw
+	response.Certificate = certRep.Certificate
 
-	api.LogCertificate(w, certRep.Certificate)
-
-	return writeSCEPResponse(w, scepResponse)
+	return response, nil
 }
 
 func certName(cert *x509.Certificate) string {
@@ -299,40 +326,26 @@ func formatCapabilities(caps []string) []byte {
 }
 
 // writeSCEPResponse writes a SCEP response back to the SCEP client.
-func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) error {
-	if response.Err != nil {
-		http.Error(w, response.Err.Error(), http.StatusInternalServerError)
-		return nil
+func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) {
+	w.Header().Set("Content-Type", contentHeader(response))
+	_, err := w.Write(response.Data)
+	if err != nil {
+		writeError(w, fmt.Errorf("error when writing scep response: %w", err)) // This could end up as an error again
 	}
-	w.Header().Set("Content-Type", contentHeader(response.Operation, response.CACertNum))
-	w.Write(response.Data)
-	return nil
 }
 
-var (
-	// TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2
-	// TODO: move capabilities to Authority or Provisioner, so that they can be configured?
-	defaultCapabilities = []string{
-		"Renewal",
-		"SHA-1",
-		"SHA-256",
-		"AES",
-		"DES3",
-		"SCEPStandard",
-		"POSTPKIOperation",
+func writeError(w http.ResponseWriter, err error) {
+	scepError := &scep.Error{
+		Err:    fmt.Errorf("post request failed: %w", err),
+		Status: http.StatusInternalServerError, // TODO: make this a param?
 	}
-)
+	api.WriteError(w, scepError)
+}
 
-const (
-	certChainHeader = "application/x-x509-ca-ra-cert"
-	leafHeader      = "application/x-x509-ca-cert"
-	pkiOpHeader     = "application/x-pki-message"
-)
-
-func contentHeader(operation string, certNum int) string {
-	switch operation {
+func contentHeader(r SCEPResponse) string {
+	switch r.Operation {
 	case opnGetCACert:
-		if certNum > 1 {
+		if r.CACertNum > 1 {
 			return certChainHeader
 		}
 		return leafHeader
diff --git a/scep/common.go b/scep/common.go
index 123c8d82..ca87841f 100644
--- a/scep/common.go
+++ b/scep/common.go
@@ -3,14 +3,21 @@ package scep
 import (
 	"context"
 	"errors"
+)
 
-	"github.com/smallstep/certificates/acme"
+// ContextKey is the key type for storing and searching for SCEP request
+// essentials in the context of a request.
+type ContextKey string
+
+const (
+	// ProvisionerContextKey provisioner key
+	ProvisionerContextKey = ContextKey("provisioner")
 )
 
 // ProvisionerFromContext searches the context for a SCEP provisioner.
 // Returns the provisioner or an error.
 func ProvisionerFromContext(ctx context.Context) (Provisioner, error) {
-	val := ctx.Value(acme.ProvisionerContextKey)
+	val := ctx.Value(ProvisionerContextKey)
 	if val == nil {
 		return nil, errors.New("provisioner expected in request context")
 	}
diff --git a/scep/errors.go b/scep/errors.go
new file mode 100644
index 00000000..52fff8ae
--- /dev/null
+++ b/scep/errors.go
@@ -0,0 +1,19 @@
+package scep
+
+// Error is an SCEP error type
+type Error struct {
+	// Type       ProbType
+	// Detail string
+	Err    error
+	Status int
+	// Sub    []*Error
+	// Identifier *Identifier
+}
+
+// Error implements the error interface.
+func (e *Error) Error() string {
+	// if e.Err == nil {
+	// 	return e.Detail
+	// }
+	return e.Err.Error()
+}