Added renewOrRekey to mockAuthority. Added Test_caHandler_Rekey

This commit is contained in:
dharanikumar-s 2020-07-05 22:05:00 +05:30
parent 01a6469d25
commit 954fda657b

View file

@ -3,6 +3,7 @@ package api
import ( import (
"bytes" "bytes"
"context" "context"
"crypto"
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@ -550,7 +551,8 @@ type mockAuthority struct {
getTLSOptions func() *tlsutil.TLSOptions getTLSOptions func() *tlsutil.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.Options, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
renewOrRekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error) loadProvisionerByCertificate func(cert *x509.Certificate) (provisioner.Interface, error)
loadProvisionerByID func(provID string) (provisioner.Interface, error) loadProvisionerByID func(provID string) (provisioner.Interface, error)
getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error) getProvisioners func(nextCursor string, limit int) (provisioner.List, string, error)
@ -611,6 +613,13 @@ func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, erro
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
func (m *mockAuthority) RenewOrRekey(oldcert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) {
if m.renewOrRekey != nil {
return m.renewOrRekey(oldcert, pk)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) { func (m *mockAuthority) GetProvisioners(nextCursor string, limit int) (provisioner.List, string, error) {
if m.getProvisioners != nil { if m.getProvisioners != nil {
return m.getProvisioners(nextCursor, limit) return m.getProvisioners(nextCursor, limit)
@ -952,6 +961,67 @@ func Test_caHandler_Renew(t *testing.T) {
} }
} }
func Test_caHandler_Rekey(t *testing.T) {
cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
}
csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(RekeyRequest{
CsrPEM: CertificateRequest{csr},
})
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
input string
tls *tls.ConnectionState
cert *x509.Certificate
root *x509.Certificate
err error
statusCode int
}{
{"ok", string(valid), cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", string(valid), nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", string(valid), &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
{"rekey error", string(valid), cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
}
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *tlsutil.TLSOptions {
return nil
},
}).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
req.TLS = tt.tls
w := httptest.NewRecorder()
h.Rekey(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Errorf("caHandler.Rekey unexpected error = %v", err)
}
if tt.statusCode < http.StatusBadRequest {
if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Rekey Body = %s, wants %s", body, expected)
}
}
})
}
}
func Test_caHandler_Provisioners(t *testing.T) { func Test_caHandler_Provisioners(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority