Add tests for OIDC and complete some JWK tests.

This commit is contained in:
Mariano Cano 2019-03-11 12:48:46 -07:00
parent dce3100cfb
commit 4ceb88fbae
4 changed files with 354 additions and 22 deletions

View file

@ -111,11 +111,16 @@ func TestJWK_Authorize(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2) t2, err := generateSimpleToken(p2.Name, testAudiences[1], key2)
assert.FatalError(t, err) assert.FatalError(t, err)
t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], []string{}, key1) t3, err := generateToken("test.smallstep.com", p1.Name, testAudiences[0], []string{}, time.Now(), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Invalid tokens // Invalid tokens
parts := strings.Split(t1, ".") parts := strings.Split(t1, ".")
key3, err := generateJSONWebKey()
assert.FatalError(t, err)
// missing key
failKey, err := generateSimpleToken(p1.Name, testAudiences[0], key3)
assert.FatalError(t, err)
// invalid token // invalid token
failTok := "foo." + parts[1] + "." + parts[2] failTok := "foo." + parts[1] + "." + parts[2]
// invalid claims // invalid claims
@ -129,7 +134,13 @@ func TestJWK_Authorize(t *testing.T) {
// invalid signature // invalid signature
failSig := t1[0 : len(t1)-2] failSig := t1[0 : len(t1)-2]
// no subject // no subject
failSub, err := generateToken("", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, key1) failSub, err := generateToken("", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now(), key1)
assert.FatalError(t, err)
// expired
failExp, err := generateToken("subject", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now().Add(-360*time.Second), key1)
assert.FatalError(t, err)
// not before
failNbf, err := generateToken("subject", p1.Name, testAudiences[0], []string{"test.smallstep.com"}, time.Now().Add(360*time.Second), key1)
assert.FatalError(t, err) assert.FatalError(t, err)
// Remove encrypted key for p2 // Remove encrypted key for p2
@ -147,12 +158,15 @@ func TestJWK_Authorize(t *testing.T) {
{"ok", p1, args{t1}, false}, {"ok", p1, args{t1}, false},
{"ok-no-encrypted-key", p2, args{t2}, false}, {"ok-no-encrypted-key", p2, args{t2}, false},
{"ok-no-sans", p1, args{t3}, false}, {"ok-no-sans", p1, args{t3}, false},
{"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true}, {"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true}, {"fail-claims", p1, args{failClaims}, true},
{"fail-issuer", p1, args{failIss}, true}, {"fail-issuer", p1, args{failIss}, true},
{"fail-audience", p1, args{failAud}, true}, {"fail-audience", p1, args{failAud}, true},
{"fail-signature", p1, args{failSig}, true}, {"fail-signature", p1, args{failSig}, true},
{"fail-subject", p1, args{failSub}, true}, {"fail-subject", p1, args{failSub}, true},
{"fail-expired", p1, args{failExp}, true},
{"fail-not-before", p1, args{failNbf}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -17,6 +17,18 @@ type openIDConfiguration struct {
JWKSetURI string `json:"jwks_uri"` JWKSetURI string `json:"jwks_uri"`
} }
// Validate validates the values in a well-known OpenID configuration endpoint.
func (c openIDConfiguration) Validate() error {
switch {
case c.Issuer == "":
return errors.New("issuer cannot be empty")
case c.JWKSetURI == "":
return errors.New("jwks_uri cannot be empty")
default:
return nil
}
}
// openIDPayload represents the fields on the id_token JWT payload. // openIDPayload represents the fields on the id_token JWT payload.
type openIDPayload struct { type openIDPayload struct {
jose.Claims jose.Claims
@ -87,12 +99,12 @@ func (o *OIDC) Init(config Config) (err error) {
if o.Claims, err = o.Claims.Init(&config.Claims); err != nil { if o.Claims, err = o.Claims.Init(&config.Claims); err != nil {
return err return err
} }
// Decode openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil { if err := getAndDecode(o.ConfigurationEndpoint, &o.configuration); err != nil {
return err return err
} }
if o.configuration.JWKSetURI == "" { if err := o.configuration.Validate(); err != nil {
return errors.Errorf("error parsing %s: jwks_uri cannot be empty", o.ConfigurationEndpoint) return errors.Wrapf(err, "error parsing %s", o.ConfigurationEndpoint)
} }
// Get JWK key set // Get JWK key set
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI) o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
@ -103,8 +115,6 @@ func (o *OIDC) Init(config Config) (err error) {
} }
// ValidatePayload validates the given token payload. // ValidatePayload validates the given token payload.
//
// TODO(mariano): avoid reply attacks validating nonce.
func (o *OIDC) ValidatePayload(p openIDPayload) error { func (o *OIDC) ValidatePayload(p openIDPayload) error {
// According to "rfc7519 JSON Web Token" acceptable skew should be no more // According to "rfc7519 JSON Web Token" acceptable skew should be no more
// than a few minutes. // than a few minutes.
@ -151,8 +161,13 @@ func (o *OIDC) Authorize(token string) ([]SignOption, error) {
return nil, err return nil, err
} }
// Admins should be able to authorize any SAN
if o.IsAdmin(claims.Email) { if o.IsAdmin(claims.Email) {
return []SignOption{}, nil return []SignOption{
profileDefaultDuration(o.Claims.DefaultTLSCertDuration()),
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
newValidityValidator(o.Claims.MinTLSCertDuration(), o.Claims.MaxTLSCertDuration()),
}, nil
} }
return []SignOption{ return []SignOption{

View file

@ -0,0 +1,287 @@
package provisioner
import (
"crypto/x509"
"strings"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func Test_openIDConfiguration_Validate(t *testing.T) {
type fields struct {
Issuer string
JWKSetURI string
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{"ok", fields{"the-issuer", "the-jwks-uri"}, false},
{"no-issuer", fields{"", "the-jwks-uri"}, true},
{"no-jwks-uri", fields{"the-issuer", ""}, true},
{"empty", fields{"", ""}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := openIDConfiguration{
Issuer: tt.fields.Issuer,
JWKSetURI: tt.fields.JWKSetURI,
}
if err := c.Validate(); (err != nil) != tt.wantErr {
t.Errorf("openIDConfiguration.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestOIDC_Getters(t *testing.T) {
p, err := generateOIDC()
assert.FatalError(t, err)
if got := p.GetID(); got != p.ClientID {
t.Errorf("OIDC.GetID() = %v, want %v", got, p.ClientID)
}
if got := p.GetName(); got != p.Name {
t.Errorf("OIDC.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeOIDC {
t.Errorf("OIDC.GetType() = %v, want %v", got, TypeOIDC)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("OIDC.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestOIDC_Init(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
config := Config{
Claims: globalProvisionerClaims,
}
type fields struct {
Type string
Name string
ClientID string
ConfigurationEndpoint string
Claims *Claims
Admins []string
}
type args struct {
config Config
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{"oidc", "name", "client-id", srv.URL + "/openid-configuration", nil, nil}, args{config}, false},
{"ok-admins", fields{"oidc", "name", "client-id", srv.URL + "/openid-configuration", nil, []string{"foo@smallstep.com"}}, args{config}, false},
{"no-name", fields{"oidc", "", "client-id", srv.URL + "/openid-configuration", nil, nil}, args{config}, true},
{"no-client-id", fields{"oidc", "name", "", srv.URL + "/openid-configuration", nil, nil}, args{config}, true},
{"no-configuration", fields{"oidc", "name", "client-id", "", nil, nil}, args{config}, true},
{"bad-configuration", fields{"oidc", "name", "client-id", srv.URL, nil, nil}, args{config}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &OIDC{
Type: tt.fields.Type,
Name: tt.fields.Name,
ClientID: tt.fields.ClientID,
ConfigurationEndpoint: tt.fields.ConfigurationEndpoint,
Claims: tt.fields.Claims,
Admins: tt.fields.Admins,
}
if err := p.Init(tt.args.config); (err != nil) != tt.wantErr {
t.Errorf("OIDC.Init() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr == false {
assert.Len(t, 2, p.keyStore.keySet.Keys)
assert.Equals(t, openIDConfiguration{
Issuer: "the-issuer",
JWKSetURI: srv.URL + "/jwks_uri",
}, p.configuration)
}
})
}
}
func TestOIDC_Authorize(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
p3, err := generateOIDC()
assert.FatalError(t, err)
// Admin
p3.Admins = []string{"name@smallstep.com"}
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p2.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
p3.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
assert.FatalError(t, p2.Init(config))
assert.FatalError(t, p3.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
t2, err := generateSimpleToken("the-issuer", p2.ClientID, &keys.Keys[1])
assert.FatalError(t, err)
t3, err := generateSimpleToken("the-issuer", p3.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// Invalid tokens
parts := strings.Split(t1, ".")
key, err := generateJSONWebKey()
assert.FatalError(t, err)
// missing key
failKey, err := generateSimpleToken("the-issuer", p1.ClientID, key)
assert.FatalError(t, err)
// invalid token
failTok := "foo." + parts[1] + "." + parts[2]
// invalid claims
failClaims := parts[0] + ".foo." + parts[1]
// invalid issuer
failIss, err := generateSimpleToken("bad-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
// invalid audience
failAud, err := generateSimpleToken("the-issuer", "foobar", &keys.Keys[0])
assert.FatalError(t, err)
// invalid signature
failSig := t1[0 : len(t1)-2]
// expired
failExp, err := generateToken("subject", "the-issuer", p1.ClientID, []string{}, time.Now().Add(-360*time.Second), &keys.Keys[0])
assert.FatalError(t, err)
// not before
failNbf, err := generateToken("subject", "the-issuer", p1.ClientID, []string{}, time.Now().Add(360*time.Second), &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok1", p1, args{t1}, false},
{"ok2", p2, args{t2}, false},
{"admin", p3, args{t3}, false},
{"fail-key", p1, args{failKey}, true},
{"fail-token", p1, args{failTok}, true},
{"fail-claims", p1, args{failClaims}, true},
{"fail-issuer", p1, args{failIss}, true},
{"fail-audience", p1, args{failAud}, true},
{"fail-signature", p1, args{failSig}, true},
{"fail-expired", p1, args{failExp}, true},
{"fail-not-before", p1, args{failNbf}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.prov.Authorize(tt.args.token)
if (err != nil) != tt.wantErr {
t.Errorf("OIDC.Authorize() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil {
assert.Nil(t, got)
} else {
assert.NotNil(t, got)
if tt.name == "admin" {
assert.Len(t, 3, got)
} else {
assert.Len(t, 4, got)
}
}
})
}
}
func TestOIDC_AuthorizeRenewal(t *testing.T) {
p1, err := generateOIDC()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{
globalClaims: &globalProvisionerClaims,
DisableRenewal: &disable,
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestOIDC_AuthorizeRevoke(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
var keys jose.JSONWebKeySet
assert.FatalError(t, getAndDecode(srv.URL+"/private", &keys))
// Create test provisioners
p1, err := generateOIDC()
assert.FatalError(t, err)
// Update configuration endpoints and initialize
config := Config{Claims: globalProvisionerClaims}
p1.ConfigurationEndpoint = srv.URL + "/.well-known/openid-configuration"
assert.FatalError(t, p1.Init(config))
t1, err := generateSimpleToken("the-issuer", p1.ClientID, &keys.Keys[0])
assert.FatalError(t, err)
type args struct {
token string
}
tests := []struct {
name string
prov *OIDC
args args
wantErr bool
}{
{"disabled", p1, args{t1}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRevoke(tt.args.token); (err != nil) != tt.wantErr {
t.Errorf("OIDC.AuthorizeRevoke() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -46,7 +46,7 @@ func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
if err != nil { if err != nil {
return jose.JSONWebKeySet{}, err return jose.JSONWebKeySet{}, err
} }
keySet.Keys = append(keySet.Keys, key.Public()) keySet.Keys = append(keySet.Keys, *key)
} }
return keySet, nil return keySet, nil
} }
@ -173,11 +173,10 @@ func generateCollection(nJWK, nOIDC int) (*Collection, error) {
} }
func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) { func generateSimpleToken(iss, aud string, jwk *jose.JSONWebKey) (string, error) {
return generateToken("subject", iss, aud, []string{"test.smallstep.com"}, jwk) return generateToken("subject", iss, aud, []string{"test.smallstep.com"}, time.Now(), jwk)
// return generateToken("the-sub", []string{"test.smallstep.com"}, jwk.KeyID, iss, aud, "testdata/root_ca.crt", now, now.Add(5*time.Minute), jwk)
} }
func generateToken(sub, iss, aud string, sans []string, jwk *jose.JSONWebKey) (string, error) { func generateToken(sub, iss, aud string, sans []string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -191,20 +190,22 @@ func generateToken(sub, iss, aud string, sans []string, jwk *jose.JSONWebKey) (s
return "", err return "", err
} }
now := time.Now()
claims := struct { claims := struct {
jose.Claims jose.Claims
Email string `json:"email"`
SANS []string `json:"sans"` SANS []string `json:"sans"`
}{ }{
Claims: jose.Claims{ Claims: jose.Claims{
ID: id, ID: id,
Subject: sub, Subject: sub,
Issuer: iss, Issuer: iss,
NotBefore: jose.NewNumericDate(now), IssuedAt: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), NotBefore: jose.NewNumericDate(iat),
Expiry: jose.NewNumericDate(iat.Add(5 * time.Minute)),
Audience: []string{aud}, Audience: []string{aud},
}, },
SANS: sans, SANS: sans,
Email: "name@smallstep.com",
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
@ -235,22 +236,37 @@ func generateJWKServer(n int) *httptest.Server {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write(b) w.Write(b)
} }
// keySet, err := generateJSONWebKeySet(n) getPublic := func(ks jose.JSONWebKeySet) jose.JSONWebKeySet {
var ret jose.JSONWebKeySet
for _, k := range ks.Keys {
ret.Keys = append(ret.Keys, k.Public())
}
return ret
}
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits.Hits++ hits.Hits++
switch r.RequestURI { switch r.RequestURI {
case "/error": case "/error":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/hits": case "/hits":
writeJSON(w, hits) writeJSON(w, hits)
case "/openid-configuration", "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
case "/random": case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, keySet) writeJSON(w, getPublic(keySet))
case "/private":
writeJSON(w, defaultKeySet)
default: default:
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, defaultKeySet) writeJSON(w, getPublic(defaultKeySet))
} }
})) })
srv.Start()
return srv
} }