From 7c53fe32c310faf14856957788e3c4b4911cf7fd Mon Sep 17 00:00:00 2001 From: max furman Date: Wed, 23 Nov 2022 15:29:28 -0800 Subject: [PATCH] Add capabilities endpoint and client integration --- api/api.go | 13 +++++++++++ api/api_test.go | 37 ++++++++++++++++++++++++++++++ ca/client.go | 31 +++++++++++++++++++++++++ ca/client_test.go | 58 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 139 insertions(+) diff --git a/api/api.go b/api/api.go index 9c2f1f31..68fca3bb 100644 --- a/api/api.go +++ b/api/api.go @@ -50,6 +50,7 @@ type Authority interface { GetRoots() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error) Version() authority.Version + Capabilities() authority.Capabilities GetCertificateRevocationList() ([]byte, error) } @@ -211,6 +212,10 @@ type VersionResponse struct { RequireClientAuthentication bool `json:"requireClientAuthentication,omitempty"` } +// CapabilitiesResponse is the response object that returns the version of the +// server. +type CapabilitiesResponse authority.Capabilities + // HealthResponse is the response object that returns the health of the server. type HealthResponse struct { Status string `json:"status"` @@ -261,8 +266,10 @@ func New(auth Authority) RouterHandler { return &caHandler{} } +// Route defines routing for the API. func Route(r Router) { r.MethodFunc("GET", "/version", Version) + r.MethodFunc("GET", "/capabilities", Capabilities) r.MethodFunc("GET", "/health", Health) r.MethodFunc("GET", "/root/{sha}", Root) r.MethodFunc("POST", "/sign", Sign) @@ -303,6 +310,12 @@ func Version(w http.ResponseWriter, r *http.Request) { }) } +// Capabilities is an HTTP handler that returns the capabilities of the authority +// server. +func Capabilities(w http.ResponseWriter, r *http.Request) { + render.JSON(w, CapabilitiesResponse(mustAuthority(r.Context()).Capabilities())) +} + // Health is an HTTP handler that returns the status of the server. func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, HealthResponse{Status: "ok"}) diff --git a/api/api_test.go b/api/api_test.go index e24751b3..4049a0d6 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -212,6 +212,7 @@ type mockAuthority struct { checkSSHHost func(ctx context.Context, principal, token string) (bool, error) getSSHBastion func(ctx context.Context, user string, hostname string) (*authority.Bastion, error) version func() authority.Version + capabilities func() authority.Capabilities } func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) { @@ -405,6 +406,13 @@ func (m *mockAuthority) Version() authority.Version { return m.ret1.(authority.Version) } +func (m *mockAuthority) Capabilities() authority.Capabilities { + if m.version != nil { + return m.capabilities() + } + return m.ret1.(authority.Capabilities) +} + func TestNewCertificate(t *testing.T) { cert := parseCertificate(rootPEM) if !reflect.DeepEqual(Certificate{Certificate: cert}, NewCertificate(cert)) { @@ -873,6 +881,35 @@ func Test_Health(t *testing.T) { } } +func Test_Capabilities(t *testing.T) { + capResp := CapabilitiesResponse{ + RequireClientAuthentication: false, + RemoteConfigurationManagement: true, + } + mockMustAuthority(t, &mockAuthority{ret1: authority.Capabilities(capResp)}) + req := httptest.NewRequest("GET", "http://example.com/capabilities", nil) + w := httptest.NewRecorder() + Capabilities(w, req) + + res := w.Result() + if res.StatusCode != 200 { + t.Errorf("caHandler.Capabilities StatusCode = %d, wants 200", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.Capabilities unexpected error = %v", err) + } + wantBytes, err := json.Marshal(capResp) + if err != nil { + assert.FatalError(t, err) + } + if !bytes.Equal(bytes.TrimSpace(body), wantBytes) { + t.Errorf("caHandler.Capabilities Body = %s, wants %s", body, wantBytes) + } +} + func Test_Root(t *testing.T) { tests := []struct { name string diff --git a/ca/client.go b/ca/client.go index bbafcfee..e143da61 100644 --- a/ca/client.go +++ b/ca/client.go @@ -45,6 +45,9 @@ var DisableIdentity = false // UserAgent will set the User-Agent header in the client requests. var UserAgent = "step-http-client/1.0" +// ErrNotFound is a standard not-found error. +var ErrNotFound = errors.New("not found") + type uaClient struct { Client *http.Client } @@ -605,6 +608,34 @@ retry: return &version, nil } +// Capabilities performs the capabilities request to the CA and returns the +// api.Capabilities struct. +func (c *Client) Capabilities() (*api.CapabilitiesResponse, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: "/capabilities"}) +retry: + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, clientError(err) + } + switch { + case resp.StatusCode == http.StatusNotFound: + return nil, ErrNotFound + case resp.StatusCode >= 400: + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + default: + var capabilities api.CapabilitiesResponse + if err := readJSON(resp.Body, &capabilities); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Capabilities; error reading %s", u) + } + return &capabilities, nil + } +} + // Health performs the health request to the CA and returns the // api.HealthResponse struct. func (c *Client) Health() (*api.HealthResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index dff7fd41..836b5835 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -206,6 +206,64 @@ func TestClient_Version(t *testing.T) { } } +func TestClient_Capabilities(t *testing.T) { + ok := &api.CapabilitiesResponse{ + RequireClientAuthentication: false, + RemoteConfigurationManagement: true, + } + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + expectedErr error + }{ + {"ok", ok, 200, false, nil}, + {"500", errs.InternalServer("force"), 500, true, errors.New(errs.InternalServerErrorDefaultMsg)}, + {"404", errs.NotFound("force"), 404, true, ErrNotFound}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + render.JSONStatus(w, tt.response, tt.responseCode) + }) + + got, err := c.Capabilities() + if (err != nil) != tt.wantErr { + t.Errorf("Client.Capabilities() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Capabilities() = %v, want nil", got) + } + if tt.responseCode == http.StatusNotFound { + assert.True(t, errors.Is(err, ErrNotFound)) + } else { + assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Version() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Health(t *testing.T) { ok := &api.HealthResponse{Status: "ok"}