diff --git a/ca/client.go b/ca/client.go index a2451000..c574a1eb 100644 --- a/ca/client.go +++ b/ca/client.go @@ -18,6 +18,7 @@ import ( "net" "net/http" "net/url" + "strconv" "strings" "github.com/pkg/errors" @@ -164,6 +165,50 @@ func parseEndpoint(endpoint string) (*url.URL, error) { return u, nil } +// ProvisionerOption is the type of options passed to the Provisioner method. +type ProvisionerOption func(o *provisionerOptions) error + +type provisionerOptions struct { + cursor string + limit int +} + +func (o *provisionerOptions) apply(opts []ProvisionerOption) (err error) { + for _, fn := range opts { + if err = fn(o); err != nil { + return + } + } + return +} + +func (o *provisionerOptions) rawQuery() string { + v := url.Values{} + if len(o.cursor) > 0 { + v.Set("cursor", o.cursor) + } + if o.limit > 0 { + v.Set("limit", strconv.Itoa(o.limit)) + } + return v.Encode() +} + +// WithProvisionerCursor will request the provisioners starting with the given cursor. +func WithProvisionerCursor(cursor string) ProvisionerOption { + return func(o *provisionerOptions) error { + o.cursor = cursor + return nil + } +} + +// WithProvisionerLimit will request the given number of provisioners. +func WithProvisionerLimit(limit int) ProvisionerOption { + return func(o *provisionerOptions) error { + o.limit = limit + return nil + } +} + // Client implements an HTTP client for the CA server. type Client struct { client *http.Client @@ -297,8 +342,18 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { // Provisioners performs the provisioners request to the CA and returns the // api.ProvisionersResponse struct with a map of provisioners. -func (c *Client) Provisioners() (*api.ProvisionersResponse, error) { - u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners"}) +// +// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to +// paginate the provisioners. +func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) { + o := new(provisionerOptions) + if err := o.apply(opts); err != nil { + return nil, err + } + u := c.endpoint.ResolveReference(&url.URL{ + Path: "/provisioners", + RawQuery: o.rawQuery(), + }) resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) diff --git a/ca/client_test.go b/ca/client_test.go index 0e231b81..a41baa20 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -396,12 +396,17 @@ func TestClient_Provisioners(t *testing.T) { tests := []struct { name string + args []ProvisionerOption + expectedURI string response interface{} responseCode int wantErr bool }{ - {"ok", ok, 200, false}, - {"fail", internalServerError, 500, true}, + {"ok", nil, "/provisioners", ok, 200, false}, + {"ok with cursor", []ProvisionerOption{WithProvisionerCursor("abc")}, "/provisioners?cursor=abc", ok, 200, false}, + {"ok with limit", []ProvisionerOption{WithProvisionerLimit(10)}, "/provisioners?limit=10", ok, 200, false}, + {"ok with cursor+limit", []ProvisionerOption{WithProvisionerCursor("abc"), WithProvisionerLimit(10)}, "/provisioners?cursor=abc&limit=10", ok, 200, false}, + {"fail", nil, "/provisioners", internalServerError, 500, true}, } srv := httptest.NewServer(nil) @@ -416,15 +421,14 @@ func TestClient_Provisioners(t *testing.T) { } srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - expected := "/provisioners" - if req.RequestURI != expected { - t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) + if req.RequestURI != tt.expectedURI { + t.Errorf("RequestURI = %s, want %s", req.RequestURI, tt.expectedURI) } w.WriteHeader(tt.responseCode) api.JSON(w, tt.response) }) - got, err := c.Provisioners() + got, err := c.Provisioners(tt.args...) if (err != nil) != tt.wantErr { t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) return @@ -447,67 +451,6 @@ func TestClient_Provisioners(t *testing.T) { } } -/* -func TestClient_Provisioners(t *testing.T) { - ok := &api.ProvisionersResponse{ - Provisioners: map[string]*jose.JSONWebKeySet{}, - } - internalServerError := api.InternalServerError(fmt.Errorf("Internal Server Error")) - - tests := []struct { - name string - response interface{} - responseCode int - wantErr bool - }{ - {"ok", ok, 200, false}, - {"fail", internalServerError, 500, true}, - } - - srv := httptest.NewServer(nil) - defer srv.Close() - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } - - srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - expected := "/provisioners" - if req.RequestURI != expected { - t.Errorf("RequestURI = %s, want %s", req.RequestURI, expected) - } - w.WriteHeader(tt.responseCode) - api.JSON(w, tt.response) - }) - - got, err := c.Provisioners() - if (err != nil) != tt.wantErr { - t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) - return - } - - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Provisioners() = %v, want nil", got) - } - if !reflect.DeepEqual(err, tt.response) { - t.Errorf("Client.Provisioners() error = %v, want %v", err, tt.response) - } - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) - } - } - }) - } -} -*/ - func TestClient_ProvisionerKey(t *testing.T) { ok := &api.ProvisionerKeyResponse{ Key: "an encrypted key",