forked from TrueCloudLab/certificates
parent
d2872564b4
commit
1db177b80d
3 changed files with 216 additions and 16 deletions
|
@ -28,6 +28,7 @@ type Authority struct {
|
||||||
provisionerIDIndex *sync.Map
|
provisionerIDIndex *sync.Map
|
||||||
encryptedKeyIndex *sync.Map
|
encryptedKeyIndex *sync.Map
|
||||||
provisionerKeySetIndex *sync.Map
|
provisionerKeySetIndex *sync.Map
|
||||||
|
sortedProvisioners provisionerSlice
|
||||||
audiences []string
|
audiences []string
|
||||||
// Do not re-initialize
|
// Do not re-initialize
|
||||||
initOnce bool
|
initOnce bool
|
||||||
|
@ -35,9 +36,31 @@ type Authority struct {
|
||||||
|
|
||||||
// New creates and initiates a new Authority type.
|
// New creates and initiates a new Authority type.
|
||||||
func New(config *Config) (*Authority, error) {
|
func New(config *Config) (*Authority, error) {
|
||||||
if err := config.Validate(); err != nil {
|
err := config.Validate()
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get sorted provisioners
|
||||||
|
var sorted provisionerSlice
|
||||||
|
if config.AuthorityConfig != nil {
|
||||||
|
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define audiences: legacy + possible urls
|
||||||
|
_, port, err := net.SplitHostPort(config.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "error parsing %s", config.Address)
|
||||||
|
}
|
||||||
|
audiences := []string{legacyAuthority}
|
||||||
|
for _, name := range config.DNSNames {
|
||||||
|
if port == "443" {
|
||||||
|
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
||||||
|
}
|
||||||
|
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
var a = &Authority{
|
var a = &Authority{
|
||||||
config: config,
|
config: config,
|
||||||
certificates: new(sync.Map),
|
certificates: new(sync.Map),
|
||||||
|
@ -45,6 +68,8 @@ func New(config *Config) (*Authority, error) {
|
||||||
provisionerIDIndex: new(sync.Map),
|
provisionerIDIndex: new(sync.Map),
|
||||||
encryptedKeyIndex: new(sync.Map),
|
encryptedKeyIndex: new(sync.Map),
|
||||||
provisionerKeySetIndex: new(sync.Map),
|
provisionerKeySetIndex: new(sync.Map),
|
||||||
|
sortedProvisioners: sorted,
|
||||||
|
audiences: audiences,
|
||||||
}
|
}
|
||||||
if err := a.init(); err != nil {
|
if err := a.init(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -70,21 +95,6 @@ func (a *Authority) init() error {
|
||||||
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
sum := sha256.Sum256(a.rootX509Crt.Raw)
|
||||||
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt)
|
||||||
|
|
||||||
// Define audiences: legacy + possible urls
|
|
||||||
_, port, err := net.SplitHostPort(a.config.Address)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrapf(err, "error parsing %s", a.config.Address)
|
|
||||||
}
|
|
||||||
audiences := []string{legacyAuthority}
|
|
||||||
for _, name := range a.config.DNSNames {
|
|
||||||
if port == "443" {
|
|
||||||
audiences = append(audiences, fmt.Sprintf("https://%s/sign", name), fmt.Sprintf("https://%s/1.0/sign", name))
|
|
||||||
}
|
|
||||||
audiences = append(audiences, fmt.Sprintf("https://%s:%s/sign", name, port), fmt.Sprintf("https://%s:%s/1.0/sign", name, port))
|
|
||||||
|
|
||||||
}
|
|
||||||
a.audiences = audiences
|
|
||||||
|
|
||||||
// Decrypt and load intermediate public / private key pair.
|
// Decrypt and load intermediate public / private key pair.
|
||||||
if len(a.config.Password) > 0 {
|
if len(a.config.Password) > 0 {
|
||||||
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
a.intermediateIdentity, err = x509util.LoadIdentityFromDisk(
|
||||||
|
|
|
@ -1,11 +1,25 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||||
|
const DefaultProvisionersLimit = 20
|
||||||
|
|
||||||
|
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||||
|
const DefaultProvisionersMax = 100
|
||||||
|
|
||||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||||
val, ok := a.encryptedKeyIndex.Load(kid)
|
val, ok := a.encryptedKeyIndex.Load(kid)
|
||||||
|
@ -27,3 +41,74 @@ func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||||
func (a *Authority) GetProvisioners() ([]*Provisioner, error) {
|
func (a *Authority) GetProvisioners() ([]*Provisioner, error) {
|
||||||
return a.config.AuthorityConfig.Provisioners, nil
|
return a.config.AuthorityConfig.Provisioners, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type uidProvisioner struct {
|
||||||
|
provisioner *provisioner.Provisioner
|
||||||
|
uid string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSortedProvisioners(provisioners []*provisioner.Provisioner) (provisionerSlice, error) {
|
||||||
|
if len(provisioners) > math.MaxUint32 {
|
||||||
|
return nil, errors.New("too many provisioners")
|
||||||
|
}
|
||||||
|
|
||||||
|
var slice provisionerSlice
|
||||||
|
bi := make([]byte, 4)
|
||||||
|
for i, p := range provisioners {
|
||||||
|
sum, err := provisionerSum(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||||
|
// Using big endian format to get the strings sorted:
|
||||||
|
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||||
|
binary.BigEndian.PutUint32(bi, uint32(i))
|
||||||
|
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||||
|
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||||
|
slice = append(slice, uidProvisioner{
|
||||||
|
provisioner: p,
|
||||||
|
uid: hex.EncodeToString(sum),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sort.Sort(slice)
|
||||||
|
return slice, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type provisionerSlice []uidProvisioner
|
||||||
|
|
||||||
|
func (p provisionerSlice) Len() int { return len(p) }
|
||||||
|
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||||
|
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
|
|
||||||
|
func (p provisionerSlice) Find(cursor string, limit int) ([]*provisioner.Provisioner, string) {
|
||||||
|
switch {
|
||||||
|
case limit <= 0:
|
||||||
|
limit = DefaultProvisionersLimit
|
||||||
|
case limit > DefaultProvisionersMax:
|
||||||
|
limit = DefaultProvisionersMax
|
||||||
|
}
|
||||||
|
|
||||||
|
n := len(p)
|
||||||
|
cursor = fmt.Sprintf("%040s", cursor)
|
||||||
|
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
|
||||||
|
|
||||||
|
var slice []*provisioner.Provisioner
|
||||||
|
for ; i < n && len(slice) < limit; i++ {
|
||||||
|
slice = append(slice, p[i].provisioner)
|
||||||
|
}
|
||||||
|
if i < n {
|
||||||
|
return slice, strings.TrimLeft(p[i].uid, "0")
|
||||||
|
}
|
||||||
|
return slice, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// provisionerSum returns the SHA1 of the json representation of the
|
||||||
|
// provisioner. From this we will create the unique and sorted id.
|
||||||
|
func provisionerSum(p *provisioner.Provisioner) ([]byte, error) {
|
||||||
|
b, err := json.Marshal(p.Key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error marshalling provisioner")
|
||||||
|
}
|
||||||
|
sum := sha1.Sum(b)
|
||||||
|
return sum[:], nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,11 +1,17 @@
|
||||||
package authority
|
package authority
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/ca-component/provisioner"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetEncryptedKey(t *testing.T) {
|
func TestGetEncryptedKey(t *testing.T) {
|
||||||
|
@ -120,3 +126,102 @@ func TestGetProvisioners(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateProvisioner(t *testing.T) *provisioner.Provisioner {
|
||||||
|
issuer, err := randutil.Alphanumeric(10)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Create a new JWK
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// Encrypt JWK
|
||||||
|
salt, err := randutil.Salt(jose.PBKDF2SaltSize)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
b, err := json.Marshal(jwk)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
recipient := jose.Recipient{
|
||||||
|
Algorithm: jose.PBES2_HS256_A128KW,
|
||||||
|
Key: []byte("password"),
|
||||||
|
PBES2Count: jose.PBKDF2Iterations,
|
||||||
|
PBES2Salt: salt,
|
||||||
|
}
|
||||||
|
opts := new(jose.EncrypterOptions)
|
||||||
|
opts.WithContentType(jose.ContentType("jwk+json"))
|
||||||
|
encrypter, err := jose.NewEncrypter(jose.DefaultEncAlgorithm, recipient, opts)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jwe, err := encrypter.Encrypt(b)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
// get public and encrypted keys
|
||||||
|
public := jwk.Public()
|
||||||
|
encrypted, err := jwe.CompactSerialize()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
return &provisioner.Provisioner{
|
||||||
|
Issuer: issuer,
|
||||||
|
Type: "JWT",
|
||||||
|
Key: &public,
|
||||||
|
EncryptedKey: encrypted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_newSortedProvisioners(t *testing.T) {
|
||||||
|
provisioners := make([]*provisioner.Provisioner, 20)
|
||||||
|
for i := range provisioners {
|
||||||
|
provisioners[i] = generateProvisioner(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
ps, err := newSortedProvisioners(provisioners)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
prev := ""
|
||||||
|
for i, p := range ps {
|
||||||
|
if p.uid < prev {
|
||||||
|
t.Errorf("%s should be less that %s", p.uid, prev)
|
||||||
|
}
|
||||||
|
if p.provisioner.Key.KeyID != provisioners[i].Key.KeyID {
|
||||||
|
t.Errorf("provisioner order is not the same: %s != %s", p.provisioner.Key.KeyID, provisioners[i].Key.KeyID)
|
||||||
|
}
|
||||||
|
prev = p.uid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_provisionerSlice_Find(t *testing.T) {
|
||||||
|
trim := func(s string) string {
|
||||||
|
return strings.TrimLeft(s, "0")
|
||||||
|
}
|
||||||
|
provisioners := make([]*provisioner.Provisioner, 20)
|
||||||
|
for i := range provisioners {
|
||||||
|
provisioners[i] = generateProvisioner(t)
|
||||||
|
}
|
||||||
|
ps, err := newSortedProvisioners(provisioners)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
cursor string
|
||||||
|
limit int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
p provisionerSlice
|
||||||
|
args args
|
||||||
|
want []*provisioner.Provisioner
|
||||||
|
want1 string
|
||||||
|
}{
|
||||||
|
{"all", ps, args{"", DefaultProvisionersMax}, provisioners[0:20], ""},
|
||||||
|
{"0 to 19", ps, args{"", 20}, provisioners[0:20], ""},
|
||||||
|
{"0 to 9", ps, args{"", 10}, provisioners[0:10], trim(ps[10].uid)},
|
||||||
|
{"9 to 19", ps, args{trim(ps[10].uid), 10}, provisioners[10:20], ""},
|
||||||
|
{"1", ps, args{trim(ps[1].uid), 1}, provisioners[1:2], trim(ps[2].uid)},
|
||||||
|
{"1 to 5", ps, args{trim(ps[1].uid), 4}, provisioners[1:5], trim(ps[5].uid)},
|
||||||
|
{"defaultLimit", ps, args{"", 0}, provisioners[0:20], ""},
|
||||||
|
{"overTheLimit", ps, args{"", DefaultProvisionersMax + 1}, provisioners[0:20], ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, got1 := tt.p.Find(tt.args.cursor, tt.args.limit)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("provisionerSlice.Find() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
if got1 != tt.want1 {
|
||||||
|
t.Errorf("provisionerSlice.Find() got1 = %v, want %v", got1, tt.want1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue