scep: minor cleanup (#867)
* api, scep: removed scep.Error * scep/api: replaced nextHTTP with http.HandlerFunc * scep/api: renamed writeSCEPResponse to writeResponse * scep/api: renamed decodeSCEPRequest to decodeRequest * scep/api: renamed writeError to fail * scep/api: replaced pkg/errors with errors * scep/api: formatted imports * scep/api: do not export SCEPRequest & SCEPResponse * scep/api: do not export Handler * api: flush errors better
This commit is contained in:
parent
37207793f9
commit
bca74cb6a7
3 changed files with 86 additions and 108 deletions
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/scep"
|
||||
)
|
||||
|
||||
// WriteError writes to w a JSON representation of the given error.
|
||||
|
@ -25,22 +24,9 @@ func WriteError(w http.ResponseWriter, err error) {
|
|||
case *admin.Error:
|
||||
admin.WriteError(w, k)
|
||||
return
|
||||
case *scep.Error:
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
cause := errors.Cause(err)
|
||||
if sc, ok := err.(errs.StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
if sc, ok := cause.(errs.StatusCoder); ok {
|
||||
w.WriteHeader(sc.StatusCode())
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// Write errors in the response writer
|
||||
if rl, ok := w.(logging.ResponseLogger); ok {
|
||||
|
@ -60,6 +46,16 @@ func WriteError(w http.ResponseWriter, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
code := http.StatusInternalServerError
|
||||
if sc, ok := err.(errs.StatusCoder); ok {
|
||||
code = sc.StatusCode()
|
||||
} else if sc, ok := cause.(errs.StatusCoder); ok {
|
||||
code = sc.StatusCode()
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
log.Error(w, err)
|
||||
}
|
||||
|
|
158
scep/api/api.go
158
scep/api/api.go
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -11,7 +13,6 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
microscep "github.com/micromdm/scep/v2/scep"
|
||||
"github.com/pkg/errors"
|
||||
"go.mozilla.org/pkcs7"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
|
@ -30,22 +31,20 @@ const (
|
|||
|
||||
const maxPayloadSize = 2 << 20
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
||||
const (
|
||||
certChainHeader = "application/x-x509-ca-ra-cert"
|
||||
leafHeader = "application/x-x509-ca-cert"
|
||||
pkiOperationHeader = "application/x-pki-message"
|
||||
)
|
||||
|
||||
// SCEPRequest is a SCEP server request.
|
||||
type SCEPRequest struct {
|
||||
// request is a SCEP server request.
|
||||
type request struct {
|
||||
Operation string
|
||||
Message []byte
|
||||
}
|
||||
|
||||
// SCEPResponse is a SCEP server response.
|
||||
type SCEPResponse struct {
|
||||
// response is a SCEP server response.
|
||||
type response struct {
|
||||
Operation string
|
||||
CACertNum int
|
||||
Data []byte
|
||||
|
@ -53,18 +52,18 @@ type SCEPResponse struct {
|
|||
Error error
|
||||
}
|
||||
|
||||
// Handler is the SCEP request handler.
|
||||
type Handler struct {
|
||||
// handler is the SCEP request handler.
|
||||
type handler struct {
|
||||
Auth scep.Interface
|
||||
}
|
||||
|
||||
// New returns a new SCEP API router.
|
||||
func New(scepAuth scep.Interface) api.RouterHandler {
|
||||
return &Handler{scepAuth}
|
||||
return &handler{scepAuth}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
func (h *handler) Route(r api.Router) {
|
||||
getLink := h.Auth.GetLinkExplicit
|
||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get))
|
||||
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
|
||||
|
@ -73,64 +72,64 @@ func (h *Handler) Route(r api.Router) {
|
|||
}
|
||||
|
||||
// Get handles all SCEP GET requests
|
||||
func (h *Handler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) Get(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
request, err := decodeSCEPRequest(r)
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, errors.Wrap(err, "invalid scep get request"))
|
||||
fail(w, fmt.Errorf("invalid scep get request: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var response SCEPResponse
|
||||
var res response
|
||||
|
||||
switch request.Operation {
|
||||
switch req.Operation {
|
||||
case opnGetCACert:
|
||||
response, err = h.GetCACert(ctx)
|
||||
res, err = h.GetCACert(ctx)
|
||||
case opnGetCACaps:
|
||||
response, err = h.GetCACaps(ctx)
|
||||
res, err = h.GetCACaps(ctx)
|
||||
case opnPKIOperation:
|
||||
// TODO: implement the GET for PKI operation? Default CACAPS doesn't specify this is in use, though
|
||||
default:
|
||||
err = errors.Errorf("unknown operation: %s", request.Operation)
|
||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
writeError(w, errors.Wrap(err, "scep get request failed"))
|
||||
fail(w, fmt.Errorf("scep get request failed: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
writeSCEPResponse(w, response)
|
||||
writeResponse(w, res)
|
||||
}
|
||||
|
||||
// Post handles all SCEP POST requests
|
||||
func (h *Handler) Post(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) Post(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
request, err := decodeSCEPRequest(r)
|
||||
req, err := decodeRequest(r)
|
||||
if err != nil {
|
||||
writeError(w, errors.Wrap(err, "invalid scep post request"))
|
||||
fail(w, fmt.Errorf("invalid scep post request: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var response SCEPResponse
|
||||
var res response
|
||||
|
||||
switch request.Operation {
|
||||
switch req.Operation {
|
||||
case opnPKIOperation:
|
||||
response, err = h.PKIOperation(ctx, request)
|
||||
res, err = h.PKIOperation(ctx, req)
|
||||
default:
|
||||
err = errors.Errorf("unknown operation: %s", request.Operation)
|
||||
err = fmt.Errorf("unknown operation: %s", req.Operation)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
writeError(w, errors.Wrap(err, "scep post request failed"))
|
||||
fail(w, fmt.Errorf("scep post request failed: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
writeSCEPResponse(w, response)
|
||||
writeResponse(w, res)
|
||||
}
|
||||
|
||||
func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
|
||||
func decodeRequest(r *http.Request) (request, error) {
|
||||
|
||||
defer r.Body.Close()
|
||||
|
||||
|
@ -146,7 +145,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
|
|||
case http.MethodGet:
|
||||
switch operation {
|
||||
case opnGetCACert, opnGetCACaps:
|
||||
return SCEPRequest{
|
||||
return request{
|
||||
Operation: operation,
|
||||
Message: []byte{},
|
||||
}, nil
|
||||
|
@ -158,50 +157,50 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
|
|||
// TODO: verify this; it seems like it should be StdEncoding instead of URLEncoding
|
||||
decodedMessage, err := base64.URLEncoding.DecodeString(message)
|
||||
if err != nil {
|
||||
return SCEPRequest{}, err
|
||||
return request{}, err
|
||||
}
|
||||
return SCEPRequest{
|
||||
return request{
|
||||
Operation: operation,
|
||||
Message: decodedMessage,
|
||||
}, nil
|
||||
default:
|
||||
return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation)
|
||||
return request{}, fmt.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
case http.MethodPost:
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize))
|
||||
if err != nil {
|
||||
return SCEPRequest{}, err
|
||||
return request{}, err
|
||||
}
|
||||
return SCEPRequest{
|
||||
return request{
|
||||
Operation: operation,
|
||||
Message: body,
|
||||
}, nil
|
||||
default:
|
||||
return SCEPRequest{}, errors.Errorf("unsupported method: %s", method)
|
||||
return request{}, fmt.Errorf("unsupported method: %s", method)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupProvisioner loads the provisioner associated with the request.
|
||||
// Responds 404 if the provisioner does not exist.
|
||||
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||
func (h *handler) lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
name := chi.URLParam(r, "provisionerName")
|
||||
provisionerName, err := url.PathUnescape(name)
|
||||
if err != nil {
|
||||
api.WriteError(w, errors.Errorf("error url unescaping provisioner name '%s'", name))
|
||||
fail(w, fmt.Errorf("error url unescaping provisioner name '%s'", name))
|
||||
return
|
||||
}
|
||||
|
||||
p, err := h.Auth.LoadProvisionerByName(provisionerName)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
fail(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov, ok := p.(*provisioner.SCEP)
|
||||
if !ok {
|
||||
api.WriteError(w, errors.New("provisioner must be of type SCEP"))
|
||||
fail(w, errors.New("provisioner must be of type SCEP"))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -212,59 +211,59 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
|||
}
|
||||
|
||||
// GetCACert returns the CA certificates in a SCEP response
|
||||
func (h *Handler) GetCACert(ctx context.Context) (SCEPResponse, error) {
|
||||
func (h *handler) GetCACert(ctx context.Context) (response, error) {
|
||||
|
||||
certs, err := h.Auth.GetCACertificates(ctx)
|
||||
if err != nil {
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
return SCEPResponse{}, errors.New("missing CA cert")
|
||||
return response{}, errors.New("missing CA cert")
|
||||
}
|
||||
|
||||
response := SCEPResponse{
|
||||
res := response{
|
||||
Operation: opnGetCACert,
|
||||
CACertNum: len(certs),
|
||||
}
|
||||
|
||||
if len(certs) == 1 {
|
||||
response.Data = certs[0].Raw
|
||||
res.Data = certs[0].Raw
|
||||
} else {
|
||||
// create degenerate pkcs7 certificate structure, according to
|
||||
// https://tools.ietf.org/html/rfc8894#section-4.2.1.2, because
|
||||
// not signed or encrypted data has to be returned.
|
||||
data, err := microscep.DegenerateCertificates(certs)
|
||||
if err != nil {
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
response.Data = data
|
||||
res.Data = data
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetCACaps returns the CA capabilities in a SCEP response
|
||||
func (h *Handler) GetCACaps(ctx context.Context) (SCEPResponse, error) {
|
||||
func (h *handler) GetCACaps(ctx context.Context) (response, error) {
|
||||
|
||||
caps := h.Auth.GetCACaps(ctx)
|
||||
|
||||
response := SCEPResponse{
|
||||
res := response{
|
||||
Operation: opnGetCACaps,
|
||||
Data: formatCapabilities(caps),
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// PKIOperation performs PKI operations and returns a SCEP response
|
||||
func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPResponse, error) {
|
||||
func (h *handler) PKIOperation(ctx context.Context, req request) (response, error) {
|
||||
|
||||
// parse the message using microscep implementation
|
||||
microMsg, err := microscep.ParsePKIMessage(request.Message)
|
||||
microMsg, err := microscep.ParsePKIMessage(req.Message)
|
||||
if err != nil {
|
||||
// return the error, because we can't use the msg for creating a CertRep
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
|
||||
// this is essentially doing the same as microscep.ParsePKIMessage, but
|
||||
|
@ -272,7 +271,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
|
|||
// wrapper for the microscep implementation.
|
||||
p7, err := pkcs7.Parse(microMsg.Raw)
|
||||
if err != nil {
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
|
||||
// copy over properties to our internal PKIMessage
|
||||
|
@ -285,7 +284,7 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
|
|||
}
|
||||
|
||||
if err := h.Auth.DecryptPKIEnvelope(ctx, msg); err != nil {
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
|
||||
// NOTE: at this point we have sufficient information for returning nicely signed CertReps
|
||||
|
@ -317,61 +316,56 @@ func (h *Handler) PKIOperation(ctx context.Context, request SCEPRequest) (SCEPRe
|
|||
|
||||
certRep, err := h.Auth.SignCSR(ctx, csr, msg)
|
||||
if err != nil {
|
||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, errors.Wrap(err, "error when signing new certificate"))
|
||||
return h.createFailureResponse(ctx, csr, msg, microscep.BadRequest, fmt.Errorf("error when signing new certificate: %w", err))
|
||||
}
|
||||
|
||||
response := SCEPResponse{
|
||||
res := response{
|
||||
Operation: opnPKIOperation,
|
||||
Data: certRep.Raw,
|
||||
Certificate: certRep.Certificate,
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func formatCapabilities(caps []string) []byte {
|
||||
return []byte(strings.Join(caps, "\r\n"))
|
||||
}
|
||||
|
||||
// writeSCEPResponse writes a SCEP response back to the SCEP client.
|
||||
func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) {
|
||||
// writeResponse writes a SCEP response back to the SCEP client.
|
||||
func writeResponse(w http.ResponseWriter, res response) {
|
||||
|
||||
if response.Error != nil {
|
||||
log.Error(w, response.Error)
|
||||
if res.Error != nil {
|
||||
log.Error(w, res.Error)
|
||||
}
|
||||
|
||||
if response.Certificate != nil {
|
||||
api.LogCertificate(w, response.Certificate)
|
||||
if res.Certificate != nil {
|
||||
api.LogCertificate(w, res.Certificate)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentHeader(response))
|
||||
_, err := w.Write(response.Data)
|
||||
if err != nil {
|
||||
writeError(w, errors.Wrap(err, "error when writing scep response")) // This could end up as an error again
|
||||
}
|
||||
w.Header().Set("Content-Type", contentHeader(res))
|
||||
_, _ = w.Write(res.Data)
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, err error) {
|
||||
scepError := &scep.Error{
|
||||
Message: err.Error(),
|
||||
Status: http.StatusInternalServerError, // TODO: make this a param?
|
||||
}
|
||||
api.WriteError(w, scepError)
|
||||
func fail(w http.ResponseWriter, err error) {
|
||||
log.Error(w, err)
|
||||
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func (h *Handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (SCEPResponse, error) {
|
||||
func (h *handler) createFailureResponse(ctx context.Context, csr *x509.CertificateRequest, msg *scep.PKIMessage, info microscep.FailInfo, failError error) (response, error) {
|
||||
certRepMsg, err := h.Auth.CreateFailureResponse(ctx, csr, msg, scep.FailInfoName(info), failError.Error())
|
||||
if err != nil {
|
||||
return SCEPResponse{}, err
|
||||
return response{}, err
|
||||
}
|
||||
return SCEPResponse{
|
||||
return response{
|
||||
Operation: opnPKIOperation,
|
||||
Data: certRepMsg.Raw,
|
||||
Error: failError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func contentHeader(r SCEPResponse) string {
|
||||
func contentHeader(r response) string {
|
||||
switch r.Operation {
|
||||
case opnGetCACert:
|
||||
if r.CACertNum > 1 {
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
package scep
|
||||
|
||||
// Error is an SCEP error type
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Status int `json:"-"`
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
Loading…
Reference in a new issue