From c49a9d5e33bf2092335fc6689a41efca9612a466 Mon Sep 17 00:00:00 2001
From: Mariano Cano <mariano@smallstep.com>
Date: Tue, 10 Mar 2020 19:01:45 -0700
Subject: [PATCH] Add context parameter to all SSH methods.

---
 api/ssh.go                           | 36 ++++++++++++++--------------
 api/sshRekey.go                      |  5 ++--
 api/sshRenew.go                      |  5 ++--
 api/sshRevoke.go                     |  3 +--
 authority/authority.go               |  6 ++---
 authority/options.go                 |  6 ++---
 authority/provisioner/oidc.go        |  2 +-
 authority/provisioner/provisioner.go |  4 ++--
 authority/ssh.go                     | 22 ++++++++---------
 9 files changed, 43 insertions(+), 46 deletions(-)

diff --git a/api/ssh.go b/api/ssh.go
index f0b090d1..fc598502 100644
--- a/api/ssh.go
+++ b/api/ssh.go
@@ -19,16 +19,16 @@ import (
 
 // SSHAuthority is the interface implemented by a SSH CA authority.
 type SSHAuthority interface {
-	SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
-	RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error)
-	RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
-	SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
-	GetSSHRoots() (*authority.SSHKeys, error)
-	GetSSHFederation() (*authority.SSHKeys, error)
-	GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
+	SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
+	RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
+	RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
+	SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
+	GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error)
+	GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error)
+	GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
 	CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
-	GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error)
-	GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
+	GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
+	GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
 }
 
 // SSHSignRequest is the request body of an SSH certificate request.
@@ -282,14 +282,14 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
 		ValidAfter:  body.ValidAfter,
 	}
 
-	ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
+	ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
 	signOpts, err := h.Authority.Authorize(ctx, body.OTT)
 	if err != nil {
 		WriteError(w, errs.UnauthorizedErr(err))
 		return
 	}
 
-	cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
+	cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
 	if err != nil {
 		WriteError(w, errs.ForbiddenErr(err))
 		return
@@ -297,7 +297,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
 
 	var addUserCertificate *SSHCertificate
 	if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
-		addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
+		addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
 		if err != nil {
 			WriteError(w, errs.ForbiddenErr(err))
 			return
@@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
 				NotAfter:  provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)),
 			}
 		}
-		ctx := authority.NewContextWithSkipTokenReuse(context.Background())
+		ctx := authority.NewContextWithSkipTokenReuse(r.Context())
 		ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
 		signOpts, err := h.Authority.Authorize(ctx, body.OTT)
 		if err != nil {
@@ -341,7 +341,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
 // SSHRoots is an HTTP handler that returns the SSH public keys for user and host
 // certificates.
 func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
-	keys, err := h.Authority.GetSSHRoots()
+	keys, err := h.Authority.GetSSHRoots(r.Context())
 	if err != nil {
 		WriteError(w, errs.InternalServerErr(err))
 		return
@@ -366,7 +366,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
 // SSHFederation is an HTTP handler that returns the federated SSH public keys
 // for user and host certificates.
 func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
-	keys, err := h.Authority.GetSSHFederation()
+	keys, err := h.Authority.GetSSHFederation(r.Context())
 	if err != nil {
 		WriteError(w, errs.InternalServerErr(err))
 		return
@@ -401,7 +401,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
+	ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
 	if err != nil {
 		WriteError(w, errs.InternalServerErr(err))
 		return
@@ -450,7 +450,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
 		cert = r.TLS.PeerCertificates[0]
 	}
 
-	hosts, err := h.Authority.GetSSHHosts(cert)
+	hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
 	if err != nil {
 		WriteError(w, errs.InternalServerErr(err))
 		return
@@ -472,7 +472,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
+	bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
 	if err != nil {
 		WriteError(w, errs.InternalServerErr(err))
 		return
diff --git a/api/sshRekey.go b/api/sshRekey.go
index a5cc1f06..285422f9 100644
--- a/api/sshRekey.go
+++ b/api/sshRekey.go
@@ -1,7 +1,6 @@
 package api
 
 import (
-	"context"
 	"net/http"
 
 	"github.com/pkg/errors"
@@ -56,7 +55,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
+	ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
 	signOpts, err := h.Authority.Authorize(ctx, body.OTT)
 	if err != nil {
 		WriteError(w, errs.UnauthorizedErr(err))
@@ -67,7 +66,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
 		WriteError(w, errs.InternalServerErr(err))
 	}
 
-	newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
+	newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
 	if err != nil {
 		WriteError(w, errs.ForbiddenErr(err))
 		return
diff --git a/api/sshRenew.go b/api/sshRenew.go
index 11a9d8e8..048c83a3 100644
--- a/api/sshRenew.go
+++ b/api/sshRenew.go
@@ -1,7 +1,6 @@
 package api
 
 import (
-	"context"
 	"net/http"
 
 	"github.com/pkg/errors"
@@ -46,7 +45,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
+	ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
 	_, err := h.Authority.Authorize(ctx, body.OTT)
 	if err != nil {
 		WriteError(w, errs.UnauthorizedErr(err))
@@ -57,7 +56,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
 		WriteError(w, errs.InternalServerErr(err))
 	}
 
-	newCert, err := h.Authority.RenewSSH(oldCert)
+	newCert, err := h.Authority.RenewSSH(ctx, oldCert)
 	if err != nil {
 		WriteError(w, errs.ForbiddenErr(err))
 		return
diff --git a/api/sshRevoke.go b/api/sshRevoke.go
index b8d1dadd..5a1c858c 100644
--- a/api/sshRevoke.go
+++ b/api/sshRevoke.go
@@ -1,7 +1,6 @@
 package api
 
 import (
-	"context"
 	"net/http"
 
 	"github.com/smallstep/certificates/authority"
@@ -65,7 +64,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
 		PassiveOnly: body.Passive,
 	}
 
-	ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod)
+	ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
 	// A token indicates that we are using the api via a provisioner token,
 	// otherwise it is assumed that the certificate is revoking itself over mTLS.
 	logOtt(w, body.OTT)
diff --git a/authority/authority.go b/authority/authority.go
index 0730fe5e..8cf4cfc1 100644
--- a/authority/authority.go
+++ b/authority/authority.go
@@ -51,9 +51,9 @@ type Authority struct {
 	startTime time.Time
 
 	// Custom functions
-	sshBastionFunc   func(user, hostname string) (*Bastion, error)
+	sshBastionFunc   func(ctx context.Context, user, hostname string) (*Bastion, error)
 	sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
-	sshGetHostsFunc  func(cert *x509.Certificate) ([]sshutil.Host, error)
+	sshGetHostsFunc  func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
 	getIdentityFunc  provisioner.GetIdentityFunc
 }
 
@@ -227,7 +227,7 @@ func (a *Authority) init() error {
 	// TODO: should we also be combining the ssh federated roots here?
 	// If we rotate ssh roots keys, sshpop provisioner will lose ability to
 	// validate old SSH certificates, unless they are added as federated certs.
-	sshKeys, err := a.GetSSHRoots()
+	sshKeys, err := a.GetSSHRoots(context.Background())
 	if err != nil {
 		return err
 	}
diff --git a/authority/options.go b/authority/options.go
index 2d655a2b..04cd7bef 100644
--- a/authority/options.go
+++ b/authority/options.go
@@ -28,7 +28,7 @@ func WithDatabase(db db.AuthDB) Option {
 
 // WithGetIdentityFunc sets a custom function to retrieve the identity from
 // an external resource.
-func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
+func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
 	return func(a *Authority) error {
 		a.getIdentityFunc = fn
 		return nil
@@ -37,7 +37,7 @@ func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provis
 
 // WithSSHBastionFunc sets a custom function to get the bastion for a
 // given user-host pair.
-func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
+func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option {
 	return func(a *Authority) error {
 		a.sshBastionFunc = fn
 		return nil
@@ -46,7 +46,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
 
 // WithSSHGetHosts sets a custom function to get the bastion for a
 // given user-host pair.
-func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Option {
+func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)) Option {
 	return func(a *Authority) error {
 		a.sshGetHostsFunc = fn
 		return nil
diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go
index 0b5448af..f90d96b5 100644
--- a/authority/provisioner/oidc.go
+++ b/authority/provisioner/oidc.go
@@ -339,7 +339,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
 
 	// Get the identity using either the default identityFunc or one injected
 	// externally.
-	iden, err := o.getIdentityFunc(o, claims.Email)
+	iden, err := o.getIdentityFunc(ctx, o, claims.Email)
 	if err != nil {
 		return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
 	}
diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go
index 4cf1c5f9..885e7cf0 100644
--- a/authority/provisioner/provisioner.go
+++ b/authority/provisioner/provisioner.go
@@ -330,10 +330,10 @@ type Identity struct {
 }
 
 // GetIdentityFunc is a function that returns an identity.
-type GetIdentityFunc func(p Interface, email string) (*Identity, error)
+type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
 
 // DefaultIdentityFunc return a default identity depending on the provisioner type.
-func DefaultIdentityFunc(p Interface, email string) (*Identity, error) {
+func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
 	switch k := p.(type) {
 	case *OIDC:
 		name := SanitizeSSHUserPrincipal(email)
diff --git a/authority/ssh.go b/authority/ssh.go
index f47447d5..d28205e2 100644
--- a/authority/ssh.go
+++ b/authority/ssh.go
@@ -104,7 +104,7 @@ type SSHKeys struct {
 }
 
 // GetSSHRoots returns the SSH User and Host public keys.
-func (a *Authority) GetSSHRoots() (*SSHKeys, error) {
+func (a *Authority) GetSSHRoots(context.Context) (*SSHKeys, error) {
 	return &SSHKeys{
 		HostKeys: a.sshCAHostCerts,
 		UserKeys: a.sshCAUserCerts,
@@ -112,7 +112,7 @@ func (a *Authority) GetSSHRoots() (*SSHKeys, error) {
 }
 
 // GetSSHFederation returns the public keys for federated SSH signers.
-func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
+func (a *Authority) GetSSHFederation(context.Context) (*SSHKeys, error) {
 	return &SSHKeys{
 		HostKeys: a.sshCAHostFederatedCerts,
 		UserKeys: a.sshCAUserFederatedCerts,
@@ -120,7 +120,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
 }
 
 // GetSSHConfig returns rendered templates for clients (user) or servers (host).
-func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
+func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
 	if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
 		return nil, errs.NotFound("getSSHConfig: ssh is not configured")
 	}
@@ -166,9 +166,9 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
 
 // GetSSHBastion returns the bastion configuration, for the given pair user,
 // hostname.
-func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) {
+func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*Bastion, error) {
 	if a.sshBastionFunc != nil {
-		bs, err := a.sshBastionFunc(user, hostname)
+		bs, err := a.sshBastionFunc(ctx, user, hostname)
 		return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
 	}
 	if a.config.SSH != nil {
@@ -181,7 +181,7 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
 }
 
 // SignSSH creates a signed SSH certificate with the given public key and options.
-func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
+func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
 	var mods []provisioner.SSHCertModifier
 	var validators []provisioner.SSHCertValidator
 
@@ -282,7 +282,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
 }
 
 // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
-func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
+func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
 	nonce, err := randutil.ASCII(32)
 	if err != nil {
 		return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
@@ -353,7 +353,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
 }
 
 // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
-func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
+func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
 	var validators []provisioner.SSHCertValidator
 
 	for _, op := range signOpts {
@@ -443,7 +443,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
 }
 
 // SignSSHAddUser signs a certificate that provisions a new user in a server.
-func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
+func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
 	if a.sshCAUserCertSignKey == nil {
 		return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
 	}
@@ -527,9 +527,9 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
 }
 
 // GetSSHHosts returns a list of valid host principals.
-func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
+func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
 	if a.sshGetHostsFunc != nil {
-		hosts, err := a.sshGetHostsFunc(cert)
+		hosts, err := a.sshGetHostsFunc(ctx, cert)
 		return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
 	}
 	hostnames, err := a.db.GetSSHHostPrincipals()