Merge pull request #82 from smallstep/fix-max-age-0

Fix panic when max-age is set to zero.
This commit is contained in:
Mariano Cano 2019-06-25 11:14:07 -07:00 committed by GitHub
commit 0c3e0088cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 7 deletions

View file

@ -18,7 +18,7 @@ const (
defaultCacheJitter = 1 * time.Hour defaultCacheJitter = 1 * time.Hour
) )
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]*)") var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
type keyStore struct { type keyStore struct {
sync.RWMutex sync.RWMutex
@ -81,13 +81,13 @@ func (ks *keyStore) reload() {
ks.Unlock() ks.Unlock()
} }
// nextReloadDuration would return the duration for the next rotation. If age is
// 0 it will randomly rotate between 0-12 hours, but every time we call to Get
// it will automatically rotate.
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
n := rand.Int63n(int64(ks.jitter)) n := rand.Int63n(int64(ks.jitter))
age -= time.Duration(n) age -= time.Duration(n)
if age < 0 { return abs(age)
age = 0
}
return age
} }
func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) { func getKeysFromJWKsURI(uri string) (jose.JSONWebKeySet, time.Duration, error) {
@ -125,6 +125,10 @@ func getCacheJitter(age time.Duration) time.Duration {
switch { switch {
case age > time.Hour: case age > time.Hour:
return defaultCacheJitter return defaultCacheJitter
case age == 0:
// Avoids a 0 jitter. The duration is not important as it will rotate
// automatically on each Get request.
return defaultCacheJitter
default: default:
return age / 3 return age / 3
} }
@ -133,3 +137,11 @@ func getCacheJitter(age time.Duration) time.Duration {
func getExpirationTime(age time.Duration) time.Time { func getExpirationTime(age time.Duration) time.Time {
return time.Now().Truncate(time.Second).Add(age) return time.Now().Truncate(time.Second).Add(age)
} }
// abs returns the absolute value of n.
func abs(n time.Duration) time.Duration {
if n < 0 {
return -n
}
return n
}

View file

@ -91,6 +91,49 @@ func Test_keyStore(t *testing.T) {
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits)) assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
} }
func Test_keyStore_noCache(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/no-cache")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
keySet1 := ks.keySet
ks.RUnlock()
// The keys will rotate on Get.
// So we won't be able to find the cached ones
assert.Len(t, 2, keySet1.Keys)
assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID))
assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
ks.RLock()
keySet2 := ks.keySet
ks.RUnlock()
if reflect.DeepEqual(keySet1, keySet2) {
t.Error("keyStore did not rotated")
}
// The keys will rotate on Get.
// So we won't be able to find the cached ones
assert.Len(t, 2, keySet2.Keys)
assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID))
assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Check hits
resp, err := srv.Client().Get(srv.URL + "/hits")
assert.FatalError(t, err)
hits := struct {
Hits int `json:"hits"`
}{}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(&hits)
assert.FatalError(t, err)
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
}
func Test_keyStore_Get(t *testing.T) { func Test_keyStore_Get(t *testing.T) {
srv := generateJWKServer(2) srv := generateJWKServer(2)
defer srv.Close() defer srv.Close()
@ -119,3 +162,31 @@ func Test_keyStore_Get(t *testing.T) {
}) })
} }
} }
func Test_abs(t *testing.T) {
maxInt64 := time.Duration(1<<63 - 1)
minInt64 := time.Duration(-1 << 63)
type args struct {
n time.Duration
}
tests := []struct {
name string
args args
want time.Duration
}{
{"ok", args{0}, 0},
{"ok", args{-time.Hour}, time.Hour},
{"ok", args{time.Hour}, time.Hour},
{"ok maxInt64", args{maxInt64}, maxInt64},
{"ok minInt64 + 1", args{minInt64 + 1}, maxInt64},
{"overflow on minInt64", args{minInt64}, minInt64},
{"overflow on minInt64", args{minInt64}, -minInt64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := abs(tt.args.n); got != tt.want {
t.Errorf("abs() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -658,7 +658,7 @@ func generateJWKServer(n int) *httptest.Server {
return ret return ret
} }
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) defaultKeySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
srv := httptest.NewUnstartedServer(nil) srv := httptest.NewUnstartedServer(nil)
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits.Hits++ hits.Hits++
@ -670,9 +670,13 @@ func generateJWKServer(n int) *httptest.Server {
case "/openid-configuration", "/.well-known/openid-configuration": case "/openid-configuration", "/.well-known/openid-configuration":
writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"}) writeJSON(w, openIDConfiguration{Issuer: "the-issuer", JWKSetURI: srv.URL + "/jwks_uri"})
case "/random": case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet) keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(keySet)) writeJSON(w, getPublic(keySet))
case "/no-cache":
keySet := must(generateJSONWebKeySet(n))[0].(jose.JSONWebKeySet)
w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate")
writeJSON(w, getPublic(keySet))
case "/private": case "/private":
writeJSON(w, defaultKeySet) writeJSON(w, defaultKeySet)
default: default: