Simplify SCEP provisioner context handling

This commit is contained in:
Herman Slatman 2023-06-01 16:22:00 +02:00
parent 8fc3a46387
commit b2bf2c330b
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
4 changed files with 21 additions and 42 deletions

View file

@ -858,8 +858,8 @@ func (a *Authority) IsRevoked(sn string) (bool, error) {
return a.db.IsRevoked(sn)
}
// requiresSCEPService iterates over the configured provisioners
// and determines if one of them is a SCEP provisioner.
// requiresSCEP iterates over the configured provisioners
// and determines if at least one of them is a SCEP provisioner.
func (a *Authority) requiresSCEP() bool {
for _, p := range a.config.AuthorityConfig.Provisioners {
if p.GetType() == provisioner.TypeSCEP {

View file

@ -221,7 +221,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return
}
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
ctx = scep.NewProvisionerContext(ctx, scep.Provisioner(prov))
next(w, r.WithContext(ctx))
}
}

View file

@ -136,10 +136,7 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e
// Using an RA does not seem to exist in https://tools.ietf.org/html/rfc8894, but is mentioned in
// https://tools.ietf.org/id/draft-nourse-scep-21.html.
func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) {
p, err := provisionerFromContext(ctx)
if err != nil {
return
}
p := provisionerFromContext(ctx)
// if a provisioner specific RSA decrypter is available, it is returned as
// the first certificate.
@ -214,10 +211,7 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err
}
func (a *Authority) selectDecrypter(ctx context.Context) (cert *x509.Certificate, pkey crypto.PrivateKey, err error) {
p, err := provisionerFromContext(ctx)
if err != nil {
return nil, nil, err
}
p := provisionerFromContext(ctx)
// return provisioner specific decrypter, if available
if cert, pkey = p.GetDecrypter(); cert != nil && pkey != nil {
@ -239,10 +233,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model
// the implementation after the one in the ACME authority. Requires storage, etc.
p, err := provisionerFromContext(ctx)
if err != nil {
return nil, err
}
p := provisionerFromContext(ctx)
// check if CSRReqMessage has already been decrypted
if msg.CSRReqMessage.CSR == nil {
@ -463,10 +454,7 @@ func (a *Authority) CreateFailureResponse(_ context.Context, _ *x509.Certificate
// GetCACaps returns the CA capabilities
func (a *Authority) GetCACaps(ctx context.Context) []string {
p, err := provisionerFromContext(ctx)
if err != nil {
return defaultCapabilities
}
p := provisionerFromContext(ctx)
caps := p.GetCapabilities()
if len(caps) == 0 {
@ -483,9 +471,6 @@ func (a *Authority) GetCACaps(ctx context.Context) []string {
}
func (a *Authority) ValidateChallenge(ctx context.Context, challenge, transactionID string) error {
p, err := provisionerFromContext(ctx)
if err != nil {
return err
}
p := provisionerFromContext(ctx)
return p.ValidateChallenge(ctx, challenge, transactionID)
}

View file

@ -4,7 +4,6 @@ import (
"context"
"crypto"
"crypto/x509"
"errors"
"time"
"github.com/smallstep/certificates/authority/provisioner"
@ -24,25 +23,20 @@ type Provisioner interface {
ValidateChallenge(ctx context.Context, challenge, transactionID string) error
}
// ContextKey is the key type for storing and searching for SCEP request
// essentials in the context of a request.
type ContextKey string
const (
// ProvisionerContextKey provisioner key
ProvisionerContextKey = ContextKey("provisioner")
)
// provisionerKey is the key type for storing and searching a
// SCEP provisioner in the context.
type provisionerKey struct{}
// provisionerFromContext searches the context for a SCEP provisioner.
// Returns the provisioner or an error.
func provisionerFromContext(ctx context.Context) (Provisioner, error) {
val := ctx.Value(ProvisionerContextKey)
if val == nil {
return nil, errors.New("provisioner expected in request context")
// Returns the provisioner or panics if no SCEP provisioner is found.
func provisionerFromContext(ctx context.Context) Provisioner {
p, ok := ctx.Value(provisionerKey{}).(Provisioner)
if !ok {
panic("SCEP provisioner expected in request context")
}
p, ok := val.(Provisioner)
if !ok || p == nil {
return nil, errors.New("provisioner in context is not a SCEP provisioner")
return p
}
return p, nil
func NewProvisionerContext(ctx context.Context, p Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, p)
}