From bf364f0a5f31a7fea35b72cbc9a4759debbf2c7d Mon Sep 17 00:00:00 2001 From: Cristian Le Date: Fri, 30 Apr 2021 09:14:28 +0900 Subject: [PATCH] Draft: adding usernames to GetIdentityFunc --- authority/options.go | 2 +- authority/provisioner/oidc.go | 7 +------ authority/provisioner/oidc_test.go | 5 +++-- authority/provisioner/provisioner.go | 7 ++++--- 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/authority/options.go b/authority/options.go index 9594f989..9626f48e 100644 --- a/authority/options.go +++ b/authority/options.go @@ -47,7 +47,7 @@ func WithDatabase(db db.AuthDB) Option { // WithGetIdentityFunc sets a custom function to retrieve the identity from // an external resource. -func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string) (*provisioner.Identity, error)) Option { +func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, email string, usernames ...string) (*provisioner.Identity, error)) Option { return func(a *Authority) error { a.getIdentityFunc = fn return nil diff --git a/authority/provisioner/oidc.go b/authority/provisioner/oidc.go index 0a85875e..787de317 100644 --- a/authority/provisioner/oidc.go +++ b/authority/provisioner/oidc.go @@ -389,15 +389,10 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption // Get the identity using either the default identityFunc or one injected // externally. - iden, err := o.getIdentityFunc(ctx, o, claims.Email) + iden, err := o.getIdentityFunc(ctx, o, claims.Email, claims.PreferredUsername) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") } - // Reuse the contains function provided for simplicity - if !containsAllMembers(iden.Usernames, []string{claims.PreferredUsername}) { - // Add preferred_username to the identity's Username - iden.Usernames = append(iden.Usernames, claims.PreferredUsername) - } // Certificate templates. data := sshutil.CreateTemplateData(sshutil.UserCert, claims.Email, iden.Usernames) diff --git a/authority/provisioner/oidc_test.go b/authority/provisioner/oidc_test.go index b0e2f2f4..d203516c 100644 --- a/authority/provisioner/oidc_test.go +++ b/authority/provisioner/oidc_test.go @@ -500,12 +500,13 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) { assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p5.Init(config)) - p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p4.getIdentityFunc = func(ctx context.Context, p Interface, email string, usernames ...string) (*Identity, error) { return &Identity{Usernames: []string{"max", "mariano"}}, nil } - p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { + p5.getIdentityFunc = func(ctx context.Context, p Interface, email string, usernames ...string) (*Identity, error) { return nil, errors.New("force") } + // Additional test needed for empty usernames and duplicate email and usernames t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0]) assert.FatalError(t, err) diff --git a/authority/provisioner/provisioner.go b/authority/provisioner/provisioner.go index aed1900a..8cf42953 100644 --- a/authority/provisioner/provisioner.go +++ b/authority/provisioner/provisioner.go @@ -337,10 +337,10 @@ type Permissions struct { } // GetIdentityFunc is a function that returns an identity. -type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error) +type GetIdentityFunc func(ctx context.Context, p Interface, email string, usernames ...string) (*Identity, error) // DefaultIdentityFunc return a default identity depending on the provisioner type. -func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) { +func DefaultIdentityFunc(ctx context.Context, p Interface, email string, usernames ...string) (*Identity, error) { switch k := p.(type) { case *OIDC: // OIDC principals would be: @@ -351,13 +351,14 @@ func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Ident if !sshUserRegex.MatchString(name) { return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email) } - usernames := []string{name} + usernames := append(usernames, name) if i := strings.LastIndex(email, "@"); i >= 0 { if local := email[:i]; !strings.EqualFold(local, name) { usernames = append(usernames, local) } } usernames = append(usernames, email) + // Some remove duplicate function should be added return &Identity{ Usernames: usernames, }, nil