diff --git a/api/ssh.go b/api/ssh.go index b559c27a..546c8f1e 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -25,7 +25,7 @@ type SSHAuthority interface { GetSSHRoots() (*authority.SSHKeys, error) GetSSHFederation() (*authority.SSHKeys, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) - CheckSSHHost(principal string) (bool, error) + CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) GetSSHBastion(user string, hostname string) (*authority.Bastion, error) } @@ -199,6 +199,7 @@ type SSHConfigResponse struct { type SSHCheckPrincipalRequest struct { Type string `json:"type"` Principal string `json:"principal"` + Token string `json:"token,omitempty"` } // Validate checks the check principal request. @@ -431,7 +432,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } - exists, err := h.Authority.CheckSSHHost(body.Principal) + exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) if err != nil { WriteError(w, InternalServerError(err)) return diff --git a/authority/authority.go b/authority/authority.go index 9d04f339..25b40350 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto" "crypto/sha256" "crypto/x509" @@ -40,9 +41,10 @@ type Authority struct { // Do not re-initialize initOnce bool // Custom functions - sshBastionFunc func(user, hostname string) (*Bastion, error) - sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error) - getIdentityFunc provisioner.GetIdentityFunc + sshBastionFunc func(user, hostname string) (*Bastion, error) + sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) + sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error) + getIdentityFunc provisioner.GetIdentityFunc } // New creates and initiates a new Authority type. diff --git a/authority/options.go b/authority/options.go index a2e19edb..10f0ec1a 100644 --- a/authority/options.go +++ b/authority/options.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto/x509" "github.com/smallstep/certificates/authority/provisioner" @@ -42,3 +43,12 @@ func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Op a.sshGetHostsFunc = fn } } + +// WithSSHCheckHost sets a custom function to check whether a given host is +// step ssh enabled. The token is used to validate the request, while the roots +// are used to validate the token. +func WithSSHCheckHost(fn func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)) Option { + return func(a *Authority) { + a.sshCheckHostFunc = fn + } +} diff --git a/authority/ssh.go b/authority/ssh.go index 232527a8..fbf97545 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -656,7 +656,17 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate) } // CheckSSHHost checks the given principal has been registered before. -func (a *Authority) CheckSSHHost(principal string) (bool, error) { +func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) { + if a.sshCheckHostFunc != nil { + exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) + if err != nil { + return false, &apiError{ + err: errors.Wrap(err, "checkSSHHost: error from injected checkSSHHost func"), + code: http.StatusInternalServerError, + } + } + return exists, nil + } exists, err := a.db.IsSSHHost(principal) if err != nil { if err == db.ErrNotImplemented { diff --git a/ca/client.go b/ca/client.go index a5d88808..66e97275 100644 --- a/ca/client.go +++ b/ca/client.go @@ -952,11 +952,12 @@ retry: // SSHCheckHost performs the POST /ssh/check-host request to the CA with the // given principal. -func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) { +func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) { var retried bool body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ Type: provisioner.SSHHostCert, Principal: principal, + Token: token, }) if err != nil { return nil, errors.Wrap(err, "error marshaling request") diff --git a/go.mod b/go.mod index 807f8bfc..67464278 100644 --- a/go.mod +++ b/go.mod @@ -18,4 +18,4 @@ require ( gopkg.in/square/go-jose.v2 v2.4.0 ) -// replace github.com/smallstep/cli => ../cli +//replace github.com/smallstep/cli => ../cli