Use x5cInsecure token for /ssh/check-host endpoint

This commit is contained in:
max furman 2019-12-09 23:14:56 -08:00
parent ab126d6405
commit 3ac388612a
6 changed files with 32 additions and 8 deletions

View file

@ -25,7 +25,7 @@ type SSHAuthority interface {
GetSSHRoots() (*authority.SSHKeys, error) GetSSHRoots() (*authority.SSHKeys, error)
GetSSHFederation() (*authority.SSHKeys, error) GetSSHFederation() (*authority.SSHKeys, error)
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error) GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
CheckSSHHost(principal string) (bool, error) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error)
GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error) GetSSHHosts(cert *x509.Certificate) ([]sshutil.Host, error)
GetSSHBastion(user string, hostname string) (*authority.Bastion, error) GetSSHBastion(user string, hostname string) (*authority.Bastion, error)
} }
@ -199,6 +199,7 @@ type SSHConfigResponse struct {
type SSHCheckPrincipalRequest struct { type SSHCheckPrincipalRequest struct {
Type string `json:"type"` Type string `json:"type"`
Principal string `json:"principal"` Principal string `json:"principal"`
Token string `json:"token,omitempty"`
} }
// Validate checks the check principal request. // Validate checks the check principal request.
@ -431,7 +432,7 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
exists, err := h.Authority.CheckSSHHost(body.Principal) exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token)
if err != nil { if err != nil {
WriteError(w, InternalServerError(err)) WriteError(w, InternalServerError(err))
return return

View file

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto" "crypto"
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
@ -41,6 +42,7 @@ type Authority struct {
initOnce bool initOnce bool
// Custom functions // Custom functions
sshBastionFunc func(user, hostname string) (*Bastion, error) sshBastionFunc func(user, hostname string) (*Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error) sshGetHostsFunc func(cert *x509.Certificate) ([]sshutil.Host, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
} }

View file

@ -1,6 +1,7 @@
package authority package authority
import ( import (
"context"
"crypto/x509" "crypto/x509"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -42,3 +43,12 @@ func WithSSHGetHosts(fn func(cert *x509.Certificate) ([]sshutil.Host, error)) Op
a.sshGetHostsFunc = fn a.sshGetHostsFunc = fn
} }
} }
// WithSSHCheckHost sets a custom function to check whether a given host is
// step ssh enabled. The token is used to validate the request, while the roots
// are used to validate the token.
func WithSSHCheckHost(fn func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)) Option {
return func(a *Authority) {
a.sshCheckHostFunc = fn
}
}

View file

@ -656,7 +656,17 @@ func (a *Authority) SignSSHAddUser(key ssh.PublicKey, subject *ssh.Certificate)
} }
// CheckSSHHost checks the given principal has been registered before. // CheckSSHHost checks the given principal has been registered before.
func (a *Authority) CheckSSHHost(principal string) (bool, error) { func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) {
if a.sshCheckHostFunc != nil {
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
if err != nil {
return false, &apiError{
err: errors.Wrap(err, "checkSSHHost: error from injected checkSSHHost func"),
code: http.StatusInternalServerError,
}
}
return exists, nil
}
exists, err := a.db.IsSSHHost(principal) exists, err := a.db.IsSSHHost(principal)
if err != nil { if err != nil {
if err == db.ErrNotImplemented { if err == db.ErrNotImplemented {

View file

@ -952,11 +952,12 @@ retry:
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the // SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal. // given principal.
func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse, error) { func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert, Type: provisioner.SSHHostCert,
Principal: principal, Principal: principal,
Token: token,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshaling request") return nil, errors.Wrap(err, "error marshaling request")

2
go.mod
View file

@ -18,4 +18,4 @@ require (
gopkg.in/square/go-jose.v2 v2.4.0 gopkg.in/square/go-jose.v2 v2.4.0
) )
// replace github.com/smallstep/cli => ../cli //replace github.com/smallstep/cli => ../cli