From 9408d0f24b3f08b037b691cac16dc906b724632c Mon Sep 17 00:00:00 2001
From: Mariano Cano <mariano.cano@gmail.com>
Date: Tue, 2 Aug 2022 19:28:49 -0700
Subject: [PATCH] Send RA provisioner information to the CA

---
 authority/authority.go               |  5 ++++-
 authority/provisioner/jwk.go         | 14 ++++++++++++--
 authority/provisioner/provisioner.go | 19 +++++++++++++++++++
 authority/provisioner/x5c.go         | 11 ++++++++++-
 authority/tls.go                     | 19 ++++++++++++++-----
 cas/apiv1/options.go                 |  4 ++++
 cas/apiv1/requests.go                | 18 +++++++++++++-----
 cas/stepcas/issuer.go                |  8 +++++++-
 cas/stepcas/jwk_issuer.go            | 15 +++++++++++----
 cas/stepcas/stepcas.go               | 19 +++++++++++++++----
 cas/stepcas/x5c_issuer.go            | 15 +++++++++++----
 11 files changed, 120 insertions(+), 27 deletions(-)

diff --git a/authority/authority.go b/authority/authority.go
index 933ceb14..3c74c037 100644
--- a/authority/authority.go
+++ b/authority/authority.go
@@ -312,6 +312,7 @@ func (a *Authority) init() error {
 		if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
 			return errors.New("error initializing linkedca: token authority and configured authority do not match")
 		}
+		a.config.AuthorityConfig.AuthorityID = linkedcaClient.authorityID
 		linkedcaClient.Run()
 	}
 
@@ -322,6 +323,9 @@ func (a *Authority) init() error {
 			options = *a.config.AuthorityConfig.Options
 		}
 
+		// AuthorityID might be empty. It's always available linked CAs/RAs.
+		options.AuthorityID = a.config.AuthorityConfig.AuthorityID
+
 		// Configure linked RA
 		if linkedcaClient != nil && options.CertificateAuthority == "" {
 			conf, err := linkedcaClient.GetConfiguration(ctx)
@@ -357,7 +361,6 @@ func (a *Authority) init() error {
 				return err
 			}
 		}
-
 		a.x509CAService, err = cas.New(ctx, options)
 		if err != nil {
 			return err
diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go
index de592941..966fa155 100644
--- a/authority/provisioner/jwk.go
+++ b/authority/provisioner/jwk.go
@@ -23,7 +23,8 @@ type jwtPayload struct {
 }
 
 type stepPayload struct {
-	SSH *SignSSHOptions `json:"ssh,omitempty"`
+	SSH    *SignSSHOptions `json:"ssh,omitempty"`
+	RAInfo *RAInfo         `json:"ra,omitempty"`
 }
 
 // JWK is the default provisioner, an entity that can sign tokens necessary for
@@ -172,8 +173,17 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
 		return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
 	}
 
