This commit is contained in:
max furman 2021-05-18 16:50:54 -07:00
parent 5d09d04d14
commit 4f3e5ef64d
18 changed files with 277 additions and 401 deletions

View file

@ -316,7 +316,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
// Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r)
cursor, limit, err := ParseCursor(r)
if err != nil {
WriteError(w, errs.BadRequestErr(err))
return
@ -427,7 +427,8 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
}
}
func parseCursor(r *http.Request) (cursor string, limit int, err error) {
// ParseCursor parses the cursor and limit from the request query params.
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query()
cursor = q.Get("cursor")
if v := q.Get("limit"); len(v) > 0 {

View file

@ -1,8 +1,17 @@
package admin
// Type specifies the type of administrator privileges the admin has.
import "github.com/smallstep/certificates/authority/status"
// Type specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR
type Type string
var (
// TypeSuper superadmin
TypeSuper = Type("SUPER_ADMIN")
// TypeRegular regular
TypeRegular = Type("REGULAR")
)
// Admin type.
type Admin struct {
ID string `json:"id"`
@ -12,4 +21,5 @@ type Admin struct {
ProvisionerType string `json:"provisionerType"`
ProvisionerID string `json:"provisionerID"`
Type Type `json:"type"`
Status status.Type `json:"status"`
}

View file

@ -2,56 +2,54 @@ package admin
import (
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"fmt"
"sort"
"strings"
"sync"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/certificates/authority/provisioner"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultAdminLimit is the default limit for listing provisioners.
const DefaultAdminLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
// DefaultAdminMax is the maximum limit for listing provisioners.
const DefaultAdminMax = 100
/*
type uidProvisioner struct {
provisioner Interface
type uidAdmin struct {
admin *Admin
uid string
}
type provisionerSlice []uidProvisioner
type adminSlice []uidAdmin
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
*/
// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
}
func (p adminSlice) Len() int { return len(p) }
func (p adminSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Collection is a memory map of admins.
type Collection struct {
byID *sync.Map
bySubProv *sync.Map
byProv *sync.Map
sorted adminSlice
provisioners *provisioner.Collection
count int
countByProvisioner map[string]int
}
// NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner.
func NewCollection() *Collection {
func NewCollection(provisioners *provisioner.Collection) *Collection {
return &Collection{
byID: new(sync.Map),
byProv: new(sync.Map),
bySubProv: new(sync.Map),
countByProvisioner: map[string]int{},
provisioners: provisioners,
}
}
@ -88,12 +86,18 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*Admin, bool) {
// Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos.
func (c *Collection) Store(adm *Admin) error {
provName := adm.ProvisionerName
p, ok := c.provisioners.Load(adm.ProvisionerID)
if !ok {
return fmt.Errorf("provisioner %s not found", adm.ProvisionerID)
}
adm.ProvisionerName = p.GetName()
adm.ProvisionerType = p.GetType().String()
// Store admin always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(adm.ID, adm); loaded {
return errors.New("cannot add multiple admins with the same id")
}
provName := adm.ProvisionerName
// Store admin alwasy in bySubProv. Subject <-> ProvisionerName must be unique.
if _, loaded := c.bySubProv.LoadOrStore(subProvNameHash(adm.Subject, provName), adm); loaded {
c.byID.Delete(adm.ID)
@ -109,6 +113,21 @@ func (c *Collection) Store(adm *Admin) error {
}
c.count++
// Store sorted admins.
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
bi := make([]byte, 4)
_sum := sha1.Sum([]byte(adm.ID))
sum := _sum[:]
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
c.sorted = append(c.sorted, uidAdmin{
admin: adm,
uid: hex.EncodeToString(sum),
})
sort.Sort(c.sorted)
return nil
}
@ -125,23 +144,22 @@ func (c *Collection) CountByProvisioner(provName string) int {
return 0
}
/*
// Find implements pagination on a list of sorted provisioners.
func (c *Collection) Find(cursor string, limit int) (List, string) {
func (c *Collection) Find(cursor string, limit int) ([]*Admin, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
limit = DefaultAdminLimit
case limit > DefaultAdminMax:
limit = DefaultAdminMax
}
n := c.sorted.Len()
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
slice := List{}
slice := []*Admin{}
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, c.sorted[i].provisioner)
slice = append(slice, c.sorted[i].admin)
}
if i < n {
@ -149,7 +167,6 @@ func (c *Collection) Find(cursor string, limit int) (List, string) {
}
return slice, ""
}
*/
func loadAdmin(m *sync.Map, key string) (*Admin, bool) {
a, ok := m.Load(key)

View file

@ -175,7 +175,7 @@ func (a *Authority) ReloadAuthConfig() error {
}
}
// Store all the admins
a.admins = admin.NewCollection()
a.admins = admin.NewCollection(a.provisioners)
for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil {
return err
@ -431,7 +431,7 @@ func (a *Authority) init() error {
}
}
// Store all the admins
a.admins = admin.NewCollection()
a.admins = admin.NewCollection(a.provisioners)
for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil {
return err

View file

@ -1,55 +1,23 @@
package mgmt
import (
"context"
"github.com/smallstep/certificates/authority/admin"
)
// AdminType specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR
type AdminType string
type AdminType admin.Type
var (
// AdminTypeSuper superadmin
AdminTypeSuper = AdminType("SUPER_ADMIN")
AdminTypeSuper = admin.TypeSuper
// AdminTypeRegular regular
AdminTypeRegular = AdminType("REGULAR")
AdminTypeRegular = admin.TypeRegular
)
// Admin type.
type Admin struct {
ID string `json:"id"`
AuthorityID string `json:"-"`
ProvisionerID string `json:"provisionerID"`
Subject string `json:"subject"`
ProvisionerName string `json:"provisionerName"`
ProvisionerType string `json:"provisionerType"`
Type AdminType `json:"type"`
Status StatusType `json:"status"`
}
// CreateAdmin builds and stores an admin type in the DB.
func CreateAdmin(ctx context.Context, db DB, provName, sub string, typ AdminType) (*Admin, error) {
adm := &Admin{
Subject: sub,
ProvisionerName: provName,
Type: typ,
Status: StatusActive,
}
if err := db.CreateAdmin(ctx, adm); err != nil {
return nil, WrapErrorISE(err, "error creating admin")
}
return adm, nil
}
type Admin admin.Admin
// ToCertificates converts an Admin to the Admin type expected by the authority.
func (adm *Admin) ToCertificates() (*admin.Admin, error) {
return &admin.Admin{
ID: adm.ID,
Subject: adm.Subject,
ProvisionerID: adm.ProvisionerID,
ProvisionerName: adm.ProvisionerName,
ProvisionerType: adm.ProvisionerType,
Type: admin.Type(adm.Type),
}, nil
return (*admin.Admin)(adm), nil
}

View file

@ -8,27 +8,34 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
)
// CreateAdminRequest represents the body for a CreateAdmin request.
type CreateAdminRequest struct {
Subject string `json:"subject"`
Provisioner string `json:"provisioner"`
Type mgmt.AdminType `json:"type"`
Type admin.Type `json:"type"`
}
// Validate validates a new-admin request body.
func (car *CreateAdminRequest) Validate(c *admin.Collection) error {
if _, ok := c.LoadBySubProv(car.Subject, car.Provisioner); ok {
return mgmt.NewError(mgmt.ErrorBadRequestType,
"admin with subject %s and provisioner name %s already exists", car.Subject, car.Provisioner)
"admin with subject: '%s' and provisioner: '%s' already exists", car.Subject, car.Provisioner)
}
return nil
}
// GetAdminsResponse for returning a list of admins.
type GetAdminsResponse struct {
Admins []*admin.Admin `json:"admins"`
NextCursor string `json:"nextCursor"`
}
// UpdateAdminRequest represents the body for a UpdateAdmin request.
type UpdateAdminRequest struct {
Type mgmt.AdminType `json:"type"`
Type admin.Type `json:"type"`
}
// Validate validates a new-admin request body.
@ -43,27 +50,31 @@ type DeleteResponse struct {
// GetAdmin returns the requested admin, or an error.
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id")
prov, err := h.db.GetAdmin(ctx, id)
if err != nil {
api.WriteError(w, err)
adm, ok := h.auth.GetAdminCollection().LoadByID(id)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType,
"admin %s not found", id))
return
}
api.JSON(w, prov)
api.JSON(w, adm)
}
// GetAdmins returns all admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
admins, err := h.db.GetAdmins(ctx)
cursor, limit, err := api.ParseCursor(r)
if err != nil {
api.WriteError(w, err)
api.WriteError(w, mgmt.WrapError(mgmt.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return
}
api.JSON(w, admins)
admins, nextCursor := h.auth.GetAdminCollection().Find(cursor, limit)
api.JSON(w, &GetAdminsResponse{
Admins: admins,
NextCursor: nextCursor,
})
}
// CreateAdmin creates a new admin.
@ -81,16 +92,24 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return
}
p, ok := h.auth.GetProvisionerCollection().LoadByName(body.Provisioner)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", body.Provisioner))
return
}
adm := &mgmt.Admin{
ProvisionerName: body.Provisioner,
ProvisionerID: p.GetID(),
Subject: body.Subject,
Type: body.Type,
Status: mgmt.StatusActive,
Status: status.Active,
}
if err := h.db.CreateAdmin(ctx, adm); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error creating admin"))
return
}
adm.ProvisionerName = p.GetName()
adm.ProvisionerType = p.GetType().String()
api.JSON(w, adm)
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
@ -112,7 +131,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id))
return
}
adm.Status = mgmt.StatusDeleted
adm.Status = status.Deleted
if err := h.db.UpdateAdmin(ctx, adm); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id))
return
@ -135,17 +154,19 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")
adm, err := h.db.GetAdmin(ctx, id)
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id))
adm, ok := h.auth.GetAdminCollection().LoadByID(id)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "admin %s not found", id))
return
}
if adm.Type == body.Type {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "admin %s already has type %s", id, adm.Type))
return
}
// TODO validate
adm.Type = body.Type
if err := h.db.UpdateAdmin(ctx, adm); err != nil {
if err := h.db.UpdateAdmin(ctx, (*mgmt.Admin)(adm)); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id))
return
}

View file

@ -8,6 +8,7 @@ import (
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
)
// CreateProvisionerRequest represents the body for a CreateProvisioner request.
@ -55,7 +56,13 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "name")
prov, err := h.db.GetProvisionerByName(ctx, name)
p, ok := h.auth.GetProvisionerCollection().LoadByName(name)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name))
return
}
prov, err := h.db.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
return
@ -123,6 +130,8 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
c := h.auth.GetAdminCollection()
fmt.Printf("c.Count() = %+v\n", c.Count())
fmt.Printf("c.CountByProvisioner() = %+v\n", c.CountByProvisioner(name))
if c.Count() == c.CountByProvisioner(name) {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType,
"cannot remove provisioner %s because no admins will remain", name))
@ -130,14 +139,18 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
}
ctx := r.Context()
prov, err := h.db.GetProvisionerByName(ctx, name)
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving provisioner %s", name))
p, ok := h.auth.GetProvisionerCollection().LoadByName(name)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name))
return
}
fmt.Printf("prov = %+v\n", prov)
prov.Status = mgmt.StatusDeleted
if err := h.db.UpdateProvisioner(ctx, name, prov); err != nil {
prov, err := h.db.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error loading provisioner %s from db", name))
return
}
prov.Status = status.Deleted
if err := h.db.UpdateProvisioner(ctx, prov); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating provisioner %s", name))
return
}
@ -150,8 +163,8 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
ID: adm.ID,
ProvisionerID: adm.ProvisionerID,
Subject: adm.Subject,
Type: mgmt.AdminType(adm.Type),
Status: mgmt.StatusDeleted,
Type: adm.Type,
Status: status.Deleted,
}); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, name))
return

View file

@ -4,6 +4,7 @@ import (
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
)
// AuthConfig represents the Authority Configuration.
@ -15,7 +16,7 @@ type AuthConfig struct {
Admins []*Admin `json:"-"`
Claims *Claims `json:"claims,omitempty"`
Backdate string `json:"backdate,omitempty"`
Status StatusType `json:"status,omitempty"`
Status status.Type `json:"status,omitempty"`
}
func NewDefaultAuthConfig() *AuthConfig {
@ -23,7 +24,7 @@ func NewDefaultAuthConfig() *AuthConfig {
Claims: NewDefaultClaims(),
ASN1DN: &config.ASN1DN{},
Backdate: config.DefaultBackdate.String(),
Status: StatusActive,
Status: status.Active,
}
}

View file

@ -14,16 +14,6 @@ const (
DefaultAuthorityID = "00000000-0000-0000-0000-000000000000"
)
// StatusType is the type for status.
type StatusType string
var (
// StatusActive active
StatusActive = StatusType("active")
// StatusDeleted deleted
StatusDeleted = StatusType("deleted")
)
// Claims encapsulates all x509 and ssh claims applied to the authority
// configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc.
type Claims struct {

View file

@ -14,9 +14,8 @@ var ErrNotFound = errors.New("not found")
type DB interface {
CreateProvisioner(ctx context.Context, prov *Provisioner) error
GetProvisioner(ctx context.Context, id string) (*Provisioner, error)
GetProvisionerByName(ctx context.Context, name string) (*Provisioner, error)
GetProvisioners(ctx context.Context) ([]*Provisioner, error)
UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error
UpdateProvisioner(ctx context.Context, prov *Provisioner) error
CreateAdmin(ctx context.Context, admin *Admin) error
GetAdmin(ctx context.Context, id string) (*Admin, error)
@ -33,9 +32,8 @@ type DB interface {
type MockDB struct {
MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error
MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error)
MockGetProvisionerByName func(ctx context.Context, name string) (*Provisioner, error)
MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error)
MockUpdateProvisioner func(ctx context.Context, name string, prov *Provisioner) error
MockUpdateProvisioner func(ctx context.Context, prov *Provisioner) error
MockCreateAdmin func(ctx context.Context, adm *Admin) error
MockGetAdmin func(ctx context.Context, id string) (*Admin, error)
@ -70,16 +68,6 @@ func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*Provisioner, e
return m.MockRet1.(*Provisioner), m.MockError
}
// GetProvisionerByName mock.
func (m *MockDB) GetProvisionerByName(ctx context.Context, id string) (*Provisioner, error) {
if m.MockGetProvisionerByName != nil {
return m.MockGetProvisionerByName(ctx, id)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*Provisioner), m.MockError
}
// GetProvisioners mock
func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
if m.MockGetProvisioners != nil {
@ -91,9 +79,9 @@ func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
}
// UpdateProvisioner mock
func (m *MockDB) UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error {
func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *Provisioner) error {
if m.MockUpdateProvisioner != nil {
return m.MockUpdateProvisioner(ctx, name, prov)
return m.MockUpdateProvisioner(ctx, prov)
}
return m.MockError
}

View file

@ -6,7 +6,9 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql"
)
@ -16,7 +18,7 @@ type dbAdmin struct {
AuthorityID string `json:"authorityID"`
ProvisionerID string `json:"provisionerID"`
Subject string `json:"subject"`
Type mgmt.AdminType `json:"type"`
Type admin.Type `json:"type"`
CreatedAt time.Time `json:"createdAt"`
DeletedAt time.Time `json:"deletedAt"`
}
@ -71,10 +73,10 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) {
ProvisionerID: dba.ProvisionerID,
Subject: dba.Subject,
Type: dba.Type,
Status: mgmt.StatusActive,
Status: status.Active,
}
if !dba.DeletedAt.IsZero() {
adm.Status = mgmt.StatusDeleted
adm.Status = status.Deleted
}
return adm, nil
}
@ -89,19 +91,13 @@ func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) {
if err != nil {
return nil, err
}
if adm.Status == mgmt.StatusDeleted {
if adm.Status == status.Deleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID)
}
if adm.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"admin %s is not owned by authority %s", adm.ID, db.authorityID)
}
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
return adm, nil
}
@ -114,34 +110,18 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) {
if err != nil {
return nil, errors.Wrap(err, "error loading admins")
}
var (
provCache = map[string]*mgmt.Provisioner{}
admins []*mgmt.Admin
)
var admins = []*mgmt.Admin{}
for _, entry := range dbEntries {
adm, err := unmarshalAdmin(entry.Value, string(entry.Key))
if err != nil {
return nil, err
}
if adm.Status == mgmt.StatusDeleted {
if adm.Status == status.Deleted {
continue
}
if adm.AuthorityID != db.authorityID {
continue
}
var (
prov *mgmt.Provisioner
ok bool
)
if prov, ok = provCache[adm.ProvisionerID]; !ok {
prov, err = db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
provCache[adm.ProvisionerID] = prov
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
admins = append(admins, adm)
}
return admins, nil
@ -156,24 +136,6 @@ func (db *DB) CreateAdmin(ctx context.Context, adm *mgmt.Admin) error {
}
adm.AuthorityID = db.authorityID
// If provisionerID is set, then use it, otherwise load the provisioner
// to get the name.
if adm.ProvisionerID == "" {
prov, err := db.GetProvisionerByName(ctx, adm.ProvisionerName)
if err != nil {
return err
}
adm.ProvisionerID = prov.ID
adm.ProvisionerType = prov.Type
} else {
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
}
dba := &dbAdmin{
ID: adm.ID,
AuthorityID: db.authorityID,
@ -196,7 +158,7 @@ func (db *DB) UpdateAdmin(ctx context.Context, adm *mgmt.Admin) error {
nu := old.clone()
// If the admin was active but is now deleted ...
if old.DeletedAt.IsZero() && adm.Status == mgmt.StatusDeleted {
if old.DeletedAt.IsZero() && adm.Status == status.Deleted {
nu.DeletedAt = clock.Now()
}
nu.Type = adm.Type

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql"
)
@ -106,7 +107,7 @@ func (db *DB) UpdateAuthConfig(ctx context.Context, ac *mgmt.AuthConfig) error {
nu := old.clone()
// If the authority was active but is now deleted ...
if old.DeletedAt.IsZero() && ac.Status == mgmt.StatusDeleted {
if old.DeletedAt.IsZero() && ac.Status == status.Deleted {
nu.DeletedAt = clock.Now()
}
nu.Claims = ac.Claims

View file

@ -3,7 +3,6 @@ package nosql
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
@ -60,7 +59,6 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
}
}
fmt.Printf("oldB = %+v\n", oldB)
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
switch {

View file

@ -3,13 +3,12 @@ package nosql
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
// dbProvisioner is the database representation of a Provisioner type.
@ -39,20 +38,6 @@ func (dbp *dbProvisioner) clone() *dbProvisioner {
return &u
}
func (db *DB) getProvisionerIDByName(ctx context.Context, name string) (string, error) {
data, err := db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if nosql.IsErrNotFound(err) {
return "", mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name)
} else if err != nil {
return "", mgmt.WrapErrorISE(err, "error loading provisioner %s", name)
}
ni := new(provisionerNameID)
if err := json.Unmarshal(data, ni); err != nil {
return "", mgmt.WrapErrorISE(err, "error unmarshaling provisionerNameID for provisioner %s", name)
}
return ni.ID, nil
}
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(authorityProvisionersTable, []byte(id))
if nosql.IsErrNotFound(err) {
@ -82,25 +67,6 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner,
return dbp, nil
}
func (db *DB) getDBProvisionerByName(ctx context.Context, name string) (*dbProvisioner, error) {
id, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
dbp, err := db.getDBProvisioner(ctx, id)
if err != nil {
return nil, err
}
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", name)
}
if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", name, db.authorityID)
}
return dbp, nil
}
// GetProvisioner retrieves and unmarshals a provisioner from the database.
func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) {
data, err := db.getDBProvisionerBytes(ctx, id)
@ -112,7 +78,7 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner,
if err != nil {
return nil, err
}
if prov.Status == mgmt.StatusDeleted {
if prov.Status == status.Deleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", prov.ID)
}
if prov.AuthorityID != db.authorityID {
@ -122,15 +88,6 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner,
return prov, nil
}
// GetProvisionerByName retrieves a provisioner from the database by name.
func (db *DB) GetProvisionerByName(ctx context.Context, name string) (*mgmt.Provisioner, error) {
p, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
return db.GetProvisioner(ctx, id)
}
func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) {
var dbp = new(dbProvisioner)
if err := json.Unmarshal(data, dbp); err != nil {
@ -157,14 +114,14 @@ func unmarshalProvisioner(data []byte, name string) (*mgmt.Provisioner, error) {
Name: dbp.Name,
Claims: dbp.Claims,
Details: details,
Status: mgmt.StatusActive,
Status: status.Active,
X509Template: dbp.X509Template,
X509TemplateData: dbp.X509TemplateData,
SSHTemplate: dbp.SSHTemplate,
SSHTemplateData: dbp.SSHTemplateData,
}
if !dbp.DeletedAt.IsZero() {
prov.Status = mgmt.StatusDeleted
prov.Status = status.Deleted
}
return prov, nil
}
@ -182,7 +139,7 @@ func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error)
if err != nil {
return nil, err
}
if prov.Status == mgmt.StatusDeleted {
if prov.Status == status.Deleted {
continue
}
if prov.AuthorityID != db.authorityID {
@ -219,37 +176,8 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
SSHTemplateData: prov.SSHTemplateData,
CreatedAt: clock.Now(),
}
dbpBytes, err := json.Marshal(dbp)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameIndex %s", prov.Name)
}
if err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersTable,
Key: []byte(dbp.ID),
Cmd: database.CmpAndSwap,
Value: dbpBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(dbp.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
},
}); err != nil {
if err := db.save(ctx, prov.ID, dbp, nil, "provisioner", authorityProvisionersTable); err != nil {
return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
}
@ -257,22 +185,11 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
}
// UpdateProvisioner saves an updated provisioner to the database.
func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Provisioner) error {
id, err := db.getProvisionerIDByName(ctx, name)
func (db *DB) UpdateProvisioner(ctx context.Context, prov *mgmt.Provisioner) error {
old, err := db.getDBProvisioner(ctx, prov.ID)
if err != nil {
return err
}
prov.ID = id
oldBytes, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return err
}
fmt.Printf("oldBytes = %+v\n", oldBytes)
old, err := unmarshalDBProvisioner(oldBytes, id)
if err != nil {
return err
}
fmt.Printf("old = %+v\n", old)
nu := old.clone()
@ -281,91 +198,15 @@ func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Pro
nu.Claims = prov.Claims
nu.Details, err = json.Marshal(prov.Details)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", name)
return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)
}
nu.X509Template = prov.X509Template
nu.X509TemplateData = prov.X509TemplateData
nu.SSHTemplateData = prov.SSHTemplateData
var txs = []*database.TxEntry{}
// If the provisioner was active but is now deleted ...
if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted {
nu.DeletedAt = clock.Now()
txs = append(txs, &database.TxEntry{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
})
}
if prov.Name != name {
// If the new name does not match the old name then:
// 1) check that the new name is not already taken
// 2) delete the old name-id index resource
// 3) create a new name-id index resource
// 4) update the provisioner resource
nuBytes, err := json.Marshal(nu)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameID for provisioner %s", prov.Name)
}
_, err = db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if err == nil {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", prov.Name)
} else if !nosql.IsErrNotFound(err) {
return mgmt.WrapErrorISE(err, "error loading provisionerNameID %s", prov.Name)
}
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(prov.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
} else {
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
}
if err != nil {
if err := db.save(ctx, prov.ID, nu, old, "provisioner", authorityProvisionersTable); err != nil {
return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name)
}
return nil
}

View file

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
"go.step.sm/crypto/jose"
)
@ -69,7 +70,7 @@ type Provisioner struct {
X509TemplateData []byte `json:"x509TemplateData"`
SSHTemplate string `json:"sshTemplate"`
SSHTemplateData []byte `json:"sshTemplateData"`
Status StatusType `json:"status"`
Status status.Type `json:"status"`
}
func (p *Provisioner) GetOptions() *provisioner.Options {
@ -101,7 +102,7 @@ func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...Pro
X509TemplateData: pc.X509TemplateData,
SSHTemplate: pc.SSHTemplate,
SSHTemplateData: pc.SSHTemplateData,
Status: StatusActive,
Status: status.Active,
}
if err := db.CreateProvisioner(ctx, p); err != nil {

View file

@ -183,7 +183,7 @@ func (c *Collection) Store(p Interface) error {
// Store provisioner always by name.
if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded {
c.byID.Delete(p.GetID())
return errors.New("cannot add multiple provisioners with the same id")
return errors.New("cannot add multiple provisioners with the same name")
}
// Store provisioner in byKey if EncryptedKey is defined.

View file

@ -0,0 +1,11 @@
package status
// Type is the type for status.
type Type string
var (
// Active active
Active = Type("active")
// Deleted deleted
Deleted = Type("deleted")
)

View file

@ -7,8 +7,10 @@ import (
"net/http"
"net/url"
"path"
"strconv"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt"
mgmtAPI "github.com/smallstep/certificates/authority/mgmt/api"
"github.com/smallstep/certificates/errs"
@ -88,6 +90,80 @@ retry:
return adm, nil
}
// AdminOption is the type of options passed to the Provisioner method.
type AdminOption func(o *adminOptions) error
type adminOptions struct {
cursor string
limit int
}
func (o *adminOptions) apply(opts []AdminOption) (err error) {
for _, fn := range opts {
if err = fn(o); err != nil {
return
}
}
return
}
func (o *adminOptions) rawQuery() string {
v := url.Values{}
if len(o.cursor) > 0 {
v.Set("cursor", o.cursor)
}
if o.limit > 0 {
v.Set("limit", strconv.Itoa(o.limit))
}
return v.Encode()
}
// WithAdminCursor will request the admins starting with the given cursor.
func WithAdminCursor(cursor string) AdminOption {
return func(o *adminOptions) error {
o.cursor = cursor
return nil
}
}
// WithAdminLimit will request the given number of admins.
func WithAdminLimit(limit int) AdminOption {
return func(o *adminOptions) error {
o.limit = limit
return nil
}
}
// GetAdmins performs the GET /mgmt/admins request to the CA.
func (c *MgmtClient) GetAdmins(opts ...AdminOption) (*mgmtAPI.GetAdminsResponse, error) {
var retried bool
o := new(adminOptions)
if err := o.apply(opts); err != nil {
return nil, err
}
u := c.endpoint.ResolveReference(&url.URL{
Path: "/mgmt/admins",
RawQuery: o.rawQuery(),
})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var body = new(mgmtAPI.GetAdminsResponse)
if err := readJSON(resp.Body, body); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return body, nil
}
// CreateAdmin performs the POST /mgmt/admin request to the CA.
func (c *MgmtClient) CreateAdmin(req *mgmtAPI.CreateAdminRequest) (*mgmt.Admin, error) {
var retried bool
@ -139,7 +215,7 @@ retry:
}
// UpdateAdmin performs the PUT /mgmt/admin/{id} request to the CA.
func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*mgmt.Admin, error) {
func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*admin.Admin, error) {
var retried bool
body, err := json.Marshal(uar)
if err != nil {
@ -162,36 +238,13 @@ retry:
}
return nil, readMgmtError(resp.Body)
}
var adm = new(mgmt.Admin)
var adm = new(admin.Admin)
if err := readJSON(resp.Body, adm); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return adm, nil
}
// GetAdmins performs the GET /mgmt/admins request to the CA.
func (c *MgmtClient) GetAdmins() ([]*mgmt.Admin, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/mgmt/admins"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var admins = new([]*mgmt.Admin)
if err := readJSON(resp.Body, admins); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return *admins, nil
}
// GetProvisioner performs the GET /mgmt/provisioner/{id} request to the CA.
func (c *MgmtClient) GetProvisioner(id string) (*mgmt.Provisioner, error) {
var retried bool