Fix unit tests on the api package

This commit is contained in:
Mariano Cano 2022-04-27 10:38:53 -07:00
parent a93653ea8e
commit 817af3d696
4 changed files with 77 additions and 68 deletions

View file

@ -54,7 +54,8 @@ type Authority interface {
var errAuthority = errors.New("authority is not in context") var errAuthority = errors.New("authority is not in context")
func mustAuthority(ctx context.Context) Authority { // mustAuthority will be replaced on unit tests.
var mustAuthority = func(ctx context.Context) Authority {
a, ok := authority.FromContext(ctx) a, ok := authority.FromContext(ctx)
if !ok { if !ok {
panic(errAuthority) panic(errAuthority)
@ -249,7 +250,9 @@ type FederationResponse struct {
} }
// caHandler is the type used to implement the different CA HTTP endpoints. // caHandler is the type used to implement the different CA HTTP endpoints.
type caHandler struct{} type caHandler struct {
Authority Authority
}
// Route configures the http request router. // Route configures the http request router.
func (h *caHandler) Route(r Router) { func (h *caHandler) Route(r Router) {

View file

@ -171,6 +171,17 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr return csr
} }
func mockMustAuthority(t *testing.T, a Authority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) Authority {
return a
}
}
type mockAuthority struct { type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
@ -789,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) {
} }
} }
func Test_caHandler_Health(t *testing.T) { func Test_Health(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/health", nil) req := httptest.NewRequest("GET", "http://example.com/health", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := New(&mockAuthority{}).(*caHandler) Health(w, req)
h.Health(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != 200 { if res.StatusCode != 200 {
@ -811,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) {
} }
} }
func Test_caHandler_Root(t *testing.T) { func Test_Root(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
root *x509.Certificate root *x509.Certificate
@ -832,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Root(w, req) Root(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -855,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) {
} }
} }
func Test_caHandler_Sign(t *testing.T) { func Test_Sign(t *testing.T) {
csr := parseCertificateRequest(csrPEM) csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(SignRequest{ valid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr}, CsrPEM: CertificateRequest{csr},
@ -896,7 +906,7 @@ func Test_caHandler_Sign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr, ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr return tt.certAttrOpts, tt.autherr
@ -904,10 +914,10 @@ func Test_caHandler_Sign(t *testing.T) {
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Sign(logging.NewResponseLogger(w), req) Sign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -928,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) {
} }
} }
func Test_caHandler_Renew(t *testing.T) { func Test_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1018,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
@ -1039,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) {
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) Renew(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -1073,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) {
} }
} }
func Test_caHandler_Rekey(t *testing.T) { func Test_Rekey(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1104,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Rekey(logging.NewResponseLogger(w), req) Rekey(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1134,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) {
} }
} }
func Test_caHandler_Provisioners(t *testing.T) { func Test_Provisioners(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1200,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, Provisioners(tt.args.w, tt.args.r)
}
h.Provisioners(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1238,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
} }
} }
func Test_caHandler_ProvisionerKey(t *testing.T) { func Test_ProvisionerKey(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1270,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, ProvisionerKey(tt.args.w, tt.args.r)
}
h.ProvisionerKey(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1298,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
} }
} }
func Test_caHandler_Roots(t *testing.T) { func Test_Roots(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1319,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/roots", nil) req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Roots(w, req) Roots(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1360,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
req := httptest.NewRequest("GET", "https://example.com/roots", nil) req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.RootsPEM(w, req) RootsPEM(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1384,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
} }
} }
func Test_caHandler_Federation(t *testing.T) { func Test_Federation(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1405,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/federation", nil) req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Federation(w, req) Federation(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {

View file

@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
for name, _tc := range tests { for name, _tc := range tests {
tc := _tc(t) tc := _tc(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*caHandler) mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
if tc.tls != nil { if tc.tls != nil {
req.TLS = tc.tls req.TLS = tc.tls
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Revoke(logging.NewResponseLogger(w), req) Revoke(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

View file

@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
} }
} }
func Test_caHandler_SSHSign(t *testing.T) { func Test_SSHSign(t *testing.T) {
user, err := getSignedUserCertificate() user, err := getSignedUserCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
host, err := getSignedHostCertificate() host, err := getSignedHostCertificate()
@ -315,7 +315,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
return []provisioner.SignOption{}, tt.authErr return []provisioner.SignOption{}, tt.authErr
}, },
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr return tt.tlsSignCerts, tt.tlsSignErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHSign(logging.NewResponseLogger(w), req) SSHSign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
} }
func Test_caHandler_SSHRoots(t *testing.T) { func Test_SSHRoots(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHRoots(logging.NewResponseLogger(w), req) SSHRoots(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
} }
func Test_caHandler_SSHFederation(t *testing.T) { func Test_SSHFederation(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHFederation(logging.NewResponseLogger(w), req) SSHFederation(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
} }
func Test_caHandler_SSHConfig(t *testing.T) { func Test_SSHConfig(t *testing.T) {
userOutput := []templates.Output{ userOutput := []templates.Output{
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
return tt.output, tt.err return tt.output, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHConfig(logging.NewResponseLogger(w), req) SSHConfig(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
} }
func Test_caHandler_SSHCheckHost(t *testing.T) { func Test_SSHCheckHost(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
req string req string
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
return tt.exists, tt.err return tt.exists, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHCheckHost(logging.NewResponseLogger(w), req) SSHCheckHost(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
} }
func Test_caHandler_SSHGetHosts(t *testing.T) { func Test_SSHGetHosts(t *testing.T) {
hosts := []authority.Host{ hosts := []authority.Host{
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
return tt.hosts, tt.err return tt.hosts, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHGetHosts(logging.NewResponseLogger(w), req) SSHGetHosts(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
} }
func Test_caHandler_SSHBastion(t *testing.T) { func Test_SSHBastion(t *testing.T) {
bastion := &authority.Bastion{ bastion := &authority.Bastion{
Hostname: "bastion.local", Hostname: "bastion.local",
} }
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
return tt.bastion, tt.bastionErr return tt.bastion, tt.bastionErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHBastion(logging.NewResponseLogger(w), req) SSHBastion(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {