diff --git a/api/api.go b/api/api.go index 3850d921..23f18d35 100644 --- a/api/api.go +++ b/api/api.go @@ -249,9 +249,12 @@ func (h *caHandler) Route(r Router) { r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("GET", "/federation", h.Federation) + // SSH CA + r.MethodFunc("GET", "/ssh/sign", h.SignSSH) + r.MethodFunc("GET", "/ssh/keys", h.SSHKeys) + // For compatibility with old code: r.MethodFunc("POST", "/re-sign", h.Renew) - // SSH CA r.MethodFunc("POST", "/sign-ssh", h.SignSSH) } diff --git a/api/api_test.go b/api/api_test.go index d141247c..8c3034d0 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -512,6 +512,7 @@ type mockAuthority struct { getEncryptedKey func(kid string) (string, error) getRoots func() ([]*x509.Certificate, error) getFederation func() ([]*x509.Certificate, error) + getSSHKeys func() (*authority.SSHKeys, error) } // TODO: remove once Authorize is deprecated. @@ -617,6 +618,13 @@ func (m *mockAuthority) GetFederation() ([]*x509.Certificate, error) { return m.ret1.([]*x509.Certificate), m.err } +func (m *mockAuthority) GetSSHKeys() (*authority.SSHKeys, error) { + if m.getSSHKeys != nil { + return m.getSSHKeys() + } + return m.ret1.(*authority.SSHKeys), m.err +} + func Test_caHandler_Route(t *testing.T) { type fields struct { Authority Authority diff --git a/api/ssh.go b/api/ssh.go index 456f239a..edc49a10 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -16,7 +16,7 @@ import ( type SSHAuthority interface { SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) SignSSHAddUser(key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) - SSHConfig() (*authority.SSHConfiguration, error) + GetSSHKeys() (*authority.SSHKeys, error) } // SignSSHRequest is the request body of an SSH certificate request. @@ -30,15 +30,29 @@ type SignSSHRequest struct { AddUserPublicKey []byte `json:"addUserPublicKey,omitempty"` } +// Validate validates the SignSSHRequest. +func (s *SignSSHRequest) Validate() error { + switch { + case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: + return errors.Errorf("unknown certType %s", s.CertType) + case len(s.PublicKey) == 0: + return errors.New("missing or empty publicKey") + case len(s.OTT) == 0: + return errors.New("missing or empty ott") + default: + return nil + } +} + // SignSSHResponse is the response object that returns the SSH certificate. type SignSSHResponse struct { Certificate SSHCertificate `json:"crt"` AddUserCertificate *SSHCertificate `json:"addUserCrt,omitempty"` } -// SSHConfigResponse represents the response object that returns the SSH user -// and host keys. -type SSHConfigResponse struct { +// SSHKeysResponse represents the response object that returns the SSH user and +// host keys. +type SSHKeysResponse struct { UserKey *SSHPublicKey `json:"userKey,omitempty"` HostKey *SSHPublicKey `json:"hostKey,omitempty"` } @@ -58,21 +72,6 @@ func (c SSHCertificate) MarshalJSON() ([]byte, error) { return []byte(`"` + s + `"`), nil } -// SSHPublicKey represents a public key in a response object. -type SSHPublicKey struct { - ssh.PublicKey -} - -// MarshalJSON implements the json.Marshaler interface. Returns a quoted, -// base64 encoded, openssh wire format version of the public key. -func (p *SSHPublicKey) MarshalJSON() ([]byte, error) { - if p == nil || p.PublicKey == nil { - return []byte("null"), nil - } - s := base64.StdEncoding.EncodeToString(p.PublicKey.Marshal()) - return []byte(`"` + s + `"`), nil -} - // UnmarshalJSON implements the json.Unmarshaler interface. The certificate is // expected to be a quoted, base64 encoded, openssh wire formatted block of bytes. func (c *SSHCertificate) UnmarshalJSON(data []byte) error { @@ -100,18 +99,43 @@ func (c *SSHCertificate) UnmarshalJSON(data []byte) error { return nil } -// Validate validates the SignSSHRequest. -func (s *SignSSHRequest) Validate() error { - switch { - case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: - return errors.Errorf("unknown certType %s", s.CertType) - case len(s.PublicKey) == 0: - return errors.New("missing or empty publicKey") - case len(s.OTT) == 0: - return errors.New("missing or empty ott") - default: +// SSHPublicKey represents a public key in a response object. +type SSHPublicKey struct { + ssh.PublicKey +} + +// MarshalJSON implements the json.Marshaler interface. Returns a quoted, +// base64 encoded, openssh wire format version of the public key. +func (p *SSHPublicKey) MarshalJSON() ([]byte, error) { + if p == nil || p.PublicKey == nil { + return []byte("null"), nil + } + s := base64.StdEncoding.EncodeToString(p.PublicKey.Marshal()) + return []byte(`"` + s + `"`), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. The public key is +// expected to be a quoted, base64 encoded, openssh wire formatted block of +// bytes. +func (p *SSHPublicKey) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "error decoding ssh public key") + } + if s == "" { + p.PublicKey = nil return nil } + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return errors.Wrap(err, "error decoding ssh public key") + } + pub, err := ssh.ParsePublicKey(data) + if err != nil { + return errors.Wrap(err, "error parsing ssh public key") + } + p.PublicKey = pub + return nil } // SignSSH is an HTTP handler that reads an SignSSHRequest with a one-time-token @@ -182,24 +206,24 @@ func (h *caHandler) SignSSH(w http.ResponseWriter, r *http.Request) { }) } -// SSHConfig is an HTTP handler that returns the SSH public keys for user and -// host certificates. -func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { - config, err := h.Authority.SSHConfig() +// SSHKeys is an HTTP handler that returns the SSH public keys for user and host +// certificates. +func (h *caHandler) SSHKeys(w http.ResponseWriter, r *http.Request) { + keys, err := h.Authority.GetSSHKeys() if err != nil { WriteError(w, NotFound(err)) return } var host, user *SSHPublicKey - if config.HostKey != nil { - host = &SSHPublicKey{config.HostKey} + if keys.HostKey != nil { + host = &SSHPublicKey{PublicKey: keys.HostKey} } - if config.UserKey != nil { - user = &SSHPublicKey{config.UserKey} + if keys.UserKey != nil { + user = &SSHPublicKey{PublicKey: keys.UserKey} } - JSON(w, &SSHConfigResponse{ + JSON(w, &SSHKeysResponse{ HostKey: host, UserKey: user, }) diff --git a/api/ssh_test.go b/api/ssh_test.go index 9deb5c88..55a0db90 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/logging" "golang.org/x/crypto/ssh" @@ -296,23 +297,75 @@ func Test_caHandler_SignSSH(t *testing.T) { }, }).(*caHandler) - req := httptest.NewRequest("POST", "http://example.com/sign-ssh", bytes.NewReader(tt.req)) + req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) w := httptest.NewRecorder() h.SignSSH(logging.NewResponseLogger(w), req) res := w.Result() if res.StatusCode != tt.statusCode { - t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) } body, err := ioutil.ReadAll(res.Body) res.Body.Close() if err != nil { - t.Errorf("caHandler.Root unexpected error = %v", err) + t.Errorf("caHandler.SignSSH unexpected error = %v", err) } if tt.statusCode < http.StatusBadRequest { if !bytes.Equal(bytes.TrimSpace(body), tt.body) { - t.Errorf("caHandler.Root Body = %s, wants %s", body, tt.body) + t.Errorf("caHandler.SignSSH Body = %s, wants %s", body, tt.body) + } + } + }) + } +} + +func Test_caHandler_SSHKeys(t *testing.T) { + user, err := ssh.NewPublicKey(sshUserKey.Public()) + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + + host, err := ssh.NewPublicKey(sshHostKey.Public()) + assert.FatalError(t, err) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + tests := []struct { + name string + keys *authority.SSHKeys + keysErr error + body []byte + statusCode int + }{ + {"ok", &authority.SSHKeys{HostKey: host, UserKey: user}, nil, []byte(fmt.Sprintf(`{"userKey":"%s","hostKey":"%s"}`, userB64, hostB64)), http.StatusOK}, + {"user", &authority.SSHKeys{UserKey: user}, nil, []byte(fmt.Sprintf(`{"userKey":"%s"}`, userB64)), http.StatusOK}, + {"host", &authority.SSHKeys{HostKey: host}, nil, []byte(fmt.Sprintf(`{"hostKey":"%s"}`, hostB64)), http.StatusOK}, + {"error", nil, fmt.Errorf("an error"), nil, http.StatusNotFound}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := New(&mockAuthority{ + getSSHKeys: func() (*authority.SSHKeys, error) { + return tt.keys, tt.keysErr + }, + }).(*caHandler) + + req := httptest.NewRequest("GET", "http://example.com/ssh/keys", http.NoBody) + w := httptest.NewRecorder() + h.SSHKeys(logging.NewResponseLogger(w), req) + res := w.Result() + + if res.StatusCode != tt.statusCode { + t.Errorf("caHandler.SSHKeys StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Errorf("caHandler.SSHKeys unexpected error = %v", err) + } + if tt.statusCode < http.StatusBadRequest { + if !bytes.Equal(bytes.TrimSpace(body), tt.body) { + t.Errorf("caHandler.SSHKeys Body = %s, wants %s", body, tt.body) } } }) diff --git a/authority/ssh.go b/authority/ssh.go index dc7ebe0c..c83ce88b 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -24,28 +24,28 @@ const ( SSHAddUserCommand = "sudo useradd -m ; nc -q0 localhost 22" ) -// SSHConfiguration is the return type for SSHConfig. -type SSHConfiguration struct { +// SSHKeys represents the SSH User and Host public keys. +type SSHKeys struct { UserKey ssh.PublicKey HostKey ssh.PublicKey } -// SSHConfig returns the SSH User and Host public keys. -func (a *Authority) SSHConfig() (*SSHConfiguration, error) { - var config SSHConfiguration +// GetSSHKeys returns the SSH User and Host public keys. +func (a *Authority) GetSSHKeys() (*SSHKeys, error) { + var keys SSHKeys if a.sshCAUserCertSignKey != nil { - config.UserKey = a.sshCAUserCertSignKey.PublicKey() + keys.UserKey = a.sshCAUserCertSignKey.PublicKey() } if a.sshCAHostCertSignKey != nil { - config.HostKey = a.sshCAHostCertSignKey.PublicKey() + keys.HostKey = a.sshCAHostCertSignKey.PublicKey() } - if config.UserKey == nil && config.HostKey == nil { + if keys.UserKey == nil && keys.HostKey == nil { return nil, &apiError{ err: errors.New("sshConfig: ssh is not configured"), code: http.StatusNotFound, } } - return &config, nil + return &keys, nil } // SignSSH creates a signed SSH certificate with the given public key and options. diff --git a/ca/client.go b/ca/client.go index 826bee7f..fc964b89 100644 --- a/ca/client.go +++ b/ca/client.go @@ -380,7 +380,7 @@ func (c *Client) SignSSH(req *api.SignSSHRequest) (*api.SignSSHResponse, error) if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - u := c.endpoint.ResolveReference(&url.URL{Path: "/sign-ssh"}) + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"}) resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", u) @@ -527,6 +527,24 @@ func (c *Client) Federation() (*api.FederationResponse, error) { return &federation, nil } +// SSHKeys performs the get ssh keys request to the CA and returns the +// api.SSHKeysResponse struct. +func (c *Client) SSHKeys() (*api.SSHKeysResponse, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/keys"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + return nil, readError(resp.Body) + } + var keys api.SSHKeysResponse + if err := readJSON(resp.Body, &keys); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return &keys, nil +} + // RootFingerprint is a helper method that returns the current root fingerprint. // It does an health connection and gets the fingerprint from the TLS verified // chains. diff --git a/ca/client_test.go b/ca/client_test.go index 1c90e52b..8bb83fe1 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -2,6 +2,9 @@ package ca import ( "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" "encoding/json" "encoding/pem" @@ -17,6 +20,7 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/cli/crypto/x509util" + "golang.org/x/crypto/ssh" ) const ( @@ -96,6 +100,14 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` ) +func mustKey() *ecdsa.PrivateKey { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + return priv +} + func parseCertificate(data string) *x509.Certificate { block, _ := pem.Decode([]byte(data)) if block == nil { @@ -702,6 +714,67 @@ func TestClient_Federation(t *testing.T) { } } +func TestClient_SSHKeys(t *testing.T) { + key, err := ssh.NewPublicKey(mustKey().Public()) + if err != nil { + t.Fatal(err) + } + + ok := &api.SSHKeysResponse{ + HostKey: &api.SSHPublicKey{PublicKey: key}, + UserKey: &api.SSHPublicKey{PublicKey: key}, + } + notFound := api.NotFound(fmt.Errorf("Not Found")) + + tests := []struct { + name string + response interface{} + responseCode int + wantErr bool + }{ + {"ok", ok, 200, false}, + {"not found", notFound, 404, true}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + api.JSONStatus(w, tt.response, tt.responseCode) + }) + + got, err := c.SSHKeys() + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.SSHKeys() = %v, want nil", got) + } + if !reflect.DeepEqual(err, tt.response) { + t.Errorf("Client.SSHKeys() error = %v, want %v", err, tt.response) + } + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) + } + } + }) + } +} + func Test_parseEndpoint(t *testing.T) { expected1 := &url.URL{Scheme: "https", Host: "ca.smallstep.com"} expected2 := &url.URL{Scheme: "https", Host: "ca.smallstep.com", Path: "/1.0/sign"}