Add method to get all the certs.

This commit is contained in:
Mariano Cano 2019-01-04 16:51:37 -08:00
parent 10978630e5
commit 37149ed3ea
2 changed files with 86 additions and 0 deletions

View file

@ -392,6 +392,7 @@ type mockAuthority struct {
renew func(cert *x509.Certificate) (*x509.Certificate, *x509.Certificate, error)
getProvisioners func(nextCursor string, limit int) ([]*authority.Provisioner, string, error)
getEncryptedKey func(kid string) (string, error)
getFederation func(cert *x509.Certificate) ([]*x509.Certificate, error)
}
func (m *mockAuthority) Authorize(ott string) ([]interface{}, error) {
@ -443,6 +444,13 @@ func (m *mockAuthority) GetEncryptedKey(kid string) (string, error) {
return m.ret1.(string), m.err
}
func (m *mockAuthority) GetFederation(cert *x509.Certificate) ([]*x509.Certificate, error) {
if m.getFederation != nil {
return m.getFederation(cert)
}
return m.ret1.([]*x509.Certificate), m.err
}
func Test_caHandler_Route(t *testing.T) {
type fields struct {
Authority Authority
@ -812,3 +820,50 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
})
}
}
func Test_caHandler_Federation(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
tests := []struct {
name string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
}
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler)
req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Federation(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
}
}
})
}
}