Complete tests for collection.
This commit is contained in:
parent
54d86ca1c1
commit
2a5430fee1
5 changed files with 223 additions and 89 deletions
|
@ -129,18 +129,15 @@ func (c *Collection) Store(p Interface) error {
|
|||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
sum, err := provisionerSum(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bi := make([]byte, 4)
|
||||
sum := provisionerSum(p)
|
||||
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||
c.sorted = append(c.sorted, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
sort.Sort(c.sorted)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -182,9 +179,9 @@ func loadProvisioner(m *sync.Map, key string) (Interface, bool) {
|
|||
|
||||
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
|
||||
// create the unique and sorted id.
|
||||
func provisionerSum(p Interface) ([]byte, error) {
|
||||
func provisionerSum(p Interface) []byte {
|
||||
sum := sha1.Sum([]byte(p.GetID()))
|
||||
return sum[:], nil
|
||||
return sum[:]
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
|
|
|
@ -1,79 +1,17 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// func Test_newSortedProvisioners(t *testing.T) {
|
||||
// provisioners := make(List, 20)
|
||||
// for i := range provisioners {
|
||||
// provisioners[i] = generateProvisioner(t)
|
||||
// }
|
||||
|
||||
// ps, err := newSortedProvisioners(provisioners)
|
||||
// assert.FatalError(t, err)
|
||||
// prev := ""
|
||||
// for i, p := range ps {
|
||||
// if p.uid < prev {
|
||||
// t.Errorf("%s should be less that %s", p.uid, prev)
|
||||
// }
|
||||
// if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
||||
// t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
||||
// }
|
||||
// prev = p.uid
|
||||
// }
|
||||
// }
|
||||
|
||||
// func Test_provisionerSlice_Find(t *testing.T) {
|
||||
// trim := func(s string) string {
|
||||
// return strings.TrimLeft(s, "0")
|
||||
// }
|
||||
// provisioners := make([]*Provisioner, 20)
|
||||
// for i := range provisioners {
|
||||
// provisioners[i] = generateProvisioner(t)
|
||||
// }
|
||||
// ps, err := newSortedProvisioners(provisioners)
|
||||
// assert.FatalError(t, err)
|
||||
|
||||
// type args struct {
|
||||
// cursor string
|
||||
// limit int
|
||||
// }
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// p provisionerSlice
|
||||
// args args
|
||||
// want []*JWK
|
||||
// want1 string
|
||||
// }{
|
||||
// {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
||||
// {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
||||
// {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
||||
// {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
||||
// {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
||||
// {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
||||
// {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
||||
// {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
||||
// }
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
||||
// if !reflect.DeepEqual(got, tt.want) {
|
||||
// t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
||||
// }
|
||||
// if got1 != tt.want1 {
|
||||
// t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
|
||||
func TestCollection_Load(t *testing.T) {
|
||||
p, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
|
@ -141,11 +79,16 @@ func TestCollection_LoadByToken(t *testing.T) {
|
|||
t2, c2, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keys.Keys[0])
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t3, c3, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
t4, c4, err := parseToken(token)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
|
@ -164,6 +107,7 @@ func TestCollection_LoadByToken(t *testing.T) {
|
|||
{"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true},
|
||||
{"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true},
|
||||
{"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false},
|
||||
{"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -182,3 +126,188 @@ func TestCollection_LoadByToken(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
byID := new(sync.Map)
|
||||
byID.Store(p1.GetID(), p1)
|
||||
byID.Store(p2.GetID(), p2)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||
assert.FatalError(t, err)
|
||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ok1Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok1Ext},
|
||||
}
|
||||
ok2Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok2Ext},
|
||||
}
|
||||
notFoundCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{notFoundExt},
|
||||
}
|
||||
badCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{
|
||||
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||
},
|
||||
}
|
||||
|
||||
type fields struct {
|
||||
byID *sync.Map
|
||||
audiences []string
|
||||
}
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want Interface
|
||||
want1 bool
|
||||
}{
|
||||
{"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true},
|
||||
{"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true},
|
||||
{"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true},
|
||||
{"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false},
|
||||
{"badCert", fields{byID, testAudiences}, args{badCert}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Collection{
|
||||
byID: tt.fields.byID,
|
||||
audiences: tt.fields.audiences,
|
||||
}
|
||||
got, got1 := c.LoadByCertificate(tt.args.cert)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_LoadEncryptedKey(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p1))
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, c.Store(p2))
|
||||
|
||||
// Add oidc in byKey.
|
||||
// It should not happen.
|
||||
p2KeyID := p2.keyStore.keySet.Keys[0].KeyID
|
||||
c.byKey.Store(p2KeyID, p2)
|
||||
|
||||
type args struct {
|
||||
keyID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true},
|
||||
{"oidc", args{p2KeyID}, "", false},
|
||||
{"notFound", args{"not-found"}, "", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.LoadEncryptedKey(tt.args.keyID)
|
||||
if got != tt.want {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Store(t *testing.T) {
|
||||
c := NewCollection(testAudiences)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
p Interface
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok1", args{p1}, false},
|
||||
{"ok2", args{p2}, false},
|
||||
{"fail1", args{p1}, true},
|
||||
{"fail2", args{p2}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := c.Store(tt.args.p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollection_Find(t *testing.T) {
|
||||
c, err := generateCollection(10, 10)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
}
|
||||
toList := func(ps provisionerSlice) List {
|
||||
l := List{}
|
||||
for _, p := range ps {
|
||||
l = append(l, p.provisioner)
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cursor string
|
||||
limit int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want List
|
||||
want1 string
|
||||
}{
|
||||
{"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""},
|
||||
{"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)},
|
||||
{"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""},
|
||||
{"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)},
|
||||
{"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)},
|
||||
{"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""},
|
||||
{"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := c.Find(tt.args.cursor, tt.args.limit)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Collection.Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
|||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
uri string
|
||||
keys jose.JSONWebKeySet
|
||||
keySet jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func newKeyStore(uri string) (*keyStore, error) {
|
|||
}
|
||||
ks := &keyStore{
|
||||
uri: uri,
|
||||
keys: keys,
|
||||
keySet: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
}
|
||||
ks.timer = time.AfterFunc(age, ks.reload)
|
||||
|
@ -54,7 +54,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
|
|||
ks.reload()
|
||||
ks.RLock()
|
||||
}
|
||||
keys = ks.keys.Key(kid)
|
||||
keys = ks.keySet.Key(kid)
|
||||
ks.RUnlock()
|
||||
return
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func (ks *keyStore) reload() {
|
|||
next = ks.nextReloadDuration(defaultCacheJitter / 2)
|
||||
} else {
|
||||
ks.Lock()
|
||||
ks.keys = keys
|
||||
ks.keySet = keys
|
||||
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||
ks.Unlock()
|
||||
next = ks.nextReloadDuration(age)
|
||||
|
|
|
@ -198,19 +198,27 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisi
|
|||
func (o *provisionerExtensionOption) Option(Options) x509util.WithOption {
|
||||
return func(p x509util.Profile) error {
|
||||
crt := p.Subject()
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: o.Type,
|
||||
Name: []byte(o.Name),
|
||||
CredentialID: []byte(o.CredentialID),
|
||||
})
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error marshaling provisioner extension")
|
||||
return err
|
||||
}
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
})
|
||||
crt.ExtraExtensions = append(crt.ExtraExtensions, ext)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) {
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
})
|
||||
if err != nil {
|
||||
return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension")
|
||||
}
|
||||
return pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -124,7 +124,7 @@ func generateOIDC() (*OIDC, error) {
|
|||
JWKSetURI: "https://example.com/.well-known/jwks",
|
||||
},
|
||||
keyStore: &keyStore{
|
||||
keys: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
|
|
Loading…
Reference in a new issue