diff --git a/authority/provisioner/keystore.go b/authority/provisioner/keystore.go index 2f11114a..2cc2562c 100644 --- a/authority/provisioner/keystore.go +++ b/authority/provisioner/keystore.go @@ -18,7 +18,7 @@ const ( defaultCacheJitter = 1 * time.Hour ) -var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") +var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)") type keyStore struct { sync.RWMutex @@ -81,13 +81,13 @@ func (ks *keyStore) reload() { ks.Unlock() } +// nextReloadDuration would return the duration for the next rotation. If age is +// 0 it will randomly rotate between 0-12 hours, but every time we call to Get +// it will automatically rotate. func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { n := rand.Int63n(int64(ks.jitter)) age -= time.Duration(n) - if age < 0 { - age = 0 - } - return age + return abs(age) } func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { @@ -125,6 +125,10 @@ func getCacheJitter(age time.Duration) time.Duration { switch { case age > time.Hour: return defaultCacheJitter + case age == 0: + // Avoids a 0 jitter. The duration is not important as it will rotate + // automatically on each Get request. + return defaultCacheJitter default: return age / 3 } @@ -133,3 +137,13 @@ func getCacheJitter(age time.Duration) time.Duration { func getExpirationTime(age time.Duration) time.Time { return time.Now().Truncate(time.Second).Add(age) } + +// abs returns the absolute value of n using the two's complement form. +// +// It will overflow with math.MinInt64 in the same way a branching version +// would, this is not a problem because maxAgeRegex will block negative numbers +// and the logic will never produce that number. +func abs(n time.Duration) time.Duration { + y := n >> 63 + return (n ^ y) - y +} diff --git a/authority/provisioner/keystore_test.go b/authority/provisioner/keystore_test.go index 22d5be75..63c29a3b 100644 --- a/authority/provisioner/keystore_test.go +++ b/authority/provisioner/keystore_test.go @@ -91,6 +91,49 @@ func Test_keyStore(t *testing.T) { assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) } +func Test_keyStore_noCache(t *testing.T) { + srv := generateJWKServer(2) + defer srv.Close() + + ks, err := newKeyStore(srv.URL + "/no-cache") + assert.FatalError(t, err) + defer ks.Close() + ks.RLock() + keySet1 := ks.keySet + ks.RUnlock() + // The keys will rotate on Get. + // So we won't be able to find the cached ones + assert.Len(t, 2, keySet1.Keys) + assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID)) + assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + ks.RLock() + keySet2 := ks.keySet + ks.RUnlock() + if reflect.DeepEqual(keySet1, keySet2) { + t.Error("keyStore did not rotated") + } + + // The keys will rotate on Get. + // So we won't be able to find the cached ones + assert.Len(t, 2, keySet2.Keys) + assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID)) + assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID)) + assert.Len(t, 0, ks.Get("foobar")) + + // Check hits + resp, err := srv.Client().Get(srv.URL + "/hits") + assert.FatalError(t, err) + hits := struct { + Hits int `json:"hits"` + }{} + defer resp.Body.Close() + err = json.NewDecoder(resp.Body).Decode(&hits) + assert.FatalError(t, err) + assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) +} + func Test_keyStore_Get(t *testing.T) { srv := generateJWKServer(2) defer srv.Close() @@ -119,3 +162,31 @@ func Test_keyStore_Get(t *testing.T) { }) } } + +func Test_abs(t *testing.T) { + maxInt64 := time.Duration(1<<63 - 1) + minInt64 := time.Duration(-1 << 63) + type args struct { + n time.Duration + } + tests := []struct { + name string + args args + want time.Duration + }{ + {"ok", args{0}, 0}, + {"ok", args{-time.Hour}, time.Hour}, + {"ok", args{time.Hour}, time.Hour}, + {"ok maxInt64", args{maxInt64}, maxInt64}, + {"ok minInt64 + 1", args{minInt64 + 1}, maxInt64}, + {"overflow on minInt64", args{minInt64}, minInt64}, + {"overflow on minInt64", args{minInt64}, -minInt64}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := abs(tt.args.n); got != tt.want { + t.Errorf("abs() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/provisioner/utils_test.go b/authority/provisioner/utils_test.go index 40f4ab05..d74db2b6 100644 --- a/authority/provisioner/utils_test.go +++ b/authority/provisioner/utils_test.go @@ -658,7 +658,7 @@ func generateJWKServer(n int) *httptest.Server { return ret } - defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) srv := httptest.NewUnstartedServer(nil) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits.Hits++ @@ -670,9 +670,13 @@ func generateJWKServer(n int) *httptest.Server { case "/openid-configuration", "/.well-known/openid-configuration": writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"}) case "/random": - keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) + keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) w.Header().Add("Cache-Control", "max-age=5") writeJSON(w, getPublic(keySet)) + case "/no-cache": + keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet) + w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + writeJSON(w, getPublic(keySet)) case "/private": writeJSON(w, defaultKeySet) default: