Add backend support for provisioners with cursors.

Fixes #83
This commit is contained in:
Mariano Cano 2018-10-25 15:40:12 -07:00
parent d2872564b4
commit 1db177b80d
3 changed files with 216 additions and 16 deletions

View file

@ -28,6 +28,7 @@ type Authority struct {
provisionerIDIndex *sync.Map
encryptedKeyIndex *sync.Map
provisionerKeySetIndex *sync.Map
sortedProvisioners provisionerSlice
audiences []string
// Do not re-initialize
initOnce bool
@ -35,9 +36,31 @@ type Authority struct {
// New creates and initiates a new Authority type.
func New(config *Config) (*Authority, error) {
if err := config.Validate(); err != nil {
err := config.Validate()
if err != nil {
return nil, err
}
// Get sorted provisioners
var sorted provisionerSlice
if config.AuthorityConfig != nil {
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
}
// Define audiences: legacy + possible urls
_, port, err := net.SplitHostPort(config.Address)
if err != nil {
return nil, errors.Wrapf(err, "error parsing %s", config.Address)
}
audiences := []string{legacyAuthority}
for _, name := range config.DNSNames {
if port == "443" {
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
}
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
}
var a = &Authority{
config: config,
certificates: new(sync.Map),
@ -45,6 +68,8 @@ func New(config *Config) (*Authority, error) {
provisionerIDIndex: new(sync.Map),
encryptedKeyIndex: new(sync.Map),
provisionerKeySetIndex: new(sync.Map),
sortedProvisioners: sorted,
audiences: audiences,
}
if err := a.init(); err != nil {
return nil, err
@ -70,21 +95,6 @@ func (a *Authority) init() error {
sum := sha256.Sum256(a.rootX509Crt.Raw)
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
// Define audiences: legacy + possible urls
_, port, err := net.SplitHostPort(a.config.Address)
if err != nil {
return errors.Wrapf(err, "error parsing %s", a.config.Address)
}
audiences := []string{legacyAuthority}
for _, name := range a.config.DNSNames {
if port == "443" {
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
}
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
}
a.audiences = audiences
// Decrypt and load intermediate public / private key pair.
if len(a.config.Password) > 0 {
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(

View file

@ -1,11 +1,25 @@
package authority
import (
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strings"
"github.com/pkg/errors"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
val, ok := a.encryptedKeyIndex.Load(kid)
@ -27,3 +41,74 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
func (a *Authority) GetProvisioners() ([]*Provisioner, error) {
return a.config.AuthorityConfig.Provisioners, nil
}
type uidProvisioner struct {
provisioner *provisioner.Provisioner
uid string
}
func newSortedProvisioners(provisioners []*provisioner.Provisioner) (provisionerSlice, error) {
if len(provisioners) > math.MaxUint32 {
return nil, errors.New("too many provisioners")
}
var slice provisionerSlice
bi := make([]byte, 4)
for i, p := range provisioners {
sum, err := provisionerSum(p)
if err != nil {
return nil, err
}
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
binary.BigEndian.PutUint32(bi, uint32(i))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
slice = append(slice, uidProvisioner{
provisioner: p,
uid: hex.EncodeToString(sum),
})
}
sort.Sort(slice)
return slice, nil
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p provisionerSlice) Find(cursor string, limit int) ([]*provisioner.Provisioner, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := len(p)
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
var slice []*provisioner.Provisioner
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, p[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(p[i].uid, "0")
}
return slice, ""
}
// provisionerSum returns the SHA1 of the json representation of the
// provisioner. From this we will create the unique and sorted id.
func provisionerSum(p *provisioner.Provisioner) ([]byte, error) {
b, err := json.Marshal(p.Key)
if err != nil {
return nil, errors.Wrap(err, "error marshalling provisioner")
}
sum := sha1.Sum(b)
return sum[:], nil
}

View file

@ -1,11 +1,17 @@
package authority
import (
"encoding/json"
"net/http"
"reflect"
"strings"
"testing"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/ca-component/provisioner"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
)
func TestGetEncryptedKey(t *testing.T) {
@ -120,3 +126,102 @@ func TestGetProvisioners(t *testing.T) {
})
}
}
func generateProvisioner(t *testing.T) *provisioner.Provisioner {
issuer, err := randutil.Alphanumeric(10)
assert.FatalError(t, err)
// Create a new JWK
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
// Encrypt JWK
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
assert.FatalError(t, err)
b, err := json.Marshal(jwk)
assert.FatalError(t, err)
recipient := jose.Recipient{
Algorithm: jose.PBES2_HS256_A128KW,
Key: []byte("password"),
PBES2Count: jose.PBKDF2Iterations,
PBES2Salt: salt,
}
opts := new(jose.EncrypterOptions)
opts.WithContentType(jose.ContentType("jwk+json"))
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
assert.FatalError(t, err)
jwe, err := encrypter.Encrypt(b)
assert.FatalError(t, err)
// get public and encrypted keys
public := jwk.Public()
encrypted, err := jwe.CompactSerialize()
assert.FatalError(t, err)
return &provisioner.Provisioner{
Issuer: issuer,
Type: "JWT",
Key: &public,
EncryptedKey: encrypted,
}
}
func Test_newSortedProvisioners(t *testing.T) {
provisioners := make([]*provisioner.Provisioner, 20)
for i := range provisioners {
provisioners[i] = generateProvisioner(t)
}
ps, err := newSortedProvisioners(provisioners)
assert.FatalError(t, err)
prev := ""
for i, p := range ps {
if p.uid < prev {
t.Errorf("%s should be less that %s", p.uid, prev)
}
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
}
prev = p.uid
}
}
func Test_provisionerSlice_Find(t *testing.T) {
trim := func(s string) string {
return strings.TrimLeft(s, "0")
}
provisioners := make([]*provisioner.Provisioner, 20)
for i := range provisioners {
provisioners[i] = generateProvisioner(t)
}
ps, err := newSortedProvisioners(provisioners)
assert.FatalError(t, err)
type args struct {
cursor string
limit int
}
tests := []struct {
name string
p provisionerSlice
args args
want []*provisioner.Provisioner
want1 string
}{
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
}
})
}
}