diff --git a/api/ssh.go b/api/ssh.go index f3056fc5..4bd20495 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -288,6 +288,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) diff --git a/api/sshRekey.go b/api/sshRekey.go index 184f208a..6c0a5064 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -59,6 +59,7 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) a := mustAuthority(ctx) signOpts, err := a.Authorize(ctx, body.OTT) diff --git a/api/sshRenew.go b/api/sshRenew.go index 606b45bb..4e4d0b04 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -51,6 +51,8 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) + ctx = provisioner.NewContextWithToken(ctx, body.OTT) + a := mustAuthority(ctx) _, err := a.Authorize(ctx, body.OTT) if err != nil { diff --git a/authority/authorize.go b/authority/authorize.go index 56b53658..91f1b3cb 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" + "fmt" "net/http" "net/url" "strconv" @@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool { return m } -// authorizeToken parses the token and returns the provisioner used to generate -// the token. This method enforces the One-Time use policy (tokens can only be -// used once). -func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { - // Validate payload +// getProvisionerFromToken extracts a provisioner from the given token without +// doing any token validation. +func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) { tok, err := jose.ParseSigned(token) if err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") + return nil, nil, fmt.Errorf("error parsing token: %w", err) } // Get claims w/out verification. We need to look up the provisioner @@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // before we can look up the provisioner. var claims Claims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { - return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") + return nil, nil, fmt.Errorf("error unmarshaling token: %w", err) + } + + // This method will also validate the audiences for JWK provisioners. + p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) + if !ok { + return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) + } + + return p, &claims, nil +} + +// authorizeToken parses the token and returns the provisioner used to generate +// the token. This method enforces the One-Time use policy (tokens can only be +// used once). +func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) { + p, claims, err := a.getProvisionerFromToken(token) + if err != nil { + return nil, errs.UnauthorizedErr(err) } // TODO: use new persistence layer abstraction. @@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision // This check is meant as a stopgap solution to the current lack of a persistence layer. if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { - return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") + return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority") } } - // This method will also validate the audiences for JWK provisioners. - p, ok := a.provisioners.LoadByToken(tok, &claims.Claims) - if !ok { - return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+ - "not found or invalid audience (%s)", strings.Join(claims.Audience, ", ")) - } - // Store the token to protect against reuse unless it's skipped. // If we cannot get a token id from the provisioner, just hash the token. if !SkipTokenReuseFromContext(ctx) { @@ -188,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error { } ok, err := a.db.UseToken(reuseKey, token) if err != nil { - return errs.Wrap(http.StatusInternalServerError, err, - "authority.authorizeToken: failed when attempting to store token") + return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token") } if !ok { - return errs.Unauthorized("authority.authorizeToken: token already used") + return errs.Unauthorized("token already used") } } return nil diff --git a/authority/authorize_test.go b/authority/authorize_test.go index b221d0de..af80d3d3 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeToken: error parsing token"), + err: errors.New("error parsing token"), code: http.StatusUnauthorized, } }, @@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), + err: errors.New("token issued before the bootstrap of certificate authority"), code: http.StatusUnauthorized, } }, @@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: a, token: raw, - err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), + err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"), code: http.StatusUnauthorized, } }, @@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), + err: errors.New("failed when attempting to store token: force"), code: http.StatusInternalServerError, } }, @@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) { return &authorizeTest{ auth: _a, token: raw, - err: errors.New("authority.authorizeToken: token already used"), + err: errors.New("token already used"), code: http.StatusUnauthorized, } }, @@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: context.Background(), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), - err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) { auth: a, token: "foo", ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), - err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, @@ -988,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHSign: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1082,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRenew: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1190,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRevoke: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1282,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { return &authorizeTest{ auth: a, token: "foo", - err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), + err: errors.New("authority.authorizeSSHRekey: error parsing token"), code: http.StatusUnauthorized, } }, @@ -1345,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) { } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.cert.Serial, cert.Serial) - assert.Len(t, 3, signOpts) + assert.Len(t, 4, signOpts) } } }) diff --git a/authority/linkedca.go b/authority/linkedca.go index fd5c0a81..0b98f877 100644 --- a/authority/linkedca.go +++ b/authority/linkedca.go @@ -270,13 +270,13 @@ func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, }, nil } -func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error { +func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificateChain: serializeCertificateChain(fullchain[1:]...), - Provisioner: createProvisionerIdentity(prov), + Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting certificate") } @@ -292,22 +292,23 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc return errors.Wrap(err, "error posting renewed certificate") } -func (c *linkedCaClient) StoreSSHCertificate(prov provisioner.Interface, crt *ssh.Certificate) error { +func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ Certificate: string(ssh.MarshalAuthorizedKey(crt)), - Provisioner: createProvisionerIdentity(prov), + Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting ssh certificate") } -func (c *linkedCaClient) StoreRenewedSSHCertificate(parent, crt *ssh.Certificate) error { +func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ Certificate: string(ssh.MarshalAuthorizedKey(crt)), ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)), + Provisioner: createProvisionerIdentity(p), }) return errors.Wrap(err, "error posting renewed ssh certificate") } @@ -380,14 +381,14 @@ func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error { return errors.New("not implemented yet") } -func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity { - if prov == nil { +func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity { + if p == nil { return nil } return &linkedca.ProvisionerIdentity{ - Id: prov.GetID(), - Type: linkedca.Provisioner_Type(prov.GetType()), - Name: prov.GetName(), + Id: p.GetID(), + Type: linkedca.Provisioner_Type(p.GetType()), + Name: p.GetName(), } } diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index f5cd5221..01dda2ed 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method { m, _ := ctx.Value(methodKey{}).(Method) return m } + +type tokenKey struct{} + +// NewContextWithToken creates a new context with the given token. +func NewContextWithToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, tokenKey{}, token) +} + +// TokenFromContext returns the token stored in the given context. +func TokenFromContext(ctx context.Context) (string, bool) { + token, ok := ctx.Value(tokenKey{}).(string) + return token, ok +} diff --git a/authority/provisioner/sshpop.go b/authority/provisioner/sshpop.go index c3a1a639..c0246729 100644 --- a/authority/provisioner/sshpop.go +++ b/authority/provisioner/sshpop.go @@ -222,6 +222,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") } return claims.sshCert, []SignOption{ + p, // Validate public key &sshDefaultPublicKeyValidator{}, // Validate the validity period. diff --git a/authority/provisioner/sshpop_test.go b/authority/provisioner/sshpop_test.go index 13294866..1e026883 100644 --- a/authority/provisioner/sshpop_test.go +++ b/authority/provisioner/sshpop_test.go @@ -459,9 +459,10 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Len(t, 3, opts) + assert.Len(t, 4, opts) for _, o := range opts { switch v := o.(type) { + case Interface: case *sshDefaultPublicKeyValidator: case *sshCertDefaultValidator: case *sshCertValidityValidator: diff --git a/authority/ssh.go b/authority/ssh.go index bb69f9db..1fd7f2e8 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -303,6 +303,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, err } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + backdate := a.config.AuthorityConfig.Backdate.Duration duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second now := time.Now() @@ -345,7 +351,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") } - if err = a.storeRenewedSSHCertificate(oldCert, cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") } @@ -356,8 +362,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var validators []provisioner.SSHCertValidator + var prov provisioner.Interface for _, op := range signOpts { switch o := op.(type) { + // Capture current provisioner + case provisioner.Interface: + prov = o // validate the ssh.Certificate case provisioner.SSHCertValidator: validators = append(validators, o) @@ -424,7 +434,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub } } - if err = a.storeRenewedSSHCertificate(oldCert, cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") } @@ -455,15 +465,15 @@ func (a *Authority) storeSSHCertificate(prov provisioner.Interface, cert *ssh.Ce } } -func (a *Authority) storeRenewedSSHCertificate(parent, cert *ssh.Certificate) error { +func (a *Authority) storeRenewedSSHCertificate(prov provisioner.Interface, parent, cert *ssh.Certificate) error { type sshRenewerCertificateStorer interface { - StoreRenewedSSHCertificate(parent, cert *ssh.Certificate) error + StoreRenewedSSHCertificate(p provisioner.Interface, parent, cert *ssh.Certificate) error } // Store certificate in admindb or linkedca switch s := a.adminDB.(type) { case sshRenewerCertificateStorer: - return s.StoreRenewedSSHCertificate(parent, cert) + return s.StoreRenewedSSHCertificate(prov, parent, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) } @@ -471,7 +481,7 @@ func (a *Authority) storeRenewedSSHCertificate(parent, cert *ssh.Certificate) er // Store certificate in localdb switch s := a.db.(type) { case sshRenewerCertificateStorer: - return s.StoreRenewedSSHCertificate(parent, cert) + return s.StoreRenewedSSHCertificate(prov, parent, cert) case db.CertificateStorer: return s.StoreSSHCertificate(cert) default: @@ -522,6 +532,12 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") } + // Attempt to extract the provisioner from the token. + var prov provisioner.Interface + if token, ok := provisioner.TokenFromContext(ctx); ok { + prov, _, _ = a.getProvisionerFromToken(token) + } + signer := a.sshCAUserCertSignKey principal := subject.ValidPrincipals[0] addUserPrincipal := a.getAddUserPrincipal() @@ -554,7 +570,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje } cert.Signature = sig - if err = a.storeRenewedSSHCertificate(subject, cert); err != nil && err != db.ErrNotImplemented { + if err = a.storeRenewedSSHCertificate(prov, subject, cert); err != nil && err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") }