diff --git a/ca/client.go b/ca/client.go index 89a27de0..3eef6602 100644 --- a/ca/client.go +++ b/ca/client.go @@ -286,6 +286,43 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { return &sign, nil } +// Provisioners performs the provisioners request to the CA and returns the +// api.ProvisionersResponse struct with a map of provisioners. +func (c *Client) Provisioners() (*api.ProvisionersResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var provisioners api.ProvisionersResponse + if err := readJSON(resp.Body, &provisioners); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &provisioners, nil +} + +// ProvisionerKey performs the request to the CA to get the encrypted key for +// the given provisioner kid and returns the api.ProvisionerKeyResponse struct +// with the encrypted key. +func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var key api.ProvisionerKeyResponse + if err := readJSON(resp.Body, &key); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &key, nil +} + // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { diff --git a/ca/client_test.go b/ca/client_test.go index fbca2976..8540356c 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/smallstep/ca-component/api" + "github.com/smallstep/cli/jose" ) const ( @@ -386,3 +387,122 @@ func TestClient_Renew(t *testing.T) { }) } } + +func TestClient_Provisioners(t *testing.T) { + ok := &api.ProvisionersResponse{ + Provisioners: map[string]*jose.JSONWebKeySet{}, + } + internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"fail", internalServerError, 500, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + expected := "/provisioners" + if req.RequestURI != expected { + t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) + } + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.Provisioners() + if (err != nil) != tt.wantErr { + t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Provisioners() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) + } + } + }) + } +} + +func TestClient_ProvisionerKey(t *testing.T) { + ok := &api.ProvisionerKeyResponse{ + Key: "an encrypted key", + } + notFound := api.NotFound(fmt.Errorf("Not Found")) + + tests := []struct { + name string + kid string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", "kid", ok, 200, false}, + {"fail", "invalid", notFound, 500, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + expected := "/provisioners/" + tt.kid + "/encrypted-key" + if req.RequestURI != expected { + t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) + } + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.ProvisionerKey(tt.kid) + if (err != nil) != tt.wantErr { + t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.ProvisionerKey() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.ProvisionerKey() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) + } + } + }) + } +}