Add collections interface and play around with collections.
This commit is contained in:
parent
00634fb648
commit
76f54f33d5
7 changed files with 189 additions and 9 deletions
|
@ -8,6 +8,19 @@ import (
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Admins is the interface used by the admin collection.
|
||||||
|
type Admins interface {
|
||||||
|
Store(adm *linkedca.Admin, prov provisioner.Interface) error
|
||||||
|
Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error)
|
||||||
|
Remove(id string) error
|
||||||
|
LoadByID(id string) (*linkedca.Admin, bool)
|
||||||
|
LoadBySubProv(sub, provName string) (*linkedca.Admin, bool)
|
||||||
|
LoadByProvisioner(provName string) ([]*linkedca.Admin, bool)
|
||||||
|
Find(cursor string, limit int) ([]*linkedca.Admin, string)
|
||||||
|
SuperCount() int
|
||||||
|
SuperCountByProvisioner(provName string) int
|
||||||
|
}
|
||||||
|
|
||||||
// LoadAdminByID returns an *linkedca.Admin with the given ID.
|
// LoadAdminByID returns an *linkedca.Admin with the given ID.
|
||||||
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||||
a.adminMutex.RLock()
|
a.adminMutex.RLock()
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||||
"github.com/smallstep/certificates/authority/administrator"
|
"github.com/smallstep/certificates/authority/administrator"
|
||||||
|
"github.com/smallstep/certificates/authority/cache"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/cas"
|
"github.com/smallstep/certificates/cas"
|
||||||
|
@ -35,8 +36,8 @@ import (
|
||||||
type Authority struct {
|
type Authority struct {
|
||||||
config *config.Config
|
config *config.Config
|
||||||
keyManager kms.KeyManager
|
keyManager kms.KeyManager
|
||||||
provisioners *provisioner.Collection
|
provisioners Provisioners
|
||||||
admins *administrator.Collection
|
admins Admins
|
||||||
db db.AuthDB
|
db db.AuthDB
|
||||||
adminDB admin.DB
|
adminDB admin.DB
|
||||||
templates *templates.Templates
|
templates *templates.Templates
|
||||||
|
@ -76,6 +77,7 @@ type Authority struct {
|
||||||
getIdentityFunc provisioner.GetIdentityFunc
|
getIdentityFunc provisioner.GetIdentityFunc
|
||||||
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
||||||
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
||||||
|
cachePool cache.Pool
|
||||||
|
|
||||||
adminMutex sync.RWMutex
|
adminMutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
@ -175,7 +177,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create provisioner collection.
|
// Create provisioner collection.
|
||||||
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
|
provClxn := provisioner.NewCollection(provisionerConfig)
|
||||||
for _, p := range provList {
|
for _, p := range provList {
|
||||||
if err := p.Init(provisionerConfig); err != nil {
|
if err := p.Init(provisionerConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -502,6 +504,11 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize the default cache pool.
|
||||||
|
if a.cachePool == nil {
|
||||||
|
a.cachePool = cache.DefaultPool()
|
||||||
|
}
|
||||||
|
|
||||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
provs, err := a.adminDB.GetProvisioners(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||||
|
|
89
authority/cache/cache.go
vendored
Normal file
89
authority/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
type Cache interface {
|
||||||
|
Get(context.Context, string) ([]byte, error)
|
||||||
|
Set(context.Context, string, []byte) error
|
||||||
|
Delete(context.Context, string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Getter interface {
|
||||||
|
Get(ctx context.Context, key string) ([]byte, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A GetterFunc implements Getter with a function.
|
||||||
|
type GetterFunc func(ctx context.Context, key string) ([]byte, error)
|
||||||
|
|
||||||
|
func (f GetterFunc) Get(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
return f(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Pool interface {
|
||||||
|
New(name string, getter Getter) Cache
|
||||||
|
Get(name string) (Cache, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultPool() Pool {
|
||||||
|
return &defaultPool{
|
||||||
|
caches: make(map[string]Cache),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultPool struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
caches map[string]Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *defaultPool) New(name string, getter Getter) Cache {
|
||||||
|
c := &mapCache{
|
||||||
|
m: new(sync.Map),
|
||||||
|
getter: getter,
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
p.caches[name] = c
|
||||||
|
p.mu.Unlock()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *defaultPool) Get(name string) (Cache, bool) {
|
||||||
|
p.mu.RLock()
|
||||||
|
c, ok := p.caches[name]
|
||||||
|
p.mu.RUnlock()
|
||||||
|
return c, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
type mapCache struct {
|
||||||
|
name string
|
||||||
|
m *sync.Map
|
||||||
|
getter Getter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mapCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||||
|
v, ok := m.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
b, err := m.getter.Get(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m.m.Store(key, b)
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
return v.([]byte), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mapCache) Set(ctx context.Context, key string, value []byte) error {
|
||||||
|
m.m.Store(key, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mapCache) Delete(ctx context.Context, key string) error {
|
||||||
|
m.m.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
"github.com/smallstep/certificates/authority/cache"
|
||||||
"github.com/smallstep/certificates/authority/config"
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/cas"
|
"github.com/smallstep/certificates/cas"
|
||||||
|
@ -284,6 +285,14 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithCachePool is an options that allows to define a custom cache pool.
|
||||||
|
func WithCachePool(pool cache.Pool) Option {
|
||||||
|
return func(a *Authority) error {
|
||||||
|
a.cachePool = pool
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
|
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
|
||||||
var block *pem.Block
|
var block *pem.Block
|
||||||
var certs []*x509.Certificate
|
var certs []*x509.Certificate
|
||||||
|
|
|
@ -1,19 +1,25 @@
|
||||||
package provisioner
|
package provisioner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/smallstep/certificates/authority/admin"
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
"github.com/smallstep/certificates/authority/cache"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
"go.step.sm/linkedca"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||||
|
@ -33,6 +39,10 @@ func (p provisionerSlice) Len() int { return len(p) }
|
||||||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
|
|
||||||
|
func defaultContext() (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
// loadByTokenPayload is a payload used to extract the id used to load the
|
// loadByTokenPayload is a payload used to extract the id used to load the
|
||||||
// provisioner.
|
// provisioner.
|
||||||
type loadByTokenPayload struct {
|
type loadByTokenPayload struct {
|
||||||
|
@ -50,22 +60,53 @@ type Collection struct {
|
||||||
byTokenID *sync.Map
|
byTokenID *sync.Map
|
||||||
sorted provisionerSlice
|
sorted provisionerSlice
|
||||||
audiences Audiences
|
audiences Audiences
|
||||||
|
|
||||||
|
// new
|
||||||
|
byIDCache cache.Cache
|
||||||
|
byNameCache cache.Cache
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCollection initializes a collection of provisioners. The given list of
|
// NewCollection initializes a collection of provisioners. The given list of
|
||||||
// audiences are the audiences used by the JWT provisioner.
|
// audiences are the audiences used by the JWT provisioner.
|
||||||
func NewCollection(audiences Audiences) *Collection {
|
func NewCollection(config Config) *Collection {
|
||||||
|
byID := config.CachePool.New("provisioner_by_id", cache.GetterFunc(func(ctx context.Context, id string) ([]byte, error) {
|
||||||
|
p, err := config.AdminDB.GetProvisioner(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return proto.Marshal(p)
|
||||||
|
}))
|
||||||
|
// byName maps a name with a provisioner id, we will manually fill this cache.
|
||||||
|
byName := config.CachePool.New("provisioner_by_name", cache.GetterFunc(func(ctx context.Context, name string) ([]byte, error) {
|
||||||
|
return nil, cache.ErrNotFound
|
||||||
|
}))
|
||||||
|
|
||||||
return &Collection{
|
return &Collection{
|
||||||
byID: new(sync.Map),
|
byID: new(sync.Map),
|
||||||
byKey: new(sync.Map),
|
byKey: new(sync.Map),
|
||||||
byName: new(sync.Map),
|
byName: new(sync.Map),
|
||||||
byTokenID: new(sync.Map),
|
byTokenID: new(sync.Map),
|
||||||
audiences: audiences,
|
audiences: config.Audiences,
|
||||||
|
byIDCache: byID,
|
||||||
|
byNameCache: byName,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load a provisioner by the ID.
|
// Load a provisioner by the ID.
|
||||||
func (c *Collection) Load(id string) (Interface, bool) {
|
func (c *Collection) Load(id string) (Interface, bool) {
|
||||||
|
ctx, cancel := defaultContext()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
b, err := c.byIDCache.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var p linkedca.Provisioner
|
||||||
|
if err := proto.Unmarshal(b, &p); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
log.Printf("Provisioner.Load(%s): %v", id, p)
|
||||||
return loadProvisioner(c.byID, id)
|
return loadProvisioner(c.byID, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -180,6 +221,7 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
||||||
// Store adds a provisioner to the collection and enforces the uniqueness of
|
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||||
// provisioner IDs.
|
// provisioner IDs.
|
||||||
func (c *Collection) Store(p Interface) error {
|
func (c *Collection) Store(p Interface) error {
|
||||||
|
|
||||||
// Store provisioner always in byID. ID must be unique.
|
// Store provisioner always in byID. ID must be unique.
|
||||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
||||||
return admin.NewError(admin.ErrorBadRequestType,
|
return admin.NewError(admin.ErrorBadRequestType,
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/admin"
|
||||||
|
"github.com/smallstep/certificates/authority/cache"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/certificates/errs"
|
"github.com/smallstep/certificates/errs"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
@ -214,6 +216,8 @@ type Config struct {
|
||||||
Audiences Audiences
|
Audiences Audiences
|
||||||
// DB is the interface to the authority DB client.
|
// DB is the interface to the authority DB client.
|
||||||
DB db.AuthDB
|
DB db.AuthDB
|
||||||
|
// AdminDB is the interface to the administration DB client.
|
||||||
|
AdminDB admin.DB
|
||||||
// SSHKeys are the root SSH public keys
|
// SSHKeys are the root SSH public keys
|
||||||
SSHKeys *SSHKeys
|
SSHKeys *SSHKeys
|
||||||
// GetIdentityFunc is a function that returns an identity that will be
|
// GetIdentityFunc is a function that returns an identity that will be
|
||||||
|
@ -225,6 +229,8 @@ type Config struct {
|
||||||
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
|
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
|
||||||
// certificate can be renewed.
|
// certificate can be renewed.
|
||||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||||
|
// CachePool is a type that allows to create new caches.
|
||||||
|
CachePool cache.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
type provisioner struct {
|
type provisioner struct {
|
||||||
|
|
|
@ -21,6 +21,20 @@ import (
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Provisioners is the interface used by the provisioners collection.
|
||||||
|
type Provisioners interface {
|
||||||
|
Load(id string) (provisioner.Interface, bool)
|
||||||
|
Store(p provisioner.Interface) error
|
||||||
|
Update(p provisioner.Interface) error
|
||||||
|
Remove(id string) error
|
||||||
|
LoadByName(name string) (provisioner.Interface, bool)
|
||||||
|
LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (provisioner.Interface, bool)
|
||||||
|
LoadByTokenID(tokenProvisionerID string) (provisioner.Interface, bool)
|
||||||
|
LoadByCertificate(cert *x509.Certificate) (provisioner.Interface, bool)
|
||||||
|
Find(cursor string, limit int) (provisioner.List, string)
|
||||||
|
LoadEncryptedKey(keyID string) (string, bool)
|
||||||
|
}
|
||||||
|
|
||||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||||
a.adminMutex.RLock()
|
a.adminMutex.RLock()
|
||||||
|
|
Loading…
Reference in a new issue