+	// Wrap provisioner if the token is an RA token.
+	var self Interface = p
+	if claims.Step != nil && claims.Step.RAInfo != nil {
+		self = &raProvisioner{
+			Interface: p,
+			raInfo:    claims.Step.RAInfo,
+		}
+	}
+
 	return []SignOption{
-		p,
+		self,
 		templateOptions,
 		// modifiers / withOptions
 		newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go
index 0d5cd41a..ba3153a3 100644
--- a/authority/provisioner/provisioner.go
+++ b/authority/provisioner/provisioner.go
@@ -340,6 +340,25 @@ type Permissions struct {
 	CriticalOptions map[string]string `json:"criticalOptions"`
 }
 
+// RAInfo is the information about a provisioner present in RA tokens generated
+// by StepCAS.
+type RAInfo struct {
+	AuthorityID     string `json:"authorityId"`
+	ProvisionerID   string `json:"provisionerId"`
+	ProvisionerType string `json:"provisionerType"`
+}
+
+// raProvisioner wraps a provisioner with RA data.
+type raProvisioner struct {
+	Interface
+	raInfo *RAInfo
+}
+
+// RAInfo returns the RAInfo in the wrapped provisioner.
+func (p *raProvisioner) RAInfo() *RAInfo {
+	return p.raInfo
+}
+
 // MockProvisioner for testing
 type MockProvisioner struct {
 	Mret1, Mret2, Mret3 interface{}
diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go
index b9ae24c5..183d40ea 100644
--- a/authority/provisioner/x5c.go
+++ b/authority/provisioner/x5c.go
@@ -221,8 +221,17 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
 		return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
 	}
 
+	// Wrap provisioner if the token is an RA token.
+	var self Interface = p
+	if claims.Step != nil && claims.Step.RAInfo != nil {
+		self = &raProvisioner{
+			Interface: p,
+			raInfo:    claims.Step.RAInfo,
+		}
+	}
+
 	return []SignOption{
-		p,
+		self,
 		templateOptions,
 		// modifiers / withOptions
 		newProvisionerExtensionOption(TypeX5C, p.Name, ""),
diff --git a/authority/tls.go b/authority/tls.go
index 4c29ca15..0eaace82 100644
--- a/authority/tls.go
+++ b/authority/tls.go
@@ -72,6 +72,10 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
 	}
 }
 
