diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 98b69f9f..0998adb7 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -129,18 +129,15 @@ func (c *Collection) Store(p Interface) error { // Use the first 4 bytes (32bit) of the sum to insert the order // Using big endian format to get the strings sorted: // 0x00000000, 0x00000001, 0x00000002, ... - sum, err := provisionerSum(p) - if err != nil { - return err - } bi := make([]byte, 4) + sum := provisionerSum(p) binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len())) sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] - bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 c.sorted = append(c.sorted, uidProvisioner{ provisioner: p, uid: hex.EncodeToString(sum), }) + sort.Sort(c.sorted) return nil } @@ -182,9 +179,9 @@ func loadProvisioner(m *sync.Map, key string) (Interface, bool) { // provisionerSum returns the SHA1 of the provisioners ID. From this we will // create the unique and sorted id. -func provisionerSum(p Interface) ([]byte, error) { +func provisionerSum(p Interface) []byte { sum := sha1.Sum([]byte(p.GetID())) - return sum[:], nil + return sum[:] } // matchesAudience returns true if A and B share at least one element. diff --git a/authority/provisioner/collection_test.go b/authority/provisioner/collection_test.go index 2948fd96..d4ff338e 100644 --- a/authority/provisioner/collection_test.go +++ b/authority/provisioner/collection_test.go @@ -1,79 +1,17 @@ package provisioner import ( + "crypto/x509" + "crypto/x509/pkix" "reflect" + "strings" "sync" "testing" "github.com/smallstep/assert" - "github.com/smallstep/cli/jose" ) -// func Test_newSortedProvisioners(t *testing.T) { -// provisioners := make(List, 20) -// for i := range provisioners { -// provisioners[i] = generateProvisioner(t) -// } - -// ps, err := newSortedProvisioners(provisioners) -// assert.FatalError(t, err) -// prev := "" -// for i, p := range ps { -// if p.uid < prev { -// t.Errorf("%s should be less that %s", p.uid, prev) -// } -// if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID { -// t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID) -// } -// prev = p.uid -// } -// } - -// func Test_provisionerSlice_Find(t *testing.T) { -// trim := func(s string) string { -// return strings.TrimLeft(s, "0") -// } -// provisioners := make([]*Provisioner, 20) -// for i := range provisioners { -// provisioners[i] = generateProvisioner(t) -// } -// ps, err := newSortedProvisioners(provisioners) -// assert.FatalError(t, err) - -// type args struct { -// cursor string -// limit int -// } -// tests := []struct { -// name string -// p provisionerSlice -// args args -// want []*JWK -// want1 string -// }{ -// {"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""}, -// {"0 to 19", ps, args{"", 20}, provisioners[0:20], ""}, -// {"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)}, -// {"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""}, -// {"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)}, -// {"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)}, -// {"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""}, -// {"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""}, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit) -// if !reflect.DeepEqual(got, tt.want) { -// t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want) -// } -// if got1 != tt.want1 { -// t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1) -// } -// }) -// } -// } - func TestCollection_Load(t *testing.T) { p, err := generateJWK() assert.FatalError(t, err) @@ -141,11 +79,16 @@ func TestCollection_LoadByToken(t *testing.T) { t2, c2, err := parseToken(token) assert.FatalError(t, err) - token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keys.Keys[0]) + token, err = generateSimpleToken(p3.configuration.Issuer, p3.ClientID, &p3.keyStore.keySet.Keys[0]) assert.FatalError(t, err) t3, c3, err := parseToken(token) assert.FatalError(t, err) + token, err = generateSimpleToken(p3.configuration.Issuer, "string", &p3.keyStore.keySet.Keys[0]) + assert.FatalError(t, err) + t4, c4, err := parseToken(token) + assert.FatalError(t, err) + type fields struct { byID *sync.Map audiences []string @@ -164,6 +107,7 @@ func TestCollection_LoadByToken(t *testing.T) { {"ok1", fields{byID, testAudiences}, args{t1, c1}, p1, true}, {"ok2", fields{byID, testAudiences}, args{t2, c2}, p2, true}, {"ok3", fields{byID, testAudiences}, args{t3, c3}, p3, true}, + {"bad", fields{byID, testAudiences}, args{t4, c4}, nil, false}, {"fail", fields{byID, []string{"https://foo"}}, args{t1, c1}, nil, false}, } for _, tt := range tests { @@ -182,3 +126,188 @@ func TestCollection_LoadByToken(t *testing.T) { }) } } + +func TestCollection_LoadByCertificate(t *testing.T) { + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + byID := new(sync.Map) + byID.Store(p1.GetID(), p1) + byID.Store(p2.GetID(), p2) + + ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID) + assert.FatalError(t, err) + ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID) + assert.FatalError(t, err) + notFoundExt, err := createProvisionerExtension(1, "foo", "bar") + assert.FatalError(t, err) + + ok1Cert := &x509.Certificate{ + Extensions: []pkix.Extension{ok1Ext}, + } + ok2Cert := &x509.Certificate{ + Extensions: []pkix.Extension{ok2Ext}, + } + notFoundCert := &x509.Certificate{ + Extensions: []pkix.Extension{notFoundExt}, + } + badCert := &x509.Certificate{ + Extensions: []pkix.Extension{ + {Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, + }, + } + + type fields struct { + byID *sync.Map + audiences []string + } + type args struct { + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + want Interface + want1 bool + }{ + {"ok1", fields{byID, testAudiences}, args{ok1Cert}, p1, true}, + {"ok2", fields{byID, testAudiences}, args{ok2Cert}, p2, true}, + {"noExtension", fields{byID, testAudiences}, args{&x509.Certificate{}}, &noop{}, true}, + {"notFound", fields{byID, testAudiences}, args{notFoundCert}, nil, false}, + {"badCert", fields{byID, testAudiences}, args{badCert}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Collection{ + byID: tt.fields.byID, + audiences: tt.fields.audiences, + } + got, got1 := c.LoadByCertificate(tt.args.cert) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.LoadByCertificate() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadByCertificate() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_LoadEncryptedKey(t *testing.T) { + c := NewCollection(testAudiences) + p1, err := generateJWK() + assert.FatalError(t, err) + assert.FatalError(t, c.Store(p1)) + p2, err := generateOIDC() + assert.FatalError(t, err) + assert.FatalError(t, c.Store(p2)) + + // Add oidc in byKey. + // It should not happen. + p2KeyID := p2.keyStore.keySet.Keys[0].KeyID + c.byKey.Store(p2KeyID, p2) + + type args struct { + keyID string + } + tests := []struct { + name string + args args + want string + want1 bool + }{ + {"ok", args{p1.Key.KeyID}, p1.EncryptedKey, true}, + {"oidc", args{p2KeyID}, "", false}, + {"notFound", args{"not-found"}, "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := c.LoadEncryptedKey(tt.args.keyID) + if got != tt.want { + t.Errorf("Collection.LoadEncryptedKey() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.LoadEncryptedKey() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestCollection_Store(t *testing.T) { + c := NewCollection(testAudiences) + p1, err := generateJWK() + assert.FatalError(t, err) + p2, err := generateOIDC() + assert.FatalError(t, err) + + type args struct { + p Interface + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok1", args{p1}, false}, + {"ok2", args{p2}, false}, + {"fail1", args{p1}, true}, + {"fail2", args{p2}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := c.Store(tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Collection.Store() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCollection_Find(t *testing.T) { + c, err := generateCollection(10, 10) + assert.FatalError(t, err) + + trim := func(s string) string { + return strings.TrimLeft(s, "0") + } + toList := func(ps provisionerSlice) List { + l := List{} + for _, p := range ps { + l = append(l, p.provisioner) + } + return l + } + + type args struct { + cursor string + limit int + } + tests := []struct { + name string + args args + want List + want1 string + }{ + {"all", args{"", DefaultProvisionersMax}, toList(c.sorted[0:20]), ""}, + {"0 to 19", args{"", 20}, toList(c.sorted[0:20]), ""}, + {"0 to 9", args{"", 10}, toList(c.sorted[0:10]), trim(c.sorted[10].uid)}, + {"9 to 19", args{trim(c.sorted[10].uid), 10}, toList(c.sorted[10:20]), ""}, + {"1", args{trim(c.sorted[1].uid), 1}, toList(c.sorted[1:2]), trim(c.sorted[2].uid)}, + {"1 to 5", args{trim(c.sorted[1].uid), 4}, toList(c.sorted[1:5]), trim(c.sorted[5].uid)}, + {"defaultLimit", args{"", 0}, toList(c.sorted[0:20]), ""}, + {"overTheLimit", args{"", DefaultProvisionersMax + 1}, toList(c.sorted[0:20]), ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := c.Find(tt.args.cursor, tt.args.limit) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection.Find() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Collection.Find() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index 7e49b8d7..f00bf772 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -23,7 +23,7 @@ var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") type keyStore struct { sync.RWMutex uri string - keys jose.JSONWebKeySet + keySet jose.JSONWebKeySet timer *time.Timer expiry time.Time } @@ -35,7 +35,7 @@ func newKeyStore(uri string) (*keyStore, error) { } ks := &keyStore{ uri: uri, - keys: keys, + keySet: keys, expiry: getExpirationTime(age), } ks.timer = time.AfterFunc(age, ks.reload) @@ -54,7 +54,7 @@ func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) { ks.reload() ks.RLock() } - keys = ks.keys.Key(kid) + keys = ks.keySet.Key(kid) ks.RUnlock() return } @@ -66,7 +66,7 @@ func (ks *keyStore) reload() { next = ks.nextReloadDuration(defaultCacheJitter / 2) } else { ks.Lock() - ks.keys = keys + ks.keySet = keys ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() ks.Unlock() next = ks.nextReloadDuration(age) diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 33c4991b..d3448ef8 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -198,19 +198,27 @@ func newProvisionerExtensionOption(typ Type, name, credentialID string) *provisi func (o *provisionerExtensionOption) Option(Options) x509util.WithOption { return func(p x509util.Profile) error { crt := p.Subject() - b, err := asn1.Marshal(stepProvisionerASN1{ - Type: o.Type, - Name: []byte(o.Name), - CredentialID: []byte(o.CredentialID), - }) + ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID) if err != nil { - return errors.Wrapf(err, "error marshaling provisioner extension") + return err } - crt.ExtraExtensions = append(crt.ExtraExtensions, pkix.Extension{ - Id: stepOIDProvisioner, - Critical: false, - Value: b, - }) + crt.ExtraExtensions = append(crt.ExtraExtensions, ext) return nil } } + +func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extension, error) { + b, err := asn1.Marshal(stepProvisionerASN1{ + Type: typ, + Name: []byte(name), + CredentialID: []byte(credentialID), + }) + if err != nil { + return pkix.Extension{}, errors.Wrapf(err, "error marshaling provisioner extension") + } + return pkix.Extension{ + Id: stepOIDProvisioner, + Critical: false, + Value: b, + }, nil +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 364555fc..b8a66de2 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -124,7 +124,7 @@ func generateOIDC() (*OIDC, error) { JWKSetURI: "https://example.com/.well-known/jwks", }, keyStore: &keyStore{ - keys: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, + keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, expiry: time.Now().Add(24 * time.Hour), }, }, nil