package provisioner import ( "crypto/sha1" "crypto/x509" "encoding/asn1" "encoding/binary" "encoding/hex" "fmt" "net/url" "sort" "strings" "sync" "github.com/pkg/errors" "github.com/smallstep/cli/jose" ) // DefaultProvisionersLimit is the default limit for listing provisioners. const DefaultProvisionersLimit = 20 // DefaultProvisionersMax is the maximum limit for listing provisioners. const DefaultProvisionersMax = 100 type uidProvisioner struct { provisioner *Provisioner uid string } type provisionerSlice []uidProvisioner func (p provisionerSlice) Len() int { return len(p) } func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // Collection is a memory map of provisioners. type Collection struct { byID *sync.Map byKey *sync.Map sorted provisionerSlice audiences []string } // NewCollection initializes a collection of provisioners. The given list of // audiences are the audiences used by the JWT provisioner. func NewCollection(audiences []string) *Collection { return &Collection{ byID: new(sync.Map), byKey: new(sync.Map), audiences: audiences, } } // Load a provisioner by the ID. func (c *Collection) Load(id string) (*Provisioner, bool) { return loadProvisioner(c.byID, id) } // LoadByToken parses the token claims and loads the provisioner associated. func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (*Provisioner, bool) { // match with server audiences if matchesAudience(claims.Audience, c.audiences) { // If matches with stored audiences it will be a JWT token (default), and // the id would be :. return c.Load(claims.Issuer + ":" + token.Headers[0].KeyID) } // The ID will be just the clientID stored in azp or aud. var payload openIDPayload if err := token.UnsafeClaimsWithoutVerification(&payload); err != nil { return nil, false } // audience is required if len(payload.Audience) == 0 { return nil, false } if len(payload.AuthorizedParty) > 0 { return c.Load(payload.AuthorizedParty) } return c.Load(payload.Audience[0]) } // LoadByCertificate lookds for the provisioner extension and extracts the // proper id to load the provisioner. func (c *Collection) LoadByCertificate(cert *x509.Certificate) (*Provisioner, bool) { for _, e := range cert.Extensions { if e.Id.Equal(stepOIDProvisioner) { var provisioner stepProvisionerASN1 if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { return nil, false } if provisioner.Type == int(TypeJWK) { return c.Load(string(provisioner.Name) + ":" + string(provisioner.CredentialID)) } return c.Load(string(provisioner.CredentialID)) } } return nil, false } // LoadEncryptedKey returns a the encrypted key by KeyID. At this moment only // JWK encrypted keys are indexed by KeyID. func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) { p, ok := loadProvisioner(c.byKey, keyID) if !ok { return "", false } _, key, ok := p.GetEncryptedKey() return key, ok } // Store adds a provisioner to the collection, it makes sure two provisioner // does not have the same ID. func (c *Collection) Store(p *Provisioner) error { // Store provisioner always in byID. ID must be unique. if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded == true { return errors.New("cannot add multiple provisioners with the same id") } // Store provisioner in byKey in EncryptedKey is defined. if kid, _, ok := p.GetEncryptedKey(); ok { c.byKey.Store(kid, p) } // Store sorted provisioners. // Use the first 4 bytes (32bit) of the sum to insert the order // Using big endian format to get the strings sorted: // 0x00000000, 0x00000001, 0x00000002, ... sum, err := provisionerSum(p) if err != nil { return err } bi := make([]byte, 4) binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len())) sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3] bi[0], bi[1], bi[2], bi[3] = 0, 0, 0, 0 c.sorted = append(c.sorted, uidProvisioner{ provisioner: p, uid: hex.EncodeToString(sum), }) return nil } // Find implements pagination on a list of sorted provisioners. func (c *Collection) Find(cursor string, limit int) ([]*Provisioner, string) { switch { case limit <= 0: limit = DefaultProvisionersLimit case limit > DefaultProvisionersMax: limit = DefaultProvisionersMax } n := c.sorted.Len() cursor = fmt.Sprintf("%040s", cursor) i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) slice := []*Provisioner{} for ; i < n && len(slice) < limit; i++ { slice = append(slice, c.sorted[i].provisioner) } if i < n { return slice, strings.TrimLeft(c.sorted[i].uid, "0") } return slice, "" } func loadProvisioner(m *sync.Map, key string) (*Provisioner, bool) { i, ok := m.Load(key) if !ok { return nil, false } p, ok := i.(*Provisioner) if !ok { return nil, false } return p, true } // provisionerSum returns the SHA1 of the provisioners ID. From this we will // create the unique and sorted id. func provisionerSum(p *Provisioner) ([]byte, error) { sum := sha1.Sum([]byte(p.GetID())) return sum[:], nil } // matchesAudience returns true if A and B share at least one element. func matchesAudience(as, bs []string) bool { if len(bs) == 0 || len(as) == 0 { return false } for _, b := range bs { for _, a := range as { if b == a || stripPort(a) == stripPort(b) { return true } } } return false } // stripPort attempts to strip the port from the given url. If parsing the url // produces errors it will just return the passed argument. func stripPort(rawurl string) string { u, err := url.Parse(rawurl) if err != nil { return rawurl } u.Host = u.Hostname() return u.String() }