diff --git a/api/api_test.go b/api/api_test.go index edefbd47..8350a41d 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api import ( "bytes" "context" + "crypto" "crypto/dsa" "crypto/ecdsa" "crypto/elliptic" @@ -550,7 +551,8 @@ type mockAuthority struct { getTLSOptions func() *tlsutil.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) - renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + renew func(cert *x509.Certificate) ([]*x509.Certificate, error) + renewOrRekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) @@ -611,6 +613,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) RenewOrRekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) { + if m.renewOrRekey != nil { + return m.renewOrRekey(oldcert, pk) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { if m.getProvisioners != nil { return m.getProvisioners(nextCursor, limit) @@ -952,6 +961,67 @@ func Test_caHandler_Renew(t *testing.T) { } } +func Test_caHandler_Rekey(t *testing.T) { + cs := &tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, + } + csr := parseCertificateRequest(csrPEM) + valid, err := json.Marshal(RekeyRequest{ + CsrPEM: CertificateRequest{csr}, + }) + if err != nil { + t.Fatal(err) + } + tests := []struct { + name string + input string + tls *tls.ConnectionState + cert *x509.Certificate + root *x509.Certificate + err error + statusCode int + }{ + {"ok", string(valid), cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, + {"no tls", string(valid), nil, nil, nil, nil, http.StatusBadRequest}, + {"no peer certificates", string(valid), &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, + {"rekey error", string(valid), cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, + {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, + } + + expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + ret1: tt.cert, ret2: tt.root, err: tt.err, + getTLSOptions: func() *tlsutil.TLSOptions { + return nil + }, + }).(*caHandler) + req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) + req.TLS = tt.tls + w := httptest.NewRecorder() + h.Rekey(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Rekey unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), expected) { + t.Errorf("caHandler.Rekey Body = %s, wants %s", body, expected) + } + } + }) + } +} + func Test_caHandler_Provisioners(t *testing.T) { type fields struct { Authority Authority