Simplify SCEP provisioner context handling

This commit is contained in:
Herman Slatman 2023-06-01 16:22:00 +02:00
parent 8fc3a46387
commit b2bf2c330b
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
4 changed files with 21 additions and 42 deletions

View file

@ -858,8 +858,8 @@ func (a *Authority) IsRevoked(sn string) (bool, error) {
return a.db.IsRevoked(sn) return a.db.IsRevoked(sn)
} }
// requiresSCEPService iterates over the configured provisioners // requiresSCEP iterates over the configured provisioners
// and determines if one of them is a SCEP provisioner. // and determines if at least one of them is a SCEP provisioner.
func (a *Authority) requiresSCEP() bool { func (a *Authority) requiresSCEP() bool {
for _, p := range a.config.AuthorityConfig.Provisioners { for _, p := range a.config.AuthorityConfig.Provisioners {
if p.GetType() == provisioner.TypeSCEP { if p.GetType() == provisioner.TypeSCEP {

View file

@ -221,7 +221,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc {
return return
} }
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) ctx = scep.NewProvisionerContext(ctx, scep.Provisioner(prov))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }

View file

@ -136,10 +136,7 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e
// Using an RA does not seem to exist in https://tools.ietf.org/html/rfc8894, but is mentioned in // Using an RA does not seem to exist in https://tools.ietf.org/html/rfc8894, but is mentioned in
// https://tools.ietf.org/id/draft-nourse-scep-21.html. // https://tools.ietf.org/id/draft-nourse-scep-21.html.
func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) { func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) {
p, err := provisionerFromContext(ctx) p := provisionerFromContext(ctx)
if err != nil {
return
}
// if a provisioner specific RSA decrypter is available, it is returned as // if a provisioner specific RSA decrypter is available, it is returned as
// the first certificate. // the first certificate.
@ -214,10 +211,7 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err
} }
func (a *Authority) selectDecrypter(ctx context.Context) (cert *x509.Certificate, pkey crypto.PrivateKey, err error) { func (a *Authority) selectDecrypter(ctx context.Context) (cert *x509.Certificate, pkey crypto.PrivateKey, err error) {
p, err := provisionerFromContext(ctx) p := provisionerFromContext(ctx)
if err != nil {
return nil, nil, err
}
// return provisioner specific decrypter, if available // return provisioner specific decrypter, if available
if cert, pkey = p.GetDecrypter(); cert != nil && pkey != nil { if cert, pkey = p.GetDecrypter(); cert != nil && pkey != nil {
@ -239,10 +233,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m
// poll for the status. It seems to be similar as what can happen in ACME, so might want to model // poll for the status. It seems to be similar as what can happen in ACME, so might want to model
// the implementation after the one in the ACME authority. Requires storage, etc. // the implementation after the one in the ACME authority. Requires storage, etc.
p, err := provisionerFromContext(ctx) p := provisionerFromContext(ctx)
if err != nil {
return nil, err
}
// check if CSRReqMessage has already been decrypted // check if CSRReqMessage has already been decrypted
if msg.CSRReqMessage.CSR == nil { if msg.CSRReqMessage.CSR == nil {
@ -463,10 +454,7 @@ func (a *Authority) CreateFailureResponse(_ context.Context, _ *x509.Certificate
// GetCACaps returns the CA capabilities // GetCACaps returns the CA capabilities
func (a *Authority) GetCACaps(ctx context.Context) []string { func (a *Authority) GetCACaps(ctx context.Context) []string {
p, err := provisionerFromContext(ctx) p := provisionerFromContext(ctx)
if err != nil {
return defaultCapabilities
}
caps := p.GetCapabilities() caps := p.GetCapabilities()
if len(caps) == 0 { if len(caps) == 0 {
@ -483,9 +471,6 @@ func (a *Authority) GetCACaps(ctx context.Context) []string {
} }
func (a *Authority) ValidateChallenge(ctx context.Context, challenge, transactionID string) error { func (a *Authority) ValidateChallenge(ctx context.Context, challenge, transactionID string) error {
p, err := provisionerFromContext(ctx) p := provisionerFromContext(ctx)
if err != nil {
return err
}
return p.ValidateChallenge(ctx, challenge, transactionID) return p.ValidateChallenge(ctx, challenge, transactionID)
} }

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"errors"
"time" "time"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -24,25 +23,20 @@ type Provisioner interface {
ValidateChallenge(ctx context.Context, challenge, transactionID string) error ValidateChallenge(ctx context.Context, challenge, transactionID string) error
} }
// ContextKey is the key type for storing and searching for SCEP request // provisionerKey is the key type for storing and searching a
// essentials in the context of a request. // SCEP provisioner in the context.
type ContextKey string type provisionerKey struct{}
const (
// ProvisionerContextKey provisioner key
ProvisionerContextKey = ContextKey("provisioner")
)
// provisionerFromContext searches the context for a SCEP provisioner. // provisionerFromContext searches the context for a SCEP provisioner.
// Returns the provisioner or an error. // Returns the provisioner or panics if no SCEP provisioner is found.
func provisionerFromContext(ctx context.Context) (Provisioner, error) { func provisionerFromContext(ctx context.Context) Provisioner {
val := ctx.Value(ProvisionerContextKey) p, ok := ctx.Value(provisionerKey{}).(Provisioner)
if val == nil { if !ok {
return nil, errors.New("provisioner expected in request context") panic("SCEP provisioner expected in request context")
} }
p, ok := val.(Provisioner) return p
if !ok || p == nil { }
return nil, errors.New("provisioner in context is not a SCEP provisioner")
} func NewProvisionerContext(ctx context.Context, p Provisioner) context.Context {
return p, nil return context.WithValue(ctx, provisionerKey{}, p)
} }