diff --git a/api/ssh.go b/api/ssh.go index 43b24d52..cec2dcb7 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -2,6 +2,7 @@ package api import ( "context" + "crypto/x509" "encoding/base64" "encoding/json" "net/http" @@ -23,7 +24,7 @@ type SSHAuthority interface { GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) CheckSSHHost(principal string) (bool, error) - GetSSHHosts(user string) ([]string, error) + GetSSHHosts(cert *x509.Certificate) ([]string, error) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) } @@ -436,18 +437,12 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { - WriteError(w, BadRequest(errors.New("missing peer certificate"))) - return + var cert *x509.Certificate + if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { + cert = r.TLS.PeerCertificates[0] } - cert := r.TLS.PeerCertificates[0] - email := cert.EmailAddresses[0] - if len(email) == 0 { - WriteError(w, BadRequest(errors.New("client certificate missing email SAN"))) - return - } - hosts, err := h.Authority.GetSSHHosts(email) + hosts, err := h.Authority.GetSSHHosts(cert) if err != nil { WriteError(w, InternalServerError(err)) return diff --git a/authority/authority.go b/authority/authority.go index e00d978c..44fe3fc7 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -41,7 +41,7 @@ type Authority struct { initOnce bool // Custom functions sshBastionFunc func(user, hostname string) (*Bastion, error) - sshGetHostsFunc func(user string) ([]string, error) + sshGetHostsFunc func(cert *x509.Certificate) ([]string, error) getIdentityFunc provisioner.GetIdentityFunc } diff --git a/authority/options.go b/authority/options.go index 5a161118..f1738e68 100644 --- a/authority/options.go +++ b/authority/options.go @@ -1,6 +1,8 @@ package authority import ( + "crypto/x509" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" ) @@ -34,7 +36,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option { // WithSSHGetHosts sets a custom function to get the bastion for a // given user-host pair. -func WithSSHGetHosts(fn func(user string) ([]string, error)) Option { +func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]string, error)) Option { return func(a *Authority) { a.sshGetHostsFunc = fn } diff --git a/authority/ssh.go b/authority/ssh.go index 4f34d81a..779a6da9 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -3,6 +3,7 @@ package authority import ( "context" "crypto/rand" + "crypto/x509" "encoding/binary" "net/http" "strings" @@ -673,14 +674,18 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) { } // GetSSHHosts returns a list of valid host principals. -func (a *Authority) GetSSHHosts(email string) ([]string, error) { - if a.sshBastionFunc != nil { - return a.sshGetHostsFunc(email) +func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]string, error) { + if a.sshGetHostsFunc != nil { + return a.sshGetHostsFunc(cert) } - return nil, &apiError{ - err: errors.New("getSSHHosts is not configured"), - code: http.StatusNotFound, + hosts, err := a.db.GetSSHHostPrincipals() + if err != nil { + return nil, &apiError{ + err: errors.Wrap(err, "getSSHHosts"), + code: http.StatusInternalServerError, + } } + return hosts, nil } func (a *Authority) getAddUserPrincipal() (cmd string) {