Improve SCEP API logic and error handling

This commit is contained in:
Herman Slatman 2021-02-27 00:34:50 +01:00 committed by max furman
parent 30d3a26c20
commit a191319da9
4 changed files with 145 additions and 103 deletions

View file

@ -11,6 +11,7 @@ import (
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
) )
// WriteError writes to w a JSON representation of the given error. // WriteError writes to w a JSON representation of the given error.
@ -22,6 +23,8 @@ func WriteError(w http.ResponseWriter, err error) {
case *mgmt.Error: case *mgmt.Error:
mgmt.WriteError(w, k) mgmt.WriteError(w, k)
return return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default: default:
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
} }

View file

@ -11,7 +11,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/scep"
@ -27,6 +26,30 @@ const (
// TODO: add other (more optional) operations and handling // TODO: add other (more optional) operations and handling
) )
const maxPayloadSize = 2 << 20
type nextHTTP = func(http.ResponseWriter, *http.Request)
var (
// TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2
// TODO: move capabilities to Authority or Provisioner, so that they can be configured?
defaultCapabilities = []string{
"Renewal",
"SHA-1",
"SHA-256",
"AES",
"DES3",
"SCEPStandard",
"POSTPKIOperation",
}
)
const (
certChainHeader = "application/x-x509-ca-ra-cert"
leafHeader = "application/x-x509-ca-cert"
pkiOpHeader = "application/x-pki-message"
)
// SCEPRequest is a SCEP server request. // SCEPRequest is a SCEP server request.
type SCEPRequest struct { type SCEPRequest struct {
Operation string Operation string
@ -38,7 +61,7 @@ type SCEPResponse struct {
Operation string Operation string
CACertNum int CACertNum int
Data []byte Data []byte
Err error Certificate *x509.Certificate
} }
// Handler is the SCEP request handler. // Handler is the SCEP request handler.
@ -64,60 +87,65 @@ func (h *Handler) Route(r api.Router) {
} }
// Get handles all SCEP GET requests
func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { func (h *Handler) Get(w http.ResponseWriter, r *http.Request) {
scepRequest, err := decodeSCEPRequest(r) request, err := decodeSCEPRequest(r)
if err != nil { if err != nil {
fmt.Println(err) writeError(w, fmt.Errorf("not a scep get request: %w", err))
fmt.Println("not a scep get request") return
w.WriteHeader(500)
} }
scepResponse := SCEPResponse{Operation: scepRequest.Operation} ctx := r.Context()
var response SCEPResponse
switch scepRequest.Operation { switch request.Operation {
case opnGetCACert: case opnGetCACert:
err := h.GetCACert(w, r, scepResponse) response, err = h.GetCACert(ctx)
if err != nil {
fmt.Println(err)
}
case opnGetCACaps: case opnGetCACaps:
err := h.GetCACaps(w, r, scepResponse) response, err = h.GetCACaps(ctx)
if err != nil {
fmt.Println(err)
}
case opnPKIOperation: case opnPKIOperation:
// TODO: implement the GET for PKI operation
default: default:
err = fmt.Errorf("unknown operation: %s", request.Operation)
} }
if err != nil {
writeError(w, fmt.Errorf("get request failed: %w", err))
return
}
writeSCEPResponse(w, response)
} }
// Post handles all SCEP POST requests
func (h *Handler) Post(w http.ResponseWriter, r *http.Request) { func (h *Handler) Post(w http.ResponseWriter, r *http.Request) {
scepRequest, err := decodeSCEPRequest(r) request, err := decodeSCEPRequest(r)
if err != nil { if err != nil {
fmt.Println(err) writeError(w, fmt.Errorf("not a scep post request: %w", err))
fmt.Println("not a scep post request") return
w.WriteHeader(500)
} }
scepResponse := SCEPResponse{Operation: scepRequest.Operation} ctx := r.Context()
var response SCEPResponse
switch scepRequest.Operation { switch request.Operation {
case opnPKIOperation: case opnPKIOperation:
err := h.PKIOperation(w, r, scepRequest, scepResponse) response, err = h.PKIOperation(ctx, request)
if err != nil {
fmt.Println(err)
}
default: default:
err = fmt.Errorf("unknown operation: %s", request.Operation)
} }
} if err != nil {
writeError(w, fmt.Errorf("post request failed: %w", err))
return
}
const maxPayloadSize = 2 << 20 api.LogCertificate(w, response.Certificate)
writeSCEPResponse(w, response)
}
func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
@ -169,8 +197,6 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
} }
} }
type nextHTTP = func(http.ResponseWriter, *http.Request)
// lookupProvisioner loads the provisioner associated with the request. // lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist. // Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
@ -189,67 +215,69 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
p, err := h.Auth.LoadProvisionerByID("scep/" + provisionerID) p, err := h.Auth.LoadProvisionerByID("scep/" + provisionerID)
if err != nil { if err != nil {
api.WriteError(w, err) writeError(w, err)
return return
} }
scepProvisioner, ok := p.(*provisioner.SCEP) provisioner, ok := p.(*provisioner.SCEP)
if !ok { if !ok {
api.WriteError(w, errors.New("provisioner must be of type SCEP")) writeError(w, errors.New("provisioner must be of type SCEP"))
return return
} }
ctx := r.Context() ctx := r.Context()
ctx = context.WithValue(ctx, acme.ProvisionerContextKey, scep.Provisioner(scepProvisioner)) ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
func (h *Handler) GetCACert(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error { // GetCACert returns the CA certificates in a SCEP response
func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) {
certs, err := h.Auth.GetCACertificates() certs, err := h.Auth.GetCACertificates()
if err != nil { if err != nil {
return err return SCEPResponse{}, err
} }
if len(certs) == 0 { if len(certs) == 0 {
scepResponse.CACertNum = 0 return SCEPResponse{}, errors.New("missing CA cert")
scepResponse.Err = errors.New("missing CA Cert")
} else if len(certs) == 1 {
scepResponse.Data = certs[0].Raw
scepResponse.CACertNum = 1
} else {
data, err := microscep.DegenerateCertificates(certs)
scepResponse.CACertNum = len(certs)
scepResponse.Data = data
scepResponse.Err = err
} }
return writeSCEPResponse(w, scepResponse) response := SCEPResponse{Operation: opnGetCACert}
response.CACertNum = len(certs)
if len(certs) == 1 {
response.Data = certs[0].Raw
} else {
data, err := microscep.DegenerateCertificates(certs)
if err != nil {
return SCEPResponse{}, err
}
response.Data = data
}
return response, nil
} }
func (h *Handler) GetCACaps(w http.ResponseWriter, r *http.Request, scepResponse SCEPResponse) error { // GetCACaps returns the CA capabilities in a SCEP response
func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) {
//ctx := r.Context() response := SCEPResponse{Operation: opnGetCACaps}
// _, err := ProvisionerFromContext(ctx)
// if err != nil {
// return err
// }
// TODO: get the actual capabilities from provisioner config // TODO: get the actual capabilities from provisioner config
scepResponse.Data = formatCapabilities(defaultCapabilities) response.Data = formatCapabilities(defaultCapabilities)
return writeSCEPResponse(w, scepResponse) return response, nil
} }
func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepRequest SCEPRequest, scepResponse SCEPResponse) error { // PKIOperation performs PKI operations and returns a SCEP response
func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) {
ctx := r.Context() response := SCEPResponse{Operation: opnPKIOperation}
microMsg, err := microscep.ParsePKIMessage(scepRequest.Message) microMsg, err := microscep.ParsePKIMessage(request.Message)
if err != nil { if err != nil {
return err return SCEPResponse{}, err
} }
msg := &scep.PKIMessage{ msg := &scep.PKIMessage{
@ -260,7 +288,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
} }
if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil { if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil {
return err return SCEPResponse{}, err
} }
if msg.MessageType == microscep.PKCSReq { if msg.MessageType == microscep.PKCSReq {
@ -271,7 +299,7 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
certRep, err := h.Auth.SignCSR(ctx, csr, msg) certRep, err := h.Auth.SignCSR(ctx, csr, msg)
if err != nil { if err != nil {
return err return SCEPResponse{}, err
} }
// //cert := certRep.CertRepMessage.Certificate // //cert := certRep.CertRepMessage.Certificate
@ -280,11 +308,10 @@ func (h *Handler) PKIOperation(w http.ResponseWriter, r *http.Request, scepReque
// // TODO: check if CN already exists, if renewal is allowed and if existing should be revoked; fail if not // // TODO: check if CN already exists, if renewal is allowed and if existing should be revoked; fail if not
// // TODO: store the new cert for CN locally; should go into the DB // // TODO: store the new cert for CN locally; should go into the DB
scepResponse.Data = certRep.Raw response.Data = certRep.Raw
response.Certificate = certRep.Certificate
api.LogCertificate(w, certRep.Certificate) return response, nil
return writeSCEPResponse(w, scepResponse)
} }
func certName(cert *x509.Certificate) string { func certName(cert *x509.Certificate) string {
@ -299,40 +326,26 @@ func formatCapabilities(caps []string) []byte {
} }
// writeSCEPResponse writes a SCEP response back to the SCEP client. // writeSCEPResponse writes a SCEP response back to the SCEP client.
func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) error { func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) {
if response.Err != nil { w.Header().Set("Content-Type", contentHeader(response))
http.Error(w, response.Err.Error(), http.StatusInternalServerError) _, err := w.Write(response.Data)
return nil if err != nil {
writeError(w, fmt.Errorf("error when writing scep response: %w", err)) // This could end up as an error again
} }
w.Header().Set("Content-Type", contentHeader(response.Operation, response.CACertNum))
w.Write(response.Data)
return nil
} }
var ( func writeError(w http.ResponseWriter, err error) {
// TODO: check the default capabilities; https://tools.ietf.org/html/rfc8894#section-3.5.2 scepError := &scep.Error{
// TODO: move capabilities to Authority or Provisioner, so that they can be configured? Err: fmt.Errorf("post request failed: %w", err),
defaultCapabilities = []string{ Status: http.StatusInternalServerError, // TODO: make this a param?
"Renewal",
"SHA-1",
"SHA-256",
"AES",
"DES3",
"SCEPStandard",
"POSTPKIOperation",
} }
) api.WriteError(w, scepError)
}
const ( func contentHeader(r SCEPResponse) string {
certChainHeader = "application/x-x509-ca-ra-cert" switch r.Operation {
leafHeader = "application/x-x509-ca-cert"
pkiOpHeader = "application/x-pki-message"
)
func contentHeader(operation string, certNum int) string {
switch operation {
case opnGetCACert: case opnGetCACert:
if certNum > 1 { if r.CACertNum > 1 {
return certChainHeader return certChainHeader
} }
return leafHeader return leafHeader

View file

@ -3,14 +3,21 @@ package scep
import ( import (
"context" "context"
"errors" "errors"
)
"github.com/smallstep/certificates/acme" // ContextKey is the key type for storing and searching for SCEP request
// essentials in the context of a request.
type ContextKey string
const (
// ProvisionerContextKey provisioner key
ProvisionerContextKey = ContextKey("provisioner")
) )
// ProvisionerFromContext searches the context for a SCEP provisioner. // ProvisionerFromContext searches the context for a SCEP provisioner.
// Returns the provisioner or an error. // Returns the provisioner or an error.
func ProvisionerFromContext(ctx context.Context) (Provisioner, error) { func ProvisionerFromContext(ctx context.Context) (Provisioner, error) {
val := ctx.Value(acme.ProvisionerContextKey) val := ctx.Value(ProvisionerContextKey)
if val == nil { if val == nil {
return nil, errors.New("provisioner expected in request context") return nil, errors.New("provisioner expected in request context")
} }

19
scep/errors.go Normal file
View file

@ -0,0 +1,19 @@
package scep
// Error is an SCEP error type
type Error struct {
// Type ProbType
// Detail string
Err error
Status int
// Sub []*Error
// Identifier *Identifier
}
// Error implements the error interface.
func (e *Error) Error() string {
// if e.Err == nil {
// return e.Detail
// }
return e.Err.Error()
}