forked from TrueCloudLab/certificates
Add tests for keyStore.
This commit is contained in:
parent
2a5430fee1
commit
cf2dba3efb
4 changed files with 202 additions and 8 deletions
|
@ -120,7 +120,7 @@ func (d *Duration) UnmarshalJSON(data []byte) (err error) {
|
||||||
return errors.New("duration cannot be nil")
|
return errors.New("duration cannot be nil")
|
||||||
}
|
}
|
||||||
if err = json.Unmarshal(data, &s); err != nil {
|
if err = json.Unmarshal(data, &s); err != nil {
|
||||||
return errors.Wrapf(err, "error unmarshalling %s", data)
|
return errors.Wrapf(err, "error unmarshaling %s", data)
|
||||||
}
|
}
|
||||||
if _d, err = time.ParseDuration(s); err != nil {
|
if _d, err = time.ParseDuration(s); err != nil {
|
||||||
return errors.Wrapf(err, "error parsing %s as duration", s)
|
return errors.Wrapf(err, "error parsing %s as duration", s)
|
||||||
|
|
|
@ -26,6 +26,7 @@ type keyStore struct {
|
||||||
keySet jose.JSONWebKeySet
|
keySet jose.JSONWebKeySet
|
||||||
timer *time.Timer
|
timer *time.Timer
|
||||||
expiry time.Time
|
expiry time.Time
|
||||||
|
jitter time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newKeyStore(uri string) (*keyStore, error) {
|
func newKeyStore(uri string) (*keyStore, error) {
|
||||||
|
@ -37,8 +38,10 @@ func newKeyStore(uri string) (*keyStore, error) {
|
||||||
uri: uri,
|
uri: uri,
|
||||||
keySet: keys,
|
keySet: keys,
|
||||||
expiry: getExpirationTime(age),
|
expiry: getExpirationTime(age),
|
||||||
|
jitter: getCacheJitter(age),
|
||||||
}
|
}
|
||||||
ks.timer = time.AfterFunc(age, ks.reload)
|
next := ks.nextReloadDuration(age)
|
||||||
|
ks.timer = time.AfterFunc(next, ks.reload)
|
||||||
return ks, nil
|
return ks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,13 +66,14 @@ func (ks *keyStore) reload() {
|
||||||
var next time.Duration
|
var next time.Duration
|
||||||
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
keys, age, err := getKeysFromJWKsURI(ks.uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
next = ks.nextReloadDuration(defaultCacheJitter / 2)
|
next = ks.nextReloadDuration(ks.jitter / 2)
|
||||||
} else {
|
} else {
|
||||||
ks.Lock()
|
ks.Lock()
|
||||||
ks.keySet = keys
|
ks.keySet = keys
|
||||||
ks.expiry = time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
ks.expiry = getExpirationTime(age)
|
||||||
ks.Unlock()
|
ks.jitter = getCacheJitter(age)
|
||||||
next = ks.nextReloadDuration(age)
|
next = ks.nextReloadDuration(age)
|
||||||
|
ks.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
ks.Lock()
|
ks.Lock()
|
||||||
|
@ -78,7 +82,7 @@ func (ks *keyStore) reload() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
func (ks *keyStore) nextReloadDuration(age time.Duration) time.Duration {
|
||||||
n := rand.Int63n(int64(defaultCacheJitter))
|
n := rand.Int63n(int64(ks.jitter))
|
||||||
age -= time.Duration(n)
|
age -= time.Duration(n)
|
||||||
if age < 0 {
|
if age < 0 {
|
||||||
age = 0
|
age = 0
|
||||||
|
@ -117,6 +121,15 @@ func getCacheAge(cacheControl string) time.Duration {
|
||||||
return age
|
return age
|
||||||
}
|
}
|
||||||
|
|
||||||
func getExpirationTime(age time.Duration) time.Time {
|
func getCacheJitter(age time.Duration) time.Duration {
|
||||||
return time.Now().Round(time.Second).Add(age - 1*time.Minute).UTC()
|
switch {
|
||||||
|
case age > time.Hour:
|
||||||
|
return defaultCacheJitter
|
||||||
|
default:
|
||||||
|
return age / 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExpirationTime(age time.Duration) time.Time {
|
||||||
|
return time.Now().Round(time.Second).Add(age)
|
||||||
}
|
}
|
||||||
|
|
122
authority/provisioner/keystore_test.go
Normal file
122
authority/provisioner/keystore_test.go
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
package provisioner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_newKeyStore(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
ks, err := newKeyStore(srv.URL)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
uri string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want jose.JSONWebKeySet
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{srv.URL}, ks.keySet, false},
|
||||||
|
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := newKeyStore(tt.args.uri)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
if !reflect.DeepEqual(got.keySet, tt.want) {
|
||||||
|
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
got.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_keyStore(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
ks, err := newKeyStore(srv.URL + "/random")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
ks.RLock()
|
||||||
|
keySet1 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
// Check contents
|
||||||
|
assert.Len(t, 2, keySet1.Keys)
|
||||||
|
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
// Wait for rotation
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
ks.RLock()
|
||||||
|
keySet2 := ks.keySet
|
||||||
|
ks.RUnlock()
|
||||||
|
if reflect.DeepEqual(keySet1, keySet2) {
|
||||||
|
t.Error("keyStore did not rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check contents
|
||||||
|
assert.Len(t, 2, keySet2.Keys)
|
||||||
|
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
|
||||||
|
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
|
||||||
|
assert.Len(t, 0, ks.Get("foobar"))
|
||||||
|
|
||||||
|
// Check hits
|
||||||
|
resp, err := srv.Client().Get(srv.URL + "/hits")
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
hits := struct {
|
||||||
|
Hits int `json:"hits"`
|
||||||
|
}{}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&hits)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_keyStore_Get(t *testing.T) {
|
||||||
|
srv := generateJWKServer(2)
|
||||||
|
defer srv.Close()
|
||||||
|
ks, err := newKeyStore(srv.URL)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
defer ks.Close()
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
kid string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ks *keyStore
|
||||||
|
args args
|
||||||
|
wantKeys []jose.JSONWebKey
|
||||||
|
}{
|
||||||
|
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
|
||||||
|
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
|
||||||
|
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
println(tt.name)
|
||||||
|
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
|
||||||
|
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,6 +4,9 @@ import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/cli/crypto/randutil"
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
@ -17,6 +20,15 @@ var testAudiences = []string{
|
||||||
"https://ca.smallsteomcom/1.0/sign",
|
"https://ca.smallsteomcom/1.0/sign",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func must(args ...interface{}) []interface{} {
|
||||||
|
if l := len(args); l > 0 && args[l-1] != nil {
|
||||||
|
if err, ok := args[l-1].(error); ok {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -30,6 +42,18 @@ func generateJSONWebKey() (*jose.JSONWebKey, error) {
|
||||||
return jwk, nil
|
return jwk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateJSONWebKeySet(n int) (jose.JSONWebKeySet, error) {
|
||||||
|
var keySet jose.JSONWebKeySet
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
key, err := generateJSONWebKey()
|
||||||
|
if err != nil {
|
||||||
|
return jose.JSONWebKeySet{}, err
|
||||||
|
}
|
||||||
|
keySet.Keys = append(keySet.Keys, key.Public())
|
||||||
|
}
|
||||||
|
return keySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
func encryptJSONWebKey(jwk *jose.JSONWebKey) (*jose.JSONWebEncryption, error) {
|
||||||
b, err := json.Marshal(jwk)
|
b, err := json.Marshal(jwk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -206,3 +230,38 @@ func parseToken(token string) (*jose.JSONWebToken, *jose.Claims, error) {
|
||||||
}
|
}
|
||||||
return tok, claims, nil
|
return tok, claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateJWKServer(n int) *httptest.Server {
|
||||||
|
hits := struct {
|
||||||
|
Hits int `json:"hits"`
|
||||||
|
}{}
|
||||||
|
writeJSON := func(w http.ResponseWriter, v interface{}) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Add("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(b)
|
||||||
|
}
|
||||||
|
// keySet, err := generateJSONWebKeySet(n)
|
||||||
|
defaultKeySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hits.Hits++
|
||||||
|
switch r.RequestURI {
|
||||||
|
case "/error":
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
case "/hits":
|
||||||
|
writeJSON(w, hits)
|
||||||
|
case "/random":
|
||||||
|
keySet := must(generateJSONWebKeySet(2))[0].(jose.JSONWebKeySet)
|
||||||
|
fmt.Println(keySet)
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, keySet)
|
||||||
|
default:
|
||||||
|
w.Header().Add("Cache-Control", "max-age=5")
|
||||||
|
writeJSON(w, defaultKeySet)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue