Use provisioner.Collection to store and request the provisioners.

This commit is contained in:
Mariano Cano 2019-03-06 15:00:23 -08:00
parent 34833d4fd5
commit c776ca3bd6
2 changed files with 22 additions and 126 deletions

View file

@ -8,6 +8,7 @@ import (
"sync"
"time"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
)
@ -16,18 +17,15 @@ const legacyAuthority = "step-certificate-authority"
// Authority implements the Certificate Authority internal interface.
type Authority struct {
config *Config
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
ottMap *sync.Map
startTime time.Time
provisionerIDIndex *sync.Map
encryptedKeyIndex *sync.Map
provisionerKeySetIndex *sync.Map
sortedProvisioners provisionerSlice
audiences []string
config *Config
rootX509Certs []*x509.Certificate
intermediateIdentity *x509util.Identity
validateOnce bool
certificates *sync.Map
ottMap *sync.Map
startTime time.Time
provisioners *provisioner.Collection
audiences []string
// Do not re-initialize
initOnce bool
}
@ -39,15 +37,6 @@ func New(config *Config) (*Authority, error) {
return nil, err
}
// Get sorted provisioners
var sorted provisionerSlice
if config.AuthorityConfig != nil {
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
if err != nil {
return nil, err
}
}
// Define audiences: legacy + possible urls without the ports.
// The CA might have proxies in front so we cannot rely on the port.
audiences := []string{legacyAuthority}
@ -56,14 +45,11 @@ func New(config *Config) (*Authority, error) {
}
var a = &Authority{
config: config,
certificates: new(sync.Map),
ottMap: new(sync.Map),
provisionerIDIndex: new(sync.Map),
encryptedKeyIndex: new(sync.Map),
provisionerKeySetIndex: new(sync.Map),
sortedProvisioners: sorted,
audiences: audiences,
config: config,
certificates: new(sync.Map),
ottMap: new(sync.Map),
provisioners: provisioner.NewCollection(audiences),
audiences: audiences,
}
if err := a.init(); err != nil {
return nil, err
@ -120,10 +106,10 @@ func (a *Authority) init() error {
}
}
// Store all the provisioners
for _, p := range a.config.AuthorityConfig.Provisioners {
a.provisionerIDIndex.Store(p.ID(), p)
if len(p.EncryptedKey) != 0 {
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
if err := a.provisioners.Store(p); err != nil {
return err
}
}

View file

@ -1,115 +1,25 @@
package authority
import (
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"math"
"net/http"
"sort"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
val, ok := a.encryptedKeyIndex.Load(kid)
key, ok := a.provisioners.LoadEncryptedKey(kid)
if !ok {
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
http.StatusNotFound, context{}}
}
key, ok := val.(string)
if !ok {
return "", &apiError{errors.Errorf("stored value is not a string"),
http.StatusInternalServerError, context{}}
}
return key, nil
}
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
// with their public keys.
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) {
provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit)
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*provisioner.Provisioner, string, error) {
provisioners, nextCursor := a.provisioners.Find(cursor, limit)
return provisioners, nextCursor, nil
}
type uidProvisioner struct {
provisioner *Provisioner
uid string
}
func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) {
if len(provisioners) > math.MaxInt32 {
return nil, errors.New("too many provisioners")
}
var slice provisionerSlice
bi := make([]byte, 4)
for i, p := range provisioners {
sum, err := provisionerSum(p)
if err != nil {
return nil, err
}
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
binary.BigEndian.PutUint32(bi, uint32(i))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
slice = append(slice, uidProvisioner{
provisioner: p,
uid: hex.EncodeToString(sum),
})
}
sort.Sort(slice)
return slice, nil
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := len(p)
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
var slice []*Provisioner
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, p[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(p[i].uid, "0")
}
return slice, ""
}
// provisionerSum returns the SHA1 of the json representation of the
// provisioner. From this we will create the unique and sorted id.
func provisionerSum(p *Provisioner) ([]byte, error) {
b, err := json.Marshal(p.Key)
if err != nil {
return nil, errors.Wrap(err, "error marshalling provisioner")
}
sum := sha1.Sum(b)
return sum[:], nil
}