scep: minor cleanup (#867)

* api, scep: removed scep.Error

* scep/api: replaced nextHTTP with http.HandlerFunc

* scep/api: renamed writeSCEPResponse to writeResponse

* scep/api: renamed decodeSCEPRequest to decodeRequest

* scep/api: renamed writeError to fail

* scep/api: replaced pkg/errors with errors

* scep/api: formatted imports

* scep/api: do not export SCEPRequest & SCEPResponse

* scep/api: do not export Handler

* api: flush errors better
This commit is contained in:
Panagiotis Siatras 2022-03-24 14:58:50 +02:00 committed by GitHub
parent 082734474b
commit b98f86a515
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 108 deletions

View file

@ -13,7 +13,6 @@ import (
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/scep"
) )
// WriteError writes to w a JSON representation of the given error. // WriteError writes to w a JSON representation of the given error.
@ -25,22 +24,9 @@ func WriteError(w http.ResponseWriter, err error) {
case *admin.Error: case *admin.Error:
admin.WriteError(w, k) admin.WriteError(w, k)
return return
case *scep.Error:
w.Header().Set("Content-Type", "text/plain")
default:
w.Header().Set("Content-Type", "application/json")
} }
cause := errors.Cause(err) cause := errors.Cause(err)
if sc, ok := err.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
if sc, ok := cause.(errs.StatusCoder); ok {
w.WriteHeader(sc.StatusCode())
} else {
w.WriteHeader(http.StatusInternalServerError)
}
}
// Write errors in the response writer // Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok { if rl, ok := w.(logging.ResponseLogger); ok {
@ -60,6 +46,16 @@ func WriteError(w http.ResponseWriter, err error) {
} }
} }
code := http.StatusInternalServerError
if sc, ok := err.(errs.StatusCoder); ok {
code = sc.StatusCode()
} else if sc, ok := cause.(errs.StatusCoder); ok {
code = sc.StatusCode()
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(err); err != nil { if err := json.NewEncoder(w).Encode(err); err != nil {
log.Error(w, err) log.Error(w, err)
} }

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -11,7 +13,6 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
microscep "github.com/micromdm/scep/v2/scep" microscep "github.com/micromdm/scep/v2/scep"
"github.com/pkg/errors"
"go.mozilla.org/pkcs7" "go.mozilla.org/pkcs7"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
@ -30,22 +31,20 @@ const (
const maxPayloadSize = 2 << 20 const maxPayloadSize = 2 << 20
type nextHTTP = func(http.ResponseWriter, *http.Request)
const ( const (
certChainHeader = "application/x-x509-ca-ra-cert" certChainHeader = "application/x-x509-ca-ra-cert"
leafHeader = "application/x-x509-ca-cert" leafHeader = "application/x-x509-ca-cert"
pkiOperationHeader = "application/x-pki-message" pkiOperationHeader = "application/x-pki-message"
) )
// SCEPRequest is a SCEP server request. // request is a SCEP server request.
type SCEPRequest struct { type request struct {
Operation string Operation string
Message []byte Message []byte
} }
// SCEPResponse is a SCEP server response. // response is a SCEP server response.
type SCEPResponse struct { type response struct {
Operation string Operation string
CACertNum int CACertNum int
Data []byte Data []byte
@ -53,18 +52,18 @@ type SCEPResponse struct {
Error error Error error
} }
// Handler is the SCEP request handler. // handler is the SCEP request handler.
type Handler struct { type handler struct {
Auth scep.Interface Auth scep.Interface
} }
// New returns a new SCEP API router. // New returns a new SCEP API router.
func New(scepAuth scep.Interface) api.RouterHandler { func New(scepAuth scep.Interface) api.RouterHandler {
return &Handler{scepAuth} return &handler{scepAuth}
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *handler) Route(r api.Router) {
getLink := h.Auth.GetLinkExplicit getLink := h.Auth.GetLinkExplicit
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get))
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
@ -73,64 +72,64 @@ func (h *Handler) Route(r api.Router) {
} }
// Get handles all SCEP GET requests // Get handles all SCEP GET requests
func (h *Handler) Get(w http.ResponseWriter, r *http.Request) { func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
request, err := decodeSCEPRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
writeError(w, errors.Wrap(err, "invalid scep get request")) fail(w, fmt.Errorf("invalid scep get request: %w", err))
return return
} }
ctx := r.Context() ctx := r.Context()
var response SCEPResponse var res response
switch request.Operation { switch req.Operation {
case opnGetCACert: case opnGetCACert:
response, err = h.GetCACert(ctx) res, err = h.GetCACert(ctx)
case opnGetCACaps: case opnGetCACaps:
response, err = h.GetCACaps(ctx) res, err = h.GetCACaps(ctx)
case opnPKIOperation: case opnPKIOperation:
// TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though // TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though
default: default:
err = errors.Errorf("unknown operation: %s", request.Operation) err = fmt.Errorf("unknown operation: %s", req.Operation)
} }
if err != nil { if err != nil {
writeError(w, errors.Wrap(err, "scep get request failed")) fail(w, fmt.Errorf("scep get request failed: %w", err))
return return
} }
writeSCEPResponse(w, response) writeResponse(w, res)
} }
// Post handles all SCEP POST requests // Post handles all SCEP POST requests
func (h *Handler) Post(w http.ResponseWriter, r *http.Request) { func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
request, err := decodeSCEPRequest(r) req, err := decodeRequest(r)
if err != nil { if err != nil {
writeError(w, errors.Wrap(err, "invalid scep post request")) fail(w, fmt.Errorf("invalid scep post request: %w", err))
return return
} }
ctx := r.Context() ctx := r.Context()
var response SCEPResponse var res response
switch request.Operation { switch req.Operation {
case opnPKIOperation: case opnPKIOperation:
response, err = h.PKIOperation(ctx, request) res, err = h.PKIOperation(ctx, req)
default: default:
err = errors.Errorf("unknown operation: %s", request.Operation) err = fmt.Errorf("unknown operation: %s", req.Operation)
} }
if err != nil { if err != nil {
writeError(w, errors.Wrap(err, "scep post request failed")) fail(w, fmt.Errorf("scep post request failed: %w", err))
return return
} }
writeSCEPResponse(w, response) writeResponse(w, res)
} }
func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) { func decodeRequest(r *http.Request) (request, error) {
defer r.Body.Close() defer r.Body.Close()
@ -146,7 +145,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
case http.MethodGet: case http.MethodGet:
switch operation { switch operation {
case opnGetCACert, opnGetCACaps: case opnGetCACert, opnGetCACaps:
return SCEPRequest{ return request{
Operation: operation, Operation: operation,
Message: []byte{}, Message: []byte{},
}, nil }, nil
@ -158,50 +157,50 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
// TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding // TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding
decodedMessage, err := base64.URLEncoding.DecodeString(message) decodedMessage, err := base64.URLEncoding.DecodeString(message)
if err != nil { if err != nil {
return SCEPRequest{}, err return request{}, err
} }
return SCEPRequest{ return request{
Operation: operation, Operation: operation,
Message: decodedMessage, Message: decodedMessage,
}, nil }, nil
default: default:
return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) return request{}, fmt.Errorf("unsupported operation: %s", operation)
} }
case http.MethodPost: case http.MethodPost:
body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize))
if err != nil { if err != nil {
return SCEPRequest{}, err return request{}, err
} }
return SCEPRequest{ return request{
Operation: operation, Operation: operation,
Message: body, Message: body,
}, nil }, nil
default: default:
return SCEPRequest{}, errors.Errorf("unsupported method: %s", method) return request{}, fmt.Errorf("unsupported method: %s", method)
} }
} }
// lookupProvisioner loads the provisioner associated with the request. // lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist. // Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP { func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "provisionerName") name := chi.URLParam(r, "provisionerName")
provisionerName, err := url.PathUnescape(name) provisionerName, err := url.PathUnescape(name)
if err != nil { if err != nil {
api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name)) fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name))
return return
} }
p, err := h.Auth.LoadProvisionerByName(provisionerName) p, err := h.Auth.LoadProvisionerByName(provisionerName)
if err != nil { if err != nil {
api.WriteError(w, err) fail(w, err)
return return
} }
prov, ok := p.(*provisioner.SCEP) prov, ok := p.(*provisioner.SCEP)
if !ok { if !ok {
api.WriteError(w, errors.New("provisioner must be of type SCEP")) fail(w, errors.New("provisioner must be of type SCEP"))
return return
} }
@ -212,59 +211,59 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
} }
// GetCACert returns the CA certificates in a SCEP response // GetCACert returns the CA certificates in a SCEP response
func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) { func (h *handler) GetCACert(ctx context.Context) (response, error) {
certs, err := h.Auth.GetCACertificates(ctx) certs, err := h.Auth.GetCACertificates(ctx)
if err != nil { if err != nil {
return SCEPResponse{}, err return response{}, err
} }
if len(certs) == 0 { if len(certs) == 0 {
return SCEPResponse{}, errors.New("missing CA cert") return response{}, errors.New("missing CA cert")
} }
response := SCEPResponse{ res := response{
Operation: opnGetCACert, Operation: opnGetCACert,
CACertNum: len(certs), CACertNum: len(certs),
} }
if len(certs) == 1 { if len(certs) == 1 {
response.Data = certs[0].Raw res.Data = certs[0].Raw
} else { } else {
// create degenerate pkcs7 certificate structure, according to // create degenerate pkcs7 certificate structure, according to
// https://tools.ietf.org/html/rfc8894#section-4.2.1.2, because // https://tools.ietf.org/html/rfc8894#section-4.2.1.2, because
// not signed or encrypted data has to be returned. // not signed or encrypted data has to be returned.
data, err := microscep.DegenerateCertificates(certs) data, err := microscep.DegenerateCertificates(certs)
if err != nil { if err != nil {
return SCEPResponse{}, err return response{}, err
} }
response.Data = data res.Data = data
} }
return response, nil return res, nil
} }
// GetCACaps returns the CA capabilities in a SCEP response // GetCACaps returns the CA capabilities in a SCEP response
func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) { func (h *handler) GetCACaps(ctx context.Context) (response, error) {
caps := h.Auth.GetCACaps(ctx) caps := h.Auth.GetCACaps(ctx)
response := SCEPResponse{ res := response{
Operation: opnGetCACaps, Operation: opnGetCACaps,
Data: formatCapabilities(caps), Data: formatCapabilities(caps),
} }
return response, nil return res, nil
} }
// PKIOperation performs PKI operations and returns a SCEP response // PKIOperation performs PKI operations and returns a SCEP response
func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) { func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) {
// parse the message using microscep implementation // parse the message using microscep implementation
microMsg, err := microscep.ParsePKIMessage(request.Message) microMsg, err := microscep.ParsePKIMessage(req.Message)
if err != nil { if err != nil {
// return the error, because we can't use the msg for creating a CertRep // return the error, because we can't use the msg for creating a CertRep
return SCEPResponse{}, err return response{}, err
} }
// this is essentially doing the same as microscep.ParsePKIMessage, but // this is essentially doing the same as microscep.ParsePKIMessage, but
@ -272,7 +271,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
// wrapper for the microscep implementation. // wrapper for the microscep implementation.
p7, err := pkcs7.Parse(microMsg.Raw) p7, err := pkcs7.Parse(microMsg.Raw)
if err != nil { if err != nil {
return SCEPResponse{}, err return response{}, err
} }
// copy over properties to our internal PKIMessage // copy over properties to our internal PKIMessage
@ -285,7 +284,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
} }
if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil { if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil {
return SCEPResponse{}, err return response{}, err
} }
// NOTE: at this point we have sufficient information for returning nicely signed CertReps // NOTE: at this point we have sufficient information for returning nicely signed CertReps
@ -317,61 +316,56 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
certRep, err := h.Auth.SignCSR(ctx, csr, msg) certRep, err := h.Auth.SignCSR(ctx, csr, msg)
if err != nil { if err != nil {
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.Wrap(err, "error when signing new certificate")) return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
} }
response := SCEPResponse{ res := response{
Operation: opnPKIOperation, Operation: opnPKIOperation,
Data: certRep.Raw, Data: certRep.Raw,
Certificate: certRep.Certificate, Certificate: certRep.Certificate,
} }
return response, nil return res, nil
} }
func formatCapabilities(caps []string) []byte { func formatCapabilities(caps []string) []byte {
return []byte(strings.Join(caps, "\r\n")) return []byte(strings.Join(caps, "\r\n"))
} }
// writeSCEPResponse writes a SCEP response back to the SCEP client. // writeResponse writes a SCEP response back to the SCEP client.
func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { func writeResponse(w http.ResponseWriter, res response) {
if response.Error != nil { if res.Error != nil {
log.Error(w, response.Error) log.Error(w, res.Error)
} }
if response.Certificate != nil { if res.Certificate != nil {
api.LogCertificate(w, response.Certificate) api.LogCertificate(w, res.Certificate)
} }
w.Header().Set("Content-Type", contentHeader(response)) w.Header().Set("Content-Type", contentHeader(res))
_, err := w.Write(response.Data) _, _ = w.Write(res.Data)
if err != nil {
writeError(w, errors.Wrap(err, "error when writing scep response")) // This could end up as an error again
}
} }
func writeError(w http.ResponseWriter, err error) { func fail(w http.ResponseWriter, err error) {
scepError := &scep.Error{ log.Error(w, err)
Message: err.Error(),
Status: http.StatusInternalServerError, // TODO: make this a param? http.Error(w, err.Error(), http.StatusInternalServerError)
}
api.WriteError(w, scepError)
} }
func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) { func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) {
certRepMsg, err := h.Auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error()) certRepMsg, err := h.Auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
if err != nil { if err != nil {
return SCEPResponse{}, err return response{}, err
} }
return SCEPResponse{ return response{
Operation: opnPKIOperation, Operation: opnPKIOperation,
Data: certRepMsg.Raw, Data: certRepMsg.Raw,
Error: failError, Error: failError,
}, nil }, nil
} }
func contentHeader(r SCEPResponse) string { func contentHeader(r response) string {
switch r.Operation { switch r.Operation {
case opnGetCACert: case opnGetCACert:
if r.CACertNum > 1 { if r.CACertNum > 1 {

View file

@ -1,12 +0,0 @@
package scep
// Error is an SCEP error type
type Error struct {
Message string `json:"message"`
Status int `json:"-"`
}
// Error implements the error interface.
func (e *Error) Error() string {
return e.Message
}