change func def for getSSHHosts
* continue to return all hosts if injection method not specified
This commit is contained in:
parent
3fda081e42
commit
35912cc906
4 changed files with 21 additions and 19 deletions
17
api/ssh.go
17
api/ssh.go
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -23,7 +24,7 @@ type SSHAuthority interface {
|
||||||
GetSSHFederation() (*authority.SSHKeys, error)
|
GetSSHFederation() (*authority.SSHKeys, error)
|
||||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
||||||
CheckSSHHost(principal string) (bool, error)
|
CheckSSHHost(principal string) (bool, error)
|
||||||
GetSSHHosts(user string) ([]string, error)
|
GetSSHHosts(cert *x509.Certificate) ([]string, error)
|
||||||
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
|
GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -436,18 +437,12 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
var cert *x509.Certificate
|
||||||
WriteError(w, BadRequest(errors.New("missing peer certificate")))
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||||
return
|
cert = r.TLS.PeerCertificates[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
cert := r.TLS.PeerCertificates[0]
|
hosts, err := h.Authority.GetSSHHosts(cert)
|
||||||
email := cert.EmailAddresses[0]
|
|
||||||
if len(email) == 0 {
|
|
||||||
WriteError(w, BadRequest(errors.New("client certificate missing email SAN")))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hosts, err := h.Authority.GetSSHHosts(email)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
WriteError(w, InternalServerError(err))
|
WriteError(w, InternalServerError(err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -41,7 +41,7 @@ type Authority struct {
|
||||||
initOnce bool
|
initOnce bool
|
||||||
// Custom functions
|
// Custom functions
|
||||||
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
sshBastionFunc func(user, hostname string) (*Bastion, error)
|
||||||
sshGetHostsFunc func(user string) ([]string, error)
|
sshGetHostsFunc func(cert *x509.Certificate) ([]string, error)
|
||||||
getIdentityFunc provisioner.GetIdentityFunc
|
getIdentityFunc provisioner.GetIdentityFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
)
|
)
|
||||||
|
@ -34,7 +36,7 @@ func WithSSHBastionFunc(fn func(user, host string) (*Bastion, error)) Option {
|
||||||
|
|
||||||
// WithSSHGetHosts sets a custom function to get the bastion for a
|
// WithSSHGetHosts sets a custom function to get the bastion for a
|
||||||
// given user-host pair.
|
// given user-host pair.
|
||||||
func WithSSHGetHosts(fn func(user string) ([]string, error)) Option {
|
func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]string, error)) Option {
|
||||||
return func(a *Authority) {
|
return func(a *Authority) {
|
||||||
a.sshGetHostsFunc = fn
|
a.sshGetHostsFunc = fn
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package authority
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -673,14 +674,18 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSSHHosts returns a list of valid host principals.
|
// GetSSHHosts returns a list of valid host principals.
|
||||||
func (a *Authority) GetSSHHosts(email string) ([]string, error) {
|
func (a *Authority) GetSSHHosts(cert *x509.Certificate) ([]string, error) {
|
||||||
if a.sshBastionFunc != nil {
|
if a.sshGetHostsFunc != nil {
|
||||||
return a.sshGetHostsFunc(email)
|
return a.sshGetHostsFunc(cert)
|
||||||
}
|
}
|
||||||
return nil, &apiError{
|
hosts, err := a.db.GetSSHHostPrincipals()
|
||||||
err: errors.New("getSSHHosts is not configured"),
|
if err != nil {
|
||||||
code: http.StatusNotFound,
|
return nil, &apiError{
|
||||||
|
err: errors.Wrap(err, "getSSHHosts"),
|
||||||
|
code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return hosts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||||
|
|
Loading…
Reference in a new issue