forked from TrueCloudLab/certificates
Merge pull request #562 from smallstep/renew-db-interface
Renew DB interface and Rekey
This commit is contained in:
commit
b9b1ac04d1
3 changed files with 113 additions and 1 deletions
|
@ -263,7 +263,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5
|
||||||
}
|
}
|
||||||
|
|
||||||
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...)
|
||||||
if err = a.storeCertificate(fullchain); err != nil {
|
if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil {
|
||||||
if err != db.ErrNotImplemented {
|
if err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...)
|
||||||
}
|
}
|
||||||
|
@ -287,6 +287,19 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error {
|
||||||
return a.db.StoreCertificate(fullchain[0])
|
return a.db.StoreCertificate(fullchain[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// storeRenewedCertificate allows to use an extension of the db.AuthDB interface
|
||||||
|
// that can log if a certificate has been renewed or rekeyed.
|
||||||
|
//
|
||||||
|
// TODO: at some point we should implement this in the standard implementation.
|
||||||
|
func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error {
|
||||||
|
if s, ok := a.db.(interface {
|
||||||
|
StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error
|
||||||
|
}); ok {
|
||||||
|
return s.StoreRenewedCertificate(oldCert, fullchain...)
|
||||||
|
}
|
||||||
|
return a.db.StoreCertificate(fullchain[0])
|
||||||
|
}
|
||||||
|
|
||||||
// RevokeOptions are the options for the Revoke API.
|
// RevokeOptions are the options for the Revoke API.
|
||||||
type RevokeOptions struct {
|
type RevokeOptions struct {
|
||||||
Serial string
|
Serial string
|
||||||
|
|
30
ca/client.go
30
ca/client.go
|
@ -616,6 +616,36 @@ retry:
|
||||||
return &sign, nil
|
return &sign, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rekey performs the rekey request to the CA and returns the api.SignResponse
|
||||||
|
// struct.
|
||||||
|
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
|
||||||
|
var retried bool
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshaling request")
|
||||||
|
}
|
||||||
|
|
||||||
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
|
||||||
|
client := &http.Client{Transport: tr}
|
||||||
|
retry:
|
||||||
|
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; client POST %s failed", u)
|
||||||
|
}
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
if !retried && c.retryOnError(resp) {
|
||||||
|
retried = true
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
return nil, readError(resp.Body)
|
||||||
|
}
|
||||||
|
var sign api.SignResponse
|
||||||
|
if err := readJSON(resp.Body, &sign); err != nil {
|
||||||
|
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u)
|
||||||
|
}
|
||||||
|
return &sign, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
|
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
|
||||||
// struct.
|
// struct.
|
||||||
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
|
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
|
||||||
|
|
|
@ -529,6 +529,75 @@ func TestClient_Renew(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Rekey(t *testing.T) {
|
||||||
|
ok := &api.SignResponse{
|
||||||
|
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
||||||
|
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
||||||
|
CertChainPEM: []api.Certificate{
|
||||||
|
{Certificate: parseCertificate(certPEM)},
|
||||||
|
{Certificate: parseCertificate(rootPEM)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &api.RekeyRequest{
|
||||||
|
CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
request *api.RekeyRequest
|
||||||
|
response interface{}
|
||||||
|
responseCode int
|
||||||
|
wantErr bool
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{"ok", request, ok, 200, false, nil},
|
||||||
|
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||||
|
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||||
|
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := httptest.NewServer(nil)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewClient() error = %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||||
|
})
|
||||||
|
|
||||||
|
got, err := c.Rekey(tt.request, nil)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
fmt.Printf("%+v", err)
|
||||||
|
t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if got != nil {
|
||||||
|
t.Errorf("Client.Renew() = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc, ok := err.(errs.StatusCoder)
|
||||||
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
|
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||||
|
assert.HasPrefix(t, tt.err.Error(), err.Error())
|
||||||
|
default:
|
||||||
|
if !reflect.DeepEqual(got, tt.response) {
|
||||||
|
t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_Provisioners(t *testing.T) {
|
func TestClient_Provisioners(t *testing.T) {
|
||||||
ok := &api.ProvisionersResponse{
|
ok := &api.ProvisionersResponse{
|
||||||
Provisioners: provisioner.List{},
|
Provisioners: provisioner.List{},
|
||||||
|
|
Loading…
Reference in a new issue