From c1fc45c87234adf19b4656318aec4cb463347757 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 29 Jul 2020 16:06:39 -0700 Subject: [PATCH] Simplify SSH modifiers with options. It also changes the behavior of the request options to modify only the validity of the certificate. --- authority/provisioner/sign_ssh_options.go | 147 ++++++++++------------ authority/ssh.go | 54 +++++--- 2 files changed, 103 insertions(+), 98 deletions(-) diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 6352204f..a9638d24 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -24,14 +24,7 @@ const ( // certificate. type SSHCertModifier interface { SignOption - Modify(cert *ssh.Certificate) error -} - -// SSHCertOptionModifier is the interface used to add custom options used -// to modify the SSH certificate. -type SSHCertOptionModifier interface { - SignOption - Option(o SignSSHOptions) SSHCertModifier + Modify(cert *ssh.Certificate, opts SignSSHOptions) error } // SSHCertValidator is the interface used to validate an SSH certificate. @@ -47,14 +40,6 @@ type SSHCertOptionsValidator interface { Valid(got SignSSHOptions) error } -// sshModifierFunc is an adapter to allow the use of ordinary functions as SSH -// certificate modifiers. -type sshModifierFunc func(cert *ssh.Certificate) error - -func (f sshModifierFunc) Modify(cert *ssh.Certificate) error { - return f(cert) -} - // SignSSHOptions contains the options that can be passed to the SignSSH method. type SignSSHOptions struct { CertType string `json:"certType"` @@ -72,7 +57,7 @@ func (o SignSSHOptions) Type() uint32 { } // Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate. -func (o SignSSHOptions) Modify(cert *ssh.Certificate) error { +func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { switch o.CertType { case "": // ignore case SSHUserCert: @@ -86,6 +71,12 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error { cert.KeyId = o.KeyID cert.ValidPrincipals = o.Principals + return o.ModifyValidity(cert) +} + +// ModifyValidity modifies only the ValidAfter and ValidBefore on the given +// ssh.Certificate. +func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error { t := now() if !o.ValidAfter.IsZero() { cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix()) @@ -96,7 +87,6 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error { if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore { return errors.New("ssh certificate valid after cannot be greater than valid before") } - return nil } @@ -123,7 +113,7 @@ func (o SignSSHOptions) match(got SignSSHOptions) error { type sshCertPrincipalsModifier []string // Modify the ValidPrincipals value of the cert. -func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { +func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.ValidPrincipals = []string(o) return nil } @@ -132,7 +122,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { // Key ID in the SSH certificate. type sshCertKeyIDModifier string -func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.KeyId = string(m) return nil } @@ -142,7 +132,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error { type sshCertTypeModifier string // Modify sets the CertType for the ssh certificate. -func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.CertType = sshCertTypeUInt32(string(m)) return nil } @@ -151,7 +141,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { // ValidAfter in the SSH certificate. type sshCertValidAfterModifier uint64 -func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.ValidAfter = uint64(m) return nil } @@ -160,7 +150,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error { // ValidBefore in the SSH certificate. type sshCertValidBeforeModifier uint64 -func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error { +func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error { cert.ValidBefore = uint64(m) return nil } @@ -217,27 +207,27 @@ type sshDefaultDuration struct { *Claimer } -func (m *sshDefaultDuration) Option(o SignSSHOptions) SSHCertModifier { - return sshModifierFunc(func(cert *ssh.Certificate) error { - d, err := m.DefaultSSHCertDuration(cert.CertType) - if err != nil { - return err - } +// Modify implements SSHCertModifier and sets the validity if it has not been +// set, but it always applies the backdate. +func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error { + d, err := m.DefaultSSHCertDuration(cert.CertType) + if err != nil { + return err + } - var backdate uint64 - if cert.ValidAfter == 0 { - backdate = uint64(o.Backdate / time.Second) - cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) - } - if cert.ValidBefore == 0 { - cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second) - } - // Apply backdate safely - if cert.ValidAfter > backdate { - cert.ValidAfter -= backdate - } - return nil - }) + var backdate uint64 + if cert.ValidAfter == 0 { + backdate = uint64(o.Backdate / time.Second) + cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) + } + if cert.ValidBefore == 0 { + cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second) + } + // Apply backdate safely + if cert.ValidAfter > backdate { + cert.ValidAfter -= backdate + } + return nil } // sshLimitDuration adjusts the duration to min(default, remaining provisioning @@ -250,51 +240,52 @@ type sshLimitDuration struct { NotAfter time.Time } -func (m *sshLimitDuration) Option(o SignSSHOptions) SSHCertModifier { +// Modify implements SSHCertModifier and modifies the validity of the +// certificate to expire before the configured limit. +func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error { if m.NotAfter.IsZero() { defaultDuration := &sshDefaultDuration{m.Claimer} - return defaultDuration.Option(o) + return defaultDuration.Modify(cert, o) } - return sshModifierFunc(func(cert *ssh.Certificate) error { - d, err := m.DefaultSSHCertDuration(cert.CertType) - if err != nil { - return err - } + // Make sure the duration is within the limits. + d, err := m.DefaultSSHCertDuration(cert.CertType) + if err != nil { + return err + } - var backdate uint64 - if cert.ValidAfter == 0 { - backdate = uint64(o.Backdate / time.Second) - cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) - } + var backdate uint64 + if cert.ValidAfter == 0 { + backdate = uint64(o.Backdate / time.Second) + cert.ValidAfter = uint64(now().Truncate(time.Second).Unix()) + } - certValidAfter := time.Unix(int64(cert.ValidAfter), 0) - if certValidAfter.After(m.NotAfter) { - return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", - m.NotAfter, certValidAfter) - } + certValidAfter := time.Unix(int64(cert.ValidAfter), 0) + if certValidAfter.After(m.NotAfter) { + return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)", + m.NotAfter, certValidAfter) + } - if cert.ValidBefore == 0 { - certValidBefore := certValidAfter.Add(d) - if m.NotAfter.Before(certValidBefore) { - certValidBefore = m.NotAfter - } - cert.ValidBefore = uint64(certValidBefore.Unix()) - } else { - certValidBefore := time.Unix(int64(cert.ValidBefore), 0) - if m.NotAfter.Before(certValidBefore) { - return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", - m.NotAfter, certValidBefore) - } + if cert.ValidBefore == 0 { + certValidBefore := certValidAfter.Add(d) + if m.NotAfter.Before(certValidBefore) { + certValidBefore = m.NotAfter } - - // Apply backdate safely - if cert.ValidAfter > backdate { - cert.ValidAfter -= backdate + cert.ValidBefore = uint64(certValidBefore.Unix()) + } else { + certValidBefore := time.Unix(int64(cert.ValidBefore), 0) + if m.NotAfter.Before(certValidBefore) { + return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)", + m.NotAfter, certValidBefore) } + } - return nil - }) + // Apply backdate safely + if cert.ValidAfter > backdate { + cert.ValidAfter -= backdate + } + + return nil } // sshCertOptionsValidator validates the user SSHOptions with the ones diff --git a/authority/ssh.go b/authority/ssh.go index 61050029..1d449b80 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -206,6 +206,8 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname str // SignSSH creates a signed SSH certificate with the given public key and options. func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var ( + err error + certType sshutil.CertType certOptions []sshutil.Option mods []provisioner.SSHCertModifier validators []provisioner.SSHCertValidator @@ -214,6 +216,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // Set backdate with the configured value opts.Backdate = a.config.AuthorityConfig.Backdate.Duration + // Validate certificate type. + if opts.CertType != "" { + certType, err = sshutil.CertTypeFromString(opts.CertType) + if err != nil { + return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") + } + } + for _, op := range signOpts { switch o := op.(type) { // add options to NewCertificate @@ -224,10 +234,6 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi case provisioner.SSHCertModifier: mods = append(mods, o) - // modify the ssh.Certificate given the SSHOptions - case provisioner.SSHCertOptionModifier: - mods = append(mods, o.Option(opts)) - // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) @@ -235,16 +241,24 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi // validate the given SSHOptions case provisioner.SSHCertOptionsValidator: if err := o.Valid(opts); err != nil { - return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") + return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") } default: - return nil, errs.InternalServer("signSSH: invalid extra option type %T", o) + return nil, errs.InternalServer("authority.SignSSH: invalid extra option type %T", o) } } + // Simulated certificate request with request options. + cr := sshutil.CertificateRequest{ + Type: certType, + KeyID: opts.KeyID, + Principals: opts.Principals, + Key: key, + } + // Create certificate from template. - certificate, err := sshutil.NewCertificate(key, certOptions...) + certificate, err := sshutil.NewCertificate(cr, certOptions...) if err != nil { if _, ok := err.(*sshutil.TemplateError); ok { return nil, errs.NewErr(http.StatusBadRequest, err, @@ -255,19 +269,19 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH") } - // Get actual *ssh.Certificate and continue with user and provisioner - // modifiers. + // Get actual *ssh.Certificate and continue with provisioner modifiers. cert := certificate.GetCertificate() - // Use SignSSHOptions to modify the certificate. - if err := opts.Modify(cert); err != nil { - return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") + // Use SignSSHOptions to modify the certificate validity. It will be later + // checked or set if not defined. + if err := opts.ModifyValidity(cert); err != nil { + return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") } // Use provisioner modifiers. for _, m := range mods { - if err := m.Modify(cert); err != nil { - return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") + if err := m.Modify(cert, opts); err != nil { + return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") } } @@ -276,33 +290,33 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi switch cert.CertType { case ssh.UserCert: if a.sshCAUserCertSignKey == nil { - return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled") + return nil, errs.NotImplemented("authority.SignSSH: user certificate signing is not enabled") } signer = a.sshCAUserCertSignKey case ssh.HostCert: if a.sshCAHostCertSignKey == nil { - return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled") + return nil, errs.NotImplemented("authority.SignSSH: host certificate signing is not enabled") } signer = a.sshCAHostCertSignKey default: - return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType) + return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", cert.CertType) } // Sign certificate. cert, err = sshutil.CreateCertificate(cert, signer) if err != nil { - return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate") } // User provisioners validators. for _, v := range validators { if err := v.Valid(cert, opts); err != nil { - return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") + return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH") } } if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented { - return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db") + return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db") } return cert, nil