Add domains and check emails properly.

This commit is contained in:
Mariano Cano 2019-03-15 13:49:50 -07:00
parent 5edbce017f
commit 60880d1f0a
4 changed files with 91 additions and 21 deletions

View file

@ -119,7 +119,7 @@ func TestJWK_Authorize(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2) t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
assert.FatalError(t, err) assert.FatalError(t, err)
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], []string{}, time.Now(), key1) t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], "", []string{}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Invalid tokens // Invalid tokens
@ -142,13 +142,13 @@ func TestJWK_Authorize(t *testing.T) {
// invalid signature // invalid signature
failSig := t1[0 : len(t1)-2] failSig := t1[0 : len(t1)-2]
// no subject // no subject
failSub, err := generateToken("", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now(), key1) failSub, err := generateToken("", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// expired // expired
failExp, err := generateToken("subject", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1) failExp, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// not before // not before
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1) failNbf, err := generateToken("subject", p1.Name, testAudiences[0], "", []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Remove encrypted key for p2 // Remove encrypted key for p2

View file

@ -4,6 +4,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -49,8 +50,9 @@ type OIDC struct {
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"` ClientSecret string `json:"clientSecret"`
ConfigurationEndpoint string `json:"configurationEndpoint"` ConfigurationEndpoint string `json:"configurationEndpoint"`
Admins []string `json:"admins"`
Domains []string `json:"domains"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Admins []string `json:"admins,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
} }
@ -58,14 +60,22 @@ type OIDC struct {
// IsAdmin returns true if the given email is in the Admins whitelist, false // IsAdmin returns true if the given email is in the Admins whitelist, false
// otherwise. // otherwise.
func (o *OIDC) IsAdmin(email string) bool { func (o *OIDC) IsAdmin(email string) bool {
email = sanitizeEmail(email)
for _, e := range o.Admins { for _, e := range o.Admins {
if e == email { if email == sanitizeEmail(e) {
return true return true
} }
} }
return false return false
} }
func sanitizeEmail(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i] + strings.ToLower(email[i:])
}
return email
}
// GetID returns the provisioner unique identifier, the OIDC provisioner the // GetID returns the provisioner unique identifier, the OIDC provisioner the
// uses the clientID for this. // uses the clientID for this.
func (o *OIDC) GetID() string { func (o *OIDC) GetID() string {
@ -130,9 +140,32 @@ func (o *OIDC) ValidatePayload(p openIDPayload) error {
}, time.Minute); err != nil { }, time.Minute); err != nil {
return errors.Wrap(err, "failed to validate payload") return errors.Wrap(err, "failed to validate payload")
} }
// Validate azp if present
if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID { if p.AuthorizedParty != "" && p.AuthorizedParty != o.ClientID {
return errors.New("failed to validate payload: invalid azp") return errors.New("failed to validate payload: invalid azp")
} }
// Enforce an email claim
if p.Email == "" {
return errors.New("failed to validate payload: email not found")
}
// Validate domains (case-insensitive)
if !o.IsAdmin(p.Email) && len(o.Domains) > 0 {
email := sanitizeEmail(p.Email)
var found bool
for _, d := range o.Domains {
if strings.HasSuffix(email, "@"+strings.ToLower(d)) {
found = true
break
}
}
if !found {
return errors.New("failed to validate payload: email is not allowed")
}
}
return nil return nil
} }

View file

