diff --git a/authority/authority.go b/authority/authority.go index 77c887a2..3177efd9 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -41,7 +41,7 @@ type Authority struct { initOnce bool // Custom functions sshBastionFunc func(user, hostname string) (*Bastion, error) - getIdentityFunc func(p provisioner.Interface, email string) (*provisioner.Identity, error) + getIdentityFunc provisioner.GetIdentityFunc } // New creates and initiates a new Authority type. @@ -192,6 +192,7 @@ func (a *Authority) init() error { UserKeys: sshKeys.UserKeys, HostKeys: sshKeys.HostKeys, }, + GetIdentityFunc: a.getIdentityFunc, } // Store all the provisioners for _, p := range a.config.AuthorityConfig.Provisioners { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 05c079d7..c47960f9 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -197,10 +197,10 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Add modifiers from custom claims // FIXME: this is also set in the sign method using SSHOptions.Modify. if opts.CertType != "" { - signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType)) + signOptions = append(signOptions, sshCertTypeModifier(opts.CertType)) } if len(opts.Principals) > 0 { - signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals)) + signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals)) } if !opts.ValidAfter.IsZero() { signOptions = append(signOptions, sshCertificateValidAfterModifier(opts.ValidAfter.RelativeTime(t).Unix())) diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index d97f96f2..4538ef81 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -64,6 +64,7 @@ type OIDC struct { configuration openIDConfiguration keyStore *keyStore claimer *Claimer + getIdentityFunc GetIdentityFunc } // IsAdmin returns true if the given email is in the Admins whitelist, false @@ -169,6 +170,13 @@ func (o *OIDC) Init(config Config) (err error) { if err != nil { return err } + + // Set the identity getter if it exists, otherwise use the default. + if config.GetIdentityFunc == nil { + o.getIdentityFunc = DefaultIdentityFunc + } else { + o.getIdentityFunc = config.GetIdentityFunc + } return nil } @@ -326,23 +334,26 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption sshCertificateKeyIDModifier(claims.Email), } - name := SanitizeSSHUserPrincipal(claims.Email) - if !sshUserRegex.MatchString(name) { - return nil, errors.Errorf("invalid principal '%s' from email address '%s'", name, claims.Email) + // Get the identity using either the default identityFunc or one injected + // externally. + iden, err := o.getIdentityFunc(o, claims.Email) + if err != nil { + return nil, errors.Wrap(err, "authorizeSSHSign") } - - // Admin users will default to user + name but they can be changed by the - // user options. Non-admins are only able to sign user certificates. defaults := SSHOptions{ CertType: SSHUserCert, - Principals: []string{name}, + Principals: iden.Usernames, } + // Admin users can use any principal, and can sign user and host certificates. + // Non-admin users can only use principals returned by the identityFunc, and + // can only sign user certificates. if !o.IsAdmin(claims.Email) { signOptions = append(signOptions, sshCertificateOptionsValidator(defaults)) } - // Default to a user with name as principal if not set + // Default to a user certificate with usernames as principals if those options + // are not set. signOptions = append(signOptions, sshCertificateDefaultsModifier(defaults)) return append(signOptions, diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index 8e0c823c..cbb7b2a2 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -347,6 +347,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, err) p3, err := generateOIDC() assert.FatalError(t, err) + p4, err := generateOIDC() + assert.FatalError(t, err) + p5, err := generateOIDC() + assert.FatalError(t, err) // Admin + Domains p3.Admins = []string{"name@smallstep.com", "root@example.com"} p3.Domains = []string{"smallstep.com"} @@ -356,12 +360,27 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p4.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" + p5.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration" assert.FatalError(t, p1.Init(config)) assert.FatalError(t, p2.Init(config)) assert.FatalError(t, p3.Init(config)) + assert.FatalError(t, p4.Init(config)) + assert.FatalError(t, p5.Init(config)) + + p4.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + return &Identity{Usernames: []string{"max", "mariano"}}, nil + } + p5.getIdentityFunc = func(p Interface, email string) (*Identity, error) { + return nil, errors.New("force") + } t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) + okGetIdentityToken, err := generateSimpleToken("the-issuer", p4.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) + failGetIdentityToken, err := generateSimpleToken("the-issuer", p5.ClientID, &keys.Keys[0]) + assert.FatalError(t, err) // Admin email not in domains okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{}, time.Now(), &keys.Keys[0]) assert.FatalError(t, err) @@ -384,11 +403,11 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { userDuration := p1.claimer.DefaultUserSSHCertDuration() hostDuration := p1.claimer.DefaultHostSSHCertDuration() expectedUserOptions := &SSHOptions{ - CertType: "user", Principals: []string{"name"}, + CertType: "user", Principals: []string{"name", "name@smallstep.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedAdminOptions := &SSHOptions{ - CertType: "user", Principals: []string{"root"}, + CertType: "user", Principals: []string{"root", "root@example.com"}, ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), } expectedHostOptions := &SSHOptions{ @@ -412,17 +431,32 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { {"ok", p1, args{t1, SSHOptions{}, pub}, expectedUserOptions, false, false}, {"ok-rsa2048", p1, args{t1, SSHOptions{}, rsa2048.Public()}, expectedUserOptions, false, false}, {"ok-user", p1, args{t1, SSHOptions{CertType: "user"}, pub}, expectedUserOptions, false, false}, - {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, - {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, + {"ok-principals", p1, args{t1, SSHOptions{Principals: []string{"name"}}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + {"ok-principals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{Principals: []string{"mariano"}}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"mariano"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + {"ok-emptyPrincipals-getIdentity", p4, args{okGetIdentityToken, SSHOptions{}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"max", "mariano"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + {"ok-options", p1, args{t1, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, {"admin", p3, args{okAdmin, SSHOptions{}, pub}, expectedAdminOptions, false, false}, {"admin-user", p3, args{okAdmin, SSHOptions{CertType: "user"}, pub}, expectedAdminOptions, false, false}, - {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, expectedAdminOptions, false, false}, - {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, expectedUserOptions, false, false}, + {"admin-principals", p3, args{okAdmin, SSHOptions{Principals: []string{"root"}}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"root"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, + {"admin-options", p3, args{okAdmin, SSHOptions{CertType: "user", Principals: []string{"name"}}, pub}, + &SSHOptions{CertType: "user", Principals: []string{"name"}, + ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration))}, false, false}, {"admin-host", p3, args{okAdmin, SSHOptions{CertType: "host", Principals: []string{"smallstep.com"}}, pub}, expectedHostOptions, false, false}, {"fail-rsa1024", p1, args{t1, SSHOptions{}, rsa1024.Public()}, expectedUserOptions, false, true}, {"fail-user-host", p1, args{t1, SSHOptions{CertType: "host"}, pub}, nil, false, true}, {"fail-user-principals", p1, args{t1, SSHOptions{Principals: []string{"root"}}, pub}, nil, false, true}, {"fail-email", p3, args{failEmail, SSHOptions{}, pub}, nil, true, false}, + {"fail-getIdentity", p5, args{failGetIdentityToken, SSHOptions{}, pub}, nil, true, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index 8d0673a3..4b4200f5 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -185,6 +185,9 @@ type Config struct { DB db.AuthDB // SSHKeys are the root SSH public keys SSHKeys *SSHKeys + // GetIdentityFunc is a function that returns an identity that will be + // used by the provisioner to populate certificate attributes. + GetIdentityFunc GetIdentityFunc } type provisioner struct { @@ -314,7 +317,7 @@ func (b *base) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certif } // AuthorizeSSHRekey returns an unimplmented error. Provisioners should overwrite -// this method if they will support authorizing tokens for renewing SSH Certificates. +// this method if they will support authorizing tokens for rekeying SSH Certificates. func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { return nil, nil, errors.New("not implemented; provisioner does not implement AuthorizeSSHRekey") } @@ -325,6 +328,23 @@ type Identity struct { Usernames []string `json:"usernames"` } +// GetIdentityFunc is a function that returns an identity. +type GetIdentityFunc func(p Interface, email string) (*Identity, error) + +// DefaultIdentityFunc return a default identity depending on the provisioner type. +func DefaultIdentityFunc(p Interface, email string) (*Identity, error) { + switch k := p.(type) { + case *OIDC: + name := SanitizeSSHUserPrincipal(email) + if !sshUserRegex.MatchString(name) { + return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) + } + return &Identity{Usernames: []string{name, email}}, nil + default: + return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k) + } +} + // MockProvisioner for testing type MockProvisioner struct { Mret1, Mret2, Mret3 interface{} @@ -335,9 +355,13 @@ type MockProvisioner struct { MgetType func() Type MgetEncryptedKey func() (string, string, bool) Minit func(Config) error - MauthorizeRevoke func(ott string) error MauthorizeSign func(ctx context.Context, ott string) ([]SignOption, error) - MauthorizeRenewal func(*x509.Certificate) error + MauthorizeRenew func(ctx context.Context, cert *x509.Certificate) error + MauthorizeRevoke func(ctx context.Context, ott string) error + MauthorizeSSHSign func(ctx context.Context, ott string) ([]SignOption, error) + MauthorizeSSHRenew func(ctx context.Context, ott string) (*ssh.Certificate, error) + MauthorizeSSHRekey func(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error) + MauthorizeSSHRevoke func(ctx context.Context, ott string) error } // GetID mock @@ -391,14 +415,6 @@ func (m *MockProvisioner) Init(c Config) error { return m.Merr } -// AuthorizeRevoke mock -func (m *MockProvisioner) AuthorizeRevoke(ott string) error { - if m.MauthorizeRevoke != nil { - return m.MauthorizeRevoke(ott) - } - return m.Merr -} - // AuthorizeSign mock func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]SignOption, error) { if m.MauthorizeSign != nil { @@ -407,10 +423,50 @@ func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]Sign return m.Mret1.([]SignOption), m.Merr } -// AuthorizeRenewal mock -func (m *MockProvisioner) AuthorizeRenewal(c *x509.Certificate) error { - if m.MauthorizeRenewal != nil { - return m.MauthorizeRenewal(c) +// AuthorizeRevoke mock +func (m *MockProvisioner) AuthorizeRevoke(ctx context.Context, ott string) error { + if m.MauthorizeRevoke != nil { + return m.MauthorizeRevoke(ctx, ott) + } + return m.Merr +} + +// AuthorizeRenew mock +func (m *MockProvisioner) AuthorizeRenew(ctx context.Context, c *x509.Certificate) error { + if m.MauthorizeRenew != nil { + return m.MauthorizeRenew(ctx, c) + } + return m.Merr +} + +// AuthorizeSSHSign mock +func (m *MockProvisioner) AuthorizeSSHSign(ctx context.Context, ott string) ([]SignOption, error) { + if m.MauthorizeSign != nil { + return m.MauthorizeSign(ctx, ott) + } + return m.Mret1.([]SignOption), m.Merr +} + +// AuthorizeSSHRenew mock +func (m *MockProvisioner) AuthorizeSSHRenew(ctx context.Context, ott string) (*ssh.Certificate, error) { + if m.MauthorizeRenew != nil { + return m.MauthorizeSSHRenew(ctx, ott) + } + return m.Mret1.(*ssh.Certificate), m.Merr +} + +// AuthorizeSSHRekey mock +func (m *MockProvisioner) AuthorizeSSHRekey(ctx context.Context, ott string) (*ssh.Certificate, []SignOption, error) { + if m.MauthorizeSSHRekey != nil { + return m.MauthorizeSSHRekey(ctx, ott) + } + return m.Mret1.(*ssh.Certificate), m.Mret2.([]SignOption), m.Merr +} + +// AuthorizeSSHRevoke mock +func (m *MockProvisioner) AuthorizeSSHRevoke(ctx context.Context, ott string) error { + if m.MauthorizeSSHRevoke != nil { + return m.MauthorizeSSHRevoke(ctx, ott) } return m.Merr } diff --git a/authority/provisioner/provisioner_test.go b/authority/provisioner/provisioner_test.go index d79c2b69..14e62769 100644 --- a/authority/provisioner/provisioner_test.go +++ b/authority/provisioner/provisioner_test.go @@ -2,6 +2,9 @@ package provisioner import ( "testing" + + "github.com/pkg/errors" + "github.com/smallstep/assert" ) func TestType_String(t *testing.T) { @@ -52,3 +55,49 @@ func TestSanitizeSSHUserPrincipal(t *testing.T) { }) } } + +func TestDefaultIdentityFunc(t *testing.T) { + type test struct { + p Interface + email string + err error + identity *Identity + } + tests := map[string]func(*testing.T) test{ + "fail/unsupported-provisioner": func(t *testing.T) test { + return test{ + p: &X5C{}, + err: errors.New("provisioner type '*provisioner.X5C' not supported by identity function"), + } + }, + "fail/bad-ssh-regex": func(t *testing.T) test { + return test{ + p: &OIDC{}, + email: "$%^#_>@smallstep.com", + err: errors.New("invalid principal '______' from email '$%^#_>@smallstep.com'"), + } + }, + "ok": func(t *testing.T) test { + return test{ + p: &OIDC{}, + email: "max.furman@smallstep.com", + identity: &Identity{Usernames: []string{"maxfurman", "max.furman@smallstep.com"}}, + } + }, + } + for name, get := range tests { + t.Run(name, func(t *testing.T) { + tc := get(t) + identity, err := DefaultIdentityFunc(tc.p, tc.email) + if err != nil { + if assert.NotNil(t, tc.err) { + assert.Equals(t, tc.err.Error(), err.Error()) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, identity.Usernames, tc.identity.Usernames) + } + } + }) + } +} diff --git a/authority/provisioner/sign_ssh_options.go b/authority/provisioner/sign_ssh_options.go index 06ddf697..ceb57105 100644 --- a/authority/provisioner/sign_ssh_options.go +++ b/authority/provisioner/sign_ssh_options.go @@ -107,6 +107,16 @@ func (o SSHOptions) match(got SSHOptions) error { return nil } +// sshCertPrincipalsModifier is an SSHCertificateModifier that sets the +// principals to the SSH certificate. +type sshCertPrincipalsModifier []string + +// Modify the ValidPrincipals value of the cert. +func (o sshCertPrincipalsModifier) Modify(cert *ssh.Certificate) error { + cert.ValidPrincipals = []string(o) + return nil +} + // sshCertificateKeyIDModifier is an SSHCertificateModifier that sets the given // Key ID in the SSH certificate. type sshCertificateKeyIDModifier string @@ -116,24 +126,16 @@ func (m sshCertificateKeyIDModifier) Modify(cert *ssh.Certificate) error { return nil } -// sshCertificateCertTypeModifier is an SSHCertificateModifier that sets the -// certificate type to the SSH certificate. -type sshCertificateCertTypeModifier string +// sshCertTypeModifier is an SSHCertificateModifier that sets the +// certificate type. +type sshCertTypeModifier string -func (m sshCertificateCertTypeModifier) Modify(cert *ssh.Certificate) error { +// Modify sets the CertType for the ssh certificate. +func (m sshCertTypeModifier) Modify(cert *ssh.Certificate) error { cert.CertType = sshCertTypeUInt32(string(m)) return nil } -// sshCertificatePrincipalsModifier is an SSHCertificateModifier that sets the -// principals to the SSH certificate. -type sshCertificatePrincipalsModifier []string - -func (m sshCertificatePrincipalsModifier) Modify(cert *ssh.Certificate) error { - cert.ValidPrincipals = []string(m) - return nil -} - // sshCertificateValidAfterModifier is an SSHCertificateModifier that sets the // ValidAfter in the SSH certificate. type sshCertificateValidAfterModifier uint64 diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index 4fa15a44..651cd136 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -237,10 +237,10 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, // Add modifiers from custom claims // FIXME: this is also set in the sign method using SSHOptions.Modify. if opts.CertType != "" { - signOptions = append(signOptions, sshCertificateCertTypeModifier(opts.CertType)) + signOptions = append(signOptions, sshCertTypeModifier(opts.CertType)) } if len(opts.Principals) > 0 { - signOptions = append(signOptions, sshCertificatePrincipalsModifier(opts.Principals)) + signOptions = append(signOptions, sshCertPrincipalsModifier(opts.Principals)) } t := now() if !opts.ValidAfter.IsZero() { diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index 4fc4dbe0..94018b55 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -636,9 +636,9 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) { assert.Equals(t, SSHOptions(v), *tc.claims.Step.SSH) case sshCertificateKeyIDModifier: assert.Equals(t, string(v), "foo") - case sshCertificateCertTypeModifier: + case sshCertTypeModifier: assert.Equals(t, string(v), tc.claims.Step.SSH.CertType) - case sshCertificatePrincipalsModifier: + case sshCertPrincipalsModifier: assert.Equals(t, []string(v), tc.claims.Step.SSH.Principals) case sshCertificateValidAfterModifier: assert.Equals(t, int64(v), tc.claims.Step.SSH.ValidAfter.RelativeTime(nw).Unix())