Merge pull request #943 from smallstep/ssh-renew-provisioner

Add provisioner to SSH renewals
This commit is contained in:
Mariano Cano 2022-05-23 17:21:55 -07:00 committed by GitHub
commit 911cec21da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 102 additions and 57 deletions

View file

@ -288,6 +288,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx) a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)

View file

@ -59,6 +59,7 @@ func SSHRekey(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx) a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)

View file

@ -51,6 +51,8 @@ func SSHRenew(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx) a := mustAuthority(ctx)
_, err := a.Authorize(ctx, body.OTT) _, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {

View file

@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
return m return m
} }
// authorizeToken parses the token and returns the provisioner used to generate // getProvisionerFromToken extracts a provisioner from the given token without
// the token. This method enforces the One-Time use policy (tokens can only be // doing any token validation.
// used once). func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) {
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
// Validate payload
tok, err := jose.ParseSigned(token) tok, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") return nil, nil, fmt.Errorf("error parsing token: %w", err)
} }
// Get claims w/out verification. We need to look up the provisioner // Get claims w/out verification. We need to look up the provisioner
@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") return nil, nil, fmt.Errorf("error unmarshaling token: %w", err)
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
return p, &claims, nil
}
// authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be
// used once).
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
p, claims, err := a.getProvisionerFromToken(token)
if err != nil {
return nil, errs.UnauthorizedErr(err)
} }
// TODO: use new persistence layer abstraction. // TODO: use new persistence layer abstraction.
@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// This check is meant as a stopgap solution to the current lack of a persistence layer. // This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority")
} }
} }
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token. // If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) { if !SkipTokenReuseFromContext(ctx) {
@ -188,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
} }
ok, err := a.db.UseToken(reuseKey, token) ok, err := a.db.UseToken(reuseKey, token)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
"authority.authorizeToken: failed when attempting to store token")
} }
if !ok { if !ok {
return errs.Unauthorized("authority.authorizeToken: token already used") return errs.Unauthorized("token already used")
} }
} }
return nil return nil

View file

@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeToken: error parsing token"), err: errors.New("error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), err: errors.New("token issued before the bootstrap of certificate authority"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), err: errors.New("failed when attempting to store token: force"),
code: http.StatusInternalServerError, code: http.StatusInternalServerError,
} }
}, },
@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: context.Background(), ctx: context.Background(),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -988,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1082,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1190,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1282,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1345,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Serial, cert.Serial) assert.Equals(t, tc.cert.Serial, cert.Serial)
assert.Len(t, 3, signOpts) assert.Len(t, 4, signOpts)
} }
} }
}) })

View file

@ -270,13 +270,13 @@ func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData,
}, nil }, nil
} }
func (c *linkedCaClient) StoreCertificateChain(prov provisioner.Interface, fullchain ...*x509.Certificate) error { func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
Provisioner: createProvisionerIdentity(prov), Provisioner: createProvisionerIdentity(p),
}) })
return errors.Wrap(err, "error posting certificate") return errors.Wrap(err, "error posting certificate")
} }
@ -292,22 +292,23 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc
return errors.Wrap(err, "error posting renewed certificate") return errors.Wrap(err, "error posting renewed certificate")
} }
func (c *linkedCaClient) StoreSSHCertificate(prov provisioner.Interface, crt *ssh.Certificate) error { func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)), Certificate: string(ssh.MarshalAuthorizedKey(crt)),
Provisioner: createProvisionerIdentity(prov), Provisioner: createProvisionerIdentity(p),
}) })
return errors.Wrap(err, "error posting ssh certificate") return errors.Wrap(err, "error posting ssh certificate")
} }
func (c *linkedCaClient) StoreRenewedSSHCertificate(parent, crt *ssh.Certificate) error { func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)), Certificate: string(ssh.MarshalAuthorizedKey(crt)),
ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)), ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)),
Provisioner: createProvisionerIdentity(p),
}) })
return errors.Wrap(err, "error posting renewed ssh certificate") return errors.Wrap(err, "error posting renewed ssh certificate")
} }
@ -380,14 +381,14 @@ func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet") return errors.New("not implemented yet")
} }
func createProvisionerIdentity(prov provisioner.Interface) *linkedca.ProvisionerIdentity { func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity {
if prov == nil { if p == nil {
return nil return nil
} }
return &linkedca.ProvisionerIdentity{ return &linkedca.ProvisionerIdentity{
Id: prov.GetID(), Id: p.GetID(),
Type: linkedca.Provisioner_Type(prov.GetType()), Type: linkedca.Provisioner_Type(p.GetType()),
Name: prov.GetName(), Name: p.GetName(),
} }
} }

View file

@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method) m, _ := ctx.Value(methodKey{}).(Method)
return m return m
} }
type tokenKey struct{}
// NewContextWithToken creates a new context with the given token.
func NewContextWithToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, tokenKey{}, token)
}
// TokenFromContext returns the token stored in the given context.
func TokenFromContext(ctx context.Context) (string, bool) {
token, ok := ctx.Value(tokenKey{}).(string)
return token, ok
}

