Add tests for keyStore.
This commit is contained in:
parent
2a5430fee1
commit
cf2dba3efb
4 changed files with 202 additions and 8 deletions
|
@ -120,7 +120,7 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
|||
return errors.New("duration cannot be nil")
|
||||
}
|
||||
if err = json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
||||
return errors.Wrapf(err, "error unmarshaling %s", data)
|
||||
}
|
||||
if _d, err = time.ParseDuration(s); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||
|
|
|
@ -26,6 +26,7 @@ type keyStore struct {
|
|||
keySet jose.JSONWebKeySet
|
||||
timer *time.Timer
|
||||
expiry time.Time
|
||||
jitter time.Duration
|
||||
}
|
||||
|
||||
func newKeyStore(uri string) (*keyStore, error) {
|
||||
|
@ -37,8 +38,10 @@ func newKeyStore(uri string) (*keyStore, error) {
|
|||
uri: uri,
|
||||
keySet: keys,
|
||||
expiry: getExpirationTime(age),
|
||||
jitter: getCacheJitter(age),
|
||||
}
|
||||
ks.timer = time.AfterFunc(age, ks.reload)
|
||||
next := ks.nextReloadDuration(age)
|
||||
ks.timer = time.AfterFunc(next, ks.reload)
|
||||
return ks, nil
|
||||
}
|
||||
|
||||
|
@ -63,13 +66,14 @@ func (ks *keyStore) reload() {
|
|||
var next time.Duration
|
||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||
if err != nil {
|
||||
next = ks.nextReloadDuration(defaultCacheJitter / 2)
|
||||
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||
} else {
|
||||
ks.Lock()
|
||||
ks.keySet = keys
|
||||
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||
ks.Unlock()
|
||||
ks.expiry = getExpirationTime(age)
|
||||
ks.jitter = getCacheJitter(age)
|
||||
next = ks.nextReloadDuration(age)
|
||||
ks.Unlock()
|
||||
}
|
||||
|
||||
ks.Lock()
|
||||
|
@ -78,7 +82,7 @@ func (ks *keyStore) reload() {
|
|||
}
|
||||
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(defaultCacheJitter))
|
||||
n := rand.Int63n(int64(ks.jitter))
|
||||
age -= time.Duration(n)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
|
@ -117,6 +121,15 @@ func getCacheAge(cacheControl string) time.Duration {
|
|||
return age
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
||||
func getCacheJitter(age time.Duration) time.Duration {
|
||||
switch {
|
||||
case age > time.Hour:
|
||||
return defaultCacheJitter
|
||||
default:
|
||||
return age / 3
|
||||
}
|
||||
}
|
||||
|
||||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Round(time.Second).Add(age)
|
||||
}
|
||||
|
|
122
authority/provisioner/keystore_test.go
Normal file
122
authority/provisioner/keystore_test.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func Test_newKeyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
uri string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want jose.JSONWebKeySet
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{srv.URL}, ks.keySet, false},
|
||||
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := newKeyStore(tt.args.uri)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||
}
|
||||
got.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_keyStore(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
ks, err := newKeyStore(srv.URL + "/random")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet1.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Wait for rotation
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
ks.RUnlock()
|
||||
if reflect.DeepEqual(keySet1, keySet2) {
|
||||
t.Error("keyStore did not rotated")
|
||||
}
|
||||
|
||||
// Check contents
|
||||
assert.Len(t, 2, keySet2.Keys)
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
|
||||
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Check hits
|
||||
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||
assert.FatalError(t, err)
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
defer resp.Body.Close()
|
||||
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||
assert.FatalError(t, err)
|
||||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||
}
|
||||
|
||||
func Test_keyStore_Get(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
ks, err := newKeyStore(srv.URL)
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
|
||||
type args struct {
|
||||
kid string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
ks *keyStore
|
||||
args args
|
||||
wantKeys []jose.JSONWebKey
|
||||
}{
|
||||
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
|
||||
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
|
||||
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
println(tt.name)
|
||||
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
|
||||
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,6 +4,9 @@ import (
|
|||
"crypto"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
|
@ -17,6 +20,15 @@ var testAudiences = []string{
|
|||
"https://ca.smallsteomcom/1.0/sign",
|
||||
}
|
||||
|
||||
func must(args ...interface{}) []interface{} {
|
||||
if l := len(args); l > 0 && args[l-1] != nil {
|
||||
if err, ok := args[l-1].(error); ok {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
if err != nil {
|
||||
|
@ -30,6 +42,18 @@ func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
|||
return jwk, nil
|
||||
}
|
||||
|
||||
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
|
||||
var keySet jose.JSONWebKeySet
|
||||
for i := 0; i < n; i++ {
|
||||
key, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return jose.JSONWebKeySet{}, err
|
||||
}
|
||||
keySet.Keys = append(keySet.Keys, key.Public())
|
||||
}
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||
b, err := json.Marshal(jwk)
|
||||
if err != nil {
|
||||
|
@ -206,3 +230,38 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
|||
}
|
||||
return tok, claims, nil
|
||||
}
|
||||
|
||||
func generateJWKServer(n int) *httptest.Server {
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(b)
|
||||
}
|
||||
// keySet, err := generateJSONWebKeySet(n)
|
||||
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits.Hits++
|
||||
switch r.RequestURI {
|
||||
case "/error":
|
||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||
case "/hits":
|
||||
writeJSON(w, hits)
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
fmt.Println(keySet)
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, keySet)
|
||||
default:
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, defaultKeySet)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue