Add tests for keyStore.

This commit is contained in:
Mariano Cano 2019-03-08 15:08:18 -08:00
parent 2a5430fee1
commit cf2dba3efb
4 changed files with 202 additions and 8 deletions

View file

@ -120,7 +120,7 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) {
return errors.New("duration cannot be nil") return errors.New("duration cannot be nil")
} }
if err = json.Unmarshal(data, &s); err != nil { if err = json.Unmarshal(data, &s); err != nil {
return errors.Wrapf(err, "error unmarshalling %s", data) return errors.Wrapf(err, "error unmarshaling %s", data)
} }
if _d, err = time.ParseDuration(s); err != nil { if _d, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s) return errors.Wrapf(err, "error parsing %s as duration", s)

View file

@ -26,6 +26,7 @@ type keyStore struct {
keySet jose.JSONWebKeySet keySet jose.JSONWebKeySet
timer *time.Timer timer *time.Timer
expiry time.Time expiry time.Time
jitter time.Duration
} }
func newKeyStore(uri string) (*keyStore, error) { func newKeyStore(uri string) (*keyStore, error) {
@ -37,8 +38,10 @@ func newKeyStore(uri string) (*keyStore, error) {
uri: uri, uri: uri,
keySet: keys, keySet: keys,
expiry: getExpirationTime(age), expiry: getExpirationTime(age),
jitter: getCacheJitter(age),
} }
ks.timer = time.AfterFunc(age, ks.reload) next := ks.nextReloadDuration(age)
ks.timer = time.AfterFunc(next, ks.reload)
return ks, nil return ks, nil
} }
@ -63,13 +66,14 @@ func (ks *keyStore) reload() {
var next time.Duration var next time.Duration
keys, age, err := getKeysFromJWKsURI(ks.uri) keys, age, err := getKeysFromJWKsURI(ks.uri)
if err != nil { if err != nil {
next = ks.nextReloadDuration(defaultCacheJitter / 2) next = ks.nextReloadDuration(ks.jitter / 2)
} else { } else {
ks.Lock() ks.Lock()
ks.keySet = keys ks.keySet = keys
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() ks.expiry = getExpirationTime(age)
ks.Unlock() ks.jitter = getCacheJitter(age)
next = ks.nextReloadDuration(age) next = ks.nextReloadDuration(age)
ks.Unlock()
} }
ks.Lock() ks.Lock()
@ -78,7 +82,7 @@ func (ks *keyStore) reload() {
} }
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration { func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
n := rand.Int63n(int64(defaultCacheJitter)) n := rand.Int63n(int64(ks.jitter))
age -= time.Duration(n) age -= time.Duration(n)
if age < 0 { if age < 0 {
age = 0 age = 0
@ -117,6 +121,15 @@ func getCacheAge(cacheControl string) time.Duration {
return age return age
} }
func getExpirationTime(age time.Duration) time.Time { func getCacheJitter(age time.Duration) time.Duration {
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC() switch {
case age > time.Hour:
return defaultCacheJitter
default:
return age / 3
}
}
func getExpirationTime(age time.Duration) time.Time {
return time.Now().Round(time.Second).Add(age)
} }

View file

@ -0,0 +1,122 @@
package provisioner
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func Test_newKeyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
uri string
}
tests := []struct {
name string
args args
want jose.JSONWebKeySet
wantErr bool
}{
{"ok", args{srv.URL}, ks.keySet, false},
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newKeyStore(tt.args.uri)
if (err != nil) != tt.wantErr {
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if !reflect.DeepEqual(got.keySet, tt.want) {
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
}
got.Close()
}
})
}
}
func Test_keyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/random")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
keySet1 := ks.keySet
ks.RUnlock()
// Check contents
assert.Len(t, 2, keySet1.Keys)
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Wait for rotation
time.Sleep(5 * time.Second)
ks.RLock()
keySet2 := ks.keySet
ks.RUnlock()
if reflect.DeepEqual(keySet1, keySet2) {
t.Error("keyStore did not rotated")
}
// Check contents
assert.Len(t, 2, keySet2.Keys)
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Check hits
resp, err := srv.Client().Get(srv.URL + "/hits")
assert.FatalError(t, err)
hits := struct {
Hits int `json:"hits"`
}{}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(&hits)
assert.FatalError(t, err)
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
}
func Test_keyStore_Get(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
kid string
}
tests := []struct {
name string
ks *keyStore
args args
wantKeys []jose.JSONWebKey
}{
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
println(tt.name)
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
}
})
}
}

View file

@ -4,6 +4,9 @@ import (
"crypto" "crypto"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"time" "time"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
@ -17,6 +20,15 @@ var testAudiences = []string{
"https://ca.smallsteomcom/1.0/sign", "https://ca.smallsteomcom/1.0/sign",
} }
func must(args ...interface{}) []interface{} {
if l := len(args); l > 0 && args[l-1] != nil {
if err, ok := args[l-1].(error); ok {
panic(err)
}
}
return args
}
func generateJSONWebKey() (*jose.JSONWebKey, error) { func generateJSONWebKey() (*jose.JSONWebKey, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
if err != nil { if err != nil {
@ -30,6 +42,18 @@ func generateJSONWebKey() (*jose.JSONWebKey, error) {
return jwk, nil return jwk, nil
} }
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
var keySet jose.JSONWebKeySet
for i := 0; i < n; i++ {
key, err := generateJSONWebKey()
if err != nil {
return jose.JSONWebKeySet{}, err
}
keySet.Keys = append(keySet.Keys, key.Public())
}
return keySet, nil
}
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) { func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
b, err := json.Marshal(jwk) b, err := json.Marshal(jwk)
if err != nil { if err != nil {
@ -206,3 +230,38 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
} }
return tok, claims, nil return tok, claims, nil
} }
func generateJWKServer(n int) *httptest.Server {
hits := struct {
Hits int `json:"hits"`
}{}
writeJSON := func(w http.ResponseWriter, v interface{}) {
b, err := json.Marshal(v)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(b)
}
// keySet, err := generateJSONWebKeySet(n)
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hits.Hits++
switch r.RequestURI {
case "/error":
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
case "/hits":
writeJSON(w, hits)
case "/random":
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
fmt.Println(keySet)
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, keySet)
default:
w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, defaultKeySet)
}
}))
}