diff --git a/api/api.go b/api/api.go index 1ce44d77..d9beeb5c 100644 --- a/api/api.go +++ b/api/api.go @@ -250,15 +250,15 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/federation", h.Federation) // SSH CA - r.MethodFunc("POST", "/ssh/sign", h.SignSSH) - r.MethodFunc("GET", "/ssh/keys", h.SSHKeys) - r.MethodFunc("GET", "/ssh/federation", h.SSHFederatedKeys) + r.MethodFunc("POST", "/ssh/sign", h.SSHSign) + r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) + r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) - r.MethodFunc("POST", "/sign-ssh", h.SignSSH) + r.MethodFunc("POST", "/sign-ssh", h.SSHSign) } // Health is an HTTP handler that returns the status of the server. diff --git a/api/api_test.go b/api/api_test.go index eb4de053..f1d90db7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -513,7 +513,8 @@ type mockAuthority struct { getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) - getSSHKeys func() (*authority.SSHKeys, error) + getSSHRoots func() (*authority.SSHKeys, error) + getSSHFederation func() (*authority.SSHKeys, error) getSSHConfig func(typ string, data map[string]string) ([]templates.Output, error) } @@ -620,9 +621,16 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { return m.ret1.([]*x509.Certificate), m.err } -func (m *mockAuthority) GetSSHKeys() (*authority.SSHKeys, error) { - if m.getSSHKeys != nil { - return m.getSSHKeys() +func (m *mockAuthority) GetSSHRoots() (*authority.SSHKeys, error) { + if m.getSSHRoots != nil { + return m.getSSHRoots() + } + return m.ret1.(*authority.SSHKeys), m.err +} + +func (m *mockAuthority) GetSSHFederation() (*authority.SSHKeys, error) { + if m.getSSHFederation != nil { + return m.getSSHFederation() } return m.ret1.(*authority.SSHKeys), m.err } diff --git a/api/ssh.go b/api/ssh.go index 8d0b421d..e34174db 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -17,13 +17,13 @@ import ( type SSHAuthority interface { SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - GetSSHKeys() (*authority.SSHKeys, error) - GetSSHFederatedKeys() (*authority.SSHKeys, error) + GetSSHRoots() (*authority.SSHKeys, error) + GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) } -// SignSSHRequest is the request body of an SSH certificate request. -type SignSSHRequest struct { +// SSHSignRequest is the request body of an SSH certificate request. +type SSHSignRequest struct { PublicKey []byte `json:"publicKey"` //base64 encoded OTT string `json:"ott"` CertType string `json:"certType,omitempty"` @@ -33,8 +33,8 @@ type SignSSHRequest struct { AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"` } -// Validate validates the SignSSHRequest. -func (s *SignSSHRequest) Validate() error { +// Validate validates the SSHSignRequest. +func (s *SSHSignRequest) Validate() error { switch { case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: return errors.Errorf("unknown certType %s", s.CertType) @@ -47,15 +47,15 @@ func (s *SignSSHRequest) Validate() error { } } -// SignSSHResponse is the response object that returns the SSH certificate. -type SignSSHResponse struct { +// SSHSignResponse is the response object that returns the SSH certificate. +type SSHSignResponse struct { Certificate SSHCertificate `json:"crt"` AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"` } -// SSHKeysResponse represents the response object that returns the SSH user and +// SSHRootsResponse represents the response object that returns the SSH user and // host keys. -type SSHKeysResponse struct { +type SSHRootsResponse struct { UserKeys []SSHPublicKey `json:"userKey,omitempty"` HostKeys []SSHPublicKey `json:"hostKey,omitempty"` } @@ -170,11 +170,11 @@ type SSHConfigResponse struct { HostTemplates []Template `json:"hostTemplates,omitempty"` } -// SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token +// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { - var body SignSSHRequest +func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { + var body SSHSignRequest if err := ReadJSON(r.Body, &body); err != nil { WriteError(w, BadRequest(errors.Wrap(err, "error reading request body"))) return @@ -232,16 +232,16 @@ func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusCreated) - JSON(w, &SignSSHResponse{ + JSON(w, &SSHSignResponse{ Certificate: SSHCertificate{cert}, AddUserCertificate: addUserCertificate, }) } -// SSHKeys is an HTTP handler that returns the SSH public keys for user and host +// SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. -func (h *caHandler) SSHKeys(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHKeys() +func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { + keys, err := h.Authority.GetSSHRoots() if err != nil { WriteError(w, InternalServerError(err)) return @@ -252,7 +252,7 @@ func (h *caHandler) SSHKeys(w http.ResponseWriter, r *http.Request) { return } - resp := new(SSHKeysResponse) + resp := new(SSHRootsResponse) for _, k := range keys.HostKeys { resp.HostKeys = append(resp.HostKeys, SSHPublicKey{PublicKey: k}) } @@ -263,10 +263,10 @@ func (h *caHandler) SSHKeys(w http.ResponseWriter, r *http.Request) { JSON(w, resp) } -// SSHFederatedKeys is an HTTP handler that returns the federated SSH public -// keys for user and host certificates. -func (h *caHandler) SSHFederatedKeys(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederatedKeys() +// SSHFederation is an HTTP handler that returns the federated SSH public keys +// for user and host certificates. +func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { + keys, err := h.Authority.GetSSHFederation() if err != nil { WriteError(w, NotFound(err)) return @@ -277,7 +277,7 @@ func (h *caHandler) SSHFederatedKeys(w http.ResponseWriter, r *http.Request) { return } - resp := new(SSHKeysResponse) + resp := new(SSHRootsResponse) for _, k := range keys.HostKeys { resp.HostKeys = append(resp.HostKeys, SSHPublicKey{PublicKey: k}) } diff --git a/api/ssh_test.go b/api/ssh_test.go index 55a0db90..3f5fcdbb 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -219,7 +219,7 @@ func TestSignSSHRequest_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &SignSSHRequest{ + s := &SSHSignRequest{ PublicKey: tt.fields.PublicKey, OTT: tt.fields.OTT, CertType: tt.fields.CertType, @@ -235,7 +235,7 @@ func TestSignSSHRequest_Validate(t *testing.T) { } } -func Test_caHandler_SignSSH(t *testing.T) { +func Test_caHandler_SSHSign(t *testing.T) { user, err := getSignedUserCertificate() assert.FatalError(t, err) host, err := getSignedHostCertificate() @@ -244,17 +244,17 @@ func Test_caHandler_SignSSH(t *testing.T) { userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) - userReq, err := json.Marshal(SignSSHRequest{ + userReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", }) assert.FatalError(t, err) - hostReq, err := json.Marshal(SignSSHRequest{ + hostReq, err := json.Marshal(SSHSignRequest{ PublicKey: host.Key.Marshal(), OTT: "ott", }) assert.FatalError(t, err) - userAddReq, err := json.Marshal(SignSSHRequest{ + userAddReq, err := json.Marshal(SSHSignRequest{ PublicKey: user.Key.Marshal(), OTT: "ott", AddUserPublicKey: user.Key.Marshal(), @@ -299,7 +299,7 @@ func Test_caHandler_SignSSH(t *testing.T) { req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() - h.SignSSH(logging.NewResponseLogger(w), req) + h.SSHSign(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { @@ -320,7 +320,7 @@ func Test_caHandler_SignSSH(t *testing.T) { } } -func Test_caHandler_SSHKeys(t *testing.T) { +func Test_caHandler_SSHRoots(t *testing.T) { user, err := ssh.NewPublicKey(sshUserKey.Public()) assert.FatalError(t, err) userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) @@ -336,22 +336,23 @@ func Test_caHandler_SSHKeys(t *testing.T) { body []byte statusCode int }{ - {"ok", &authority.SSHKeys{HostKey: host, UserKey: user}, nil, []byte(fmt.Sprintf(`{"userKey":"%s","hostKey":"%s"}`, userB64, hostB64)), http.StatusOK}, - {"user", &authority.SSHKeys{UserKey: user}, nil, []byte(fmt.Sprintf(`{"userKey":"%s"}`, userB64)), http.StatusOK}, - {"host", &authority.SSHKeys{HostKey: host}, nil, []byte(fmt.Sprintf(`{"hostKey":"%s"}`, hostB64)), http.StatusOK}, - {"error", nil, fmt.Errorf("an error"), nil, http.StatusNotFound}, + {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, + {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, + {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, + {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, + {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := New(&mockAuthority{ - getSSHKeys: func() (*authority.SSHKeys, error) { + getSSHRoots: func() (*authority.SSHKeys, error) { return tt.keys, tt.keysErr }, }).(*caHandler) req := httptest.NewRequest("GET", "http://example.com/ssh/keys", http.NoBody) w := httptest.NewRecorder() - h.SSHKeys(logging.NewResponseLogger(w), req) + h.SSHRoots(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { diff --git a/authority/ssh.go b/authority/ssh.go index 33f00cec..741d57cf 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -87,16 +87,16 @@ type SSHKeys struct { HostKeys []ssh.PublicKey } -// GetSSHKeys returns the SSH User and Host public keys. -func (a *Authority) GetSSHKeys() (*SSHKeys, error) { +// GetSSHRoots returns the SSH User and Host public keys. +func (a *Authority) GetSSHRoots() (*SSHKeys, error) { return &SSHKeys{ HostKeys: a.sshCAHostCerts, UserKeys: a.sshCAUserCerts, }, nil } -// GetSSHFederatedKeys returns the public keys for federated SSH signers. -func (a *Authority) GetSSHFederatedKeys() (*SSHKeys, error) { +// GetSSHFederation returns the public keys for federated SSH signers. +func (a *Authority) GetSSHFederation() (*SSHKeys, error) { return &SSHKeys{ HostKeys: a.sshCAHostFederatedCerts, UserKeys: a.sshCAUserFederatedCerts, diff --git a/ca/client.go b/ca/client.go index 45e1e7ce..7f34a92c 100644 --- a/ca/client.go +++ b/ca/client.go @@ -373,28 +373,6 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) { return &sign, nil } -// SignSSH performs the SSH certificate sign request to the CA and returns the -// api.SignSSHResponse struct. -func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) { - body, err := json.Marshal(req) - if err != nil { - return nil, errors.Wrap(err, "error marshaling request") - } - u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) - resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) - if err != nil { - return nil, errors.Wrapf(err, "client POST %s failed", u) - } - if resp.StatusCode >= 400 { - return nil, readError(resp.Body) - } - var sign api.SignSSHResponse - if err := readJSON(resp.Body, &sign); err != nil { - return nil, errors.Wrapf(err, "error reading %s", u) - } - return &sign, nil -} - // Renew performs the renew request to the CA and returns the api.SignResponse // struct. func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { @@ -527,10 +505,32 @@ func (c *Client) Federation() (*api.FederationResponse, error) { return &federation, nil } -// SSHKeys performs the get /ssh/keys request to the CA and returns the -// api.SSHKeysResponse struct. -func (c *Client) SSHKeys() (*api.SSHKeysResponse, error) { - u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/keys"}) +// SSHSign performs the POST /ssh/sign request to the CA and returns the +// api.SSHSignResponse struct. +func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var sign api.SSHSignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &sign, nil +} + +// SSHRoots performs the GET /ssh/roots request to the CA and returns the +// api.SSHRootsResponse struct. +func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"}) resp, err := c.client.Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) @@ -538,7 +538,7 @@ func (c *Client) SSHKeys() (*api.SSHKeysResponse, error) { if resp.StatusCode >= 400 { return nil, readError(resp.Body) } - var keys api.SSHKeysResponse + var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } @@ -546,8 +546,8 @@ func (c *Client) SSHKeys() (*api.SSHKeysResponse, error) { } // SSHFederation performs the get /ssh/federation request to the CA and returns -// the api.SSHKeysResponse struct. -func (c *Client) SSHFederation() (*api.SSHKeysResponse, error) { +// the api.SSHRootsResponse struct. +func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) { u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"}) resp, err := c.client.Get(u.String()) if err != nil { @@ -556,15 +556,15 @@ func (c *Client) SSHFederation() (*api.SSHKeysResponse, error) { if resp.StatusCode >= 400 { return nil, readError(resp.Body) } - var keys api.SSHKeysResponse + var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { return nil, errors.Wrapf(err, "error reading %s", u) } return &keys, nil } -// SSHConfig performs the POST request to the CA to get the ssh configuration -// templates. +// SSHConfig performs the POST /ssh/config request to the CA to get the ssh +// configuration templates. func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) { body, err := json.Marshal(req) if err != nil { diff --git a/ca/client_test.go b/ca/client_test.go index 8bb83fe1..a655dce0 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -714,15 +714,15 @@ func TestClient_Federation(t *testing.T) { } } -func TestClient_SSHKeys(t *testing.T) { +func TestClient_SSHRoots(t *testing.T) { key, err := ssh.NewPublicKey(mustKey().Public()) if err != nil { t.Fatal(err) } - ok := &api.SSHKeysResponse{ - HostKey: &api.SSHPublicKey{PublicKey: key}, - UserKey: &api.SSHPublicKey{PublicKey: key}, + ok := &api.SSHRootsResponse{ + HostKeys: []api.SSHPublicKey{{PublicKey: key}}, + UserKeys: []api.SSHPublicKey{{PublicKey: key}}, } notFound := api.NotFound(fmt.Errorf("Not Found")) @@ -751,7 +751,7 @@ func TestClient_SSHKeys(t *testing.T) { api.JSONStatus(w, tt.response, tt.responseCode) }) - got, err := c.SSHKeys() + got, err := c.SSHRoots() if (err != nil) != tt.wantErr { fmt.Printf("%+v", err) t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr)