forked from TrueCloudLab/certificates
parent
f12e2dedd5
commit
e66272d6f0
3 changed files with 96 additions and 7 deletions
|
@ -18,7 +18,7 @@ const (
|
|||
defaultCacheJitter = 1 * time.Hour
|
||||
)
|
||||
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)")
|
||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
|
||||
|
||||
type keyStore struct {
|
||||
sync.RWMutex
|
||||
|
@ -81,13 +81,13 @@ func (ks *keyStore) reload() {
|
|||
ks.Unlock()
|
||||
}
|
||||
|
||||
// nextReloadDuration would return the duration for the next rotation. If age is
|
||||
// 0 it will randomly rotate between 0-12 hours, but every time we call to Get
|
||||
// it will automatically rotate.
|
||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||
n := rand.Int63n(int64(ks.jitter))
|
||||
age -= time.Duration(n)
|
||||
if age < 0 {
|
||||
age = 0
|
||||
}
|
||||
return age
|
||||
return abs(age)
|
||||
}
|
||||
|
||||
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
|
||||
|
@ -125,6 +125,10 @@ func getCacheJitter(age time.Duration) time.Duration {
|
|||
switch {
|
||||
case age > time.Hour:
|
||||
return defaultCacheJitter
|
||||
case age == 0:
|
||||
// Avoids a 0 jitter. The duration is not important as it will rotate
|
||||
// automatically on each Get request.
|
||||
return defaultCacheJitter
|
||||
default:
|
||||
return age / 3
|
||||
}
|
||||
|
@ -133,3 +137,13 @@ func getCacheJitter(age time.Duration) time.Duration {
|
|||
func getExpirationTime(age time.Duration) time.Time {
|
||||
return time.Now().Truncate(time.Second).Add(age)
|
||||
}
|
||||
|
||||
// abs returns the absolute value of n using the two's complement form.
|
||||
//
|
||||
// It will overflow with math.MinInt64 in the same way a branching version
|
||||
// would, this is not a problem because maxAgeRegex will block negative numbers
|
||||
// and the logic will never produce that number.
|
||||
func abs(n time.Duration) time.Duration {
|
||||
y := n >> 63
|
||||
return (n ^ y) - y
|
||||
}
|
||||
|
|
|
@ -91,6 +91,49 @@ func Test_keyStore(t *testing.T) {
|
|||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||
}
|
||||
|
||||
func Test_keyStore_noCache(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
||||
ks, err := newKeyStore(srv.URL + "/no-cache")
|
||||
assert.FatalError(t, err)
|
||||
defer ks.Close()
|
||||
ks.RLock()
|
||||
keySet1 := ks.keySet
|
||||
ks.RUnlock()
|
||||
// The keys will rotate on Get.
|
||||
// So we won't be able to find the cached ones
|
||||
assert.Len(t, 2, keySet1.Keys)
|
||||
assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID))
|
||||
assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
ks.RLock()
|
||||
keySet2 := ks.keySet
|
||||
ks.RUnlock()
|
||||
if reflect.DeepEqual(keySet1, keySet2) {
|
||||
t.Error("keyStore did not rotated")
|
||||
}
|
||||
|
||||
// The keys will rotate on Get.
|
||||
// So we won't be able to find the cached ones
|
||||
assert.Len(t, 2, keySet2.Keys)
|
||||
assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID))
|
||||
assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID))
|
||||
assert.Len(t, 0, ks.Get("foobar"))
|
||||
|
||||
// Check hits
|
||||
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||
assert.FatalError(t, err)
|
||||
hits := struct {
|
||||
Hits int `json:"hits"`
|
||||
}{}
|
||||
defer resp.Body.Close()
|
||||
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||
assert.FatalError(t, err)
|
||||
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||
}
|
||||
|
||||
func Test_keyStore_Get(t *testing.T) {
|
||||
srv := generateJWKServer(2)
|
||||
defer srv.Close()
|
||||
|
@ -119,3 +162,31 @@ func Test_keyStore_Get(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_abs(t *testing.T) {
|
||||
maxInt64 := time.Duration(1<<63 - 1)
|
||||
minInt64 := time.Duration(-1 << 63)
|
||||
type args struct {
|
||||
n time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want time.Duration
|
||||
}{
|
||||
{"ok", args{0}, 0},
|
||||
{"ok", args{-time.Hour}, time.Hour},
|
||||
{"ok", args{time.Hour}, time.Hour},
|
||||
{"ok maxInt64", args{maxInt64}, maxInt64},
|
||||
{"ok minInt64 + 1", args{minInt64 + 1}, maxInt64},
|
||||
{"overflow on minInt64", args{minInt64}, minInt64},
|
||||
{"overflow on minInt64", args{minInt64}, -minInt64},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := abs(tt.args.n); got != tt.want {
|
||||
t.Errorf("abs() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -658,7 +658,7 @@ func generateJWKServer(n int) *httptest.Server {
|
|||
return ret
|
||||
}
|
||||
|
||||
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||
srv := httptest.NewUnstartedServer(nil)
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hits.Hits++
|
||||
|
@ -670,9 +670,13 @@ func generateJWKServer(n int) *httptest.Server {
|
|||
case "/openid-configuration", "/.well-known/openid-configuration":
|
||||
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
|
||||
case "/random":
|
||||
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(keySet))
|
||||
case "/no-cache":
|
||||
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
|
||||
w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
|
||||
writeJSON(w, getPublic(keySet))
|
||||
case "/private":
|
||||
writeJSON(w, defaultKeySet)
|
||||
default:
|
||||
|
|
Loading…
Reference in a new issue