forked from TrueCloudLab/certificates
Add collections interface and play around with collections.
This commit is contained in:
parent
00634fb648
commit
76f54f33d5
7 changed files with 189 additions and 9 deletions
|
@ -8,6 +8,19 @@ import (
|
|||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// Admins is the interface used by the admin collection.
|
||||
type Admins interface {
|
||||
Store(adm *linkedca.Admin, prov provisioner.Interface) error
|
||||
Update(id string, nu *linkedca.Admin) (*linkedca.Admin, error)
|
||||
Remove(id string) error
|
||||
LoadByID(id string) (*linkedca.Admin, bool)
|
||||
LoadBySubProv(sub, provName string) (*linkedca.Admin, bool)
|
||||
LoadByProvisioner(provName string) ([]*linkedca.Admin, bool)
|
||||
Find(cursor string, limit int) ([]*linkedca.Admin, string)
|
||||
SuperCount() int
|
||||
SuperCountByProvisioner(provName string) int
|
||||
}
|
||||
|
||||
// LoadAdminByID returns an *linkedca.Admin with the given ID.
|
||||
func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||
a.adminMutex.RLock()
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/smallstep/certificates/authority/admin"
|
||||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/administrator"
|
||||
"github.com/smallstep/certificates/authority/cache"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
|
@ -35,8 +36,8 @@ import (
|
|||
type Authority struct {
|
||||
config *config.Config
|
||||
keyManager kms.KeyManager
|
||||
provisioners *provisioner.Collection
|
||||
admins *administrator.Collection
|
||||
provisioners Provisioners
|
||||
admins Admins
|
||||
db db.AuthDB
|
||||
adminDB admin.DB
|
||||
templates *templates.Templates
|
||||
|
@ -76,6 +77,7 @@ type Authority struct {
|
|||
getIdentityFunc provisioner.GetIdentityFunc
|
||||
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
||||
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
||||
cachePool cache.Pool
|
||||
|
||||
adminMutex sync.RWMutex
|
||||
}
|
||||
|
@ -175,7 +177,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Create provisioner collection.
|
||||
provClxn := provisioner.NewCollection(provisionerConfig.Audiences)
|
||||
provClxn := provisioner.NewCollection(provisionerConfig)
|
||||
for _, p := range provList {
|
||||
if err := p.Init(provisionerConfig); err != nil {
|
||||
return err
|
||||
|
@ -502,6 +504,11 @@ func (a *Authority) init() error {
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize the default cache pool.
|
||||
if a.cachePool == nil {
|
||||
a.cachePool = cache.DefaultPool()
|
||||
}
|
||||
|
||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||
|
|
89
authority/cache/cache.go
vendored
Normal file
89
authority/cache/cache.go
vendored
Normal file
|
@ -0,0 +1,89 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
type Cache interface {
|
||||
Get(context.Context, string) ([]byte, error)
|
||||
Set(context.Context, string, []byte) error
|
||||
Delete(context.Context, string) error
|
||||
}
|
||||
|
||||
type Getter interface {
|
||||
Get(ctx context.Context, key string) ([]byte, error)
|
||||
}
|
||||
|
||||
// A GetterFunc implements Getter with a function.
|
||||
type GetterFunc func(ctx context.Context, key string) ([]byte, error)
|
||||
|
||||
func (f GetterFunc) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
return f(ctx, key)
|
||||
}
|
||||
|
||||
type Pool interface {
|
||||
New(name string, getter Getter) Cache
|
||||
Get(name string) (Cache, bool)
|
||||
}
|
||||
|
||||
func DefaultPool() Pool {
|
||||
return &defaultPool{
|
||||
caches: make(map[string]Cache),
|
||||
}
|
||||
}
|
||||
|
||||
type defaultPool struct {
|
||||
mu sync.RWMutex
|
||||
caches map[string]Cache
|
||||
}
|
||||
|
||||
func (p *defaultPool) New(name string, getter Getter) Cache {
|
||||
c := &mapCache{
|
||||
m: new(sync.Map),
|
||||
getter: getter,
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.caches[name] = c
|
||||
p.mu.Unlock()
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *defaultPool) Get(name string) (Cache, bool) {
|
||||
p.mu.RLock()
|
||||
c, ok := p.caches[name]
|
||||
p.mu.RUnlock()
|
||||
return c, ok
|
||||
}
|
||||
|
||||
type mapCache struct {
|
||||
name string
|
||||
m *sync.Map
|
||||
getter Getter
|
||||
}
|
||||
|
||||
func (m *mapCache) Get(ctx context.Context, key string) ([]byte, error) {
|
||||
v, ok := m.m.Load(key)
|
||||
if !ok {
|
||||
b, err := m.getter.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.m.Store(key, b)
|
||||
return b, nil
|
||||
}
|
||||
return v.([]byte), nil
|
||||
}
|
||||
|
||||
func (m *mapCache) Set(ctx context.Context, key string, value []byte) error {
|
||||
m.m.Store(key, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mapCache) Delete(ctx context.Context, key string) error {
|
||||
m.m.Delete(key)
|
||||
return nil
|
||||
}
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/cache"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
|
@ -284,6 +285,14 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithCachePool is an options that allows to define a custom cache pool.
|
||||
func WithCachePool(pool cache.Pool) Option {
|
||||
return func(a *Authority) error {
|
||||
a.cachePool = pool
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
|
||||
var block *pem.Block
|
||||
var certs []*x509.Certificate
|
||||
|
|
|
@ -1,19 +1,25 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/cache"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// DefaultProvisionersLimit is the default limit for listing provisioners.
|
||||
|
@ -33,6 +39,10 @@ func (p provisionerSlice) Len() int { return len(p) }
|
|||
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
|
||||
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
|
||||
func defaultContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 5*time.Second)
|
||||
}
|
||||
|
||||
// loadByTokenPayload is a payload used to extract the id used to load the
|
||||
// provisioner.
|
||||
type loadByTokenPayload struct {
|
||||
|
@ -50,22 +60,53 @@ type Collection struct {
|
|||
byTokenID *sync.Map
|
||||
sorted provisionerSlice
|
||||
audiences Audiences
|
||||
|
||||
// new
|
||||
byIDCache cache.Cache
|
||||
byNameCache cache.Cache
|
||||
}
|
||||
|
||||
// NewCollection initializes a collection of provisioners. The given list of
|
||||
// audiences are the audiences used by the JWT provisioner.
|
||||
func NewCollection(audiences Audiences) *Collection {
|
||||
func NewCollection(config Config) *Collection {
|
||||
byID := config.CachePool.New("provisioner_by_id", cache.GetterFunc(func(ctx context.Context, id string) ([]byte, error) {
|
||||
p, err := config.AdminDB.GetProvisioner(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return proto.Marshal(p)
|
||||
}))
|
||||
// byName maps a name with a provisioner id, we will manually fill this cache.
|
||||
byName := config.CachePool.New("provisioner_by_name", cache.GetterFunc(func(ctx context.Context, name string) ([]byte, error) {
|
||||
return nil, cache.ErrNotFound
|
||||
}))
|
||||
|
||||
return &Collection{
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
byName: new(sync.Map),
|
||||
byTokenID: new(sync.Map),
|
||||
audiences: audiences,
|
||||
byID: new(sync.Map),
|
||||
byKey: new(sync.Map),
|
||||
byName: new(sync.Map),
|
||||
byTokenID: new(sync.Map),
|
||||
audiences: config.Audiences,
|
||||
byIDCache: byID,
|
||||
byNameCache: byName,
|
||||
}
|
||||
}
|
||||
|
||||
// Load a provisioner by the ID.
|
||||
func (c *Collection) Load(id string) (Interface, bool) {
|
||||
ctx, cancel := defaultContext()
|
||||
defer cancel()
|
||||
|
||||
b, err := c.byIDCache.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var p linkedca.Provisioner
|
||||
if err := proto.Unmarshal(b, &p); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
log.Printf("Provisioner.Load(%s): %v", id, p)
|
||||
return loadProvisioner(c.byID, id)
|
||||
}
|
||||
|
||||
|
@ -180,6 +221,7 @@ func (c *Collection) LoadEncryptedKey(keyID string) (string, bool) {
|
|||
// Store adds a provisioner to the collection and enforces the uniqueness of
|
||||
// provisioner IDs.
|
||||
func (c *Collection) Store(p Interface) error {
|
||||
|
||||
// Store provisioner always in byID. ID must be unique.
|
||||
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
|
||||
return admin.NewError(admin.ErrorBadRequestType,
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/cache"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -214,6 +216,8 @@ type Config struct {
|
|||
Audiences Audiences
|
||||
// DB is the interface to the authority DB client.
|
||||
DB db.AuthDB
|
||||
// AdminDB is the interface to the administration DB client.
|
||||
AdminDB admin.DB
|
||||
// SSHKeys are the root SSH public keys
|
||||
SSHKeys *SSHKeys
|
||||
// GetIdentityFunc is a function that returns an identity that will be
|
||||
|
@ -225,6 +229,8 @@ type Config struct {
|
|||
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
|
||||
// certificate can be renewed.
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
// CachePool is a type that allows to create new caches.
|
||||
CachePool cache.Pool
|
||||
}
|
||||
|
||||
type provisioner struct {
|
||||
|
|
|
@ -21,6 +21,20 @@ import (
|
|||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
// Provisioners is the interface used by the provisioners collection.
|
||||
type Provisioners interface {
|
||||
Load(id string) (provisioner.Interface, bool)
|
||||
Store(p provisioner.Interface) error
|
||||
Update(p provisioner.Interface) error
|
||||
Remove(id string) error
|
||||
LoadByName(name string) (provisioner.Interface, bool)
|
||||
LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (provisioner.Interface, bool)
|
||||
LoadByTokenID(tokenProvisionerID string) (provisioner.Interface, bool)
|
||||
LoadByCertificate(cert *x509.Certificate) (provisioner.Interface, bool)
|
||||
Find(cursor string, limit int) (provisioner.List, string)
|
||||
LoadEncryptedKey(keyID string) (string, bool)
|
||||
}
|
||||
|
||||
// GetEncryptedKey returns the JWE key corresponding to the given kid argument.
|
||||
func (a *Authority) GetEncryptedKey(kid string) (string, error) {
|
||||
a.adminMutex.RLock()
|
||||
|
|
Loading…
Reference in a new issue