+type raProvisioner interface {
+	RAInfo() *provisioner.RAInfo
+}
+
 // Sign creates a signed certificate from a certificate signing request.
 func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
 	var (
@@ -93,12 +97,16 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
 	signOpts.Backdate = a.config.AuthorityConfig.Backdate.Duration
 
 	var prov provisioner.Interface
+	var pInfo *casapi.ProvisionerInfo
 	for _, op := range extraOpts {
 		switch k := op.(type) {
 		// Capture current provisioner
 		case provisioner.Interface:
 			prov = k
-
+			pInfo = &casapi.ProvisionerInfo{
+				ProvisionerID:   prov.GetID(),
+				ProvisionerType: prov.GetType().String(),
+			}
 		// Adds new options to NewCertificate
 		case provisioner.CertificateOptions:
 			certOptions = append(certOptions, k.Options(signOpts)...)
@@ -221,10 +229,11 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
 	// Sign certificate
 	lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
 	resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
-		Template: leaf,
-		CSR:      csr,
-		Lifetime: lifetime,
-		Backdate: signOpts.Backdate,
+		Template:    leaf,
+		CSR:         csr,
+		Lifetime:    lifetime,
+		Backdate:    signOpts.Backdate,
+		Provisioner: pInfo,
 	})
 	if err != nil {
 		return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Sign; error creating certificate", opts...)
diff --git a/cas/apiv1/options.go b/cas/apiv1/options.go
index f69f933b..01c38efd 100644
--- a/cas/apiv1/options.go
+++ b/cas/apiv1/options.go
@@ -12,6 +12,10 @@ import (
 // Options represents the configuration options used to select and configure the
 // CertificateAuthorityService (CAS) to use.
 type Options struct {
+	// AuthorityID is the the id oc the current authority. This is used on
+	// StepCAS to add information about the origin of a certificate.
+	AuthorityID string `json:"-"`
+
 	// The type of the CAS to use.
 	Type string `json:"type"`
 
diff --git a/cas/apiv1/requests.go b/cas/apiv1/requests.go
index bf745c17..2d1b0784 100644
--- a/cas/apiv1/requests.go
+++ b/cas/apiv1/requests.go
@@ -52,11 +52,19 @@ const (
 
 // CreateCertificateRequest is the request used to sign a new certificate.
 type CreateCertificateRequest struct {
-	Template  *x509.Certificate
-	CSR       *x509.CertificateRequest
-	Lifetime  time.Duration
-	Backdate  time.Duration
-	RequestID string
+	Template    *x509.Certificate
+	CSR         *x509.CertificateRequest
+	Lifetime    time.Duration
+	Backdate    time.Duration
+	RequestID   string
+	Provisioner *ProvisionerInfo
+}
+
+// ProvisionerInfo contains information of the provisioner used to authorize an
+// certificate.
+type ProvisionerInfo struct {
+	ProvisionerID   string
+	ProvisionerType string
 }
 
 // CreateCertificateResponse is the response to a create certificate request.
diff --git a/cas/stepcas/issuer.go b/cas/stepcas/issuer.go
index be395e33..394489bc 100644
--- a/cas/stepcas/issuer.go
+++ b/cas/stepcas/issuer.go
@@ -10,8 +10,14 @@ import (
 	"github.com/smallstep/certificates/cas/apiv1"
 )
 
+type raInfo struct {
+	AuthorityID     string `json:"authorityId,omitempty"`
+	ProvisionerID   string `json:"provisionerId"`
+	ProvisionerType string `json:"provisionerType"`
+}
+
 type stepIssuer interface {
-	SignToken(subject string, sans []string) (string, error)
+	SignToken(subject string, sans []string, info *raInfo) (string, error)
 	RevokeToken(subject string) (string, error)
 	Lifetime(d time.Duration) time.Duration
 }
diff --git a/cas/stepcas/jwk_issuer.go b/cas/stepcas/jwk_issuer.go
index db45ef48..4ef4f541 100644
--- a/cas/stepcas/jwk_issuer.go
+++ b/cas/stepcas/jwk_issuer.go
@@ -53,25 +53,25 @@ func newJWKIssuer(caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssue
 	}, nil
 }
 
-func (i *jwkIssuer) SignToken(subject string, sans []string) (string, error) {
+func (i *jwkIssuer) SignToken(subject string, sans []string, info *raInfo) (string, error) {
 	aud := i.caURL.ResolveReference(&url.URL{
 		Path: "/1.0/sign",
 	}).String()
-	return i.createToken(aud, subject, sans)
+	return i.createToken(aud, subject, sans, info)
 }
 
 func (i *jwkIssuer) RevokeToken(subject string) (string, error) {
 	aud := i.caURL.ResolveReference(&url.URL{
 		Path: "/1.0/revoke",
 	}).String()
-	return i.createToken(aud, subject, nil)
+	return i.createToken(aud, subject, nil, nil)
 }
 
 func (i *jwkIssuer) Lifetime(d time.Duration) time.Duration {
 	return d
 }
 
-func (i *jwkIssuer) createToken(aud, sub string, sans []string) (string, error) {
+func (i *jwkIssuer) createToken(aud, sub string, sans []string, info *raInfo) (string, error) {
 	id, err := randutil.Hex(64) // 256 bits
 	if err != nil {
 		return "", err
@@ -84,6 +84,13 @@ func (i *jwkIssuer) createToken(aud, sub string, sans []string) (string, error)
 			"sans": sans,
 		})
 	}
+	if info != nil {
+		builder = builder.Claims(map[string]interface{}{
+			"step": map[string]interface{}{
+				"ra": info,
+			},
+		})
+	}
 
 	tok, err := builder.CompactSerialize()
 	if err != nil {
diff --git a/cas/stepcas/stepcas.go b/cas/stepcas/stepcas.go
index 9fcbd36c..2ab48c7a 100644
--- a/cas/stepcas/stepcas.go
+++ b/cas/stepcas/stepcas.go
@@ -23,6 +23,7 @@ func init() {
 type StepCAS struct {
 	iss         stepIssuer
 	client      *ca.Client
+	authorityID string
 	fingerprint string
 }
 
@@ -59,6 +60,7 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
 	return &StepCAS{
 		iss:         iss,
 		client:      client,
+		authorityID: opts.AuthorityID,
 		fingerprint: opts.CertificateAuthorityFingerprint,
 	}, nil
 }
@@ -73,7 +75,16 @@ func (s *StepCAS) CreateCertificate(req *apiv1.CreateCertificateRequest) (*apiv1
 		return nil, errors.New("createCertificateRequest `lifetime` cannot be 0")
 	}
 
-	cert, chain, err := s.createCertificate(req.CSR, req.Lifetime)
+	var info *raInfo
+	if p := req.Provisioner; p != nil {
+		info = &raInfo{
+			AuthorityID:     s.authorityID,
+			ProvisionerID:   p.ProvisionerID,
+			ProvisionerType: p.ProvisionerType,
+		}
+	}
+
+	cert, chain, err := s.createCertificate(req.CSR, req.Lifetime, info)
 	if err != nil {
 		return nil, err
 	}
@@ -135,7 +146,7 @@ func (s *StepCAS) GetCertificateAuthority(req *apiv1.GetCertificateAuthorityRequ
 	}, nil
 }
 
-func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration) (*x509.Certificate, []*x509.Certificate, error) {
+func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.Duration, raInfo *raInfo) (*x509.Certificate, []*x509.Certificate, error) {
 	sans := make([]string, 0, len(cr.DNSNames)+len(cr.EmailAddresses)+len(cr.IPAddresses)+len(cr.URIs))
 	sans = append(sans, cr.DNSNames...)
 	sans = append(sans, cr.EmailAddresses...)
@@ -151,11 +162,11 @@ func (s *StepCAS) createCertificate(cr *x509.CertificateRequest, lifetime time.D
 		commonName = sans[0]
 	}
 
-	token, err := s.iss.SignToken(commonName, sans)
+	token, err := s.iss.SignToken(commonName, sans, raInfo)
 	if err != nil {
 		return nil, nil, err
 	}
-
+	println(token)
 	resp, err := s.client.Sign(&api.SignRequest{
 		CsrPEM:   api.CertificateRequest{CertificateRequest: cr},
 		OTT:      token,
diff --git a/cas/stepcas/x5c_issuer.go b/cas/stepcas/x5c_issuer.go
index 76ed9c3c..a005e501 100644
--- a/cas/stepcas/x5c_issuer.go
+++ b/cas/stepcas/x5c_issuer.go
@@ -46,13 +46,13 @@ func newX5CIssuer(caURL *url.URL, cfg *apiv1.CertificateIssuer) (*x5cIssuer, err
 	}, nil
 }
 
-func (i *x5cIssuer) SignToken(subject string, sans []string) (string, error) {
+func (i *x5cIssuer) SignToken(subject string, sans []string, info *raInfo) (string, error) {
 	aud := i.caURL.ResolveReference(&url.URL{
 		Path:     "/1.0/sign",
 		Fragment: "x5c/" + i.issuer,
 	}).String()
 
-	return i.createToken(aud, subject, sans)
+	return i.createToken(aud, subject, sans, info)
 }
 
 func (i *x5cIssuer) RevokeToken(subject string) (string, error) {
@@ -61,7 +61,7 @@ func (i *x5cIssuer) RevokeToken(subject string) (string, error) {
 		Fragment: "x5c/" + i.issuer,
 	}).String()
 
-	return i.createToken(aud, subject, nil)
+	return i.createToken(aud, subject, nil, nil)
 }
 
 func (i *x5cIssuer) Lifetime(d time.Duration) time.Duration {
@@ -76,7 +76,7 @@ func (i *x5cIssuer) Lifetime(d time.Duration) time.Duration {
 	return d
 }
 
-func (i *x5cIssuer) createToken(aud, sub string, sans []string) (string, error) {
+func (i *x5cIssuer) createToken(aud, sub string, sans []string, info *raInfo) (string, error) {
 	signer, err := newX5CSigner(i.certFile, i.keyFile, i.password)
 	if err != nil {
 		return "", err
@@ -94,6 +94,13 @@ func (i *x5cIssuer) createToken(aud, sub string, sans []string) (string, error)
 			"sans": sans,
 		})
 	}
+	if info != nil {
+		builder = builder.Claims(map[string]interface{}{
+			"step": map[string]interface{}{
+				"ra": info,
+			},
+		})
+	}
 
 	tok, err := builder.CompactSerialize()
 	if err != nil {