Fix validity check for sshpop provisioner.

This commit is contained in:
Mariano Cano 2022-03-14 17:31:21 -07:00
parent c903f00cd4
commit 6d532045dc
2 changed files with 16 additions and 12 deletions

View file

@ -104,7 +104,7 @@ func (p *SSHPOP) Init(config Config) (err error) {
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
// //
// Checking for certificate revocation has been moved to the authority package. // Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -112,13 +112,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() //
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { // Controller.AuthorizeSSHRenew will validate this on the renewal flow.
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") if checkValidity {
} unixNow := time.Now().Unix()
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
@ -181,7 +186,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// AuthorizeSSHRevoke validates the authorization token and extracts/validates // AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
@ -194,7 +199,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// AuthorizeSSHRenew validates the authorization token and extracts/validates // AuthorizeSSHRenew validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
@ -207,7 +212,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
// AuthorizeSSHRekey validates the authorization token and extracts/validates // AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil { if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
@ -222,7 +227,6 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate

View file

@ -214,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)