Add context parameter to all SSH methods.

This commit is contained in:
Mariano Cano 2020-03-10 19:01:45 -07:00
parent 164e4ef2d0
commit c49a9d5e33
9 changed files with 43 additions and 46 deletions

View file

@ -19,16 +19,16 @@ import (
// SSHAuthority is the interface implemented by a SSH CA authority.
type SSHAuthority interface {
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error)
RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
GetSSHRoots() (*authority.SSHKeys, error)
GetSSHFederation() (*authority.SSHKeys, error)
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error)
GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error)
GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error)
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
}
// SSHSignRequest is the request body of an SSH certificate request.
@ -282,14 +282,14 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
ValidAfter: body.ValidAfter,
}
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
WriteError(w, errs.UnauthorizedErr(err))
return
}
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil {
WriteError(w, errs.ForbiddenErr(err))
return
@ -297,7 +297,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil {
WriteError(w, errs.ForbiddenErr(err))
return
@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)),
}
}
ctx := authority.NewContextWithSkipTokenReuse(context.Background())
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
if err != nil {
@ -341,7 +341,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
// certificates.
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots()
keys, err := h.Authority.GetSSHRoots(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
return
@ -366,7 +366,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
// SSHFederation is an HTTP handler that returns the federated SSH public keys
// for user and host certificates.
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation()
keys, err := h.Authority.GetSSHFederation(r.Context())
if err != nil {
WriteError(w, errs.InternalServerErr(err))
return
@ -401,7 +401,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return
}
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
return
@ -450,7 +450,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
cert = r.TLS.PeerCertificates[0]
}
hosts, err := h.Authority.GetSSHHosts(cert)
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
return
@ -472,7 +472,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return
}
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
if err != nil {
WriteError(w, errs.InternalServerErr(err))
return