diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index a62009d8..7488c7cd 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -1,22 +1,53 @@ package provisioner import ( + "crypto/sha1" + "crypto/x509" + "encoding/asn1" + "encoding/binary" + "encoding/hex" + "fmt" + "net/url" + "sort" + "strings" "sync" "github.com/pkg/errors" + "github.com/smallstep/cli/jose" ) +// DefaultProvisionersLimit is the default limit for listing provisioners. +const DefaultProvisionersLimit = 20 + +// DefaultProvisionersMax is the maximum limit for listing provisioners. +const DefaultProvisionersMax = 100 + +type uidProvisioner struct { + provisioner *Provisioner + uid string +} + +type provisionerSlice []uidProvisioner + +func (p provisionerSlice) Len() int { return len(p) } +func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } +func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } + // Collection is a memory map of provisioners. type Collection struct { - byID *sync.Map - byKey *sync.Map + byID *sync.Map + byKey *sync.Map + sorted provisionerSlice + audiences []string } -// NewCollection initializes a collection of provisioners. -func NewCollection() *Collection { +// NewCollection initializes a collection of provisioners. The given list of +// audiences are the audiences used by the JWT provisioner. +func NewCollection(audiences []string) *Collection { return &Collection{ - byID: new(sync.Map), - byKey: new(sync.Map), + byID: new(sync.Map), + byKey: new(sync.Map), + audiences: audiences, } } @@ -25,27 +56,117 @@ func (c *Collection) Load(id string) (*Provisioner, bool) { return loadProvisioner(c.byID, id) } +// LoadByToken parses the token claims and loads the provisioner associated. +func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (*Provisioner, bool) { + // match with server audiences + if matchesAudience(claims.Audience, c.audiences) { + // If matches with stored audiences it will be a JWT token (default), and + // the id would be :. + return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) + } + + // The ID will be just the clientID stored in azp or aud. + var payload openIDPayload + if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { + return nil, false + } + // audience is required + if len(payload.Audience) == 0 { + return nil, false + } + if len(payload.AuthorizedParty) > 0 { + return c.Load(payload.AuthorizedParty) + } + return c.Load(payload.Audience[0]) +} + +// LoadByCertificate lookds for the provisioner extension and extracts the +// proper id to load the provisioner. +func (c *Collection) LoadByCertificate(cert *x509.Certificate) (*Provisioner, bool) { + for _, e := range cert.Extensions { + if e.Id.Equal(stepOIDProvisioner) { + var provisioner stepProvisionerASN1 + if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { + return nil, false + } + if provisioner.Type == int(TypeJWK) { + return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID)) + } + return c.Load(string(provisioner.CredentialID)) + } + } + return nil, false +} + // LoadEncryptedKey returns a the encrypted key by KeyID. At this moment only // JWK encrypted keys are indexed by KeyID. -func (c *Collection) LoadEncryptedKey(keyID string) (*Provisioner, bool) { - return loadProvisioner(c.byKey, keyID) +func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { + p, ok := loadProvisioner(c.byKey, keyID) + if !ok { + return "", false + } + _, key, ok := p.GetEncryptedKey() + return key, ok } // Store adds a provisioner to the collection, it makes sure two provisioner // does not have the same ID. func (c *Collection) Store(p *Provisioner) error { - if _, loaded := c.byID.LoadOrStore(p.ID(), p); loaded == false { + // Store provisioner always in byID. ID must be unique. + if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true { return errors.New("cannot add multiple provisioners with the same id") } - // Store EncryptedKey if defined - if kid, key, ok := p.GetEncryptedKey(); ok { - c.byKey.Store(kid, key) + + // Store provisioner in byKey in EncryptedKey is defined. + if kid, _, ok := p.GetEncryptedKey(); ok { + c.byKey.Store(kid, p) } + + // Store sorted provisioners. + // Use the first 4 bytes (32bit) of the sum to insert the order + // Using big endian format to get the strings sorted: + // 0x00000000, 0x00000001, 0x00000002, ... + sum, err := provisionerSum(p) + if err != nil { + return err + } + bi := make([]byte, 4) + binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len())) + sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] + bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 + c.sorted = append(c.sorted, uidProvisioner{ + provisioner: p, + uid: hex.EncodeToString(sum), + }) return nil } -func loadProvisioner(m *sync.Map, id string) (*Provisioner, bool) { - i, ok := m.Load(id) +// Find implements pagination on a list of sorted provisioners. +func (c *Collection) Find(cursor string, limit int) ([]*Provisioner, string) { + switch { + case limit <= 0: + limit = DefaultProvisionersLimit + case limit > DefaultProvisionersMax: + limit = DefaultProvisionersMax + } + + n := c.sorted.Len() + cursor = fmt.Sprintf("%040s", cursor) + i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) + + slice := []*Provisioner{} + for ; i < n && len(slice) < limit; i++ { + slice = append(slice, c.sorted[i].provisioner) + } + + if i < n { + return slice, strings.TrimLeft(c.sorted[i].uid, "0") + } + return slice, "" +} + +func loadProvisioner(m *sync.Map, key string) (*Provisioner, bool) { + i, ok := m.Load(key) if !ok { return nil, false } @@ -55,3 +176,37 @@ func loadProvisioner(m *sync.Map, id string) (*Provisioner, bool) { } return p, true } + +// provisionerSum returns the SHA1 of the provisioners ID. From this we will +// create the unique and sorted id. +func provisionerSum(p *Provisioner) ([]byte, error) { + sum := sha1.Sum([]byte(p.GetID())) + return sum[:], nil +} + +// matchesAudience returns true if A and B share at least one element. +func matchesAudience(as, bs []string) bool { + if len(bs) == 0 || len(as) == 0 { + return false + } + + for _, b := range bs { + for _, a := range as { + if b == a || stripPort(a) == stripPort(b) { + return true + } + } + } + return false +} + +// stripPort attempts to strip the port from the given url. If parsing the url +// produces errors it will just return the passed argument. +func stripPort(rawurl string) string { + u, err := url.Parse(rawurl) + if err != nil { + return rawurl + } + u.Host = u.Hostname() + return u.String() +}