forked from TrueCloudLab/certificates
Fix unit tests on the api package
This commit is contained in:
parent
a93653ea8e
commit
817af3d696
4 changed files with 77 additions and 68 deletions
|
@ -54,7 +54,8 @@ type Authority interface {
|
||||||
|
|
||||||
var errAuthority = errors.New("authority is not in context")
|
var errAuthority = errors.New("authority is not in context")
|
||||||
|
|
||||||
func mustAuthority(ctx context.Context) Authority {
|
// mustAuthority will be replaced on unit tests.
|
||||||
|
var mustAuthority = func(ctx context.Context) Authority {
|
||||||
a, ok := authority.FromContext(ctx)
|
a, ok := authority.FromContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic(errAuthority)
|
panic(errAuthority)
|
||||||
|
@ -249,7 +250,9 @@ type FederationResponse struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// caHandler is the type used to implement the different CA HTTP endpoints.
|
// caHandler is the type used to implement the different CA HTTP endpoints.
|
||||||
type caHandler struct{}
|
type caHandler struct {
|
||||||
|
Authority Authority
|
||||||
|
}
|
||||||
|
|
||||||
// Route configures the http request router.
|
// Route configures the http request router.
|
||||||
func (h *caHandler) Route(r Router) {
|
func (h *caHandler) Route(r Router) {
|
||||||
|
|
|
@ -171,6 +171,17 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
||||||
return csr
|
return csr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockMustAuthority(t *testing.T, a Authority) {
|
||||||
|
t.Helper()
|
||||||
|
fn := mustAuthority
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mustAuthority = fn
|
||||||
|
})
|
||||||
|
mustAuthority = func(ctx context.Context) Authority {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type mockAuthority struct {
|
type mockAuthority struct {
|
||||||
ret1, ret2 interface{}
|
ret1, ret2 interface{}
|
||||||
err error
|
err error
|
||||||
|
@ -789,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Health(t *testing.T) {
|
func Test_Health(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
req := httptest.NewRequest("GET", "http://example.com/health", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h := New(&mockAuthority{}).(*caHandler)
|
Health(w, req)
|
||||||
h.Health(w, req)
|
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
if res.StatusCode != 200 {
|
if res.StatusCode != 200 {
|
||||||
|
@ -811,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Root(t *testing.T) {
|
func Test_Root(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
root *x509.Certificate
|
root *x509.Certificate
|
||||||
|
@ -832,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Root(w, req)
|
Root(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -855,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Sign(t *testing.T) {
|
func Test_Sign(t *testing.T) {
|
||||||
csr := parseCertificateRequest(csrPEM)
|
csr := parseCertificateRequest(csrPEM)
|
||||||
valid, err := json.Marshal(SignRequest{
|
valid, err := json.Marshal(SignRequest{
|
||||||
CsrPEM: CertificateRequest{csr},
|
CsrPEM: CertificateRequest{csr},
|
||||||
|
@ -896,7 +906,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
ret1: tt.cert, ret2: tt.root, err: tt.signErr,
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
return tt.certAttrOpts, tt.autherr
|
return tt.certAttrOpts, tt.autherr
|
||||||
|
@ -904,10 +914,10 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Sign(logging.NewResponseLogger(w), req)
|
Sign(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -928,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Renew(t *testing.T) {
|
func Test_Renew(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1018,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||||
|
@ -1039,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
req.Header = tt.header
|
req.Header = tt.header
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Renew(logging.NewResponseLogger(w), req)
|
Renew(logging.NewResponseLogger(w), req)
|
||||||
|
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
@ -1073,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Rekey(t *testing.T) {
|
func Test_Rekey(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1104,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||||
getTLSOptions: func() *authority.TLSOptions {
|
getTLSOptions: func() *authority.TLSOptions {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Rekey(logging.NewResponseLogger(w), req)
|
Rekey(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1134,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Provisioners(t *testing.T) {
|
func Test_Provisioners(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
}
|
}
|
||||||
|
@ -1200,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := &caHandler{
|
mockMustAuthority(t, tt.fields.Authority)
|
||||||
Authority: tt.fields.Authority,
|
Provisioners(tt.args.w, tt.args.r)
|
||||||
}
|
|
||||||
h.Provisioners(tt.args.w, tt.args.r)
|
|
||||||
|
|
||||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||||
res := rec.Result()
|
res := rec.Result()
|
||||||
|
@ -1238,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_ProvisionerKey(t *testing.T) {
|
func Test_ProvisionerKey(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Authority Authority
|
Authority Authority
|
||||||
}
|
}
|
||||||
|
@ -1270,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := &caHandler{
|
mockMustAuthority(t, tt.fields.Authority)
|
||||||
Authority: tt.fields.Authority,
|
ProvisionerKey(tt.args.w, tt.args.r)
|
||||||
}
|
|
||||||
h.ProvisionerKey(tt.args.w, tt.args.r)
|
|
||||||
|
|
||||||
rec := tt.args.w.(*httptest.ResponseRecorder)
|
rec := tt.args.w.(*httptest.ResponseRecorder)
|
||||||
res := rec.Result()
|
res := rec.Result()
|
||||||
|
@ -1298,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Roots(t *testing.T) {
|
func Test_Roots(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1319,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
req := httptest.NewRequest("GET", "http://example.com/roots", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Roots(w, req)
|
Roots(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1360,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
req := httptest.NewRequest("GET", "https://example.com/roots", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.RootsPEM(w, req)
|
RootsPEM(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -1384,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_Federation(t *testing.T) {
|
func Test_Federation(t *testing.T) {
|
||||||
cs := &tls.ConnectionState{
|
cs := &tls.ConnectionState{
|
||||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||||
}
|
}
|
||||||
|
@ -1405,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
|
mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
|
||||||
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
|
||||||
req.TLS = tt.tls
|
req.TLS = tt.tls
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Federation(w, req)
|
Federation(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
|
|
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
|
||||||
for name, _tc := range tests {
|
for name, _tc := range tests {
|
||||||
tc := _tc(t)
|
tc := _tc(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := New(tc.auth).(*caHandler)
|
mockMustAuthority(t, tc.auth)
|
||||||
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
|
||||||
if tc.tls != nil {
|
if tc.tls != nil {
|
||||||
req.TLS = tc.tls
|
req.TLS = tc.tls
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.Revoke(logging.NewResponseLogger(w), req)
|
Revoke(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
assert.Equals(t, tc.statusCode, res.StatusCode)
|
assert.Equals(t, tc.statusCode, res.StatusCode)
|
||||||
|
|
|
@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHSign(t *testing.T) {
|
func Test_SSHSign(t *testing.T) {
|
||||||
user, err := getSignedUserCertificate()
|
user, err := getSignedUserCertificate()
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
host, err := getSignedHostCertificate()
|
host, err := getSignedHostCertificate()
|
||||||
|
@ -315,7 +315,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
authorizeSign: func(ott string) ([]provisioner.SignOption, error) {
|
||||||
return []provisioner.SignOption{}, tt.authErr
|
return []provisioner.SignOption{}, tt.authErr
|
||||||
},
|
},
|
||||||
|
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
|
||||||
return tt.tlsSignCerts, tt.tlsSignErr
|
return tt.tlsSignCerts, tt.tlsSignErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHSign(logging.NewResponseLogger(w), req)
|
SSHSign(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHRoots(t *testing.T) {
|
func Test_SSHRoots(t *testing.T) {
|
||||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHRoots(logging.NewResponseLogger(w), req)
|
SSHRoots(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHFederation(t *testing.T) {
|
func Test_SSHFederation(t *testing.T) {
|
||||||
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
user, err := ssh.NewPublicKey(sshUserKey.Public())
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
|
||||||
|
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
|
||||||
return tt.keys, tt.keysErr
|
return tt.keys, tt.keysErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHFederation(logging.NewResponseLogger(w), req)
|
SSHFederation(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHConfig(t *testing.T) {
|
func Test_SSHConfig(t *testing.T) {
|
||||||
userOutput := []templates.Output{
|
userOutput := []templates.Output{
|
||||||
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
|
||||||
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
|
||||||
|
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
|
||||||
return tt.output, tt.err
|
return tt.output, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHConfig(logging.NewResponseLogger(w), req)
|
SSHConfig(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHCheckHost(t *testing.T) {
|
func Test_SSHCheckHost(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
req string
|
req string
|
||||||
|
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
|
||||||
return tt.exists, tt.err
|
return tt.exists, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHCheckHost(logging.NewResponseLogger(w), req)
|
SSHCheckHost(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHGetHosts(t *testing.T) {
|
func Test_SSHGetHosts(t *testing.T) {
|
||||||
hosts := []authority.Host{
|
hosts := []authority.Host{
|
||||||
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
|
||||||
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
|
||||||
|
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
|
||||||
return tt.hosts, tt.err
|
return tt.hosts, tt.err
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHGetHosts(logging.NewResponseLogger(w), req)
|
SSHGetHosts(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_caHandler_SSHBastion(t *testing.T) {
|
func Test_SSHBastion(t *testing.T) {
|
||||||
bastion := &authority.Bastion{
|
bastion := &authority.Bastion{
|
||||||
Hostname: "bastion.local",
|
Hostname: "bastion.local",
|
||||||
}
|
}
|
||||||
|
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
h := New(&mockAuthority{
|
mockMustAuthority(t, &mockAuthority{
|
||||||
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||||
return tt.bastion, tt.bastionErr
|
return tt.bastion, tt.bastionErr
|
||||||
},
|
},
|
||||||
}).(*caHandler)
|
})
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.SSHBastion(logging.NewResponseLogger(w), req)
|
SSHBastion(logging.NewResponseLogger(w), req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
|
||||||
if res.StatusCode != tt.statusCode {
|
if res.StatusCode != tt.statusCode {
|
||||||
|
|
Loading…
Reference in a new issue