Add context parameter to all SSH methods.

This commit is contained in:
Mariano Cano 2020-03-10 19:01:45 -07:00
parent 164e4ef2d0
commit c49a9d5e33
9 changed files with 43 additions and 46 deletions

View file

@ -19,16 +19,16 @@ import (
// SSHAuthority is the interface implemented by a SSH CA authority. // SSHAuthority is the interface implemented by a SSH CA authority.
type SSHAuthority interface { type SSHAuthority interface {
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
GetSSHRoots() (*authority.SSHKeys, error) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error)
GetSSHFederation() (*authority.SSHKeys, error) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error)
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
GetSSHBastion(user string, hostname string) (*authority.Bastion, error) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
} }
// SSHSignRequest is the request body of an SSH certificate request. // SSHSignRequest is the request body of an SSH certificate request.
@ -282,14 +282,14 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ValidAfter: body.ValidAfter, ValidAfter: body.ValidAfter,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) WriteError(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return
@ -297,7 +297,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var addUserCertificate *SSHCertificate var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return
@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)), NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)),
} }
} }
ctx := authority.NewContextWithSkipTokenReuse(context.Background()) ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
@ -341,7 +341,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host // SSHRoots is an HTTP handler that returns the SSH public keys for user and host
// certificates. // certificates.
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots() keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
return return
@ -366,7 +366,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
// for user and host certificates. // for user and host certificates.
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation() keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
return return
@ -401,7 +401,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
return return
@ -450,7 +450,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
cert = r.TLS.PeerCertificates[0] cert = r.TLS.PeerCertificates[0]
} }
hosts, err := h.Authority.GetSSHHosts(cert) hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
return return
@ -472,7 +472,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return return
} }
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil { if err != nil {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
return return

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -56,7 +55,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) WriteError(w, errs.UnauthorizedErr(err))
@ -67,7 +66,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -46,7 +45,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
return return
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) _, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
WriteError(w, errs.UnauthorizedErr(err)) WriteError(w, errs.UnauthorizedErr(err))
@ -57,7 +56,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
WriteError(w, errs.InternalServerErr(err)) WriteError(w, errs.InternalServerErr(err))
} }
newCert, err := h.Authority.RenewSSH(oldCert) newCert, err := h.Authority.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
WriteError(w, errs.ForbiddenErr(err)) WriteError(w, errs.ForbiddenErr(err))
return return

View file

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
@ -65,7 +64,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive, PassiveOnly: body.Passive,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)

View file

