forked from TrueCloudLab/certificates
Merge pull request #207 from smallstep/add-context
Add context to ssh methods
This commit is contained in:
commit
ee1c8dd0cd
14 changed files with 104 additions and 107 deletions
|
@ -550,8 +550,6 @@ type mockAuthority struct {
|
||||||
getTLSOptions func() *tlsutil.TLSOptions
|
getTLSOptions func() *tlsutil.TLSOptions
|
||||||
root func(shasum string) (*x509.Certificate, error)
|
root func(shasum string) (*x509.Certificate, error)
|
||||||
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||||
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
|
||||||
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
|
||||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||||
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
||||||
|
@ -560,14 +558,16 @@ type mockAuthority struct {
|
||||||
getEncryptedKey func(kid string) (string, error)
|
getEncryptedKey func(kid string) (string, error)
|
||||||
getRoots func() ([]*x509.Certificate, error)
|
getRoots func() ([]*x509.Certificate, error)
|
||||||
getFederation func() ([]*x509.Certificate, error)
|
getFederation func() ([]*x509.Certificate, error)
|
||||||
renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error)
|
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error)
|
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
getSSHRoots func() (*authority.SSHKeys, error)
|
rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
getSSHFederation func() (*authority.SSHKeys, error)
|
getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
|
||||||
getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error)
|
getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error)
|
||||||
|
getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error)
|
||||||
|
getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||||
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
|
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
|
||||||
getSSHBastion func(user string, hostname string) (*authority.Bastion, error)
|
getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||||
version func() authority.Version
|
version func() authority.Version
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -604,20 +604,6 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
|
||||||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
|
||||||
if m.signSSH != nil {
|
|
||||||
return m.signSSH(key, opts, signOpts...)
|
|
||||||
}
|
|
||||||
return m.ret1.(*ssh.Certificate), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
|
||||||
if m.signSSHAddUser != nil {
|
|
||||||
return m.signSSHAddUser(key, cert)
|
|
||||||
}
|
|
||||||
return m.ret1.(*ssh.Certificate), m.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||||
if m.renew != nil {
|
if m.renew != nil {
|
||||||
return m.renew(cert)
|
return m.renew(cert)
|
||||||
|
@ -674,44 +660,58 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
||||||
return m.ret1.([]*x509.Certificate), m.err
|
return m.ret1.([]*x509.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) {
|
func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
|
if m.signSSH != nil {
|
||||||
|
return m.signSSH(ctx, key, opts, signOpts...)
|
||||||
|
}
|
||||||
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
|
if m.signSSHAddUser != nil {
|
||||||
|
return m.signSSHAddUser(ctx, key, cert)
|
||||||
|
}
|
||||||
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
if m.renewSSH != nil {
|
if m.renewSSH != nil {
|
||||||
return m.renewSSH(cert)
|
return m.renewSSH(ctx, cert)
|
||||||
}
|
}
|
||||||
return m.ret1.(*ssh.Certificate), m.err
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
if m.rekeySSH != nil {
|
if m.rekeySSH != nil {
|
||||||
return m.rekeySSH(cert, key, signOpts...)
|
return m.rekeySSH(ctx, cert, key, signOpts...)
|
||||||
}
|
}
|
||||||
return m.ret1.(*ssh.Certificate), m.err
|
return m.ret1.(*ssh.Certificate), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
|
func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||||
if m.getSSHHosts != nil {
|
if m.getSSHHosts != nil {
|
||||||
return m.getSSHHosts(cert)
|
return m.getSSHHosts(ctx, cert)
|
||||||
}
|
}
|
||||||
return m.ret1.([]sshutil.Host), m.err
|
return m.ret1.([]sshutil.Host), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) {
|
func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
if m.getSSHRoots != nil {
|
if m.getSSHRoots != nil {
|
||||||
return m.getSSHRoots()
|
return m.getSSHRoots(ctx)
|
||||||
}
|
}
|
||||||
return m.ret1.(*authority.SSHKeys), m.err
|
return m.ret1.(*authority.SSHKeys), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) {
|
func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
if m.getSSHFederation != nil {
|
if m.getSSHFederation != nil {
|
||||||
return m.getSSHFederation()
|
return m.getSSHFederation(ctx)
|
||||||
}
|
}
|
||||||
return m.ret1.(*authority.SSHKeys), m.err
|
return m.ret1.(*authority.SSHKeys), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
|
func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
if m.getSSHConfig != nil {
|
if m.getSSHConfig != nil {
|
||||||
return m.getSSHConfig(typ, data)
|
return m.getSSHConfig(ctx, typ, data)
|
||||||
}
|
}
|
||||||
return m.ret1.([]templates.Output), m.err
|
return m.ret1.([]templates.Output), m.err
|
||||||
}
|
}
|
||||||
|
@ -723,9 +723,9 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
|
||||||
return m.ret1.(bool), m.err
|
return m.ret1.(bool), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) {
|
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
|
||||||
if m.getSSHBastion != nil {
|
if m.getSSHBastion != nil {
|
||||||
return m.getSSHBastion(user, hostname)
|
return m.getSSHBastion(ctx, user, hostname)
|
||||||
}
|
}
|
||||||
return m.ret1.(*authority.Bastion), m.err
|
return m.ret1.(*authority.Bastion), m.err
|
||||||
}
|
}
|
||||||
|
|
36
api/ssh.go
36
api/ssh.go
|
@ -19,16 +19,16 @@ import (
|
||||||
|
|
||||||
// SSHAuthority is the interface implemented by a SSH CA authority.
|
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||||
type SSHAuthority interface {
|
type SSHAuthority interface {
|
||||||
SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error)
|
RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||||
SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
GetSSHRoots() (*authority.SSHKeys, error)
|
GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error)
|
||||||
GetSSHFederation() (*authority.SSHKeys, error)
|
GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error)
|
||||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||||
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
|
CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
|
||||||
GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error)
|
GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
|
||||||
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
|
GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHSignRequest is the request body of an SSH certificate request.
|
// SSHSignRequest is the request body of an SSH certificate request.
|
||||||
|
@ -282,14 +282,14 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
ValidAfter: body.ValidAfter,
|
ValidAfter: body.ValidAfter,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
WriteError(w, errs.UnauthorizedErr(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...)
|
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
@ -297,7 +297,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var addUserCertificate *SSHCertificate
|
var addUserCertificate *SSHCertificate
|
||||||
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 {
|
||||||
addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert)
|
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)),
|
NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx := authority.NewContextWithSkipTokenReuse(context.Background())
|
ctx := authority.NewContextWithSkipTokenReuse(r.Context())
|
||||||
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -341,7 +341,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||||
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host
|
||||||
// certificates.
|
// certificates.
|
||||||
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHRoots()
|
keys, err := h.Authority.GetSSHRoots(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -366,7 +366,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
|
||||||
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
// SSHFederation is an HTTP handler that returns the federated SSH public keys
|
||||||
// for user and host certificates.
|
// for user and host certificates.
|
||||||
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
||||||
keys, err := h.Authority.GetSSHFederation()
|
keys, err := h.Authority.GetSSHFederation(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -401,7 +401,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := h.Authority.GetSSHConfig(body.Type, body.Data)
|
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -450,7 +450,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
cert = r.TLS.PeerCertificates[0]
|
cert = r.TLS.PeerCertificates[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
hosts, err := h.Authority.GetSSHHosts(cert)
|
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
@ -472,7 +472,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname)
|
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -56,7 +55,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
|
||||||
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
signOpts, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
WriteError(w, errs.UnauthorizedErr(err))
|
||||||
|
@ -67,7 +66,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...)
|
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -46,7 +45,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
|
||||||
_, err := h.Authority.Authorize(ctx, body.OTT)
|
_, err := h.Authority.Authorize(ctx, body.OTT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.UnauthorizedErr(err))
|
WriteError(w, errs.UnauthorizedErr(err))
|
||||||
|
@ -57,7 +56,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||||
WriteError(w, errs.InternalServerErr(err))
|
WriteError(w, errs.InternalServerErr(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
newCert, err := h.Authority.RenewSSH(oldCert)
|
newCert, err := h.Authority.RenewSSH(ctx, oldCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, errs.ForbiddenErr(err))
|
WriteError(w, errs.ForbiddenErr(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/authority"
|
"github.com/smallstep/certificates/authority"
|
||||||
|
@ -65,7 +64,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||||
PassiveOnly: body.Passive,
|
PassiveOnly: body.Passive,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod)
|
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
|
||||||
// A token indicates that we are using the api via a provisioner token,
|
// A token indicates that we are using the api via a provisioner token,
|
||||||
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
// otherwise it is assumed that the certificate is revoking itself over mTLS.
|
||||||
logOtt(w, body.OTT)
|
logOtt(w, body.OTT)
|
||||||
|
|
|
@ -319,10 +319,10 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
return []provisioner.SignOption{}, tt.authErr
|
return []provisioner.SignOption{}, tt.authErr
|
||||||
},
|
},
|
||||||
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
return tt.signCert, tt.signErr
|
return tt.signCert, tt.signErr
|
||||||
},
|
},
|
||||||
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
return tt.addUserCert, tt.addUserErr
|
return tt.addUserCert, tt.addUserErr
|
||||||
},
|
},
|
||||||
sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||||
|
@ -379,7 +379,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
getSSHRoots: func() (*authority.SSHKeys, error) {
|
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
@ -433,7 +433,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
getSSHFederation: func() (*authority.SSHKeys, error) {
|
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
@ -493,7 +493,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) {
|
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
return tt.output, tt.err
|
return tt.output, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
@ -591,7 +591,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) {
|
getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) {
|
||||||
return tt.hosts, tt.err
|
return tt.hosts, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
@ -646,7 +646,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
h := New(&mockAuthority{
|
||||||
getSSHBastion: func(user, hostname string) (*authority.Bastion, error) {
|
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||||
return tt.bastion, tt.bastionErr
|
return tt.bastion, tt.bastionErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
}).(*caHandler)
|
||||||
|
|
|
@ -51,9 +51,9 @@ type Authority struct {
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
// Custom functions
|
// Custom functions
|
||||||
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
|
||||||
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
||||||
sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error)
|
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
|
||||||
getIdentityFunc provisioner.GetIdentityFunc
|
getIdentityFunc provisioner.GetIdentityFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ func (a *Authority) init() error {
|
||||||
// TODO: should we also be combining the ssh federated roots here?
|
// TODO: should we also be combining the ssh federated roots here?
|
||||||
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
|
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
|
||||||
// validate old SSH certificates, unless they are added as federated certs.
|
// validate old SSH certificates, unless they are added as federated certs.
|
||||||
sshKeys, err := a.GetSSHRoots()
|
sshKeys, err := a.GetSSHRoots(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ func WithDatabase(db db.AuthDB) Option {
|
||||||
|
|
||||||
// WithGetIdentityFunc sets a custom function to retrieve the identity from
|
// WithGetIdentityFunc sets a custom function to retrieve the identity from
|
||||||
// an external resource.
|
// an external resource.
|
||||||
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
|
func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.getIdentityFunc = fn
|
a.getIdentityFunc = fn
|
||||||
return nil
|
return nil
|
||||||
|
@ -37,7 +37,7 @@ func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provis
|
||||||
|
|
||||||
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
||||||
// given user-host pair.
|
// given user-host pair.
|
||||||
func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
|
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.sshBastionFunc = fn
|
a.sshBastionFunc = fn
|
||||||
return nil
|
return nil
|
||||||
|
@ -46,7 +46,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
|
||||||
|
|
||||||
// WithSSHGetHosts sets a custom function to get the bastion for a
|
// WithSSHGetHosts sets a custom function to get the bastion for a
|
||||||
// given user-host pair.
|
// given user-host pair.
|
||||||
func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Option {
|
func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.sshGetHostsFunc = fn
|
a.sshGetHostsFunc = fn
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -339,7 +339,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
||||||
|
|
||||||
// Get the identity using either the default identityFunc or one injected
|
// Get the identity using either the default identityFunc or one injected
|
||||||
// externally.
|
// externally.
|
||||||
iden, err := o.getIdentityFunc(o, claims.Email)
|
iden, err := o.getIdentityFunc(ctx, o, claims.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
||||||
}
|
}
|
||||||
|
|
|
@ -485,10 +485,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
||||||
assert.FatalError(t, p4.Init(config))
|
assert.FatalError(t, p4.Init(config))
|
||||||
assert.FatalError(t, p5.Init(config))
|
assert.FatalError(t, p5.Init(config))
|
||||||
|
|
||||||
p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
|
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
||||||
}
|
}
|
||||||
p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
|
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -330,10 +330,10 @@ type Identity struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIdentityFunc is a function that returns an identity.
|
// GetIdentityFunc is a function that returns an identity.
|
||||||
type GetIdentityFunc func(p Interface, email string) (*Identity, error)
|
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
|
||||||
|
|
||||||
// DefaultIdentityFunc return a default identity depending on the provisioner type.
|
// DefaultIdentityFunc return a default identity depending on the provisioner type.
|
||||||
func DefaultIdentityFunc(p Interface, email string) (*Identity, error) {
|
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||||
switch k := p.(type) {
|
switch k := p.(type) {
|
||||||
case *OIDC:
|
case *OIDC:
|
||||||
name := SanitizeSSHUserPrincipal(email)
|
name := SanitizeSSHUserPrincipal(email)
|
||||||
|
|
|
@ -92,7 +92,7 @@ func TestDefaultIdentityFunc(t *testing.T) {
|
||||||
for name, get := range tests {
|
for name, get := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
tc := get(t)
|
tc := get(t)
|
||||||
identity, err := DefaultIdentityFunc(tc.p, tc.email)
|
identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.Equals(t, tc.err.Error(), err.Error())
|
assert.Equals(t, tc.err.Error(), err.Error())
|
||||||
|
|
|
@ -104,7 +104,7 @@ type SSHKeys struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHRoots returns the SSH User and Host public keys.
|
// GetSSHRoots returns the SSH User and Host public keys.
|
||||||
func (a *Authority) GetSSHRoots() (*SSHKeys, error) {
|
func (a *Authority) GetSSHRoots(context.Context) (*SSHKeys, error) {
|
||||||
return &SSHKeys{
|
return &SSHKeys{
|
||||||
HostKeys: a.sshCAHostCerts,
|
HostKeys: a.sshCAHostCerts,
|
||||||
UserKeys: a.sshCAUserCerts,
|
UserKeys: a.sshCAUserCerts,
|
||||||
|
@ -112,7 +112,7 @@ func (a *Authority) GetSSHRoots() (*SSHKeys, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHFederation returns the public keys for federated SSH signers.
|
// GetSSHFederation returns the public keys for federated SSH signers.
|
||||||
func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
|
func (a *Authority) GetSSHFederation(context.Context) (*SSHKeys, error) {
|
||||||
return &SSHKeys{
|
return &SSHKeys{
|
||||||
HostKeys: a.sshCAHostFederatedCerts,
|
HostKeys: a.sshCAHostFederatedCerts,
|
||||||
UserKeys: a.sshCAUserFederatedCerts,
|
UserKeys: a.sshCAUserFederatedCerts,
|
||||||
|
@ -120,7 +120,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
|
// GetSSHConfig returns rendered templates for clients (user) or servers (host).
|
||||||
func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
|
func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
|
if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil {
|
||||||
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
|
return nil, errs.NotFound("getSSHConfig: ssh is not configured")
|
||||||
}
|
}
|
||||||
|
@ -166,9 +166,9 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template
|
||||||
|
|
||||||
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
||||||
// hostname.
|
// hostname.
|
||||||
func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) {
|
func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*Bastion, error) {
|
||||||
if a.sshBastionFunc != nil {
|
if a.sshBastionFunc != nil {
|
||||||
bs, err := a.sshBastionFunc(user, hostname)
|
bs, err := a.sshBastionFunc(ctx, user, hostname)
|
||||||
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
||||||
}
|
}
|
||||||
if a.config.SSH != nil {
|
if a.config.SSH != nil {
|
||||||
|
@ -181,7 +181,7 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignSSH creates a signed SSH certificate with the given public key and options.
|
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||||
func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
var mods []provisioner.SSHCertModifier
|
var mods []provisioner.SSHCertModifier
|
||||||
var validators []provisioner.SSHCertValidator
|
var validators []provisioner.SSHCertValidator
|
||||||
|
|
||||||
|
@ -282,7 +282,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
|
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||||
func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
nonce, err := randutil.ASCII(32)
|
nonce, err := randutil.ASCII(32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH")
|
||||||
|
@ -353,7 +353,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
|
// RekeySSH creates a signed SSH certificate using the old SSH certificate as a template.
|
||||||
func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
var validators []provisioner.SSHCertValidator
|
var validators []provisioner.SSHCertValidator
|
||||||
|
|
||||||
for _, op := range signOpts {
|
for _, op := range signOpts {
|
||||||
|
@ -443,7 +443,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
// SignSSHAddUser signs a certificate that provisions a new user in a server.
|
||||||
func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) {
|
||||||
if a.sshCAUserCertSignKey == nil {
|
if a.sshCAUserCertSignKey == nil {
|
||||||
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled")
|
||||||
}
|
}
|
||||||
|
@ -527,9 +527,9 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHHosts returns a list of valid host principals.
|
// GetSSHHosts returns a list of valid host principals.
|
||||||
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
|
func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||||
if a.sshGetHostsFunc != nil {
|
if a.sshGetHostsFunc != nil {
|
||||||
hosts, err := a.sshGetHostsFunc(cert)
|
hosts, err := a.sshGetHostsFunc(ctx, cert)
|
||||||
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
|
return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts")
|
||||||
}
|
}
|
||||||
hostnames, err := a.db.GetSSHHostPrincipals()
|
hostnames, err := a.db.GetSSHHostPrincipals()
|
||||||
|
|
|
@ -153,7 +153,7 @@ func TestAuthority_SignSSH(t *testing.T) {
|
||||||
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||||
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||||
|
|
||||||
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
|
got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -242,7 +242,7 @@ func TestAuthority_SignSSHAddUser(t *testing.T) {
|
||||||
AddUserPrincipal: tt.fields.addUserPrincipal,
|
AddUserPrincipal: tt.fields.addUserPrincipal,
|
||||||
AddUserCommand: tt.fields.addUserCommand,
|
AddUserCommand: tt.fields.addUserCommand,
|
||||||
}
|
}
|
||||||
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
|
got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -295,7 +295,7 @@ func TestAuthority_GetSSHRoots(t *testing.T) {
|
||||||
a.sshCAUserCerts = tt.fields.sshCAUserCerts
|
a.sshCAUserCerts = tt.fields.sshCAUserCerts
|
||||||
a.sshCAHostCerts = tt.fields.sshCAHostCerts
|
a.sshCAHostCerts = tt.fields.sshCAHostCerts
|
||||||
|
|
||||||
got, err := a.GetSSHRoots()
|
got, err := a.GetSSHRoots(context.Background())
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -337,7 +337,7 @@ func TestAuthority_GetSSHFederation(t *testing.T) {
|
||||||
a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts
|
a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts
|
||||||
a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts
|
a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts
|
||||||
|
|
||||||
got, err := a.GetSSHFederation()
|
got, err := a.GetSSHFederation(context.Background())
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -463,7 +463,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
|
||||||
a.sshCAUserCertSignKey = tt.fields.userSigner
|
a.sshCAUserCertSignKey = tt.fields.userSigner
|
||||||
a.sshCAHostCertSignKey = tt.fields.hostSigner
|
a.sshCAHostCertSignKey = tt.fields.hostSigner
|
||||||
|
|
||||||
got, err := a.GetSSHConfig(tt.args.typ, tt.args.data)
|
got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -614,7 +614,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
||||||
}
|
}
|
||||||
type fields struct {
|
type fields struct {
|
||||||
config *Config
|
config *Config
|
||||||
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
|
||||||
}
|
}
|
||||||
type args struct {
|
type args struct {
|
||||||
user string
|
user string
|
||||||
|
@ -630,8 +630,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
||||||
{"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false},
|
{"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false},
|
||||||
{"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false},
|
{"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false},
|
||||||
{"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false},
|
{"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false},
|
||||||
{"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
|
{"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
|
||||||
{"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
|
{"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
|
||||||
{"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true},
|
{"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -640,7 +640,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
||||||
config: tt.fields.config,
|
config: tt.fields.config,
|
||||||
sshBastionFunc: tt.fields.sshBastionFunc,
|
sshBastionFunc: tt.fields.sshBastionFunc,
|
||||||
}
|
}
|
||||||
got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname)
|
got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -659,7 +659,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||||
a := testAuthority(t)
|
a := testAuthority(t)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
|
getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error)
|
||||||
auth *Authority
|
auth *Authority
|
||||||
cert *x509.Certificate
|
cert *x509.Certificate
|
||||||
cmp func(got []sshutil.Host)
|
cmp func(got []sshutil.Host)
|
||||||
|
@ -669,7 +669,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||||
tests := map[string]func(t *testing.T) *test{
|
tests := map[string]func(t *testing.T) *test{
|
||||||
"fail/getHostsFunc-fail": func(t *testing.T) *test {
|
"fail/getHostsFunc-fail": func(t *testing.T) *test {
|
||||||
return &test{
|
return &test{
|
||||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||||
return nil, errors.New("force")
|
return nil, errors.New("force")
|
||||||
},
|
},
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{},
|
||||||
|
@ -684,7 +684,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &test{
|
return &test{
|
||||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||||
return hosts, nil
|
return hosts, nil
|
||||||
},
|
},
|
||||||
cert: &x509.Certificate{},
|
cert: &x509.Certificate{},
|
||||||
|
@ -732,7 +732,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
auth.sshGetHostsFunc = tc.getHostsFunc
|
auth.sshGetHostsFunc = tc.getHostsFunc
|
||||||
|
|
||||||
hosts, err := auth.GetSSHHosts(tc.cert)
|
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
|
@ -901,7 +901,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
||||||
a.sshCAUserCertSignKey = tc.userSigner
|
a.sshCAUserCertSignKey = tc.userSigner
|
||||||
a.sshCAHostCertSignKey = tc.hostSigner
|
a.sshCAHostCertSignKey = tc.hostSigner
|
||||||
|
|
||||||
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
|
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
|
|
Loading…
Reference in a new issue