@ -2,6 +2,7 @@ package provisioner
import ( import (
"crypto/x509" "crypto/x509"
"fmt"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -72,6 +73,7 @@ func TestOIDC_Init(t *testing.T) {
ConfigurationEndpoint string ConfigurationEndpoint string
Claims *Claims Claims *Claims
Admins []string Admins []string
Domains []string
} }
type args struct { type args struct {
config Config config Config
@ -82,14 +84,15 @@ func TestOIDC_Init(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil}, args{config}, false}, {"ok", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
{"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}}, args{config}, false}, {"ok-admins", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}, nil}, args{config}, false},
{"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil}, args{config}, false}, {"ok-domains", fields{"oidc", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, []string{"smallstep.com"}}, args{config}, false},
{"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil}, args{config}, true}, {"ok-no-secret", fields{"oidc", "name", "client-id", "", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, false},
{"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil}, args{config}, true}, {"no-name", fields{"oidc", "", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil}, args{config}, true}, {"no-type", fields{"", "name", "client-id", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil}, args{config}, true}, {"no-client-id", fields{"oidc", "name", "", "client-secret", srv.URL + "/openid-configuration", nil, nil, nil}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil}, args{config}, true}, {"no-configuration", fields{"oidc", "name", "client-id", "client-secret", "", nil, nil, nil}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", "client-secret", srv.URL, nil, nil, nil}, args{config}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -129,8 +132,9 @@ func TestOIDC_Authorize(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
p3, err := generateOIDC() p3, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
// Admin // Admin + Domains
p3.Admins = []string{"name@smallstep.com"} p3.Admins = []string{"name@smallstep.com", "root@example.com"}
p3.Domains = []string{"smallstep.com"}
// Update configuration endpoints and initialize // Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims} config := Config{Claims: globalProvisionerClaims}
@ -148,6 +152,15 @@ func TestOIDC_Authorize(t *testing.T) {
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0]) t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
// Admin email not in domains
okAdmin, err := generateToken("subject", "the-issuer", p3.ClientID, "root@example.com", []string{"test.smallstep.com"}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid email
failEmail, err := generateToken("subject", "the-issuer", p3.ClientID, "", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
failDomain, err := generateToken("subject", "the-issuer", p3.ClientID, "name@example.com", []string{}, time.Now(), &keys.Keys[0])
assert.FatalError(t, err)
// Invalid tokens // Invalid tokens
parts := strings.Split(t1, ".") parts := strings.Split(t1, ".")
key, err := generateJSONWebKey() key, err := generateJSONWebKey()
@ -168,10 +181,10 @@ func TestOIDC_Authorize(t *testing.T) {
// invalid signature // invalid signature
failSig := t1[0 : len(t1)-2] failSig := t1[0 : len(t1)-2]
// expired // expired
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0]) failExp, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
// not before // not before
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, []string{}, time.Now().Add(360*time.Second), &keys.Keys[0]) failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, "name@smallstep.com", []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -186,6 +199,9 @@ func TestOIDC_Authorize(t *testing.T) {
{"ok1", p1, args{t1}, false}, {"ok1", p1, args{t1}, false},
{"ok2", p2, args{t2}, false}, {"ok2", p2, args{t2}, false},
{"admin", p3, args{t3}, false}, {"admin", p3, args{t3}, false},
{"admin", p3, args{okAdmin}, false},
{"fail-email", p3, args{failEmail}, true},
{"fail-domain", p3, args{failDomain}, true},
{"fail-key", p1, args{failKey}, true}, {"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true}, {"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true}, {"fail-claims", p1, args{failClaims}, true},
@ -199,6 +215,7 @@ func TestOIDC_Authorize(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token) got, err := tt.prov.Authorize(tt.args.token)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
fmt.Println(tt)
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
@ -288,3 +305,23 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
}) })
} }
} }
func Test_sanitizeEmail(t *testing.T) {
tests := []struct {
name string
email string
want string
}{
{"equal", "name@smallstep.com", "name@smallstep.com"},
{"domain-insensitive", "name@SMALLSTEP.COM", "name@smallstep.com"},
{"local-sensitive", "NaMe@smallSTEP.CoM", "NaMe@smallstep.com"},
{"multiple-@", "NaMe@NaMe@smallSTEP.CoM", "NaMe@NaMe@smallstep.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := sanitizeEmail(tt.email); got != tt.want {
t.Errorf("sanitizeEmail() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -173,10 +173,10 @@ func generateCollection(nJWK, nOIDC int) (*Collection, error) {
} }
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
return generateToken("subject", iss, aud, []string{"test.smallstep.com"}, time.Now(), jwk) return generateToken("subject", iss, aud, "name@smallstep.com", []string{"test.smallstep.com"}, time.Now(), jwk)
} }
func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -204,8 +204,8 @@ func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud}, Audience: []string{aud},
}, },
Email: email,
SANS: sans, SANS: sans,
Email: "name@smallstep.com",
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }