certificates/authority/provisioner/keystore_test.go

193 lines
4.7 KiB
Go
Raw Normal View History

2019-03-08 23:08:18 +00:00
package provisioner
import (
"encoding/json"
"fmt"
"reflect"
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func Test_newKeyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
uri string
}
tests := []struct {
name string
args args
want jose.JSONWebKeySet
wantErr bool
}{
{"ok", args{srv.URL}, ks.keySet, false},
{"fail", args{srv.URL + "/error"}, jose.JSONWebKeySet{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newKeyStore(tt.args.uri)
if (err != nil) != tt.wantErr {
t.Errorf("newKeyStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if !reflect.DeepEqual(got.keySet, tt.want) {
t.Errorf("newKeyStore() = %v, want %v", got, tt.want)
}
got.Close()
}
})
}
}
func Test_keyStore(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/random")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
keySet1 := ks.keySet
ks.RUnlock()
// Check contents
assert.Len(t, 2, keySet1.Keys)
assert.Len(t, 1, ks.Get(keySet1.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet1.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Wait for rotation
time.Sleep(5 * time.Second)
ks.RLock()
keySet2 := ks.keySet
ks.RUnlock()
if reflect.DeepEqual(keySet1, keySet2) {
t.Error("keyStore did not rotated")
}
// Check contents
assert.Len(t, 2, keySet2.Keys)
assert.Len(t, 1, ks.Get(keySet2.Keys[0].KeyID))
assert.Len(t, 1, ks.Get(keySet2.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Check hits
resp, err := srv.Client().Get(srv.URL + "/hits")
assert.FatalError(t, err)
hits := struct {
Hits int `json:"hits"`
}{}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(&hits)
assert.FatalError(t, err)
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
}
func Test_keyStore_noCache(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL + "/no-cache")
assert.FatalError(t, err)
defer ks.Close()
ks.RLock()
keySet1 := ks.keySet
ks.RUnlock()
// The keys will rotate on Get.
// So we won't be able to find the cached ones
assert.Len(t, 2, keySet1.Keys)
assert.Len(t, 0, ks.Get(keySet1.Keys[0].KeyID))
assert.Len(t, 0, ks.Get(keySet1.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
ks.RLock()
keySet2 := ks.keySet
ks.RUnlock()
if reflect.DeepEqual(keySet1, keySet2) {
t.Error("keyStore did not rotated")
}
// The keys will rotate on Get.
// So we won't be able to find the cached ones
assert.Len(t, 2, keySet2.Keys)
assert.Len(t, 0, ks.Get(keySet2.Keys[0].KeyID))
assert.Len(t, 0, ks.Get(keySet2.Keys[1].KeyID))
assert.Len(t, 0, ks.Get("foobar"))
// Check hits
resp, err := srv.Client().Get(srv.URL + "/hits")
assert.FatalError(t, err)
hits := struct {
Hits int `json:"hits"`
}{}
defer resp.Body.Close()
err = json.NewDecoder(resp.Body).Decode(&hits)
assert.FatalError(t, err)
assert.True(t, hits.Hits > 1, fmt.Sprintf("invalid number of hits: %d is not greater than 1", hits.Hits))
}
2019-03-08 23:08:18 +00:00
func Test_keyStore_Get(t *testing.T) {
srv := generateJWKServer(2)
defer srv.Close()
ks, err := newKeyStore(srv.URL)
assert.FatalError(t, err)
defer ks.Close()
type args struct {
kid string
}
tests := []struct {
name string
ks *keyStore
args args
wantKeys []jose.JSONWebKey
}{
{"ok1", ks, args{ks.keySet.Keys[0].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[0]}},
{"ok2", ks, args{ks.keySet.Keys[1].KeyID}, []jose.JSONWebKey{ks.keySet.Keys[1]}},
{"fail", ks, args{"fail"}, []jose.JSONWebKey(nil)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotKeys := tt.ks.Get(tt.args.kid); !reflect.DeepEqual(gotKeys, tt.wantKeys) {
t.Errorf("keyStore.Get() = %v, want %v", gotKeys, tt.wantKeys)
}
})
}
}
func Test_abs(t *testing.T) {
maxInt64 := time.Duration(1<<63 - 1)
minInt64 := time.Duration(-1 << 63)
type args struct {
n time.Duration
}
tests := []struct {
name string
args args
want time.Duration
}{
{"ok", args{0}, 0},
{"ok", args{-time.Hour}, time.Hour},
{"ok", args{time.Hour}, time.Hour},
{"ok maxInt64", args{maxInt64}, maxInt64},
{"ok minInt64 + 1", args{minInt64 + 1}, maxInt64},
{"overflow on minInt64", args{minInt64}, minInt64},
{"overflow on minInt64", args{minInt64}, -minInt64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := abs(tt.args.n); got != tt.want {
t.Errorf("abs() = %v, want %v", got, tt.want)
}
})
}
}