Simplify SSH modifiers with options.

It also changes the behavior of the request options to modify only
the validity of the certificate.
This commit is contained in:
Mariano Cano 2020-07-29 16:06:39 -07:00
parent df1f7e5a2e
commit c1fc45c872
2 changed files with 103 additions and 98 deletions

View file

@ -24,14 +24,7 @@ const (
// certificate. // certificate.
type SSHCertModifier interface { type SSHCertModifier interface {
SignOption SignOption
Modify(cert *ssh.Certificate) error Modify(cert *ssh.Certificate, opts SignSSHOptions) error
}
// SSHCertOptionModifier is the interface used to add custom options used
// to modify the SSH certificate.
type SSHCertOptionModifier interface {
SignOption
Option(o SignSSHOptions) SSHCertModifier
} }
// SSHCertValidator is the interface used to validate an SSH certificate. // SSHCertValidator is the interface used to validate an SSH certificate.
@ -47,14 +40,6 @@ type SSHCertOptionsValidator interface {
Valid(got SignSSHOptions) error Valid(got SignSSHOptions) error
} }
// sshModifierFunc is an adapter to allow the use of ordinary functions as SSH
// certificate modifiers.
type sshModifierFunc func(cert *ssh.Certificate) error
func (f sshModifierFunc) Modify(cert *ssh.Certificate) error {
return f(cert)
}
// SignSSHOptions contains the options that can be passed to the SignSSH method. // SignSSHOptions contains the options that can be passed to the SignSSH method.
type SignSSHOptions struct { type SignSSHOptions struct {
CertType string `json:"certType"` CertType string `json:"certType"`
@ -72,7 +57,7 @@ func (o SignSSHOptions) Type() uint32 {
} }
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate. // Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
func (o SignSSHOptions) Modify(cert *ssh.Certificate) error { func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
switch o.CertType { switch o.CertType {
case "": // ignore case "": // ignore
case SSHUserCert: case SSHUserCert:
@ -86,6 +71,12 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
cert.KeyId = o.KeyID cert.KeyId = o.KeyID
cert.ValidPrincipals = o.Principals cert.ValidPrincipals = o.Principals
return o.ModifyValidity(cert)
}
// ModifyValidity modifies only the ValidAfter and ValidBefore on the given
// ssh.Certificate.
func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
t := now() t := now()
if !o.ValidAfter.IsZero() { if !o.ValidAfter.IsZero() {
cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix()) cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
@ -96,7 +87,6 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
return errors.New("ssh certificate valid after cannot be greater than valid before") return errors.New("ssh certificate valid after cannot be greater than valid before")
} }
return nil return nil
} }
@ -123,7 +113,7 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
type sshCertPrincipalsModifier []string type sshCertPrincipalsModifier []string
// Modify the ValidPrincipals value of the cert. // Modify the ValidPrincipals value of the cert.
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidPrincipals = []string(o) cert.ValidPrincipals = []string(o)
return nil return nil
} }
@ -132,7 +122,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
// Key ID in the SSH certificate. // Key ID in the SSH certificate.
type sshCertKeyIDModifier string type sshCertKeyIDModifier string
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error { func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.KeyId = string(m) cert.KeyId = string(m)
return nil return nil
} }
@ -142,7 +132,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
type sshCertTypeModifier string type sshCertTypeModifier string
// Modify sets the CertType for the ssh certificate. // Modify sets the CertType for the ssh certificate.
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.CertType = sshCertTypeUInt32(string(m)) cert.CertType = sshCertTypeUInt32(string(m))
return nil return nil
} }
@ -151,7 +141,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
// ValidAfter in the SSH certificate. // ValidAfter in the SSH certificate.
type sshCertValidAfterModifier uint64 type sshCertValidAfterModifier uint64
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error { func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidAfter = uint64(m) cert.ValidAfter = uint64(m)
return nil return nil
} }
@ -160,7 +150,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
// ValidBefore in the SSH certificate. // ValidBefore in the SSH certificate.
type sshCertValidBeforeModifier uint64 type sshCertValidBeforeModifier uint64
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error { func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
cert.ValidBefore = uint64(m) cert.ValidBefore = uint64(m)
return nil return nil
} }
@ -217,27 +207,27 @@ type sshDefaultDuration struct {
*Claimer *Claimer
} }
func (m *sshDefaultDuration) Option(o SignSSHOptions) SSHCertModifier { // Modify implements SSHCertModifier and sets the validity if it has not been
return sshModifierFunc(func(cert *ssh.Certificate) error { // set, but it always applies the backdate.
d, err := m.DefaultSSHCertDuration(cert.CertType) func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
if err != nil { d, err := m.DefaultSSHCertDuration(cert.CertType)
return err if err != nil {
} return err
}
var backdate uint64 var backdate uint64
if cert.ValidAfter == 0 { if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second) backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
} }
if cert.ValidBefore == 0 { if cert.ValidBefore == 0 {
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second) cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
} }
// Apply backdate safely // Apply backdate safely
if cert.ValidAfter > backdate { if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate cert.ValidAfter -= backdate
} }
return nil return nil
})
} }
// sshLimitDuration adjusts the duration to min(default, remaining provisioning // sshLimitDuration adjusts the duration to min(default, remaining provisioning
@ -250,51 +240,52 @@ type sshLimitDuration struct {
NotAfter time.Time NotAfter time.Time
} }
func (m *sshLimitDuration) Option(o SignSSHOptions) SSHCertModifier { // Modify implements SSHCertModifier and modifies the validity of the
// certificate to expire before the configured limit.
func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
if m.NotAfter.IsZero() { if m.NotAfter.IsZero() {
defaultDuration := &sshDefaultDuration{m.Claimer} defaultDuration := &sshDefaultDuration{m.Claimer}
return defaultDuration.Option(o) return defaultDuration.Modify(cert, o)
} }
return sshModifierFunc(func(cert *ssh.Certificate) error { // Make sure the duration is within the limits.
d, err := m.DefaultSSHCertDuration(cert.CertType) d, err := m.DefaultSSHCertDuration(cert.CertType)
if err != nil { if err != nil {
return err return err
} }
var backdate uint64 var backdate uint64
if cert.ValidAfter == 0 { if cert.ValidAfter == 0 {
backdate = uint64(o.Backdate / time.Second) backdate = uint64(o.Backdate / time.Second)
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
} }
certValidAfter := time.Unix(int64(cert.ValidAfter), 0) certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
if certValidAfter.After(m.NotAfter) { if certValidAfter.After(m.NotAfter) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
m.NotAfter, certValidAfter) m.NotAfter, certValidAfter)
} }
if cert.ValidBefore == 0 { if cert.ValidBefore == 0 {
certValidBefore := certValidAfter.Add(d) certValidBefore := certValidAfter.Add(d)
if m.NotAfter.Before(certValidBefore) { if m.NotAfter.Before(certValidBefore) {
certValidBefore = m.NotAfter certValidBefore = m.NotAfter
}
cert.ValidBefore = uint64(certValidBefore.Unix())
} else {
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore)
}
} }
cert.ValidBefore = uint64(certValidBefore.Unix())
// Apply backdate safely } else {
if cert.ValidAfter > backdate { certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
cert.ValidAfter -= backdate if m.NotAfter.Before(certValidBefore) {
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
m.NotAfter, certValidBefore)
} }
}
return nil // Apply backdate safely
}) if cert.ValidAfter > backdate {
cert.ValidAfter -= backdate
}
return nil
} }
// sshCertOptionsValidator validates the user SSHOptions with the ones // sshCertOptionsValidator validates the user SSHOptions with the ones

View file

@ -206,6 +206,8 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname str
// SignSSH creates a signed SSH certificate with the given public key and options. // SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var ( var (
err error
certType sshutil.CertType
certOptions []sshutil.Option certOptions []sshutil.Option
mods []provisioner.SSHCertModifier mods []provisioner.SSHCertModifier
validators []provisioner.SSHCertValidator validators []provisioner.SSHCertValidator
@ -214,6 +216,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Set backdate with the configured value // Set backdate with the configured value
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
// Validate certificate type.
if opts.CertType != "" {
certType, err = sshutil.CertTypeFromString(opts.CertType)
if err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
}
}
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// add options to NewCertificate // add options to NewCertificate
@ -224,10 +234,6 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
case provisioner.SSHCertModifier: case provisioner.SSHCertModifier:
mods = append(mods, o) mods = append(mods, o)
// modify the ssh.Certificate given the SSHOptions
case provisioner.SSHCertOptionModifier:
mods = append(mods, o.Option(opts))
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
@ -235,16 +241,24 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// validate the given SSHOptions // validate the given SSHOptions
case provisioner.SSHCertOptionsValidator: case provisioner.SSHCertOptionsValidator:
if err := o.Valid(opts); err != nil { if err := o.Valid(opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
} }
default: default:
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o) return nil, errs.InternalServer("authority.SignSSH: invalid extra option type %T", o)
} }
} }
// Simulated certificate request with request options.
cr := sshutil.CertificateRequest{
Type: certType,
KeyID: opts.KeyID,
Principals: opts.Principals,
Key: key,
}
// Create certificate from template. // Create certificate from template.
certificate, err := sshutil.NewCertificate(key, certOptions...) certificate, err := sshutil.NewCertificate(cr, certOptions...)
if err != nil { if err != nil {
if _, ok := err.(*sshutil.TemplateError); ok { if _, ok := err.(*sshutil.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err, return nil, errs.NewErr(http.StatusBadRequest, err,
@ -255,19 +269,19 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH") return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH")
} }
// Get actual *ssh.Certificate and continue with user and provisioner // Get actual *ssh.Certificate and continue with provisioner modifiers.
// modifiers.
cert := certificate.GetCertificate() cert := certificate.GetCertificate()
// Use SignSSHOptions to modify the certificate. // Use SignSSHOptions to modify the certificate validity. It will be later
if err := opts.Modify(cert); err != nil { // checked or set if not defined.
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") if err := opts.ModifyValidity(cert); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
} }
// Use provisioner modifiers. // Use provisioner modifiers.
for _, m := range mods { for _, m := range mods {
if err := m.Modify(cert); err != nil { if err := m.Modify(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
} }
} }
@ -276,33 +290,33 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
switch cert.CertType { switch cert.CertType {
case ssh.UserCert: case ssh.UserCert:
if a.sshCAUserCertSignKey == nil { if a.sshCAUserCertSignKey == nil {
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled") return nil, errs.NotImplemented("authority.SignSSH: user certificate signing is not enabled")
} }
signer = a.sshCAUserCertSignKey signer = a.sshCAUserCertSignKey
case ssh.HostCert: case ssh.HostCert:
if a.sshCAHostCertSignKey == nil { if a.sshCAHostCertSignKey == nil {
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled") return nil, errs.NotImplemented("authority.SignSSH: host certificate signing is not enabled")
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType) return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", cert.CertType)
} }
// Sign certificate. // Sign certificate.
cert, err = sshutil.CreateCertificate(cert, signer) cert, err = sshutil.CreateCertificate(cert, signer)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate")
} }
// User provisioners validators. // User provisioners validators.
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert, opts); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
} }
} }
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
} }
return cert, nil return cert, nil