@ -51,9 +51,9 @@ type Authority struct {
startTime time.Time startTime time.Time
// Custom functions // Custom functions
sshBastionFunc func(user, hostname string) (*Bastion, error) sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error) sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
} }
@ -227,7 +227,7 @@ func (a *Authority) init() error {
// TODO: should we also be combining the ssh federated roots here? // TODO: should we also be combining the ssh federated roots here?
// If we rotate ssh roots keys, sshpop provisioner will lose ability to // If we rotate ssh roots keys, sshpop provisioner will lose ability to
// validate old SSH certificates, unless they are added as federated certs. // validate old SSH certificates, unless they are added as federated certs.
sshKeys, err := a.GetSSHRoots() sshKeys, err := a.GetSSHRoots(context.Background())
if err != nil { if err != nil {
return err return err
} }

View file

@ -28,7 +28,7 @@ func WithDatabase(db db.AuthDB) Option {
// WithGetIdentityFunc sets a custom function to retrieve the identity from // WithGetIdentityFunc sets a custom function to retrieve the identity from
// an external resource. // an external resource.
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.getIdentityFunc = fn a.getIdentityFunc = fn
return nil return nil
@ -37,7 +37,7 @@ func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provis
// WithSSHBastionFunc sets a custom function to get the bastion for a // WithSSHBastionFunc sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.sshBastionFunc = fn a.sshBastionFunc = fn
return nil return nil
@ -46,7 +46,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
// WithSSHGetHosts sets a custom function to get the bastion for a // WithSSHGetHosts sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Option { func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.sshGetHostsFunc = fn a.sshGetHostsFunc = fn
return nil return nil

View file

@ -339,7 +339,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. // externally.
iden, err := o.getIdentityFunc(o, claims.Email) iden, err := o.getIdentityFunc(ctx, o, claims.Email)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }

View file

@ -330,10 +330,10 @@ type Identity struct {
} }
// GetIdentityFunc is a function that returns an identity. // GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(p Interface, email string) (*Identity, error) type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner type. // DefaultIdentityFunc return a default identity depending on the provisioner type.
func DefaultIdentityFunc(p Interface, email string) (*Identity, error) { func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) { switch k := p.(type) {
case *OIDC: case *OIDC:
name := SanitizeSSHUserPrincipal(email) name := SanitizeSSHUserPrincipal(email)

View file

@ -104,7 +104,7 @@ type SSHKeys struct {
} }
// GetSSHRoots returns the SSH User and Host public keys. // GetSSHRoots returns the SSH User and Host public keys.
func (a *Authority) GetSSHRoots() (*SSHKeys, error) { func (a *Authority) GetSSHRoots(context.Context) (*SSHKeys, error) {
return &SSHKeys{ return &SSHKeys{
HostKeys: a.sshCAHostCerts, HostKeys: a.sshCAHostCerts,
UserKeys: a.sshCAUserCerts, UserKeys: a.sshCAUserCerts,
@ -112,7 +112,7 @@ func (a *Authority) GetSSHRoots() (*SSHKeys, error) {
} }
// GetSSHFederation returns the public keys for federated SSH signers. // GetSSHFederation returns the public keys for federated SSH signers.
func (a *Authority) GetSSHFederation() (*SSHKeys, error) { func (a *Authority) GetSSHFederation(context.Context) (*SSHKeys, error) {
return &SSHKeys{ return &SSHKeys{
HostKeys: a.sshCAHostFederatedCerts, HostKeys: a.sshCAHostFederatedCerts,
UserKeys: a.sshCAUserFederatedCerts, UserKeys: a.sshCAUserFederatedCerts,
@ -120,7 +120,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
} }
// GetSSHConfig returns rendered templates for clients (user) or servers (host). // GetSSHConfig returns rendered templates for clients (user) or servers (host).
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
return nil, errs.NotFound("getSSHConfig: ssh is not configured") return nil, errs.NotFound("getSSHConfig: ssh is not configured")
} }
@ -166,9 +166,9 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
// GetSSHBastion returns the bastion configuration, for the given pair user, // GetSSHBastion returns the bastion configuration, for the given pair user,
// hostname. // hostname.
func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) { func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*Bastion, error) {
if a.sshBastionFunc != nil { if a.sshBastionFunc != nil {
bs, err := a.sshBastionFunc(user, hostname) bs, err := a.sshBastionFunc(ctx, user, hostname)
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
} }
if a.config.SSH != nil { if a.config.SSH != nil {
@ -181,7 +181,7 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
} }
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var mods []provisioner.SSHCertModifier var mods []provisioner.SSHCertModifier
var validators []provisioner.SSHCertValidator var validators []provisioner.SSHCertValidator
@ -282,7 +282,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
} }
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
nonce, err := randutil.ASCII(32) nonce, err := randutil.ASCII(32)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH") return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
@ -353,7 +353,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
} }
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertValidator var validators []provisioner.SSHCertValidator
for _, op := range signOpts { for _, op := range signOpts {
@ -443,7 +443,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
} }
// SignSSHAddUser signs a certificate that provisions a new user in a server. // SignSSHAddUser signs a certificate that provisions a new user in a server.
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
} }
@ -527,9 +527,9 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
} }
// GetSSHHosts returns a list of valid host principals. // GetSSHHosts returns a list of valid host principals.
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
if a.sshGetHostsFunc != nil { if a.sshGetHostsFunc != nil {
hosts, err := a.sshGetHostsFunc(cert) hosts, err := a.sshGetHostsFunc(ctx, cert)
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
} }
hostnames, err := a.db.GetSSHHostPrincipals() hostnames, err := a.db.GetSSHHostPrincipals()