Use provisioner.Collection to store and request the provisioners.
This commit is contained in:
parent
34833d4fd5
commit
c776ca3bd6
2 changed files with 22 additions and 126 deletions
|
@ -8,6 +8,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/cli/crypto/pemutil"
|
||||
"github.com/smallstep/cli/crypto/x509util"
|
||||
)
|
||||
|
@ -16,18 +17,15 @@ const legacyAuthority = "step-certificate-authority"
|
|||
|
||||
// Authority implements the Certificate Authority internal interface.
|
||||
type Authority struct {
|
||||
config *Config
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
ottMap *sync.Map
|
||||
startTime time.Time
|
||||
provisionerIDIndex *sync.Map
|
||||
encryptedKeyIndex *sync.Map
|
||||
provisionerKeySetIndex *sync.Map
|
||||
sortedProvisioners provisionerSlice
|
||||
audiences []string
|
||||
config *Config
|
||||
rootX509Certs []*x509.Certificate
|
||||
intermediateIdentity *x509util.Identity
|
||||
validateOnce bool
|
||||
certificates *sync.Map
|
||||
ottMap *sync.Map
|
||||
startTime time.Time
|
||||
provisioners *provisioner.Collection
|
||||
audiences []string
|
||||
// Do not re-initialize
|
||||
initOnce bool
|
||||
}
|
||||
|
@ -39,15 +37,6 @@ func New(config *Config) (*Authority, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Get sorted provisioners
|
||||
var sorted provisionerSlice
|
||||
if config.AuthorityConfig != nil {
|
||||
sorted, err = newSortedProvisioners(config.AuthorityConfig.Provisioners)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Define audiences: legacy + possible urls without the ports.
|
||||
// The CA might have proxies in front so we cannot rely on the port.
|
||||
audiences := []string{legacyAuthority}
|
||||
|
@ -56,14 +45,11 @@ func New(config *Config) (*Authority, error) {
|
|||
}
|
||||
|
||||
var a = &Authority{
|
||||
config: config,
|
||||
certificates: new(sync.Map),
|
||||
ottMap: new(sync.Map),
|
||||
provisionerIDIndex: new(sync.Map),
|
||||
encryptedKeyIndex: new(sync.Map),
|
||||
provisionerKeySetIndex: new(sync.Map),
|
||||
sortedProvisioners: sorted,
|
||||
audiences: audiences,
|
||||
config: config,
|
||||
certificates: new(sync.Map),
|
||||
ottMap: new(sync.Map),
|
||||
provisioners: provisioner.NewCollection(audiences),
|
||||
audiences: audiences,
|
||||
}
|
||||
if err := a.init(); err != nil {
|
||||
return nil, err
|
||||
|
@ -120,10 +106,10 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Store all the provisioners
|
||||
for _, p := range a.config.AuthorityConfig.Provisioners {
|
||||
a.provisionerIDIndex.Store(p.ID(), p)
|
||||
if len(p.EncryptedKey) != 0 {
|
||||
a.encryptedKeyIndex.Store(p.Key.KeyID, p.EncryptedKey)
|
||||
if err := a.provisioners.Store(p); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,115 +1,25 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
val, ok := a.encryptedKeyIndex.Load(kid)
|
||||
key, ok := a.provisioners.LoadEncryptedKey(kid)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("encrypted key with kid %s was not found", kid),
|
||||
http.StatusNotFound, context{}}
|
||||
}
|
||||
|
||||
key, ok := val.(string)
|
||||
if !ok {
|
||||
return "", &apiError{errors.Errorf("stored value is not a string"),
|
||||
http.StatusInternalServerError, context{}}
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetProvisioners returns a map listing each provisioner and the JWK Key Set
|
||||
// with their public keys.
|
||||
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*Provisioner, string, error) {
|
||||
provisioners, nextCursor := a.sortedProvisioners.Find(cursor, limit)
|
||||
func (a *Authority) GetProvisioners(cursor string, limit int) ([]*provisioner.Provisioner, string, error) {
|
||||
provisioners, nextCursor := a.provisioners.Find(cursor, limit)
|
||||
return provisioners, nextCursor, nil
|
||||
}
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner *Provisioner
|
||||
uid string
|
||||
}
|
||||
|
||||
func newSortedProvisioners(provisioners []*Provisioner) (provisionerSlice, error) {
|
||||
if len(provisioners) > math.MaxInt32 {
|
||||
return nil, errors.New("too many provisioners")
|
||||
}
|
||||
|
||||
var slice provisionerSlice
|
||||
bi := make([]byte, 4)
|
||||
for i, p := range provisioners {
|
||||
sum, err := provisionerSum(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
binary.BigEndian.PutUint32(bi, uint32(i))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||
slice = append(slice, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
}
|
||||
sort.Sort(slice)
|
||||
return slice, nil
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
func (p provisionerSlice) Find(cursor string, limit int) ([]*Provisioner, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := len(p)
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return p[i].uid >= cursor })
|
||||
|
||||
var slice []*Provisioner
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, p[i].provisioner)
|
||||
}
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(p[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the json representation of the
|
||||
// provisioner. From this we will create the unique and sorted id.
|
||||
func provisionerSum(p *Provisioner) ([]byte, error) {
|
||||
b, err := json.Marshal(p.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshalling provisioner")
|
||||
}
|
||||
sum := sha1.Sum(b)
|
||||
return sum[:], nil
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue