Complete tests for collection.

This commit is contained in:
Mariano Cano 2019-03-08 12:19:44 -08:00
parent 54d86ca1c1
commit 2a5430fee1
5 changed files with 223 additions and 89 deletions

View file

@ -129,18 +129,15 @@ func (c *Collection) Store(p Interface) error {
// Use the first 4 bytes (32bit) of the sum to insert the order // Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted: // Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ... // 0x00000000, 0x00000001, 0x00000002, ...
sum, err := provisionerSum(p)
if err != nil {
return err
}
bi := make([]byte, 4) bi := make([]byte, 4)
sum := provisionerSum(p)
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len())) binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
c.sorted = append(c.sorted, uidProvisioner{ c.sorted = append(c.sorted, uidProvisioner{
provisioner: p, provisioner: p,
uid: hex.EncodeToString(sum), uid: hex.EncodeToString(sum),
}) })
sort.Sort(c.sorted)
return nil return nil
} }
@ -182,9 +179,9 @@ func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
// provisionerSum returns the SHA1 of the provisioners ID. From this we will // provisionerSum returns the SHA1 of the provisioners ID. From this we will
// create the unique and sorted id. // create the unique and sorted id.
func provisionerSum(p Interface) ([]byte, error) { func provisionerSum(p Interface) []byte {
sum := sha1.Sum([]byte(p.GetID())) sum := sha1.Sum([]byte(p.GetID()))
return sum[:], nil return sum[:]
} }
// matchesAudience returns true if A and B share at least one element. // matchesAudience returns true if A and B share at least one element.

View file

@ -1,79 +1,17 @@
package provisioner package provisioner
import ( import (
"crypto/x509"
"crypto/x509/pkix"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
// func Test_newSortedProvisioners(t *testing.T) {
// provisioners := make(List, 20)
// for i := range provisioners {
// provisioners[i] = generateProvisioner(t)
// }
// ps, err := newSortedProvisioners(provisioners)
// assert.FatalError(t, err)
// prev := ""
// for i, p := range ps {
// if p.uid < prev {
// t.Errorf("%s should be less that %s", p.uid, prev)
// }
// if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
// t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
// }
// prev = p.uid
// }
// }
// func Test_provisionerSlice_Find(t *testing.T) {
// trim := func(s string) string {
// return strings.TrimLeft(s, "0")
// }
// provisioners := make([]*Provisioner, 20)
// for i := range provisioners {
// provisioners[i] = generateProvisioner(t)
// }
// ps, err := newSortedProvisioners(provisioners)
// assert.FatalError(t, err)
// type args struct {
// cursor string
// limit int
// }
// tests := []struct {
// name string
// p provisionerSlice
// args args
// want []*JWK
// want1 string
// }{
// {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
// {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
// {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
// {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
// {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
// {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
// {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
// {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
// }
// for _, tt := range tests {
// t.Run(tt.name, func(t *testing.T) {
// got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
// if !reflect.DeepEqual(got, tt.want) {
// t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
// }
// if got1 != tt.want1 {
// t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
// }
// })
// }
// }
func TestCollection_Load(t *testing.T) { func TestCollection_Load(t *testing.T) {
p, err := generateJWK() p, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -141,11 +79,16 @@ func TestCollection_LoadByToken(t *testing.T) {
t2, c2, err := parseToken(token) t2, c2, err := parseToken(token)
assert.FatalError(t, err) assert.FatalError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keys.Keys[0]) token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
t3, c3, err := parseToken(token) t3, c3, err := parseToken(token)
assert.FatalError(t, err) assert.FatalError(t, err)
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
assert.FatalError(t, err)
t4, c4, err := parseToken(token)
assert.FatalError(t, err)
type fields struct { type fields struct {
byID *sync.Map byID *sync.Map
audiences []string audiences []string
@ -164,6 +107,7 @@ func TestCollection_LoadByToken(t *testing.T) {
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true}, {"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false}, {"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -182,3 +126,188 @@ func TestCollection_LoadByToken(t *testing.T) {
}) })
} }
} }
func TestCollection_LoadByCertificate(t *testing.T) {
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
byID := new(sync.Map)
byID.Store(p1.GetID(), p1)
byID.Store(p2.GetID(), p2)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext},
}
ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext},
}
notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt},
}
badCert := &x509.Certificate{
Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
},
}
type fields struct {
byID *sync.Map
audiences []string
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
want Interface
want1 bool
}{
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Collection{
byID: tt.fields.byID,
audiences: tt.fields.audiences,
}
got, got1 := c.LoadByCertificate(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_LoadEncryptedKey(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p1))
p2, err := generateOIDC()
assert.FatalError(t, err)
assert.FatalError(t, c.Store(p2))
// Add oidc in byKey.
// It should not happen.
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
c.byKey.Store(p2KeyID, p2)
type args struct {
keyID string
}
tests := []struct {
name string
args args
want string
want1 bool
}{
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
{"oidc", args{p2KeyID}, "", false},
{"notFound", args{"not-found"}, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
if got != tt.want {
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
}
})
}
}
func TestCollection_Store(t *testing.T) {
c := NewCollection(testAudiences)
p1, err := generateJWK()
assert.FatalError(t, err)
p2, err := generateOIDC()
assert.FatalError(t, err)
type args struct {
p Interface
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok1", args{p1}, false},
{"ok2", args{p2}, false},
{"fail1", args{p1}, true},
{"fail2", args{p2}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestCollection_Find(t *testing.T) {
c, err := generateCollection(10, 10)
assert.FatalError(t, err)
trim := func(s string) string {
return strings.TrimLeft(s, "0")
}
toList := func(ps provisionerSlice) List {
l := List{}
for _, p := range ps {
l = append(l, p.provisioner)
}
return l
}
type args struct {
cursor string
limit int
}
tests := []struct {
name string
args args
want List
want1 string
}{
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

View file

@ -23,7 +23,7 @@ var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
type keyStore struct { type keyStore struct {
sync.RWMutex sync.RWMutex
uri string uri string
keys jose.JSONWebKeySet keySet jose.JSONWebKeySet
timer *time.Timer timer *time.Timer
expiry time.Time expiry time.Time
} }
@ -35,7 +35,7 @@ func newKeyStore(uri string) (*keyStore, error) {
} }
ks := &keyStore{ ks := &keyStore{
uri: uri, uri: uri,
keys: keys, keySet: keys,
expiry: getExpirationTime(age), expiry: getExpirationTime(age),
} }
ks.timer = time.AfterFunc(age, ks.reload) ks.timer = time.AfterFunc(age, ks.reload)
@ -54,7 +54,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
ks.reload() ks.reload()
ks.RLock() ks.RLock()
} }
keys = ks.keys.Key(kid) keys = ks.keySet.Key(kid)
ks.RUnlock() ks.RUnlock()
return return
} }
@ -66,7 +66,7 @@ func (ks *keyStore) reload() {
next = ks.nextReloadDuration(defaultCacheJitter / 2) next = ks.nextReloadDuration(defaultCacheJitter / 2)
} else { } else {
ks.Lock() ks.Lock()
ks.keys = keys ks.keySet = keys
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
ks.Unlock() ks.Unlock()
next = ks.nextReloadDuration(age) next = ks.nextReloadDuration(age)

View file

@ -198,19 +198,27 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisi
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption { func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
return func(p x509util.Profile) error { return func(p x509util.Profile) error {
crt := p.Subject() crt := p.Subject()
b, err := asn1.Marshal(stepProvisionerASN1{ ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
Type: o.Type,
Name: []byte(o.Name),
CredentialID: []byte(o.CredentialID),
})
if err != nil { if err != nil {
return errors.Wrapf(err, "error marshaling provisioner extension") return err
} }
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
Id: stepOIDProvisioner,
Critical: false,
Value: b,
})
return nil return nil
} }
} }
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
})
if err != nil {
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}

View file

@ -124,7 +124,7 @@ func generateOIDC() (*OIDC, error) {
JWKSetURI: "https://example.com/.well-known/jwks", JWKSetURI: "https://example.com/.well-known/jwks",
}, },
keyStore: &keyStore{ keyStore: &keyStore{
keys: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
}, nil }, nil