Add getSSHHosts injection func
This commit is contained in:
parent
414a94b210
commit
d940ab7c20
4 changed files with 33 additions and 12 deletions
15
api/ssh.go
15
api/ssh.go
|
@ -23,7 +23,7 @@ type SSHAuthority interface {
|
|||
GetSSHFederation() (*authority.SSHKeys, error)
|
||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
||||
CheckSSHHost(principal string) (bool, error)
|
||||
GetSSHHosts() ([]string, error)
|
||||
GetSSHHosts(user string) ([]string, error)
|
||||
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
|
||||
}
|
||||
|
||||
|
@ -406,7 +406,18 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
hosts, err := h.Authority.GetSSHHosts()
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
||||
return
|
||||
}
|
||||
|
||||
cert := r.TLS.PeerCertificates[0]
|
||||
email := cert.EmailAddresses[0]
|
||||
if len(email) == 0 {
|
||||
WriteError(w, BadRequest(errors.New("client certificate missing email SAN")))
|
||||
return
|
||||
}
|
||||
hosts, err := h.Authority.GetSSHHosts(email)
|
||||
if err != nil {
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
|
|
|
@ -41,6 +41,7 @@ type Authority struct {
|
|||
initOnce bool
|
||||
// Custom functions
|
||||
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
||||
sshGetHostsFunc func(user string) ([]string, error)
|
||||
getIdentityFunc provisioner.GetIdentityFunc
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,14 @@ func WithDatabase(db db.AuthDB) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithGetIdentityFunc sets a custom function to retrieve the identity from
|
||||
// an external resource.
|
||||
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
|
||||
return func(a *Authority) {
|
||||
a.getIdentityFunc = fn
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
||||
// given user-host pair.
|
||||
func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
|
||||
|
@ -24,10 +32,10 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithGetIdentityFunc sets a custom function to retrieve the identity from
|
||||
// an external resource.
|
||||
func WithGetIdentityFunc(fn func(p provisioner.Interface, email string) (*provisioner.Identity, error)) Option {
|
||||
// WithSSHGetHosts sets a custom function to get the bastion for a
|
||||
// given user-host pair.
|
||||
func WithSSHGetHosts(fn func(user string) ([]string, error)) Option {
|
||||
return func(a *Authority) {
|
||||
a.getIdentityFunc = fn
|
||||
a.sshGetHostsFunc = fn
|
||||
}
|
||||
}
|
||||
|
|
|
@ -673,13 +673,14 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
|
|||
}
|
||||
|
||||
// GetSSHHosts returns a list of valid host principals.
|
||||
func (a *Authority) GetSSHHosts() ([]string, error) {
|
||||
ps, err := a.db.GetSSHHostPrincipals()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (a *Authority) GetSSHHosts(email string) ([]string, error) {
|
||||
if a.sshBastionFunc != nil {
|
||||
return a.sshGetHostsFunc(email)
|
||||
}
|
||||
return nil, &apiError{
|
||||
err: errors.New("getSSHHosts is not configured"),
|
||||
code: http.StatusNotFound,
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||
|
|
Loading…
Reference in a new issue