diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 3ea4e98d..9b403132 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -476,7 +477,9 @@ func TestAuthority_CheckSSHHost(t *testing.T) { err error } type args struct { + ctx context.Context principal string + token string } tests := []struct { name string @@ -485,12 +488,12 @@ func TestAuthority_CheckSSHHost(t *testing.T) { want bool wantErr bool }{ - {"true", fields{true, nil}, args{"foo.internal.com"}, true, false}, - {"false", fields{false, nil}, args{"foo.internal.com"}, false, false}, - {"notImplemented", fields{false, db.ErrNotImplemented}, args{"foo.internal.com"}, false, true}, - {"notImplemented", fields{true, db.ErrNotImplemented}, args{"foo.internal.com"}, false, true}, - {"internal", fields{false, fmt.Errorf("an error")}, args{"foo.internal.com"}, false, true}, - {"internal", fields{true, fmt.Errorf("an error")}, args{"foo.internal.com"}, false, true}, + {"true", fields{true, nil}, args{context.TODO(), "foo.internal.com", ""}, true, false}, + {"false", fields{false, nil}, args{context.TODO(), "foo.internal.com", ""}, false, false}, + {"notImplemented", fields{false, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, + {"notImplemented", fields{true, db.ErrNotImplemented}, args{context.TODO(), "foo.internal.com", ""}, false, true}, + {"internal", fields{false, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, + {"internal", fields{true, fmt.Errorf("an error")}, args{context.TODO(), "foo.internal.com", ""}, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -500,7 +503,7 @@ func TestAuthority_CheckSSHHost(t *testing.T) { return tt.fields.exists, tt.fields.err }, } - got, err := a.CheckSSHHost(tt.args.principal) + got, err := a.CheckSSHHost(tt.args.ctx, tt.args.principal, tt.args.token) if (err != nil) != tt.wantErr { t.Errorf("Authority.CheckSSHHost() error = %v, wantErr %v", err, tt.wantErr) return