parent
d2872564b4
commit
1db177b80d
3 changed files with 216 additions and 16 deletions
|
@ -28,6 +28,7 @@ type Authority struct {
|
|||
provisionerIDIndex *sync.Map
|
||||
encryptedKeyIndex *sync.Map
|
||||
provisionerKeySetIndex *sync.Map
|
||||
sortedProvisioners provisionerSlice
|
||||
audiences []string
|
||||
// Do not re-initialize
|
||||
initOnce bool
|
||||
|
@ -35,9 +36,31 @@ type Authority struct {
|
|||
|
||||
// New creates and initiates a new Authority type.
|
||||
func New(config *Config) (*Authority, error) {
|
||||
if err := config.Validate(); err != nil {
|
||||
err := config.Validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get sorted provisioners
|
||||
var sorted provisionerSlice
|
||||
if config.AuthorityConfig != nil {
|
||||
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
|
||||
}
|
||||
|
||||
// Define audiences: legacy + possible urls
|
||||
_, port, err := net.SplitHostPort(config.Address)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error parsing %s", config.Address)
|
||||
}
|
||||
audiences := []string{legacyAuthority}
|
||||
for _, name := range config.DNSNames {
|
||||
if port == "443" {
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||
}
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
|
||||
|
||||
}
|
||||
|
||||
var a = &Authority{
|
||||
config: config,
|
||||
certificates: new(sync.Map),
|
||||
|
@ -45,6 +68,8 @@ func New(config *Config) (*Authority, error) {
|
|||
provisionerIDIndex: new(sync.Map),
|
||||
encryptedKeyIndex: new(sync.Map),
|
||||
provisionerKeySetIndex: new(sync.Map),
|
||||
sortedProvisioners: sorted,
|
||||
audiences: audiences,
|
||||
}
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
|
@ -70,21 +95,6 @@ func (a *Authority) init() error {
|
|||
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
||||
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
||||
|
||||
// Define audiences: legacy + possible urls
|
||||
_, port, err := net.SplitHostPort(a.config.Address)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s", a.config.Address)
|
||||
}
|
||||
audiences := []string{legacyAuthority}
|
||||
for _, name := range a.config.DNSNames {
|
||||
if port == "443" {
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||
}
|
||||
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
|
||||
|
||||
}
|
||||
a.audiences = audiences
|
||||
|
||||
// Decrypt and load intermediate public / private key pair.
|
||||
if len(a.config.Password) > 0 {
|
||||
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
||||
|
|
|
@ -1,11 +1,25 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
val, ok := a.encryptedKeyIndex.Load(kid)
|
||||
|
@ -27,3 +41,74 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
|||
func (a *Authority) GetProvisioners() ([]*Provisioner, error) {
|
||||
return a.config.AuthorityConfig.Provisioners, nil
|
||||
}
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner *provisioner.Provisioner
|
||||
uid string
|
||||
}
|
||||
|
||||
func newSortedProvisioners(provisioners []*provisioner.Provisioner) (provisionerSlice, error) {
|
||||
if len(provisioners) > math.MaxUint32 {
|
||||
return nil, errors.New("too many provisioners")
|
||||
}
|
||||
|
||||
var slice provisionerSlice
|
||||
bi := make([]byte, 4)
|
||||
for i, p := range provisioners {
|
||||
sum, err := provisionerSum(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
binary.BigEndian.PutUint32(bi, uint32(i))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||
slice = append(slice, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
}
|
||||
sort.Sort(slice)
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
func (p provisionerSlice) Find(cursor string, limit int) ([]*provisioner.Provisioner, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := len(p)
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
|
||||
|
||||
var slice []*provisioner.Provisioner
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, p[i].provisioner)
|
||||
}
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(p[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the json representation of the
|
||||
// provisioner. From this we will create the unique and sorted id.
|
||||
func provisionerSum(p *provisioner.Provisioner) ([]byte, error) {
|
||||
b, err := json.Marshal(p.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshalling provisioner")
|
||||
}
|
||||
sum := sha1.Sum(b)
|
||||
return sum[:], nil
|
||||
}
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/ca-component/provisioner"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
func TestGetEncryptedKey(t *testing.T) {
|
||||
|
@ -120,3 +126,102 @@ func TestGetProvisioners(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateProvisioner(t *testing.T) *provisioner.Provisioner {
|
||||
issuer, err := randutil.Alphanumeric(10)
|
||||
assert.FatalError(t, err)
|
||||
// Create a new JWK
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
// Encrypt JWK
|
||||
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||
assert.FatalError(t, err)
|
||||
b, err := json.Marshal(jwk)
|
||||
assert.FatalError(t, err)
|
||||
recipient := jose.Recipient{
|
||||
Algorithm: jose.PBES2_HS256_A128KW,
|
||||
Key: []byte("password"),
|
||||
PBES2Count: jose.PBKDF2Iterations,
|
||||
PBES2Salt: salt,
|
||||
}
|
||||
opts := new(jose.EncrypterOptions)
|
||||
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||
assert.FatalError(t, err)
|
||||
jwe, err := encrypter.Encrypt(b)
|
||||
assert.FatalError(t, err)
|
||||
// get public and encrypted keys
|
||||
public := jwk.Public()
|
||||
encrypted, err := jwe.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
return &provisioner.Provisioner{
|
||||
Issuer: issuer,
|
||||
Type: "JWT",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_newSortedProvisioners(t *testing.T) {
|
||||
provisioners := make([]*provisioner.Provisioner, 20)
|
||||
for i := range provisioners {
|
||||
provisioners[i] = generateProvisioner(t)
|
||||
}
|
||||
|
||||
ps, err := newSortedProvisioners(provisioners)
|
||||
assert.FatalError(t, err)
|
||||
prev := ""
|
||||
for i, p := range ps {
|
||||
if p.uid < prev {
|
||||
t.Errorf("%s should be less that %s", p.uid, prev)
|
||||
}
|
||||
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
||||
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
||||
}
|
||||
prev = p.uid
|
||||
}
|
||||
}
|
||||
|
||||
func Test_provisionerSlice_Find(t *testing.T) {
|
||||
trim := func(s string) string {
|
||||
return strings.TrimLeft(s, "0")
|
||||
}
|
||||
provisioners := make([]*provisioner.Provisioner, 20)
|
||||
for i := range provisioners {
|
||||
provisioners[i] = generateProvisioner(t)
|
||||
}
|
||||
ps, err := newSortedProvisioners(provisioners)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
cursor string
|
||||
limit int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
p provisionerSlice
|
||||
args args
|
||||
want []*provisioner.Provisioner
|
||||
want1 string
|
||||
}{
|
||||
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
||||
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
||||
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
||||
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
||||
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
||||
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
||||
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
||||
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue