diff --git a/ca/client.go b/ca/client.go index 68c6be22..a7cd0a7a 100644 --- a/ca/client.go +++ b/ca/client.go @@ -694,6 +694,27 @@ func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) { return &hosts, nil } +// SSHBastion performs the POST /ssh/bastion request to the CA. +func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"}) + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var bastion api.SSHBastionResponse + if err := readJSON(resp.Body, &bastion); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &bastion, nil +} + // RootFingerprint is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified // chains. diff --git a/ca/client_test.go b/ca/client_test.go index fc3a5049..f9a968c0 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -18,6 +18,7 @@ import ( "github.com/smallstep/assert" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/x509util" "golang.org/x/crypto/ssh" @@ -882,3 +883,64 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } + +func TestClient_SSHBastion(t *testing.T) { + ok := &api.SSHBastionResponse{ + Hostname: "host.local", + Bastion: &authority.Bastion{ + Hostname: "bastion.local", + }, + } + badRequest := api.BadRequest(fmt.Errorf("Bad Request")) + + tests := []struct { + name string + request *api.SSHBastionRequest + response interface{} + responseCode int + wantErr bool + }{ + {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false}, + {"bad response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true}, + {"empty request", &api.SSHBastionRequest{}, badRequest, 403, true}, + {"nil request", nil, badRequest, 403, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + api.JSONStatus(w, tt.response, tt.responseCode) + }) + + got, err := c.SSHBastion(tt.request) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.SSHBastion() = %v, want nil", got) + } + if tt.responseCode != 200 && !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.SSHBastion() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response) + } + } + }) + } +}