forked from TrueCloudLab/certificates
Simplify SSH modifiers with options.
It also changes the behavior of the request options to modify only the validity of the certificate.
This commit is contained in:
parent
df1f7e5a2e
commit
c1fc45c872
2 changed files with 103 additions and 98 deletions
|
@ -24,14 +24,7 @@ const (
|
||||||
// certificate.
|
// certificate.
|
||||||
type SSHCertModifier interface {
|
type SSHCertModifier interface {
|
||||||
SignOption
|
SignOption
|
||||||
Modify(cert *ssh.Certificate) error
|
Modify(cert *ssh.Certificate, opts SignSSHOptions) error
|
||||||
}
|
|
||||||
|
|
||||||
// SSHCertOptionModifier is the interface used to add custom options used
|
|
||||||
// to modify the SSH certificate.
|
|
||||||
type SSHCertOptionModifier interface {
|
|
||||||
SignOption
|
|
||||||
Option(o SignSSHOptions) SSHCertModifier
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCertValidator is the interface used to validate an SSH certificate.
|
// SSHCertValidator is the interface used to validate an SSH certificate.
|
||||||
|
@ -47,14 +40,6 @@ type SSHCertOptionsValidator interface {
|
||||||
Valid(got SignSSHOptions) error
|
Valid(got SignSSHOptions) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshModifierFunc is an adapter to allow the use of ordinary functions as SSH
|
|
||||||
// certificate modifiers.
|
|
||||||
type sshModifierFunc func(cert *ssh.Certificate) error
|
|
||||||
|
|
||||||
func (f sshModifierFunc) Modify(cert *ssh.Certificate) error {
|
|
||||||
return f(cert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignSSHOptions contains the options that can be passed to the SignSSH method.
|
// SignSSHOptions contains the options that can be passed to the SignSSH method.
|
||||||
type SignSSHOptions struct {
|
type SignSSHOptions struct {
|
||||||
CertType string `json:"certType"`
|
CertType string `json:"certType"`
|
||||||
|
@ -72,7 +57,7 @@ func (o SignSSHOptions) Type() uint32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
|
// Modify implements SSHCertModifier and sets the SSHOption in the ssh.Certificate.
|
||||||
func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
|
func (o SignSSHOptions) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
switch o.CertType {
|
switch o.CertType {
|
||||||
case "": // ignore
|
case "": // ignore
|
||||||
case SSHUserCert:
|
case SSHUserCert:
|
||||||
|
@ -86,6 +71,12 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
|
||||||
cert.KeyId = o.KeyID
|
cert.KeyId = o.KeyID
|
||||||
cert.ValidPrincipals = o.Principals
|
cert.ValidPrincipals = o.Principals
|
||||||
|
|
||||||
|
return o.ModifyValidity(cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyValidity modifies only the ValidAfter and ValidBefore on the given
|
||||||
|
// ssh.Certificate.
|
||||||
|
func (o SignSSHOptions) ModifyValidity(cert *ssh.Certificate) error {
|
||||||
t := now()
|
t := now()
|
||||||
if !o.ValidAfter.IsZero() {
|
if !o.ValidAfter.IsZero() {
|
||||||
cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
|
cert.ValidAfter = uint64(o.ValidAfter.RelativeTime(t).Unix())
|
||||||
|
@ -96,7 +87,6 @@ func (o SignSSHOptions) Modify(cert *ssh.Certificate) error {
|
||||||
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
if cert.ValidAfter > 0 && cert.ValidBefore > 0 && cert.ValidAfter > cert.ValidBefore {
|
||||||
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
return errors.New("ssh certificate valid after cannot be greater than valid before")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,7 +113,7 @@ func (o SignSSHOptions) match(got SignSSHOptions) error {
|
||||||
type sshCertPrincipalsModifier []string
|
type sshCertPrincipalsModifier []string
|
||||||
|
|
||||||
// Modify the ValidPrincipals value of the cert.
|
// Modify the ValidPrincipals value of the cert.
|
||||||
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
cert.ValidPrincipals = []string(o)
|
cert.ValidPrincipals = []string(o)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -132,7 +122,7 @@ func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error {
|
||||||
// Key ID in the SSH certificate.
|
// Key ID in the SSH certificate.
|
||||||
type sshCertKeyIDModifier string
|
type sshCertKeyIDModifier string
|
||||||
|
|
||||||
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
cert.KeyId = string(m)
|
cert.KeyId = string(m)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -142,7 +132,7 @@ func (m sshCertKeyIDModifier) Modify(cert *ssh.Certificate) error {
|
||||||
type sshCertTypeModifier string
|
type sshCertTypeModifier string
|
||||||
|
|
||||||
// Modify sets the CertType for the ssh certificate.
|
// Modify sets the CertType for the ssh certificate.
|
||||||
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
func (m sshCertTypeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
cert.CertType = sshCertTypeUInt32(string(m))
|
cert.CertType = sshCertTypeUInt32(string(m))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -151,7 +141,7 @@ func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error {
|
||||||
// ValidAfter in the SSH certificate.
|
// ValidAfter in the SSH certificate.
|
||||||
type sshCertValidAfterModifier uint64
|
type sshCertValidAfterModifier uint64
|
||||||
|
|
||||||
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
cert.ValidAfter = uint64(m)
|
cert.ValidAfter = uint64(m)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -160,7 +150,7 @@ func (m sshCertValidAfterModifier) Modify(cert *ssh.Certificate) error {
|
||||||
// ValidBefore in the SSH certificate.
|
// ValidBefore in the SSH certificate.
|
||||||
type sshCertValidBeforeModifier uint64
|
type sshCertValidBeforeModifier uint64
|
||||||
|
|
||||||
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate) error {
|
func (m sshCertValidBeforeModifier) Modify(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||||
cert.ValidBefore = uint64(m)
|
cert.ValidBefore = uint64(m)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -217,27 +207,27 @@ type sshDefaultDuration struct {
|
||||||
*Claimer
|
*Claimer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *sshDefaultDuration) Option(o SignSSHOptions) SSHCertModifier {
|
// Modify implements SSHCertModifier and sets the validity if it has not been
|
||||||
return sshModifierFunc(func(cert *ssh.Certificate) error {
|
// set, but it always applies the backdate.
|
||||||
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
func (m *sshDefaultDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
|
||||||
if err != nil {
|
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
||||||
return err
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var backdate uint64
|
var backdate uint64
|
||||||
if cert.ValidAfter == 0 {
|
if cert.ValidAfter == 0 {
|
||||||
backdate = uint64(o.Backdate / time.Second)
|
backdate = uint64(o.Backdate / time.Second)
|
||||||
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
|
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
|
||||||
}
|
}
|
||||||
if cert.ValidBefore == 0 {
|
if cert.ValidBefore == 0 {
|
||||||
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
|
cert.ValidBefore = cert.ValidAfter + uint64(d/time.Second)
|
||||||
}
|
}
|
||||||
// Apply backdate safely
|
// Apply backdate safely
|
||||||
if cert.ValidAfter > backdate {
|
if cert.ValidAfter > backdate {
|
||||||
cert.ValidAfter -= backdate
|
cert.ValidAfter -= backdate
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshLimitDuration adjusts the duration to min(default, remaining provisioning
|
// sshLimitDuration adjusts the duration to min(default, remaining provisioning
|
||||||
|
@ -250,51 +240,52 @@ type sshLimitDuration struct {
|
||||||
NotAfter time.Time
|
NotAfter time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *sshLimitDuration) Option(o SignSSHOptions) SSHCertModifier {
|
// Modify implements SSHCertModifier and modifies the validity of the
|
||||||
|
// certificate to expire before the configured limit.
|
||||||
|
func (m *sshLimitDuration) Modify(cert *ssh.Certificate, o SignSSHOptions) error {
|
||||||
if m.NotAfter.IsZero() {
|
if m.NotAfter.IsZero() {
|
||||||
defaultDuration := &sshDefaultDuration{m.Claimer}
|
defaultDuration := &sshDefaultDuration{m.Claimer}
|
||||||
return defaultDuration.Option(o)
|
return defaultDuration.Modify(cert, o)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sshModifierFunc(func(cert *ssh.Certificate) error {
|
// Make sure the duration is within the limits.
|
||||||
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
d, err := m.DefaultSSHCertDuration(cert.CertType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var backdate uint64
|
var backdate uint64
|
||||||
if cert.ValidAfter == 0 {
|
if cert.ValidAfter == 0 {
|
||||||
backdate = uint64(o.Backdate / time.Second)
|
backdate = uint64(o.Backdate / time.Second)
|
||||||
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
|
cert.ValidAfter = uint64(now().Truncate(time.Second).Unix())
|
||||||
}
|
}
|
||||||
|
|
||||||
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
|
certValidAfter := time.Unix(int64(cert.ValidAfter), 0)
|
||||||
if certValidAfter.After(m.NotAfter) {
|
if certValidAfter.After(m.NotAfter) {
|
||||||
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
|
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validAfter (%s)",
|
||||||
m.NotAfter, certValidAfter)
|
m.NotAfter, certValidAfter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cert.ValidBefore == 0 {
|
if cert.ValidBefore == 0 {
|
||||||
certValidBefore := certValidAfter.Add(d)
|
certValidBefore := certValidAfter.Add(d)
|
||||||
if m.NotAfter.Before(certValidBefore) {
|
if m.NotAfter.Before(certValidBefore) {
|
||||||
certValidBefore = m.NotAfter
|
certValidBefore = m.NotAfter
|
||||||
}
|
|
||||||
cert.ValidBefore = uint64(certValidBefore.Unix())
|
|
||||||
} else {
|
|
||||||
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
|
|
||||||
if m.NotAfter.Before(certValidBefore) {
|
|
||||||
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
|
|
||||||
m.NotAfter, certValidBefore)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
cert.ValidBefore = uint64(certValidBefore.Unix())
|
||||||
// Apply backdate safely
|
} else {
|
||||||
if cert.ValidAfter > backdate {
|
certValidBefore := time.Unix(int64(cert.ValidBefore), 0)
|
||||||
cert.ValidAfter -= backdate
|
if m.NotAfter.Before(certValidBefore) {
|
||||||
|
return errors.Errorf("provisioning credential expiration (%s) is before requested certificate validBefore (%s)",
|
||||||
|
m.NotAfter, certValidBefore)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
// Apply backdate safely
|
||||||
})
|
if cert.ValidAfter > backdate {
|
||||||
|
cert.ValidAfter -= backdate
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshCertOptionsValidator validates the user SSHOptions with the ones
|
// sshCertOptionsValidator validates the user SSHOptions with the ones
|
||||||
|
|
|
@ -206,6 +206,8 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname str
|
||||||
// SignSSH creates a signed SSH certificate with the given public key and options.
|
// SignSSH creates a signed SSH certificate with the given public key and options.
|
||||||
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
|
||||||
var (
|
var (
|
||||||
|
err error
|
||||||
|
certType sshutil.CertType
|
||||||
certOptions []sshutil.Option
|
certOptions []sshutil.Option
|
||||||
mods []provisioner.SSHCertModifier
|
mods []provisioner.SSHCertModifier
|
||||||
validators []provisioner.SSHCertValidator
|
validators []provisioner.SSHCertValidator
|
||||||
|
@ -214,6 +216,14 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// Set backdate with the configured value
|
// Set backdate with the configured value
|
||||||
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
opts.Backdate = a.config.AuthorityConfig.Backdate.Duration
|
||||||
|
|
||||||
|
// Validate certificate type.
|
||||||
|
if opts.CertType != "" {
|
||||||
|
certType, err = sshutil.CertTypeFromString(opts.CertType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, op := range signOpts {
|
for _, op := range signOpts {
|
||||||
switch o := op.(type) {
|
switch o := op.(type) {
|
||||||
// add options to NewCertificate
|
// add options to NewCertificate
|
||||||
|
@ -224,10 +234,6 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
case provisioner.SSHCertModifier:
|
case provisioner.SSHCertModifier:
|
||||||
mods = append(mods, o)
|
mods = append(mods, o)
|
||||||
|
|
||||||
// modify the ssh.Certificate given the SSHOptions
|
|
||||||
case provisioner.SSHCertOptionModifier:
|
|
||||||
mods = append(mods, o.Option(opts))
|
|
||||||
|
|
||||||
// validate the ssh.Certificate
|
// validate the ssh.Certificate
|
||||||
case provisioner.SSHCertValidator:
|
case provisioner.SSHCertValidator:
|
||||||
validators = append(validators, o)
|
validators = append(validators, o)
|
||||||
|
@ -235,16 +241,24 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
// validate the given SSHOptions
|
// validate the given SSHOptions
|
||||||
case provisioner.SSHCertOptionsValidator:
|
case provisioner.SSHCertOptionsValidator:
|
||||||
if err := o.Valid(opts); err != nil {
|
if err := o.Valid(opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, errs.InternalServer("signSSH: invalid extra option type %T", o)
|
return nil, errs.InternalServer("authority.SignSSH: invalid extra option type %T", o)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simulated certificate request with request options.
|
||||||
|
cr := sshutil.CertificateRequest{
|
||||||
|
Type: certType,
|
||||||
|
KeyID: opts.KeyID,
|
||||||
|
Principals: opts.Principals,
|
||||||
|
Key: key,
|
||||||
|
}
|
||||||
|
|
||||||
// Create certificate from template.
|
// Create certificate from template.
|
||||||
certificate, err := sshutil.NewCertificate(key, certOptions...)
|
certificate, err := sshutil.NewCertificate(cr, certOptions...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(*sshutil.TemplateError); ok {
|
if _, ok := err.(*sshutil.TemplateError); ok {
|
||||||
return nil, errs.NewErr(http.StatusBadRequest, err,
|
return nil, errs.NewErr(http.StatusBadRequest, err,
|
||||||
|
@ -255,19 +269,19 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get actual *ssh.Certificate and continue with user and provisioner
|
// Get actual *ssh.Certificate and continue with provisioner modifiers.
|
||||||
// modifiers.
|
|
||||||
cert := certificate.GetCertificate()
|
cert := certificate.GetCertificate()
|
||||||
|
|
||||||
// Use SignSSHOptions to modify the certificate.
|
// Use SignSSHOptions to modify the certificate validity. It will be later
|
||||||
if err := opts.Modify(cert); err != nil {
|
// checked or set if not defined.
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
if err := opts.ModifyValidity(cert); err != nil {
|
||||||
|
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use provisioner modifiers.
|
// Use provisioner modifiers.
|
||||||
for _, m := range mods {
|
for _, m := range mods {
|
||||||
if err := m.Modify(cert); err != nil {
|
if err := m.Modify(cert, opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -276,33 +290,33 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
||||||
switch cert.CertType {
|
switch cert.CertType {
|
||||||
case ssh.UserCert:
|
case ssh.UserCert:
|
||||||
if a.sshCAUserCertSignKey == nil {
|
if a.sshCAUserCertSignKey == nil {
|
||||||
return nil, errs.NotImplemented("signSSH: user certificate signing is not enabled")
|
return nil, errs.NotImplemented("authority.SignSSH: user certificate signing is not enabled")
|
||||||
}
|
}
|
||||||
signer = a.sshCAUserCertSignKey
|
signer = a.sshCAUserCertSignKey
|
||||||
case ssh.HostCert:
|
case ssh.HostCert:
|
||||||
if a.sshCAHostCertSignKey == nil {
|
if a.sshCAHostCertSignKey == nil {
|
||||||
return nil, errs.NotImplemented("signSSH: host certificate signing is not enabled")
|
return nil, errs.NotImplemented("authority.SignSSH: host certificate signing is not enabled")
|
||||||
}
|
}
|
||||||
signer = a.sshCAHostCertSignKey
|
signer = a.sshCAHostCertSignKey
|
||||||
default:
|
default:
|
||||||
return nil, errs.InternalServer("signSSH: unexpected ssh certificate type: %d", cert.CertType)
|
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", cert.CertType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign certificate.
|
// Sign certificate.
|
||||||
cert, err = sshutil.CreateCertificate(cert, signer)
|
cert, err = sshutil.CreateCertificate(cert, signer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error signing certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
// User provisioners validators.
|
// User provisioners validators.
|
||||||
for _, v := range validators {
|
for _, v := range validators {
|
||||||
if err := v.Valid(cert, opts); err != nil {
|
if err := v.Valid(cert, opts); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
|
return nil, errs.Wrap(http.StatusForbidden, err, "authority.SignSSH")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
if err = a.db.StoreSSHCertificate(cert); err != nil && err != db.ErrNotImplemented {
|
||||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error storing certificate in db")
|
return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.SignSSH: error storing certificate in db")
|
||||||
}
|
}
|
||||||
|
|
||||||
return cert, nil
|
return cert, nil
|
||||||
|
|
Loading…
Reference in a new issue