change func def for getSSHHosts

* continue to return all hosts if injection method not specified
This commit is contained in:
max furman 2019-11-20 12:59:48 -08:00
parent 11c8639782
commit f92bb06b6c
4 changed files with 21 additions and 19 deletions

View file

@ -2,6 +2,7 @@ package api
import ( import (
"context" "context"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net/http" "net/http"
@ -23,7 +24,7 @@ type SSHAuthority interface {
GetSSHFederation() (*authority.SSHKeys, error) GetSSHFederation() (*authority.SSHKeys, error)
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
CheckSSHHost(principal string) (bool, error) CheckSSHHost(principal string) (bool, error)
GetSSHHosts(user string) ([]string, error) GetSSHHosts(cert *x509.Certificate) ([]string, error)
GetSSHBastion(user string, hostname string) (*authority.Bastion, error) GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
} }
@ -436,18 +437,12 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { var cert *x509.Certificate
WriteError(w, BadRequest(errors.New("missing peer certificate"))) if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return cert = r.TLS.PeerCertificates[0]
} }
cert := r.TLS.PeerCertificates[0] hosts, err := h.Authority.GetSSHHosts(cert)
email := cert.EmailAddresses[0]
if len(email) == 0 {
WriteError(w, BadRequest(errors.New("client certificate missing email SAN")))
return
}
hosts, err := h.Authority.GetSSHHosts(email)
if err != nil { if err != nil {
WriteError(w, InternalServerError(err)) WriteError(w, InternalServerError(err))
return return

View file

@ -41,7 +41,7 @@ type Authority struct {
initOnce bool initOnce bool
// Custom functions // Custom functions
sshBastionFunc func(user, hostname string) (*Bastion, error) sshBastionFunc func(user, hostname string) (*Bastion, error)
sshGetHostsFunc func(user string) ([]string, error) sshGetHostsFunc func(cert *x509.Certificate) ([]string, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
} }

View file

@ -1,6 +1,8 @@
package authority package authority
import ( import (
"crypto/x509"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
) )
@ -34,7 +36,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
// WithSSHGetHosts sets a custom function to get the bastion for a // WithSSHGetHosts sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHGetHosts(fn func(user string) ([]string, error)) Option { func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]string, error)) Option {
return func(a *Authority) { return func(a *Authority) {
a.sshGetHostsFunc = fn a.sshGetHostsFunc = fn
} }

View file

@ -3,6 +3,7 @@ package authority
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/x509"
"encoding/binary" "encoding/binary"
"net/http" "net/http"
"strings" "strings"
@ -673,14 +674,18 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
} }
// GetSSHHosts returns a list of valid host principals. // GetSSHHosts returns a list of valid host principals.
func (a *Authority) GetSSHHosts(email string) ([]string, error) { func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]string, error) {
if a.sshBastionFunc != nil { if a.sshGetHostsFunc != nil {
return a.sshGetHostsFunc(email) return a.sshGetHostsFunc(cert)
} }
return nil, &apiError{ hosts, err := a.db.GetSSHHostPrincipals()
err: errors.New("getSSHHosts is not configured"), if err != nil {
code: http.StatusNotFound, return nil, &apiError{
err: errors.Wrap(err, "getSSHHosts"),
code: http.StatusInternalServerError,
}
} }
return hosts, nil
} }
func (a *Authority) getAddUserPrincipal() (cmd string) { func (a *Authority) getAddUserPrincipal() (cmd string) {