parent
f12e2dedd5
commit
e66272d6f0
3 changed files with 96 additions and 7 deletions
|
@ -18,7 +18,7 @@ const (
|
||||||
defaultCacheJitter = 1 * time.Hour
|
defaultCacheJitter = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
|
||||||
|
|
||||||
type keyStore struct {
|
type keyStore struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
@ -81,13 +81,13 @@ func (ks *keyStore) reload() {
|
||||||
ks.Unlock()
|
ks.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nextReloadDuration would return the duration for the next rotation. If age is
|
||||||
|
// 0 it will randomly rotate between 0-12 hours, but every time we call to Get
|
||||||
|
// it will automatically rotate.
|
||||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||||
n := rand.Int63n(int64(ks.jitter))
|
n := rand.Int63n(int64(ks.jitter))
|
||||||
age -= time.Duration(n)
|
age -= time.Duration(n)
|
||||||
if age < 0 {
|
return abs(age)
|
||||||
age = 0
|
|
||||||
}
|
|
||||||
return age
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||||
|
@ -125,6 +125,10 @@ func getCacheJitter(age time.Duration) time.Duration {
|
||||||
switch {
|
switch {
|
||||||
case age > time.Hour:
|
case age > time.Hour:
|
||||||
return defaultCacheJitter
|
return defaultCacheJitter
|
||||||
|
case age == 0:
|
||||||
|
// Avoids a 0 jitter. The duration is not important as it will rotate
|
||||||
|
// automatically on each Get request.
|
||||||
|
return defaultCacheJitter
|
||||||
default:
|
default:
|
||||||
return age / 3
|
return age / 3
|
||||||
}
|
}
|
||||||
|
@ -133,3 +137,13 @@ func getCacheJitter(age time.Duration) time.Duration {
|
||||||
func getExpirationTime(age time.Duration) time.Time {
|
func getExpirationTime(age time.Duration) time.Time {
|
||||||
return time.Now().Truncate(time.Second).Add(age)
|
return time.Now().Truncate(time.Second).Add(age)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// abs returns the absolute value of n using the two's complement form.
|
||||||
|
//
|
||||||
|
// It will overflow with math.MinInt64 in the same way a branching version
|
||||||
|
// would, this is not a problem because maxAgeRegex will block negative numbers
|
||||||
|
// and the logic will never produce that number.
|
||||||
|
func abs(n time.Duration) time.Duration {
|
||||||
|
y := n >> 63
|
||||||
|
return (n ^ y) - y
|
||||||
|
}
|
||||||
|
|
|
@ -91,6 +91,49 @@ func Test_keyStore(t *testing.T) {
|
||||||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_keyStore_noCache(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ks, err := newKeyStore(srv.URL + "/no-cache")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
ks.RLock()
|
||||||
|
keySet1 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
// The keys will rotate on Get.
|
||||||
|
// So we won't be able to find the cached ones
|
||||||
|
assert.Len(t, 2, keySet1.Keys)
|
||||||
|
assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
ks.RLock()
|
||||||
|
keySet2 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
if reflect.DeepEqual(keySet1, keySet2) {
|
||||||
|
t.Error("keyStore did not rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The keys will rotate on Get.
|
||||||
|
// So we won't be able to find the cached ones
|
||||||
|
assert.Len(t, 2, keySet2.Keys)
|
||||||
|
assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
// Check hits
|
||||||
|
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hits := struct {
|
||||||
|
Hits int `json:"hits"`
|
||||||
|
}{}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||||
|
}
|
||||||
|
|
||||||
func Test_keyStore_Get(t *testing.T) {
|
func Test_keyStore_Get(t *testing.T) {
|
||||||
srv := generateJWKServer(2)
|
srv := generateJWKServer(2)
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
@ -119,3 +162,31 @@ func Test_keyStore_Get(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_abs(t *testing.T) {
|
||||||
|
maxInt64 := time.Duration(1<<63 - 1)
|
||||||
|
minInt64 := time.Duration(-1 << 63)
|
||||||
|
type args struct {
|
||||||
|
n time.Duration
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{"ok", args{0}, 0},
|
||||||
|
{"ok", args{-time.Hour}, time.Hour},
|
||||||
|
{"ok", args{time.Hour}, time.Hour},
|
||||||
|
{"ok maxInt64", args{maxInt64}, maxInt64},
|
||||||
|
{"ok minInt64 + 1", args{minInt64 + 1}, maxInt64},
|
||||||
|
{"overflow on minInt64", args{minInt64}, minInt64},
|
||||||
|
{"overflow on minInt64", args{minInt64}, -minInt64},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := abs(tt.args.n); got != tt.want {
|
||||||
|
t.Errorf("abs() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -658,7 +658,7 @@ func generateJWKServer(n int) *httptest.Server {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||||
srv := httptest.NewUnstartedServer(nil)
|
srv := httptest.NewUnstartedServer(nil)
|
||||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
hits.Hits++
|
hits.Hits++
|
||||||
|
@ -670,9 +670,13 @@ func generateJWKServer(n int) *httptest.Server {
|
||||||
case "/openid-configuration", "/.well-known/openid-configuration":
|
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||||
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||||
case "/random":
|
case "/random":
|
||||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||||
w.Header().Add("Cache-Control", "max-age=5")
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
writeJSON(w, getPublic(keySet))
|
writeJSON(w, getPublic(keySet))
|
||||||
|
case "/no-cache":
|
||||||
|
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||||
|
w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
|
||||||
|
writeJSON(w, getPublic(keySet))
|
||||||
case "/private":
|
case "/private":
|
||||||
writeJSON(w, defaultKeySet)
|
writeJSON(w, defaultKeySet)
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Reference in a new issue