forked from TrueCloudLab/certificates
Add new options to locate or list provisioners.
This commit is contained in:
parent
34ff388828
commit
fb77397fc7
1 changed files with 169 additions and 14 deletions
|
@ -1,22 +1,53 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/jose"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
const DefaultProvisionersLimit = 20
|
||||
|
||||
// DefaultProvisionersMax is the maximum limit for listing provisioners.
|
||||
const DefaultProvisionersMax = 100
|
||||
|
||||
type uidProvisioner struct {
|
||||
provisioner *Provisioner
|
||||
uid string
|
||||
}
|
||||
|
||||
type provisionerSlice []uidProvisioner
|
||||
|
||||
func (p provisionerSlice) Len() int { return len(p) }
|
||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
// Collection is a memory map of provisioners.
|
||||
type Collection struct {
|
||||
byID *sync.Map
|
||||
byKey *sync.Map
|
||||
byID *sync.Map
|
||||
byKey *sync.Map
|
||||
sorted provisionerSlice
|
||||
audiences []string
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners.
|
||||
func NewCollection() *Collection {
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(audiences []string) *Collection {
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
audiences: audiences,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,27 +56,117 @@ func (c *Collection) Load(id string) (*Provisioner, bool) {
|
|||
return loadProvisioner(c.byID, id)
|
||||
}
|
||||
|
||||
// LoadByToken parses the token claims and loads the provisioner associated.
|
||||
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (*Provisioner, bool) {
|
||||
// match with server audiences
|
||||
if matchesAudience(claims.Audience, c.audiences) {
|
||||
// If matches with stored audiences it will be a JWT token (default), and
|
||||
// the id would be <issuer>:<kid>.
|
||||
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
|
||||
}
|
||||
|
||||
// The ID will be just the clientID stored in azp or aud.
|
||||
var payload openIDPayload
|
||||
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
// audience is required
|
||||
if len(payload.Audience) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
if len(payload.AuthorizedParty) > 0 {
|
||||
return c.Load(payload.AuthorizedParty)
|
||||
}
|
||||
return c.Load(payload.Audience[0])
|
||||
}
|
||||
|
||||
// LoadByCertificate lookds for the provisioner extension and extracts the
|
||||
// proper id to load the provisioner.
|
||||
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (*Provisioner, bool) {
|
||||
for _, e := range cert.Extensions {
|
||||
if e.Id.Equal(stepOIDProvisioner) {
|
||||
var provisioner stepProvisionerASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
if provisioner.Type == int(TypeJWK) {
|
||||
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
|
||||
}
|
||||
return c.Load(string(provisioner.CredentialID))
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// LoadEncryptedKey returns a the encrypted key by KeyID. At this moment only
|
||||
// JWK encrypted keys are indexed by KeyID.
|
||||
func (c *Collection) LoadEncryptedKey(keyID string) (*Provisioner, bool) {
|
||||
return loadProvisioner(c.byKey, keyID)
|
||||
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||
p, ok := loadProvisioner(c.byKey, keyID)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
_, key, ok := p.GetEncryptedKey()
|
||||
return key, ok
|
||||
}
|
||||
|
||||
// Store adds a provisioner to the collection, it makes sure two provisioner
|
||||
// does not have the same ID.
|
||||
func (c *Collection) Store(p *Provisioner) error {
|
||||
if _, loaded := c.byID.LoadOrStore(p.ID(), p); loaded == false {
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
|
||||
return errors.New("cannot add multiple provisioners with the same id")
|
||||
}
|
||||
// Store EncryptedKey if defined
|
||||
if kid, key, ok := p.GetEncryptedKey(); ok {
|
||||
c.byKey.Store(kid, key)
|
||||
|
||||
// Store provisioner in byKey in EncryptedKey is defined.
|
||||
if kid, _, ok := p.GetEncryptedKey(); ok {
|
||||
c.byKey.Store(kid, p)
|
||||
}
|
||||
|
||||
// Store sorted provisioners.
|
||||
// Use the first 4 bytes (32bit) of the sum to insert the order
|
||||
// Using big endian format to get the strings sorted:
|
||||
// 0x00000000, 0x00000001, 0x00000002, ...
|
||||
sum, err := provisionerSum(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bi := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
|
||||
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
|
||||
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
|
||||
c.sorted = append(c.sorted, uidProvisioner{
|
||||
provisioner: p,
|
||||
uid: hex.EncodeToString(sum),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadProvisioner(m *sync.Map, id string) (*Provisioner, bool) {
|
||||
i, ok := m.Load(id)
|
||||
// Find implements pagination on a list of sorted provisioners.
|
||||
func (c *Collection) Find(cursor string, limit int) ([]*Provisioner, string) {
|
||||
switch {
|
||||
case limit <= 0:
|
||||
limit = DefaultProvisionersLimit
|
||||
case limit > DefaultProvisionersMax:
|
||||
limit = DefaultProvisionersMax
|
||||
}
|
||||
|
||||
n := c.sorted.Len()
|
||||
cursor = fmt.Sprintf("%040s", cursor)
|
||||
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
|
||||
|
||||
slice := []*Provisioner{}
|
||||
for ; i < n && len(slice) < limit; i++ {
|
||||
slice = append(slice, c.sorted[i].provisioner)
|
||||
}
|
||||
|
||||
if i < n {
|
||||
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
|
||||
}
|
||||
return slice, ""
|
||||
}
|
||||
|
||||
func loadProvisioner(m *sync.Map, key string) (*Provisioner, bool) {
|
||||
i, ok := m.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
@ -55,3 +176,37 @@ func loadProvisioner(m *sync.Map, id string) (*Provisioner, bool) {
|
|||
}
|
||||
return p, true
|
||||
}
|
||||
|
||||
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
|
||||
// create the unique and sorted id.
|
||||
func provisionerSum(p *Provisioner) ([]byte, error) {
|
||||
sum := sha1.Sum([]byte(p.GetID()))
|
||||
return sum[:], nil
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
func matchesAudience(as, bs []string) bool {
|
||||
if len(bs) == 0 || len(as) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range bs {
|
||||
for _, a := range as {
|
||||
if b == a || stripPort(a) == stripPort(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||
// produces errors it will just return the passed argument.
|
||||
func stripPort(rawurl string) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return rawurl
|
||||
}
|
||||
u.Host = u.Hostname()
|
||||
return u.String()
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue