Add backdate validation to sshCertValidityValidator.

This commit is contained in:
max furman 2020-01-24 13:42:00 -08:00
parent 3d6a18180e
commit 397a181d10
7 changed files with 47 additions and 23 deletions

View file

@ -290,7 +290,7 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o Options) error {
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v",
d, v.max) d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -36,7 +36,7 @@ type SSHCertOptionModifier interface {
// SSHCertValidator is the interface used to validate an SSH certificate. // SSHCertValidator is the interface used to validate an SSH certificate.
type SSHCertValidator interface { type SSHCertValidator interface {
SignOption SignOption
Valid(cert *ssh.Certificate) error Valid(cert *ssh.Certificate, opts SSHOptions) error
} }
// SSHCertOptionsValidator is the interface used to validate the custom // SSHCertOptionsValidator is the interface used to validate the custom
@ -310,7 +310,7 @@ type sshCertValidityValidator struct {
*Claimer *Claimer
} }
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errors.New("ssh certificate validAfter cannot be 0")
@ -336,20 +336,15 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error {
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
// calculate the duration if ValidAfter is in the past. // calculate the duration if ValidAfter is in the past.
var dur time.Duration dur := time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
if t := now().Unix(); t > int64(cert.ValidAfter) {
dur = time.Duration(int64(cert.ValidBefore)-t) * time.Second
} else {
dur = time.Duration(cert.ValidBefore-cert.ValidAfter) * time.Second
}
switch { switch {
case dur < min: case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+ return errors.Errorf("requested duration of %s is less than minimum "+
"accepted duration for selected provisioner of %s", dur, min) "accepted duration for selected provisioner of %s", dur, min)
case dur > max: case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+ return errors.Errorf("requested duration of %s is greater than maximum "+
"accepted duration for selected provisioner of %s", dur, max) "accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }
@ -360,7 +355,7 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate) error {
type sshCertDefaultValidator struct{} type sshCertDefaultValidator struct{}
// Valid returns an error if the given certificate does not contain the necessary fields. // Valid returns an error if the given certificate does not contain the necessary fields.
func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error { func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
switch { switch {
case len(cert.Nonce) == 0: case len(cert.Nonce) == 0:
return errors.New("ssh certificate nonce cannot be empty") return errors.New("ssh certificate nonce cannot be empty")
@ -395,7 +390,7 @@ func (v *sshCertDefaultValidator) Valid(cert *ssh.Certificate) error {
type sshDefaultPublicKeyValidator struct{} type sshDefaultPublicKeyValidator struct{}
// Valid checks that certificate request common name matches the one configured. // Valid checks that certificate request common name matches the one configured.
func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error { func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if cert.Key == nil { if cert.Key == nil {
return errors.New("ssh certificate key cannot be nil") return errors.New("ssh certificate key cannot be nil")
} }
@ -425,7 +420,7 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate) error {
type sshCertKeyIDValidator string type sshCertKeyIDValidator string
// Valid returns an error if the given certificate does not contain the necessary fields. // Valid returns an error if the given certificate does not contain the necessary fields.
func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate) error { func (v sshCertKeyIDValidator) Valid(cert *ssh.Certificate, o SSHOptions) error {
if string(v) != cert.KeyId { if string(v) != cert.KeyId {
return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId) return errors.Errorf("invalid ssh certificate KeyId; want %s, but got %s", string(v), cert.KeyId)
} }

View file

@ -659,7 +659,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil { if err := v.Valid(tt.cert, SSHOptions{}); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
@ -678,26 +678,31 @@ func Test_sshCertValidityValidator(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cert *ssh.Certificate cert *ssh.Certificate
opts SSHOptions
err error err error
}{ }{
{ {
"fail/validAfter-0", "fail/validAfter-0",
&ssh.Certificate{CertType: ssh.UserCert}, &ssh.Certificate{CertType: ssh.UserCert},
SSHOptions{},
errors.New("ssh certificate validAfter cannot be 0"), errors.New("ssh certificate validAfter cannot be 0"),
}, },
{ {
"fail/validBefore-in-past", "fail/validBefore-in-past",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be in the past"), errors.New("ssh certificate validBefore cannot be in the past"),
}, },
{ {
"fail/validBefore-before-validAfter", "fail/validBefore-before-validAfter",
&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate validBefore cannot be before validAfter"), errors.New("ssh certificate validBefore cannot be before validAfter"),
}, },
{ {
"fail/cert-type-not-set", "fail/cert-type-not-set",
&ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())}, &ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())},
SSHOptions{},
errors.New("ssh certificate type has not been set"), errors.New("ssh certificate type has not been set"),
}, },
{ {
@ -707,6 +712,7 @@ func Test_sshCertValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()), ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(10 * time.Minute).Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
}, },
SSHOptions{},
errors.New("unknown ssh certificate type 3"), errors.New("unknown ssh certificate type 3"),
}, },
{ {
@ -716,8 +722,19 @@ func Test_sshCertValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(4 * time.Minute).Unix()), ValidBefore: uint64(n.Add(4 * time.Minute).Unix()),
}, },
SSHOptions{Backdate: time.Second},
errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"), errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"),
}, },
{
"ok/duration-exactly-min",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(5 * time.Minute).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
},
{ {
"fail/duration>max", "fail/duration>max",
&ssh.Certificate{ &ssh.Certificate{
@ -725,7 +742,18 @@ func Test_sshCertValidityValidator(t *testing.T) {
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(48 * time.Hour).Unix()), ValidBefore: uint64(n.Add(48 * time.Hour).Unix()),
}, },
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m0s"), SSHOptions{Backdate: time.Second},
errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"),
},
{
"ok/duration-exactly-max",
&ssh.Certificate{
CertType: 1,
ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()),
},
SSHOptions{Backdate: time.Second},
nil,
}, },
{ {
"ok", "ok",
@ -734,12 +762,13 @@ func Test_sshCertValidityValidator(t *testing.T) {
ValidAfter: uint64(now().Unix()), ValidAfter: uint64(now().Unix()),
ValidBefore: uint64(now().Add(8 * time.Hour).Unix()), ValidBefore: uint64(now().Add(8 * time.Hour).Unix()),
}, },
SSHOptions{Backdate: time.Second},
nil, nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := v.Valid(tt.cert); err != nil { if err := v.Valid(tt.cert, tt.opts); err != nil {
if assert.NotNil(t, tt.err) { if assert.NotNil(t, tt.err) {
assert.HasPrefix(t, err.Error(), tt.err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }

View file

@ -116,7 +116,7 @@ func signSSHCertificate(key crypto.PublicKey, opts SSHOptions, signOpts []SignOp
// User provisioners validators // User provisioners validators
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, err return nil, err
} }
} }

View file

@ -269,7 +269,7 @@ func (a *Authority) SignSSH(key ssh.PublicKey, opts provisioner.SSHOptions, sign
// User provisioners validators // User provisioners validators
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, opts); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "signSSH") return nil, errs.Wrap(http.StatusForbidden, err, "signSSH")
} }
} }
@ -428,9 +428,9 @@ func (a *Authority) RekeySSH(oldCert *ssh.Certificate, pub ssh.PublicKey, signOp
} }
cert.Signature = sig cert.Signature = sig
// Apply validators from provisioner.. // Apply validators from provisioner.
for _, v := range validators { for _, v := range validators {
if err := v.Valid(cert); err != nil { if err := v.Valid(cert, provisioner.SSHOptions{Backdate: backdate}); err != nil {
return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH") return nil, errs.Wrap(http.StatusForbidden, err, "rekeySSH")
} }
} }

View file

@ -62,7 +62,7 @@ func (m sshTestCertModifier) Modify(cert *ssh.Certificate) error {
type sshTestCertValidator string type sshTestCertValidator string
func (v sshTestCertValidator) Valid(crt *ssh.Certificate) error { func (v sshTestCertValidator) Valid(crt *ssh.Certificate, opts provisioner.SSHOptions) error {
if v == "" { if v == "" {
return nil return nil
} }

View file

@ -178,7 +178,7 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h0m0s"), err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },