certificates/authority/provisioner/collection.go

213 lines
5.7 KiB
Go
Raw Normal View History

2019-03-05 22:28:32 +00:00
package provisioner
import (
"crypto/sha1"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"encoding/hex"
"fmt"
"net/url"
"sort"
"strings"
2019-03-05 22:28:32 +00:00
"sync"
"github.com/pkg/errors"
"github.com/smallstep/cli/jose"
2019-03-05 22:28:32 +00:00
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
type uidProvisioner struct {
provisioner *Provisioner
uid string
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
2019-03-05 22:28:32 +00:00
// Collection is a memory map of provisioners.
type Collection struct {
byID *sync.Map
byKey *sync.Map
sorted provisionerSlice
audiences []string
2019-03-05 22:28:32 +00:00
}
// NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner.
func NewCollection(audiences []string) *Collection {
2019-03-05 22:28:32 +00:00
return &Collection{
byID: new(sync.Map),
byKey: new(sync.Map),
audiences: audiences,
2019-03-05 22:28:32 +00:00
}
}
// Load a provisioner by the ID.
func (c *Collection) Load(id string) (*Provisioner, bool) {
return loadProvisioner(c.byID, id)
}
// LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (*Provisioner, bool) {
// match with server audiences
if matchesAudience(claims.Audience, c.audiences) {
// If matches with stored audiences it will be a JWT token (default), and
// the id would be <issuer>:<kid>.
return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID)
}
// The ID will be just the clientID stored in azp or aud.
var payload openIDPayload
if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil {
return nil, false
}
// audience is required
if len(payload.Audience) == 0 {
return nil, false
}
if len(payload.AuthorizedParty) > 0 {
return c.Load(payload.AuthorizedParty)
}
return c.Load(payload.Audience[0])
}
// LoadByCertificate lookds for the provisioner extension and extracts the
// proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (*Provisioner, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
if provisioner.Type == int(TypeJWK) {
return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID))
}
return c.Load(string(provisioner.CredentialID))
}
}
return nil, false
}
// LoadEncryptedKey returns a the encrypted key by KeyID. At this moment only
// JWK encrypted keys are indexed by KeyID.
func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
p, ok := loadProvisioner(c.byKey, keyID)
if !ok {
return "", false
}
_, key, ok := p.GetEncryptedKey()
return key, ok
2019-03-05 22:28:32 +00:00
}
// Store adds a provisioner to the collection, it makes sure two provisioner
// does not have the same ID.
func (c *Collection) Store(p *Provisioner) error {
// Store provisioner always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true {
2019-03-05 22:28:32 +00:00
return errors.New("cannot add multiple provisioners with the same id")
}
// Store provisioner in byKey in EncryptedKey is defined.
if kid, _, ok := p.GetEncryptedKey(); ok {
c.byKey.Store(kid, p)
}
// Store sorted provisioners.
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
sum, err := provisionerSum(p)
if err != nil {
return err
}
bi := make([]byte, 4)
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0
c.sorted = append(c.sorted, uidProvisioner{
provisioner: p,
uid: hex.EncodeToString(sum),
})
2019-03-05 22:28:32 +00:00
return nil
}
// Find implements pagination on a list of sorted provisioners.
func (c *Collection) Find(cursor string, limit int) ([]*Provisioner, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := c.sorted.Len()
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
slice := []*Provisioner{}
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, c.sorted[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
}
return slice, ""
}
func loadProvisioner(m *sync.Map, key string) (*Provisioner, bool) {
i, ok := m.Load(key)
if !ok {
return nil, false
}
p, ok := i.(*Provisioner)
if !ok {
return nil, false
}
return p, true
}
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
// create the unique and sorted id.
func provisionerSum(p *Provisioner) ([]byte, error) {
sum := sha1.Sum([]byte(p.GetID()))
return sum[:], nil
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}