diff --git a/api/api_test.go b/api/api_test.go index cbaf806f..edefbd47 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -550,8 +550,6 @@ type mockAuthority struct { getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error) @@ -560,14 +558,16 @@ type mockAuthority struct { getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) - renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error) - rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error) - getSSHRoots func() (*authority.SSHKeys, error) - getSSHFederation func() (*authority.SSHKeys, error) - getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) + signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) + getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error) + getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error) + getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) checkSSHHost func(ctx context.Context, principal, token string) (bool, error) - getSSHBastion func(user string, hostname string) (*authority.Bastion, error) + getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) version func() authority.Version } @@ -604,20 +604,6 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } -func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { - if m.signSSH != nil { - return m.signSSH(key, opts, signOpts...) - } - return m.ret1.(*ssh.Certificate), m.err -} - -func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { - if m.signSSHAddUser != nil { - return m.signSSHAddUser(key, cert) - } - return m.ret1.(*ssh.Certificate), m.err -} - func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) @@ -674,44 +660,58 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { return m.ret1.([]*x509.Certificate), m.err } -func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) { +func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + if m.signSSH != nil { + return m.signSSH(ctx, key, opts, signOpts...) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + if m.signSSHAddUser != nil { + return m.signSSHAddUser(ctx, key, cert) + } + return m.ret1.(*ssh.Certificate), m.err +} + +func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) { if m.renewSSH != nil { - return m.renewSSH(cert) + return m.renewSSH(ctx, cert) } return m.ret1.(*ssh.Certificate), m.err } -func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { if m.rekeySSH != nil { - return m.rekeySSH(cert, key, signOpts...) + return m.rekeySSH(ctx, cert, key, signOpts...) } return m.ret1.(*ssh.Certificate), m.err } -func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { +func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { if m.getSSHHosts != nil { - return m.getSSHHosts(cert) + return m.getSSHHosts(ctx, cert) } return m.ret1.([]sshutil.Host), m.err } -func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) { +func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHRoots != nil { - return m.getSSHRoots() + return m.getSSHRoots(ctx) } return m.ret1.(*authority.SSHKeys), m.err } -func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) { +func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) { if m.getSSHFederation != nil { - return m.getSSHFederation() + return m.getSSHFederation(ctx) } return m.ret1.(*authority.SSHKeys), m.err } -func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { +func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { if m.getSSHConfig != nil { - return m.getSSHConfig(typ, data) + return m.getSSHConfig(ctx, typ, data) } return m.ret1.([]templates.Output), m.err } @@ -723,9 +723,9 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin return m.ret1.(bool), m.err } -func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) { +func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) { if m.getSSHBastion != nil { - return m.getSSHBastion(user, hostname) + return m.getSSHBastion(ctx, user, hostname) } return m.ret1.(*authority.Bastion), m.err } diff --git a/api/ssh.go b/api/ssh.go index f0b090d1..fc598502 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -19,16 +19,16 @@ import ( // SSHAuthority is the interface implemented by a SSH CA authority. type SSHAuthority interface { - SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) - RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) - SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - GetSSHRoots() (*authority.SSHKeys, error) - GetSSHFederation() (*authority.SSHKeys, error) - GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) + SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) + RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) + SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) + GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) + GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) + GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) - GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) - GetSSHBastion(user string, hostname string) (*authority.Bastion, error) + GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) + GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) } // SSHSignRequest is the request body of an SSH certificate request. @@ -282,14 +282,14 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { ValidAfter: body.ValidAfter, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.UnauthorizedErr(err)) return } - cert, err := h.Authority.SignSSH(publicKey, opts, signOpts...) + cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return @@ -297,7 +297,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var addUserCertificate *SSHCertificate if addUserPublicKey != nil && cert.CertType == ssh.UserCert && len(cert.ValidPrincipals) == 1 { - addUserCert, err := h.Authority.SignSSHAddUser(addUserPublicKey, cert) + addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return @@ -316,7 +316,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: provisioner.NewTimeDuration(time.Unix(int64(cert.ValidBefore), 0)), } } - ctx := authority.NewContextWithSkipTokenReuse(context.Background()) + ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { @@ -341,7 +341,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHRoots() + keys, err := h.Authority.GetSSHRoots(r.Context()) if err != nil { WriteError(w, errs.InternalServerErr(err)) return @@ -366,7 +366,7 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederation() + keys, err := h.Authority.GetSSHFederation(r.Context()) if err != nil { WriteError(w, errs.InternalServerErr(err)) return @@ -401,7 +401,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - ts, err := h.Authority.GetSSHConfig(body.Type, body.Data) + ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) if err != nil { WriteError(w, errs.InternalServerErr(err)) return @@ -450,7 +450,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { cert = r.TLS.PeerCertificates[0] } - hosts, err := h.Authority.GetSSHHosts(cert) + hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) if err != nil { WriteError(w, errs.InternalServerErr(err)) return @@ -472,7 +472,7 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - bastion, err := h.Authority.GetSSHBastion(body.User, body.Hostname) + bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) if err != nil { WriteError(w, errs.InternalServerErr(err)) return diff --git a/api/sshRekey.go b/api/sshRekey.go index a5cc1f06..285422f9 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "github.com/pkg/errors" @@ -56,7 +55,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) signOpts, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.UnauthorizedErr(err)) @@ -67,7 +66,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.InternalServerErr(err)) } - newCert, err := h.Authority.RekeySSH(oldCert, publicKey, signOpts...) + newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 11a9d8e8..048c83a3 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "github.com/pkg/errors" @@ -46,7 +45,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) _, err := h.Authority.Authorize(ctx, body.OTT) if err != nil { WriteError(w, errs.UnauthorizedErr(err)) @@ -57,7 +56,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { WriteError(w, errs.InternalServerErr(err)) } - newCert, err := h.Authority.RenewSSH(oldCert) + newCert, err := h.Authority.RenewSSH(ctx, oldCert) if err != nil { WriteError(w, errs.ForbiddenErr(err)) return diff --git a/api/sshRevoke.go b/api/sshRevoke.go index b8d1dadd..5a1c858c 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "github.com/smallstep/certificates/authority" @@ -65,7 +64,7 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) diff --git a/api/ssh_test.go b/api/ssh_test.go index cb5c7904..874c00b7 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -319,10 +319,10 @@ func Test_caHandler_SSHSign(t *testing.T) { authorizeSign: func(ott string) ([]provisioner.SignOption, error) { return []provisioner.SignOption{}, tt.authErr }, - signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { + signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { return tt.signCert, tt.signErr }, - signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { + signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { @@ -379,7 +379,7 @@ func Test_caHandler_SSHRoots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHRoots: func() (*authority.SSHKeys, error) { + getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }).(*caHandler) @@ -433,7 +433,7 @@ func Test_caHandler_SSHFederation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHFederation: func() (*authority.SSHKeys, error) { + getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }).(*caHandler) @@ -493,7 +493,7 @@ func Test_caHandler_SSHConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) { + getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { return tt.output, tt.err }, }).(*caHandler) @@ -591,7 +591,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) { + getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) { return tt.hosts, tt.err }, }).(*caHandler) @@ -646,7 +646,7 @@ func Test_caHandler_SSHBastion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHBastion: func(user, hostname string) (*authority.Bastion, error) { + getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { return tt.bastion, tt.bastionErr }, }).(*caHandler) diff --git a/authority/authority.go b/authority/authority.go index 0730fe5e..8cf4cfc1 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -51,9 +51,9 @@ type Authority struct { startTime time.Time // Custom functions - sshBastionFunc func(user, hostname string) (*Bastion, error) + sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) - sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error) + sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) getIdentityFunc provisioner.GetIdentityFunc } @@ -227,7 +227,7 @@ func (a *Authority) init() error { // TODO: should we also be combining the ssh federated roots here? // If we rotate ssh roots keys, sshpop provisioner will lose ability to // validate old SSH certificates, unless they are added as federated certs. - sshKeys, err := a.GetSSHRoots() + sshKeys, err := a.GetSSHRoots(context.Background()) if err != nil { return err } diff --git a/authority/options.go b/authority/options.go index 2d655a2b..04cd7bef 100644 --- a/authority/options.go +++ b/authority/options.go @@ -28,7 +28,7 @@ func WithDatabase(db db.AuthDB) Option { // WithGetIdentityFunc sets a custom function to retrieve the identity from // an external resource. -func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { +func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { return func(a *Authority) error { a.getIdentityFunc = fn return nil @@ -37,7 +37,7 @@ func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provis // WithSSHBastionFunc sets a custom function to get the bastion for a // given user-host pair. -func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { +func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*Bastion, error)) Option { return func(a *Authority) error { a.sshBastionFunc = fn return nil @@ -46,7 +46,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { // WithSSHGetHosts sets a custom function to get the bastion for a // given user-host pair. -func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Option { +func WithSSHGetHosts(fn func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)) Option { return func(a *Authority) error { a.sshGetHostsFunc = fn return nil diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 0b5448af..f90d96b5 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -339,7 +339,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. - iden, err := o.getIdentityFunc(o, claims.Email) + iden, err := o.getIdentityFunc(ctx, o, claims.Email) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index d0782c1e..fbf71f4b 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -485,10 +485,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { return nil, errors.New("force") } diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 4cf1c5f9..885e7cf0 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -330,10 +330,10 @@ type Identity struct { } // GetIdentityFunc is a function that returns an identity. -type GetIdentityFunc func(p Interface, email string) (*Identity, error) +type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) // DefaultIdentityFunc return a default identity depending on the provisioner type. -func DefaultIdentityFunc(p Interface, email string) (*Identity, error) { +func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { switch k := p.(type) { case *OIDC: name := SanitizeSSHUserPrincipal(email) diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index 2577c62f..238e21a3 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -92,7 +92,7 @@ func TestDefaultIdentityFunc(t *testing.T) { for name, get := range tests { t.Run(name, func(t *testing.T) { tc := get(t) - identity, err := DefaultIdentityFunc(tc.p, tc.email) + identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email) if err != nil { if assert.NotNil(t, tc.err) { assert.Equals(t, tc.err.Error(), err.Error()) diff --git a/authority/ssh.go b/authority/ssh.go index f47447d5..d28205e2 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -104,7 +104,7 @@ type SSHKeys struct { } // GetSSHRoots returns the SSH User and Host public keys. -func (a *Authority) GetSSHRoots() (*SSHKeys, error) { +func (a *Authority) GetSSHRoots(context.Context) (*SSHKeys, error) { return &SSHKeys{ HostKeys: a.sshCAHostCerts, UserKeys: a.sshCAUserCerts, @@ -112,7 +112,7 @@ func (a *Authority) GetSSHRoots() (*SSHKeys, error) { } // GetSSHFederation returns the public keys for federated SSH signers. -func (a *Authority) GetSSHFederation() (*SSHKeys, error) { +func (a *Authority) GetSSHFederation(context.Context) (*SSHKeys, error) { return &SSHKeys{ HostKeys: a.sshCAHostFederatedCerts, UserKeys: a.sshCAUserFederatedCerts, @@ -120,7 +120,7 @@ func (a *Authority) GetSSHFederation() (*SSHKeys, error) { } // GetSSHConfig returns rendered templates for clients (user) or servers (host). -func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) { +func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { if a.sshCAUserCertSignKey == nil && a.sshCAHostCertSignKey == nil { return nil, errs.NotFound("getSSHConfig: ssh is not configured") } @@ -166,9 +166,9 @@ func (a *Authority) GetSSHConfig(typ string, data map[string]string) ([]template // GetSSHBastion returns the bastion configuration, for the given pair user, // hostname. -func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error) { +func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*Bastion, error) { if a.sshBastionFunc != nil { - bs, err := a.sshBastionFunc(user, hostname) + bs, err := a.sshBastionFunc(ctx, user, hostname) return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") } if a.config.SSH != nil { @@ -181,7 +181,7 @@ func (a *Authority) GetSSHBastion(user string, hostname string) (*Bastion, error } // SignSSH creates a signed SSH certificate with the given public key and options. -func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var mods []provisioner.SSHCertModifier var validators []provisioner.SSHCertValidator @@ -282,7 +282,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign } // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. -func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) { +func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { nonce, err := randutil.ASCII(32) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH") @@ -353,7 +353,7 @@ func (a *Authority) RenewSSH(oldCert *ssh.Certificate) (*ssh.Certificate, error) } // RekeySSH creates a signed SSH certificate using the old SSH certificate as a template. -func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var validators []provisioner.SSHCertValidator for _, op := range signOpts { @@ -443,7 +443,7 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp } // SignSSHAddUser signs a certificate that provisions a new user in a server. -func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { +func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subject *ssh.Certificate) (*ssh.Certificate, error) { if a.sshCAUserCertSignKey == nil { return nil, errs.NotImplemented("signSSHAddUser: user certificate signing is not enabled") } @@ -527,9 +527,9 @@ func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token st } // GetSSHHosts returns a list of valid host principals. -func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) { +func (a *Authority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { if a.sshGetHostsFunc != nil { - hosts, err := a.sshGetHostsFunc(cert) + hosts, err := a.sshGetHostsFunc(ctx, cert) return hosts, errs.Wrap(http.StatusInternalServerError, err, "getSSHHosts") } hostnames, err := a.db.GetSSHHostPrincipals() diff --git a/authority/ssh_test.go b/authority/ssh_test.go index b581740f..6d05e1a9 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -153,7 +153,7 @@ func TestAuthority_SignSSH(t *testing.T) { a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey - got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...) + got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr) return @@ -242,7 +242,7 @@ func TestAuthority_SignSSHAddUser(t *testing.T) { AddUserPrincipal: tt.fields.addUserPrincipal, AddUserCommand: tt.fields.addUserCommand, } - got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject) + got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject) if (err != nil) != tt.wantErr { t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr) return @@ -295,7 +295,7 @@ func TestAuthority_GetSSHRoots(t *testing.T) { a.sshCAUserCerts = tt.fields.sshCAUserCerts a.sshCAHostCerts = tt.fields.sshCAHostCerts - got, err := a.GetSSHRoots() + got, err := a.GetSSHRoots(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr) return @@ -337,7 +337,7 @@ func TestAuthority_GetSSHFederation(t *testing.T) { a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts - got, err := a.GetSSHFederation() + got, err := a.GetSSHFederation(context.Background()) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr) return @@ -463,7 +463,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) { a.sshCAUserCertSignKey = tt.fields.userSigner a.sshCAHostCertSignKey = tt.fields.hostSigner - got, err := a.GetSSHConfig(tt.args.typ, tt.args.data) + got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr) return @@ -614,7 +614,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) { } type fields struct { config *Config - sshBastionFunc func(user, hostname string) (*Bastion, error) + sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error) } type args struct { user string @@ -630,8 +630,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) { {"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false}, {"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false}, {"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false}, - {"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, - {"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, + {"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false}, + {"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true}, {"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true}, } for _, tt := range tests { @@ -640,7 +640,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) { config: tt.fields.config, sshBastionFunc: tt.fields.sshBastionFunc, } - got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname) + got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname) if (err != nil) != tt.wantErr { t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr) return @@ -659,7 +659,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { a := testAuthority(t) type test struct { - getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error) + getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error) auth *Authority cert *x509.Certificate cmp func(got []sshutil.Host) @@ -669,7 +669,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { tests := map[string]func(t *testing.T) *test{ "fail/getHostsFunc-fail": func(t *testing.T) *test { return &test{ - getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return nil, errors.New("force") }, cert: &x509.Certificate{}, @@ -684,7 +684,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { } return &test{ - getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) { + getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) { return hosts, nil }, cert: &x509.Certificate{}, @@ -732,7 +732,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) { } auth.sshGetHostsFunc = tc.getHostsFunc - hosts, err := auth.GetSSHHosts(tc.cert) + hosts, err := auth.GetSSHHosts(context.Background(), tc.cert) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder) @@ -901,7 +901,7 @@ func TestAuthority_RekeySSH(t *testing.T) { a.sshCAUserCertSignKey = tc.userSigner a.sshCAHostCertSignKey = tc.hostSigner - cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...) + cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...) if err != nil { if assert.NotNil(t, tc.err) { sc, ok := err.(errs.StatusCoder)