Add context to tests.
This commit is contained in:
parent
c49a9d5e33
commit
fa416336a8
5 changed files with 61 additions and 61 deletions
|
@ -550,8 +550,6 @@ type mockAuthority struct {
|
|||
getTLSOptions func() *tlsutil.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
signSSH func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
|
||||
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
|
||||
loadProvisionerByID func(provID string) (provisioner.Interface, error)
|
||||
|
@ -560,14 +558,16 @@ type mockAuthority struct {
|
|||
getEncryptedKey func(kid string) (string, error)
|
||||
getRoots func() ([]*x509.Certificate, error)
|
||||
getFederation func() ([]*x509.Certificate, error)
|
||||
renewSSH func(cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
rekeySSH func(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
getSSHHosts func(*x509.Certificate) ([]sshutil.Host, error)
|
||||
getSSHRoots func() (*authority.SSHKeys, error)
|
||||
getSSHFederation func() (*authority.SSHKeys, error)
|
||||
getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error)
|
||||
signSSH func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
signSSHAddUser func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
renewSSH func(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error)
|
||||
rekeySSH func(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error)
|
||||
getSSHHosts func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error)
|
||||
getSSHRoots func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHFederation func(ctx context.Context) (*authority.SSHKeys, error)
|
||||
getSSHConfig func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error)
|
||||
checkSSHHost func(ctx context.Context, principal, token string) (bool, error)
|
||||
getSSHBastion func(user string, hostname string) (*authority.Bastion, error)
|
||||
getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error)
|
||||
version func() authority.Version
|
||||
}
|
||||
|
||||
|
@ -604,20 +604,6 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.Optio
|
|||
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.signSSH != nil {
|
||||
return m.signSSH(key, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.signSSHAddUser != nil {
|
||||
return m.signSSHAddUser(key, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
if m.renew != nil {
|
||||
return m.renew(cert)
|
||||
|
@ -674,44 +660,58 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) {
|
|||
return m.ret1.([]*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewSSH(cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
func (m *mockAuthority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.signSSH != nil {
|
||||
return m.signSSH(ctx, key, opts, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.signSSHAddUser != nil {
|
||||
return m.signSSHAddUser(ctx, key, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RenewSSH(ctx context.Context, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
if m.renewSSH != nil {
|
||||
return m.renewSSH(cert)
|
||||
return m.renewSSH(ctx, cert)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) RekeySSH(cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
func (m *mockAuthority) RekeySSH(ctx context.Context, cert *ssh.Certificate, key ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
if m.rekeySSH != nil {
|
||||
return m.rekeySSH(cert, key, signOpts...)
|
||||
return m.rekeySSH(ctx, cert, key, signOpts...)
|
||||
}
|
||||
return m.ret1.(*ssh.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
func (m *mockAuthority) GetSSHHosts(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
if m.getSSHHosts != nil {
|
||||
return m.getSSHHosts(cert)
|
||||
return m.getSSHHosts(ctx, cert)
|
||||
}
|
||||
return m.ret1.([]sshutil.Host), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) {
|
||||
func (m *mockAuthority) GetSSHRoots(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHRoots != nil {
|
||||
return m.getSSHRoots()
|
||||
return m.getSSHRoots(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) {
|
||||
func (m *mockAuthority) GetSSHFederation(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
if m.getSSHFederation != nil {
|
||||
return m.getSSHFederation()
|
||||
return m.getSSHFederation(ctx)
|
||||
}
|
||||
return m.ret1.(*authority.SSHKeys), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) {
|
||||
func (m *mockAuthority) GetSSHConfig(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
if m.getSSHConfig != nil {
|
||||
return m.getSSHConfig(typ, data)
|
||||
return m.getSSHConfig(ctx, typ, data)
|
||||
}
|
||||
return m.ret1.([]templates.Output), m.err
|
||||
}
|
||||
|
@ -723,9 +723,9 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
|
|||
return m.ret1.(bool), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) {
|
||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
|
||||
if m.getSSHBastion != nil {
|
||||
return m.getSSHBastion(user, hostname)
|
||||
return m.getSSHBastion(ctx, user, hostname)
|
||||
}
|
||||
return m.ret1.(*authority.Bastion), m.err
|
||||
}
|
||||
|
|
|
@ -319,10 +319,10 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
|||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||
return []provisioner.SignOption{}, tt.authErr
|
||||
},
|
||||
signSSH: func(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||
return tt.signCert, tt.signErr
|
||||
},
|
||||
signSSHAddUser: func(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) {
|
||||
return tt.addUserCert, tt.addUserErr
|
||||
},
|
||||
sign: func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||
|
@ -379,7 +379,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHRoots: func() (*authority.SSHKeys, error) {
|
||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
@ -433,7 +433,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHFederation: func() (*authority.SSHKeys, error) {
|
||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||
return tt.keys, tt.keysErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
@ -493,7 +493,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHConfig: func(typ string, data map[string]string) ([]templates.Output, error) {
|
||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||
return tt.output, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
@ -591,7 +591,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHHosts: func(*x509.Certificate) ([]sshutil.Host, error) {
|
||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]sshutil.Host, error) {
|
||||
return tt.hosts, tt.err
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
@ -646,7 +646,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
getSSHBastion: func(user, hostname string) (*authority.Bastion, error) {
|
||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||
return tt.bastion, tt.bastionErr
|
||||
},
|
||||
}).(*caHandler)
|
||||
|
|
|
@ -485,10 +485,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.FatalError(t, p4.Init(config))
|
||||
assert.FatalError(t, p5.Init(config))
|
||||
|
||||
p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
|
||||
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
||||
}
|
||||
p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) {
|
||||
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ func TestDefaultIdentityFunc(t *testing.T) {
|
|||
for name, get := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := get(t)
|
||||
identity, err := DefaultIdentityFunc(tc.p, tc.email)
|
||||
identity, err := DefaultIdentityFunc(context.Background(), tc.p, tc.email)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
assert.Equals(t, tc.err.Error(), err.Error())
|
||||
|
|
|
@ -153,7 +153,7 @@ func TestAuthority_SignSSH(t *testing.T) {
|
|||
a.sshCAUserCertSignKey = tt.fields.sshCAUserCertSignKey
|
||||
a.sshCAHostCertSignKey = tt.fields.sshCAHostCertSignKey
|
||||
|
||||
got, err := a.SignSSH(tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||
got, err := a.SignSSH(context.Background(), tt.args.key, tt.args.opts, tt.args.signOpts...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.SignSSH() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -242,7 +242,7 @@ func TestAuthority_SignSSHAddUser(t *testing.T) {
|
|||
AddUserPrincipal: tt.fields.addUserPrincipal,
|
||||
AddUserCommand: tt.fields.addUserCommand,
|
||||
}
|
||||
got, err := a.SignSSHAddUser(tt.args.key, tt.args.subject)
|
||||
got, err := a.SignSSHAddUser(context.Background(), tt.args.key, tt.args.subject)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.SignSSHAddUser() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -295,7 +295,7 @@ func TestAuthority_GetSSHRoots(t *testing.T) {
|
|||
a.sshCAUserCerts = tt.fields.sshCAUserCerts
|
||||
a.sshCAHostCerts = tt.fields.sshCAHostCerts
|
||||
|
||||
got, err := a.GetSSHRoots()
|
||||
got, err := a.GetSSHRoots(context.Background())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetSSHRoots() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -337,7 +337,7 @@ func TestAuthority_GetSSHFederation(t *testing.T) {
|
|||
a.sshCAUserFederatedCerts = tt.fields.sshCAUserFederatedCerts
|
||||
a.sshCAHostFederatedCerts = tt.fields.sshCAHostFederatedCerts
|
||||
|
||||
got, err := a.GetSSHFederation()
|
||||
got, err := a.GetSSHFederation(context.Background())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetSSHFederation() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -463,7 +463,7 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
|
|||
a.sshCAUserCertSignKey = tt.fields.userSigner
|
||||
a.sshCAHostCertSignKey = tt.fields.hostSigner
|
||||
|
||||
got, err := a.GetSSHConfig(tt.args.typ, tt.args.data)
|
||||
got, err := a.GetSSHConfig(context.Background(), tt.args.typ, tt.args.data)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetSSHConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -614,7 +614,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
|||
}
|
||||
type fields struct {
|
||||
config *Config
|
||||
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
||||
sshBastionFunc func(ctx context.Context, user, hostname string) (*Bastion, error)
|
||||
}
|
||||
type args struct {
|
||||
user string
|
||||
|
@ -630,8 +630,8 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
|||
{"config", fields{&Config{SSH: &SSHConfig{Bastion: bastion}}, nil}, args{"user", "host.local"}, bastion, false},
|
||||
{"nil", fields{&Config{SSH: &SSHConfig{Bastion: nil}}, nil}, args{"user", "host.local"}, nil, false},
|
||||
{"empty", fields{&Config{SSH: &SSHConfig{Bastion: &Bastion{}}}, nil}, args{"user", "host.local"}, nil, false},
|
||||
{"func", fields{&Config{}, func(_, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
|
||||
{"func err", fields{&Config{}, func(_, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
|
||||
{"func", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return bastion, nil }}, args{"user", "host.local"}, bastion, false},
|
||||
{"func err", fields{&Config{}, func(_ context.Context, _, _ string) (*Bastion, error) { return nil, errors.New("foo") }}, args{"user", "host.local"}, nil, true},
|
||||
{"error", fields{&Config{SSH: nil}, nil}, args{"user", "host.local"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -640,7 +640,7 @@ func TestAuthority_GetSSHBastion(t *testing.T) {
|
|||
config: tt.fields.config,
|
||||
sshBastionFunc: tt.fields.sshBastionFunc,
|
||||
}
|
||||
got, err := a.GetSSHBastion(tt.args.user, tt.args.hostname)
|
||||
got, err := a.GetSSHBastion(context.Background(), tt.args.user, tt.args.hostname)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.GetSSHBastion() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -659,7 +659,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
|||
a := testAuthority(t)
|
||||
|
||||
type test struct {
|
||||
getHostsFunc func(*x509.Certificate) ([]sshutil.Host, error)
|
||||
getHostsFunc func(context.Context, *x509.Certificate) ([]sshutil.Host, error)
|
||||
auth *Authority
|
||||
cert *x509.Certificate
|
||||
cmp func(got []sshutil.Host)
|
||||
|
@ -669,7 +669,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
|||
tests := map[string]func(t *testing.T) *test{
|
||||
"fail/getHostsFunc-fail": func(t *testing.T) *test {
|
||||
return &test{
|
||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
cert: &x509.Certificate{},
|
||||
|
@ -684,7 +684,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
|||
}
|
||||
|
||||
return &test{
|
||||
getHostsFunc: func(cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
getHostsFunc: func(ctx context.Context, cert *x509.Certificate) ([]sshutil.Host, error) {
|
||||
return hosts, nil
|
||||
},
|
||||
cert: &x509.Certificate{},
|
||||
|
@ -732,7 +732,7 @@ func TestAuthority_GetSSHHosts(t *testing.T) {
|
|||
}
|
||||
auth.sshGetHostsFunc = tc.getHostsFunc
|
||||
|
||||
hosts, err := auth.GetSSHHosts(tc.cert)
|
||||
hosts, err := auth.GetSSHHosts(context.Background(), tc.cert)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
@ -901,7 +901,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
|
|||
a.sshCAUserCertSignKey = tc.userSigner
|
||||
a.sshCAHostCertSignKey = tc.hostSigner
|
||||
|
||||
cert, err := auth.RekeySSH(tc.cert, tc.key, tc.signOpts...)
|
||||
cert, err := auth.RekeySSH(context.Background(), tc.cert, tc.key, tc.signOpts...)
|
||||
if err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
|
|
Loading…
Reference in a new issue