From cf2dba3efbef5c9ce091f884cd89373e4392da0a Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 8 Mar 2019 15:08:18 -0800 Subject: [PATCH] Add tests for keyStore. --- authority/provisioner/claims.go | 2 +- authority/provisioner/keystore.go | 27 ++++-- authority/provisioner/keystore_test.go | 122 +++++++++++++++++++++++++ authority/provisioner/utils_test.go | 59 ++++++++++++ 4 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 authority/provisioner/keystore_test.go diff --git a/authority/provisioner/claims.go b/authority/provisioner/claims.go index 119ed77b..09317529 100644 --- a/authority/provisioner/claims.go +++ b/authority/provisioner/claims.go @@ -120,7 +120,7 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) { return errors.New("duration cannot be nil") } if err = json.Unmarshal(data, &s); err != nil { - return errors.Wrapf(err, "error unmarshalling %s", data) + return errors.Wrapf(err, "error unmarshaling %s", data) } if _d, err = time.ParseDuration(s); err != nil { return errors.Wrapf(err, "error parsing %s as duration", s) diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index f00bf772..2c03b7ba 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -26,6 +26,7 @@ type keyStore struct { keySet jose.JSONWebKeySet timer *time.Timer expiry time.Time + jitter time.Duration } func newKeyStore(uri string) (*keyStore, error) { @@ -37,8 +38,10 @@ func newKeyStore(uri string) (*keyStore, error) { uri: uri, keySet: keys, expiry: getExpirationTime(age), + jitter: getCacheJitter(age), } - ks.timer = time.AfterFunc(age, ks.reload) + next := ks.nextReloadDuration(age) + ks.timer = time.AfterFunc(next, ks.reload) return ks, nil } @@ -63,13 +66,14 @@ func (ks *keyStore) reload() { var next time.Duration keys, age, err := getKeysFromJWKsURI(ks.uri) if err != nil { - next = ks.nextReloadDuration(defaultCacheJitter / 2) + next = ks.nextReloadDuration(ks.jitter / 2) } else { ks.Lock() ks.keySet = keys - ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() - ks.Unlock() + ks.expiry = getExpirationTime(age) + ks.jitter = getCacheJitter(age) next = ks.nextReloadDuration(age) + ks.Unlock() } ks.Lock() @@ -78,7 +82,7 @@ func (ks *keyStore) reload() { } func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { - n := rand.Int63n(int64(defaultCacheJitter)) + n := rand.Int63n(int64(ks.jitter)) age -= time.Duration(n) if age < 0 { age = 0 @@ -117,6 +121,15 @@ func getCacheAge(cacheControl string) time.Duration { return age } -func getExpirationTime(age time.Duration) time.Time { - return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() +func getCacheJitter(age time.Duration) time.Duration { + switch { + case age > time.Hour: + return defaultCacheJitter + default: + return age / 3 + } +} + +func getExpirationTime(age time.Duration) time.Time { + return time.Now().Round(time.Second).Add(age) } diff --git a/authority/provisioner/keystore_test.go b/authority/provisioner/keystore_test.go new file mode 100644 index 00000000..f392c01d --- /dev/null +++ b/authority/provisioner/keystore_test.go @@ -0,0 +1,122 @@ +package provisioner + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "github.com/smallstep/assert" + "github.com/smallstep/cli/jose" +) + +func Test_newKeyStore(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + ks, err := newKeyStore(srv.URL) + assert.FatalError(t, err) + defer ks.Close() + + type args struct { + uri string + } + tests := []struct { + name string + args args + want jose.JSONWebKeySet + wantErr bool + }{ + {"ok", args{srv.URL}, ks.keySet, false}, + {"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newKeyStore(tt.args.uri) + if (err != nil) != tt.wantErr { + t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err == nil { + if !reflect.DeepEqual(got.keySet, tt.want) { + t.Errorf("newKeyStore() = %v, want %v", got, tt.want) + } + got.Close() + } + }) + } +} + +func Test_keyStore(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + ks, err := newKeyStore(srv.URL + "/random") + assert.FatalError(t, err) + defer ks.Close() + ks.RLock() + keySet1 := ks.keySet + ks.RUnlock() + // Check contents + assert.Len(t, 2, keySet1.Keys) + assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID)) + assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + // Wait for rotation + time.Sleep(5 * time.Second) + + ks.RLock() + keySet2 := ks.keySet + ks.RUnlock() + if reflect.DeepEqual(keySet1, keySet2) { + t.Error("keyStore did not rotated") + } + + // Check contents + assert.Len(t, 2, keySet2.Keys) + assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID)) + assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + // Check hits + resp, err := srv.Client().Get(srv.URL + "/hits") + assert.FatalError(t, err) + hits := struct { + Hits int `json:"hits"` + }{} + defer resp.Body.Close() + err = json.NewDecoder(resp.Body).Decode(&hits) + assert.FatalError(t, err) + assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) +} + +func Test_keyStore_Get(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + ks, err := newKeyStore(srv.URL) + assert.FatalError(t, err) + defer ks.Close() + + type args struct { + kid string + } + tests := []struct { + name string + ks *keyStore + args args + wantKeys []jose.JSONWebKey + }{ + {"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}}, + {"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}}, + {"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + println(tt.name) + if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) { + t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys) + } + }) + } +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index b8a66de2..458f8111 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -4,6 +4,9 @@ import ( "crypto" "encoding/hex" "encoding/json" + "fmt" + "net/http" + "net/http/httptest" "time" "github.com/smallstep/cli/crypto/randutil" @@ -17,6 +20,15 @@ var testAudiences = []string{ "https://ca.smallsteomcom/1.0/sign", } +func must(args ...interface{}) []interface{} { + if l := len(args); l > 0 && args[l-1] != nil { + if err, ok := args[l-1].(error); ok { + panic(err) + } + } + return args +} + func generateJSONWebKey() (*jose.JSONWebKey, error) { jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) if err != nil { @@ -30,6 +42,18 @@ func generateJSONWebKey() (*jose.JSONWebKey, error) { return jwk, nil } +func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) { + var keySet jose.JSONWebKeySet + for i := 0; i < n; i++ { + key, err := generateJSONWebKey() + if err != nil { + return jose.JSONWebKeySet{}, err + } + keySet.Keys = append(keySet.Keys, key.Public()) + } + return keySet, nil +} + func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) { b, err := json.Marshal(jwk) if err != nil { @@ -206,3 +230,38 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) { } return tok, claims, nil } + +func generateJWKServer(n int) *httptest.Server { + hits := struct { + Hits int `json:"hits"` + }{} + writeJSON := func(w http.ResponseWriter, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) + } + // keySet, err := generateJSONWebKeySet(n) + defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits.Hits++ + switch r.RequestURI { + case "/error": + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + case "/hits": + writeJSON(w, hits) + case "/random": + keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + fmt.Println(keySet) + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, keySet) + default: + w.Header().Add("Cache-Control", "max-age=5") + writeJSON(w, defaultKeySet) + } + })) +}