View file

@ -222,6 +222,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, []SignOption{ return claims.sshCert, []SignOption{
p,
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.

View file

@ -459,9 +459,10 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Len(t, 3, opts) assert.Len(t, 4, opts)
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case Interface:
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshCertValidityValidator: case *sshCertValidityValidator:

View file

@ -303,6 +303,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
return nil, err return nil, err
} }
// Attempt to extract the provisioner from the token.
var prov provisioner.Interface
if token, ok := provisioner.TokenFromContext(ctx); ok {
prov, _, _ = a.getProvisionerFromToken(token)
}
backdate := a.config.AuthorityConfig.Backdate.Duration backdate := a.config.AuthorityConfig.Backdate.Duration
duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second duration := time.Duration(oldCert.ValidBefore-oldCert.ValidAfter) * time.Second
now := time.Now() now := time.Now()
@ -345,7 +351,7 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate") return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSH: error signing certificate")
} }
if err = a.storeRenewedSSHCertificate(oldCert, cert); err != nil && err != db.ErrNotImplemented { if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "renewSSH: error storing certificate in db")
} }
@ -356,8 +362,12 @@ func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ss
func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub ssh.PublicKey, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var validators []provisioner.SSHCertValidator var validators []provisioner.SSHCertValidator
var prov provisioner.Interface
for _, op := range signOpts { for _, op := range signOpts {
switch o := op.(type) { switch o := op.(type) {
// Capture current provisioner
case provisioner.Interface:
prov = o
// validate the ssh.Certificate // validate the ssh.Certificate
case provisioner.SSHCertValidator: case provisioner.SSHCertValidator:
validators = append(validators, o) validators = append(validators, o)
@ -424,7 +434,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
} }
if err = a.storeRenewedSSHCertificate(oldCert, cert); err != nil && err != db.ErrNotImplemented { if err = a.storeRenewedSSHCertificate(prov, oldCert, cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "rekeySSH; error storing certificate in db")
} }
@ -455,15 +465,15 @@ func (a *Authority) storeSSHCertificate(prov provisioner.Interface, cert *ssh.Ce
} }
} }
func (a *Authority) storeRenewedSSHCertificate(parent, cert *ssh.Certificate) error { func (a *Authority) storeRenewedSSHCertificate(prov provisioner.Interface, parent, cert *ssh.Certificate) error {
type sshRenewerCertificateStorer interface { type sshRenewerCertificateStorer interface {
StoreRenewedSSHCertificate(parent, cert *ssh.Certificate) error StoreRenewedSSHCertificate(p provisioner.Interface, parent, cert *ssh.Certificate) error
} }
// Store certificate in admindb or linkedca // Store certificate in admindb or linkedca
switch s := a.adminDB.(type) { switch s := a.adminDB.(type) {
case sshRenewerCertificateStorer: case sshRenewerCertificateStorer:
return s.StoreRenewedSSHCertificate(parent, cert) return s.StoreRenewedSSHCertificate(prov, parent, cert)
case db.CertificateStorer: case db.CertificateStorer:
return s.StoreSSHCertificate(cert) return s.StoreSSHCertificate(cert)
} }
@ -471,7 +481,7 @@ func (a *Authority) storeRenewedSSHCertificate(parent, cert *ssh.Certificate) er
// Store certificate in localdb // Store certificate in localdb
switch s := a.db.(type) { switch s := a.db.(type) {
case sshRenewerCertificateStorer: case sshRenewerCertificateStorer:
return s.StoreRenewedSSHCertificate(parent, cert) return s.StoreRenewedSSHCertificate(prov, parent, cert)
case db.CertificateStorer: case db.CertificateStorer:
return s.StoreSSHCertificate(cert) return s.StoreSSHCertificate(cert)
default: default:
@ -522,6 +532,12 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number") return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error reading random number")
} }
// Attempt to extract the provisioner from the token.
var prov provisioner.Interface
if token, ok := provisioner.TokenFromContext(ctx); ok {
prov, _, _ = a.getProvisionerFromToken(token)
}
signer := a.sshCAUserCertSignKey signer := a.sshCAUserCertSignKey
principal := subject.ValidPrincipals[0] principal := subject.ValidPrincipals[0]
addUserPrincipal := a.getAddUserPrincipal() addUserPrincipal := a.getAddUserPrincipal()
@ -554,7 +570,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
} }
cert.Signature = sig cert.Signature = sig
if err = a.storeRenewedSSHCertificate(subject, cert); err != nil && err != db.ErrNotImplemented { if err = a.storeRenewedSSHCertificate(prov, subject, cert); err != nil && err != db.ErrNotImplemented {
return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db") return nil, errs.Wrap(http.StatusInternalServerError, err, "signSSHAddUser: error storing certificate in db")
} }