This commit is contained in:
max furman 2021-05-18 16:50:54 -07:00
parent 5d09d04d14
commit 4f3e5ef64d
18 changed files with 277 additions and 401 deletions

View file

@ -316,7 +316,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
// Provisioners returns the list of provisioners configured in the authority. // Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := parseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, errs.BadRequestErr(err))
return return
@ -427,7 +427,8 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
} }
} }
func parseCursor(r *http.Request) (cursor string, limit int, err error) { // ParseCursor parses the cursor and limit from the request query params.
func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
q := r.URL.Query() q := r.URL.Query()
cursor = q.Get("cursor") cursor = q.Get("cursor")
if v := q.Get("limit"); len(v) > 0 { if v := q.Get("limit"); len(v) > 0 {

View file

@ -1,15 +1,25 @@
package admin package admin
// Type specifies the type of administrator privileges the admin has. import "github.com/smallstep/certificates/authority/status"
// Type specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR
type Type string type Type string
var (
// TypeSuper superadmin
TypeSuper = Type("SUPER_ADMIN")
// TypeRegular regular
TypeRegular = Type("REGULAR")
)
// Admin type. // Admin type.
type Admin struct { type Admin struct {
ID string `json:"id"` ID string `json:"id"`
AuthorityID string `json:"-"` AuthorityID string `json:"-"`
Subject string `json:"subject"` Subject string `json:"subject"`
ProvisionerName string `json:"provisionerName"` ProvisionerName string `json:"provisionerName"`
ProvisionerType string `json:"provisionerType"` ProvisionerType string `json:"provisionerType"`
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Type Type `json:"type"` Type Type `json:"type"`
Status status.Type `json:"status"`
} }

View file

@ -2,56 +2,54 @@ package admin
import ( import (
"crypto/sha1" "crypto/sha1"
"encoding/binary"
"encoding/hex"
"fmt"
"sort"
"strings"
"sync" "sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose" "github.com/smallstep/certificates/authority/provisioner"
) )
// DefaultProvisionersLimit is the default limit for listing provisioners. // DefaultAdminLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20 const DefaultAdminLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners. // DefaultAdminMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100 const DefaultAdminMax = 100
/* type uidAdmin struct {
type uidProvisioner struct { admin *Admin
provisioner Interface uid string
uid string
} }
type provisionerSlice []uidProvisioner type adminSlice []uidAdmin
func (p provisionerSlice) Len() int { return len(p) } func (p adminSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } func (p adminSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
*/
// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
}
// Collection is a memory map of admins. // Collection is a memory map of admins.
type Collection struct { type Collection struct {
byID *sync.Map byID *sync.Map
bySubProv *sync.Map bySubProv *sync.Map
byProv *sync.Map byProv *sync.Map
sorted adminSlice
provisioners *provisioner.Collection
count int count int
countByProvisioner map[string]int countByProvisioner map[string]int
} }
// NewCollection initializes a collection of provisioners. The given list of // NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner. // audiences are the audiences used by the JWT provisioner.
func NewCollection() *Collection { func NewCollection(provisioners *provisioner.Collection) *Collection {
return &Collection{ return &Collection{
byID: new(sync.Map), byID: new(sync.Map),
byProv: new(sync.Map), byProv: new(sync.Map),
bySubProv: new(sync.Map), bySubProv: new(sync.Map),
countByProvisioner: map[string]int{}, countByProvisioner: map[string]int{},
provisioners: provisioners,
} }
} }
@ -88,12 +86,18 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*Admin, bool) {
// Store adds an admin to the collection and enforces the uniqueness of // Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos. // admin IDs and amdin subject <-> provisioner name combos.
func (c *Collection) Store(adm *Admin) error { func (c *Collection) Store(adm *Admin) error {
provName := adm.ProvisionerName p, ok := c.provisioners.Load(adm.ProvisionerID)
if !ok {
return fmt.Errorf("provisioner %s not found", adm.ProvisionerID)
}
adm.ProvisionerName = p.GetName()
adm.ProvisionerType = p.GetType().String()
// Store admin always in byID. ID must be unique. // Store admin always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(adm.ID, adm); loaded { if _, loaded := c.byID.LoadOrStore(adm.ID, adm); loaded {
return errors.New("cannot add multiple admins with the same id") return errors.New("cannot add multiple admins with the same id")
} }
provName := adm.ProvisionerName
// Store admin alwasy in bySubProv. Subject <-> ProvisionerName must be unique. // Store admin alwasy in bySubProv. Subject <-> ProvisionerName must be unique.
if _, loaded := c.bySubProv.LoadOrStore(subProvNameHash(adm.Subject, provName), adm); loaded { if _, loaded := c.bySubProv.LoadOrStore(subProvNameHash(adm.Subject, provName), adm); loaded {
c.byID.Delete(adm.ID) c.byID.Delete(adm.ID)
@ -109,6 +113,21 @@ func (c *Collection) Store(adm *Admin) error {
} }
c.count++ c.count++
// Store sorted admins.
// Use the first 4 bytes (32bit) of the sum to insert the order
// Using big endian format to get the strings sorted:
// 0x00000000, 0x00000001, 0x00000002, ...
bi := make([]byte, 4)
_sum := sha1.Sum([]byte(adm.ID))
sum := _sum[:]
binary.BigEndian.PutUint32(bi, uint32(c.sorted.Len()))
sum[0], sum[1], sum[2], sum[3] = bi[0], bi[1], bi[2], bi[3]
c.sorted = append(c.sorted, uidAdmin{
admin: adm,
uid: hex.EncodeToString(sum),
})
sort.Sort(c.sorted)
return nil return nil
} }
@ -125,23 +144,22 @@ func (c *Collection) CountByProvisioner(provName string) int {
return 0 return 0
} }
/*
// Find implements pagination on a list of sorted provisioners. // Find implements pagination on a list of sorted provisioners.
func (c *Collection) Find(cursor string, limit int) (List, string) { func (c *Collection) Find(cursor string, limit int) ([]*Admin, string) {
switch { switch {
case limit <= 0: case limit <= 0:
limit = DefaultProvisionersLimit limit = DefaultAdminLimit
case limit > DefaultProvisionersMax: case limit > DefaultAdminMax:
limit = DefaultProvisionersMax limit = DefaultAdminMax
} }
n := c.sorted.Len() n := c.sorted.Len()
cursor = fmt.Sprintf("%040s", cursor) cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
slice := List{} slice := []*Admin{}
for ; i < n && len(slice) < limit; i++ { for ; i < n && len(slice) < limit; i++ {
slice = append(slice, c.sorted[i].provisioner) slice = append(slice, c.sorted[i].admin)
} }
if i < n { if i < n {
@ -149,7 +167,6 @@ func (c *Collection) Find(cursor string, limit int) (List, string) {
} }
return slice, "" return slice, ""
} }
*/
func loadAdmin(m *sync.Map, key string) (*Admin, bool) { func loadAdmin(m *sync.Map, key string) (*Admin, bool) {
a, ok := m.Load(key) a, ok := m.Load(key)

View file

@ -175,7 +175,7 @@ func (a *Authority) ReloadAuthConfig() error {
} }
} }
// Store all the admins // Store all the admins
a.admins = admin.NewCollection() a.admins = admin.NewCollection(a.provisioners)
for _, adm := range a.config.AuthorityConfig.Admins { for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil { if err := a.admins.Store(adm); err != nil {
return err return err
@ -431,7 +431,7 @@ func (a *Authority) init() error {
} }
} }
// Store all the admins // Store all the admins
a.admins = admin.NewCollection() a.admins = admin.NewCollection(a.provisioners)
for _, adm := range a.config.AuthorityConfig.Admins { for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil { if err := a.admins.Store(adm); err != nil {
return err return err

View file

@ -1,55 +1,23 @@
package mgmt package mgmt
import ( import (
"context"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
) )
// AdminType specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR // AdminType specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR
type AdminType string type AdminType admin.Type
var ( var (
// AdminTypeSuper superadmin // AdminTypeSuper superadmin
AdminTypeSuper = AdminType("SUPER_ADMIN") AdminTypeSuper = admin.TypeSuper
// AdminTypeRegular regular // AdminTypeRegular regular
AdminTypeRegular = AdminType("REGULAR") AdminTypeRegular = admin.TypeRegular
) )
// Admin type. // Admin type.
type Admin struct { type Admin admin.Admin
ID string `json:"id"`
AuthorityID string `json:"-"`
ProvisionerID string `json:"provisionerID"`
Subject string `json:"subject"`
ProvisionerName string `json:"provisionerName"`
ProvisionerType string `json:"provisionerType"`
Type AdminType `json:"type"`
Status StatusType `json:"status"`
}
// CreateAdmin builds and stores an admin type in the DB.
func CreateAdmin(ctx context.Context, db DB, provName, sub string, typ AdminType) (*Admin, error) {
adm := &Admin{
Subject: sub,
ProvisionerName: provName,
Type: typ,
Status: StatusActive,
}
if err := db.CreateAdmin(ctx, adm); err != nil {
return nil, WrapErrorISE(err, "error creating admin")
}
return adm, nil
}
// ToCertificates converts an Admin to the Admin type expected by the authority. // ToCertificates converts an Admin to the Admin type expected by the authority.
func (adm *Admin) ToCertificates() (*admin.Admin, error) { func (adm *Admin) ToCertificates() (*admin.Admin, error) {
return &admin.Admin{ return (*admin.Admin)(adm), nil
ID: adm.ID,
Subject: adm.Subject,
ProvisionerID: adm.ProvisionerID,
ProvisionerName: adm.ProvisionerName,
ProvisionerType: adm.ProvisionerType,
Type: admin.Type(adm.Type),
}, nil
} }

View file

@ -8,27 +8,34 @@ import (
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
) )
// CreateAdminRequest represents the body for a CreateAdmin request. // CreateAdminRequest represents the body for a CreateAdmin request.
type CreateAdminRequest struct { type CreateAdminRequest struct {
Subject string `json:"subject"` Subject string `json:"subject"`
Provisioner string `json:"provisioner"` Provisioner string `json:"provisioner"`
Type mgmt.AdminType `json:"type"` Type admin.Type `json:"type"`
} }
// Validate validates a new-admin request body. // Validate validates a new-admin request body.
func (car *CreateAdminRequest) Validate(c *admin.Collection) error { func (car *CreateAdminRequest) Validate(c *admin.Collection) error {
if _, ok := c.LoadBySubProv(car.Subject, car.Provisioner); ok { if _, ok := c.LoadBySubProv(car.Subject, car.Provisioner); ok {
return mgmt.NewError(mgmt.ErrorBadRequestType, return mgmt.NewError(mgmt.ErrorBadRequestType,
"admin with subject %s and provisioner name %s already exists", car.Subject, car.Provisioner) "admin with subject: '%s' and provisioner: '%s' already exists", car.Subject, car.Provisioner)
} }
return nil return nil
} }
// GetAdminsResponse for returning a list of admins.
type GetAdminsResponse struct {
Admins []*admin.Admin `json:"admins"`
NextCursor string `json:"nextCursor"`
}
// UpdateAdminRequest represents the body for a UpdateAdmin request. // UpdateAdminRequest represents the body for a UpdateAdmin request.
type UpdateAdminRequest struct { type UpdateAdminRequest struct {
Type mgmt.AdminType `json:"type"` Type admin.Type `json:"type"`
} }
// Validate validates a new-admin request body. // Validate validates a new-admin request body.
@ -43,27 +50,31 @@ type DeleteResponse struct {
// GetAdmin returns the requested admin, or an error. // GetAdmin returns the requested admin, or an error.
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
prov, err := h.db.GetAdmin(ctx, id) adm, ok := h.auth.GetAdminCollection().LoadByID(id)
if err != nil { if !ok {
api.WriteError(w, err) api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType,
"admin %s not found", id))
return return
} }
api.JSON(w, prov) api.JSON(w, adm)
} }
// GetAdmins returns all admins associated with the authority. // GetAdmins returns all admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() cursor, limit, err := api.ParseCursor(r)
admins, err := h.db.GetAdmins(ctx)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, mgmt.WrapError(mgmt.ErrorBadRequestType, err,
"error parsing cursor and limit from query params"))
return return
} }
api.JSON(w, admins)
admins, nextCursor := h.auth.GetAdminCollection().Find(cursor, limit)
api.JSON(w, &GetAdminsResponse{
Admins: admins,
NextCursor: nextCursor,
})
} }
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
@ -81,16 +92,24 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
p, ok := h.auth.GetProvisionerCollection().LoadByName(body.Provisioner)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", body.Provisioner))
return
}
adm := &mgmt.Admin{ adm := &mgmt.Admin{
ProvisionerName: body.Provisioner, ProvisionerID: p.GetID(),
Subject: body.Subject, Subject: body.Subject,
Type: body.Type, Type: body.Type,
Status: mgmt.StatusActive, Status: status.Active,
} }
if err := h.db.CreateAdmin(ctx, adm); err != nil { if err := h.db.CreateAdmin(ctx, adm); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error creating admin")) api.WriteError(w, mgmt.WrapErrorISE(err, "error creating admin"))
return return
} }
adm.ProvisionerName = p.GetName()
adm.ProvisionerType = p.GetType().String()
api.JSON(w, adm) api.JSON(w, adm)
if err := h.auth.ReloadAuthConfig(); err != nil { if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err) fmt.Printf("err = %+v\n", err)
@ -112,7 +131,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id)) api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id))
return return
} }
adm.Status = mgmt.StatusDeleted adm.Status = status.Deleted
if err := h.db.UpdateAdmin(ctx, adm); err != nil { if err := h.db.UpdateAdmin(ctx, adm); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id)) api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id))
return return
@ -135,17 +154,19 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
adm, err := h.db.GetAdmin(ctx, id) adm, ok := h.auth.GetAdminCollection().LoadByID(id)
if err != nil { if !ok {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id)) api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "admin %s not found", id))
return
}
if adm.Type == body.Type {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "admin %s already has type %s", id, adm.Type))
return return
} }
// TODO validate
adm.Type = body.Type adm.Type = body.Type
if err := h.db.UpdateAdmin(ctx, adm); err != nil { if err := h.db.UpdateAdmin(ctx, (*mgmt.Admin)(adm)); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id)) api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id))
return return
} }

View file

@ -8,6 +8,7 @@ import (
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
) )
// CreateProvisionerRequest represents the body for a CreateProvisioner request. // CreateProvisionerRequest represents the body for a CreateProvisioner request.
@ -55,7 +56,13 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
prov, err := h.db.GetProvisionerByName(ctx, name) p, ok := h.auth.GetProvisionerCollection().LoadByName(name)
if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name))
return
}
prov, err := h.db.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -123,6 +130,8 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
c := h.auth.GetAdminCollection() c := h.auth.GetAdminCollection()
fmt.Printf("c.Count() = %+v\n", c.Count())
fmt.Printf("c.CountByProvisioner() = %+v\n", c.CountByProvisioner(name))
if c.Count() == c.CountByProvisioner(name) { if c.Count() == c.CountByProvisioner(name) {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType,
"cannot remove provisioner %s because no admins will remain", name)) "cannot remove provisioner %s because no admins will remain", name))
@ -130,14 +139,18 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
prov, err := h.db.GetProvisionerByName(ctx, name) p, ok := h.auth.GetProvisionerCollection().LoadByName(name)
if err != nil { if !ok {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving provisioner %s", name)) api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name))
return return
} }
fmt.Printf("prov = %+v\n", prov) prov, err := h.db.GetProvisioner(ctx, p.GetID())
prov.Status = mgmt.StatusDeleted if err != nil {
if err := h.db.UpdateProvisioner(ctx, name, prov); err != nil { api.WriteError(w, mgmt.WrapErrorISE(err, "error loading provisioner %s from db", name))
return
}
prov.Status = status.Deleted
if err := h.db.UpdateProvisioner(ctx, prov); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating provisioner %s", name)) api.WriteError(w, mgmt.WrapErrorISE(err, "error updating provisioner %s", name))
return return
} }
@ -150,8 +163,8 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
ID: adm.ID, ID: adm.ID,
ProvisionerID: adm.ProvisionerID, ProvisionerID: adm.ProvisionerID,
Subject: adm.Subject, Subject: adm.Subject,
Type: mgmt.AdminType(adm.Type), Type: adm.Type,
Status: mgmt.StatusDeleted, Status: status.Deleted,
}); err != nil { }); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, name)) api.WriteError(w, mgmt.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, name))
return return

View file

@ -4,6 +4,7 @@ import (
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
) )
// AuthConfig represents the Authority Configuration. // AuthConfig represents the Authority Configuration.
@ -15,7 +16,7 @@ type AuthConfig struct {
Admins []*Admin `json:"-"` Admins []*Admin `json:"-"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Backdate string `json:"backdate,omitempty"` Backdate string `json:"backdate,omitempty"`
Status StatusType `json:"status,omitempty"` Status status.Type `json:"status,omitempty"`
} }
func NewDefaultAuthConfig() *AuthConfig { func NewDefaultAuthConfig() *AuthConfig {
@ -23,7 +24,7 @@ func NewDefaultAuthConfig() *AuthConfig {
Claims: NewDefaultClaims(), Claims: NewDefaultClaims(),
ASN1DN: &config.ASN1DN{}, ASN1DN: &config.ASN1DN{},
Backdate: config.DefaultBackdate.String(), Backdate: config.DefaultBackdate.String(),
Status: StatusActive, Status: status.Active,
} }
} }

View file

@ -14,16 +14,6 @@ const (
DefaultAuthorityID = "00000000-0000-0000-0000-000000000000" DefaultAuthorityID = "00000000-0000-0000-0000-000000000000"
) )
// StatusType is the type for status.
type StatusType string
var (
// StatusActive active
StatusActive = StatusType("active")
// StatusDeleted deleted
StatusDeleted = StatusType("deleted")
)
// Claims encapsulates all x509 and ssh claims applied to the authority // Claims encapsulates all x509 and ssh claims applied to the authority
// configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc. // configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc.
type Claims struct { type Claims struct {

View file

@ -14,9 +14,8 @@ var ErrNotFound = errors.New("not found")
type DB interface { type DB interface {
CreateProvisioner(ctx context.Context, prov *Provisioner) error CreateProvisioner(ctx context.Context, prov *Provisioner) error
GetProvisioner(ctx context.Context, id string) (*Provisioner, error) GetProvisioner(ctx context.Context, id string) (*Provisioner, error)
GetProvisionerByName(ctx context.Context, name string) (*Provisioner, error)
GetProvisioners(ctx context.Context) ([]*Provisioner, error) GetProvisioners(ctx context.Context) ([]*Provisioner, error)
UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error UpdateProvisioner(ctx context.Context, prov *Provisioner) error
CreateAdmin(ctx context.Context, admin *Admin) error CreateAdmin(ctx context.Context, admin *Admin) error
GetAdmin(ctx context.Context, id string) (*Admin, error) GetAdmin(ctx context.Context, id string) (*Admin, error)
@ -31,11 +30,10 @@ type DB interface {
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
// a mock in tests. // a mock in tests.
type MockDB struct { type MockDB struct {
MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error
MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error) MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error)
MockGetProvisionerByName func(ctx context.Context, name string) (*Provisioner, error) MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error)
MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error) MockUpdateProvisioner func(ctx context.Context, prov *Provisioner) error
MockUpdateProvisioner func(ctx context.Context, name string, prov *Provisioner) error
MockCreateAdmin func(ctx context.Context, adm *Admin) error MockCreateAdmin func(ctx context.Context, adm *Admin) error
MockGetAdmin func(ctx context.Context, id string) (*Admin, error) MockGetAdmin func(ctx context.Context, id string) (*Admin, error)
@ -70,16 +68,6 @@ func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*Provisioner, e
return m.MockRet1.(*Provisioner), m.MockError return m.MockRet1.(*Provisioner), m.MockError
} }
// GetProvisionerByName mock.
func (m *MockDB) GetProvisionerByName(ctx context.Context, id string) (*Provisioner, error) {
if m.MockGetProvisionerByName != nil {
return m.MockGetProvisionerByName(ctx, id)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*Provisioner), m.MockError
}
// GetProvisioners mock // GetProvisioners mock
func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) { func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
if m.MockGetProvisioners != nil { if m.MockGetProvisioners != nil {
@ -91,9 +79,9 @@ func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
} }
// UpdateProvisioner mock // UpdateProvisioner mock
func (m *MockDB) UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error { func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *Provisioner) error {
if m.MockUpdateProvisioner != nil { if m.MockUpdateProvisioner != nil {
return m.MockUpdateProvisioner(ctx, name, prov) return m.MockUpdateProvisioner(ctx, prov)
} }
return m.MockError return m.MockError
} }

View file

@ -6,19 +6,21 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
// dbAdmin is the database representation of the Admin type. // dbAdmin is the database representation of the Admin type.
type dbAdmin struct { type dbAdmin struct {
ID string `json:"id"` ID string `json:"id"`
AuthorityID string `json:"authorityID"` AuthorityID string `json:"authorityID"`
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Subject string `json:"subject"` Subject string `json:"subject"`
Type mgmt.AdminType `json:"type"` Type admin.Type `json:"type"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
DeletedAt time.Time `json:"deletedAt"` DeletedAt time.Time `json:"deletedAt"`
} }
func (dbp *dbAdmin) clone() *dbAdmin { func (dbp *dbAdmin) clone() *dbAdmin {
@ -71,10 +73,10 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) {
ProvisionerID: dba.ProvisionerID, ProvisionerID: dba.ProvisionerID,
Subject: dba.Subject, Subject: dba.Subject,
Type: dba.Type, Type: dba.Type,
Status: mgmt.StatusActive, Status: status.Active,
} }
if !dba.DeletedAt.IsZero() { if !dba.DeletedAt.IsZero() {
adm.Status = mgmt.StatusDeleted adm.Status = status.Deleted
} }
return adm, nil return adm, nil
} }
@ -89,19 +91,13 @@ func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if adm.Status == mgmt.StatusDeleted { if adm.Status == status.Deleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID) return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID)
} }
if adm.AuthorityID != db.authorityID { if adm.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"admin %s is not owned by authority %s", adm.ID, db.authorityID) "admin %s is not owned by authority %s", adm.ID, db.authorityID)
} }
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
return adm, nil return adm, nil
} }
@ -114,34 +110,18 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading admins") return nil, errors.Wrap(err, "error loading admins")
} }
var ( var admins = []*mgmt.Admin{}
provCache = map[string]*mgmt.Provisioner{}
admins []*mgmt.Admin
)
for _, entry := range dbEntries { for _, entry := range dbEntries {
adm, err := unmarshalAdmin(entry.Value, string(entry.Key)) adm, err := unmarshalAdmin(entry.Value, string(entry.Key))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if adm.Status == mgmt.StatusDeleted { if adm.Status == status.Deleted {
continue continue
} }
if adm.AuthorityID != db.authorityID { if adm.AuthorityID != db.authorityID {
continue continue
} }
var (
prov *mgmt.Provisioner
ok bool
)
if prov, ok = provCache[adm.ProvisionerID]; !ok {
prov, err = db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
provCache[adm.ProvisionerID] = prov
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
admins = append(admins, adm) admins = append(admins, adm)
} }
return admins, nil return admins, nil
@ -156,24 +136,6 @@ func (db *DB) CreateAdmin(ctx context.Context, adm *mgmt.Admin) error {
} }
adm.AuthorityID = db.authorityID adm.AuthorityID = db.authorityID
// If provisionerID is set, then use it, otherwise load the provisioner
// to get the name.
if adm.ProvisionerID == "" {
prov, err := db.GetProvisionerByName(ctx, adm.ProvisionerName)
if err != nil {
return err
}
adm.ProvisionerID = prov.ID
adm.ProvisionerType = prov.Type
} else {
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
}
dba := &dbAdmin{ dba := &dbAdmin{
ID: adm.ID, ID: adm.ID,
AuthorityID: db.authorityID, AuthorityID: db.authorityID,
@ -196,7 +158,7 @@ func (db *DB) UpdateAdmin(ctx context.Context, adm *mgmt.Admin) error {
nu := old.clone() nu := old.clone()
// If the admin was active but is now deleted ... // If the admin was active but is now deleted ...
if old.DeletedAt.IsZero() && adm.Status == mgmt.StatusDeleted { if old.DeletedAt.IsZero() && adm.Status == status.Deleted {
nu.DeletedAt = clock.Now() nu.DeletedAt = clock.Now()
} }
nu.Type = adm.Type nu.Type = adm.Type

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
@ -106,7 +107,7 @@ func (db *DB) UpdateAuthConfig(ctx context.Context, ac *mgmt.AuthConfig) error {
nu := old.clone() nu := old.clone()
// If the authority was active but is now deleted ... // If the authority was active but is now deleted ...
if old.DeletedAt.IsZero() && ac.Status == mgmt.StatusDeleted { if old.DeletedAt.IsZero() && ac.Status == status.Deleted {
nu.DeletedAt = clock.Now() nu.DeletedAt = clock.Now()
} }
nu.Claims = ac.Claims nu.Claims = ac.Claims

View file

@ -3,7 +3,6 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -60,7 +59,6 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
} }
} }
fmt.Printf("oldB = %+v\n", oldB)
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
switch { switch {

View file

@ -3,13 +3,12 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/status"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
) )
// dbProvisioner is the database representation of a Provisioner type. // dbProvisioner is the database representation of a Provisioner type.
@ -39,20 +38,6 @@ func (dbp *dbProvisioner) clone() *dbProvisioner {
return &u return &u
} }
func (db *DB) getProvisionerIDByName(ctx context.Context, name string) (string, error) {
data, err := db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if nosql.IsErrNotFound(err) {
return "", mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name)
} else if err != nil {
return "", mgmt.WrapErrorISE(err, "error loading provisioner %s", name)
}
ni := new(provisionerNameID)
if err := json.Unmarshal(data, ni); err != nil {
return "", mgmt.WrapErrorISE(err, "error unmarshaling provisionerNameID for provisioner %s", name)
}
return ni.ID, nil
}
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) { func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(authorityProvisionersTable, []byte(id)) data, err := db.db.Get(authorityProvisionersTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
@ -82,25 +67,6 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner,
return dbp, nil return dbp, nil
} }
func (db *DB) getDBProvisionerByName(ctx context.Context, name string) (*dbProvisioner, error) {
id, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
dbp, err := db.getDBProvisioner(ctx, id)
if err != nil {
return nil, err
}
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", name)
}
if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", name, db.authorityID)
}
return dbp, nil
}
// GetProvisioner retrieves and unmarshals a provisioner from the database. // GetProvisioner retrieves and unmarshals a provisioner from the database.
func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) { func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) {
data, err := db.getDBProvisionerBytes(ctx, id) data, err := db.getDBProvisionerBytes(ctx, id)
@ -112,7 +78,7 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if prov.Status == mgmt.StatusDeleted { if prov.Status == status.Deleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", prov.ID) return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", prov.ID)
} }
if prov.AuthorityID != db.authorityID { if prov.AuthorityID != db.authorityID {
@ -122,15 +88,6 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner,
return prov, nil return prov, nil
} }
// GetProvisionerByName retrieves a provisioner from the database by name.
func (db *DB) GetProvisionerByName(ctx context.Context, name string) (*mgmt.Provisioner, error) {
p, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
return db.GetProvisioner(ctx, id)
}
func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) { func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) {
var dbp = new(dbProvisioner) var dbp = new(dbProvisioner)
if err := json.Unmarshal(data, dbp); err != nil { if err := json.Unmarshal(data, dbp); err != nil {
@ -157,14 +114,14 @@ func unmarshalProvisioner(data []byte, name string) (*mgmt.Provisioner, error) {
Name: dbp.Name, Name: dbp.Name,
Claims: dbp.Claims, Claims: dbp.Claims,
Details: details, Details: details,
Status: mgmt.StatusActive, Status: status.Active,
X509Template: dbp.X509Template, X509Template: dbp.X509Template,
X509TemplateData: dbp.X509TemplateData, X509TemplateData: dbp.X509TemplateData,
SSHTemplate: dbp.SSHTemplate, SSHTemplate: dbp.SSHTemplate,
SSHTemplateData: dbp.SSHTemplateData, SSHTemplateData: dbp.SSHTemplateData,
} }
if !dbp.DeletedAt.IsZero() { if !dbp.DeletedAt.IsZero() {
prov.Status = mgmt.StatusDeleted prov.Status = status.Deleted
} }
return prov, nil return prov, nil
} }
@ -182,7 +139,7 @@ func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if prov.Status == mgmt.StatusDeleted { if prov.Status == status.Deleted {
continue continue
} }
if prov.AuthorityID != db.authorityID { if prov.AuthorityID != db.authorityID {
@ -219,37 +176,8 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
SSHTemplateData: prov.SSHTemplateData, SSHTemplateData: prov.SSHTemplateData,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
dbpBytes, err := json.Marshal(dbp)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameIndex %s", prov.Name)
}
if err := db.db.Update(&database.Tx{ if err := db.save(ctx, prov.ID, dbp, nil, "provisioner", authorityProvisionersTable); err != nil {
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersTable,
Key: []byte(dbp.ID),
Cmd: database.CmpAndSwap,
Value: dbpBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(dbp.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
},
}); err != nil {
return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name) return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
} }
@ -257,22 +185,11 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
} }
// UpdateProvisioner saves an updated provisioner to the database. // UpdateProvisioner saves an updated provisioner to the database.
func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Provisioner) error { func (db *DB) UpdateProvisioner(ctx context.Context, prov *mgmt.Provisioner) error {
id, err := db.getProvisionerIDByName(ctx, name) old, err := db.getDBProvisioner(ctx, prov.ID)
if err != nil { if err != nil {
return err return err
} }
prov.ID = id
oldBytes, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return err
}
fmt.Printf("oldBytes = %+v\n", oldBytes)
old, err := unmarshalDBProvisioner(oldBytes, id)
if err != nil {
return err
}
fmt.Printf("old = %+v\n", old)
nu := old.clone() nu := old.clone()
@ -281,91 +198,15 @@ func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Pro
nu.Claims = prov.Claims nu.Claims = prov.Claims
nu.Details, err = json.Marshal(prov.Details) nu.Details, err = json.Marshal(prov.Details)
if err != nil { if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", name) return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)
} }
nu.X509Template = prov.X509Template nu.X509Template = prov.X509Template
nu.X509TemplateData = prov.X509TemplateData nu.X509TemplateData = prov.X509TemplateData
nu.SSHTemplateData = prov.SSHTemplateData nu.SSHTemplateData = prov.SSHTemplateData
var txs = []*database.TxEntry{} if err := db.save(ctx, prov.ID, nu, old, "provisioner", authorityProvisionersTable); err != nil {
// If the provisioner was active but is now deleted ...
if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted {
nu.DeletedAt = clock.Now()
txs = append(txs, &database.TxEntry{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
})
}
if prov.Name != name {
// If the new name does not match the old name then:
// 1) check that the new name is not already taken
// 2) delete the old name-id index resource
// 3) create a new name-id index resource
// 4) update the provisioner resource
nuBytes, err := json.Marshal(nu)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameID for provisioner %s", prov.Name)
}
_, err = db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if err == nil {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", prov.Name)
} else if !nosql.IsErrNotFound(err) {
return mgmt.WrapErrorISE(err, "error loading provisionerNameID %s", prov.Name)
}
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(prov.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
} else {
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
}
if err != nil {
return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name) return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name)
} }
return nil return nil
} }

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/authority/status"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
) )
@ -69,7 +70,7 @@ type Provisioner struct {
X509TemplateData []byte `json:"x509TemplateData"` X509TemplateData []byte `json:"x509TemplateData"`
SSHTemplate string `json:"sshTemplate"` SSHTemplate string `json:"sshTemplate"`
SSHTemplateData []byte `json:"sshTemplateData"` SSHTemplateData []byte `json:"sshTemplateData"`
Status StatusType `json:"status"` Status status.Type `json:"status"`
} }
func (p *Provisioner) GetOptions() *provisioner.Options { func (p *Provisioner) GetOptions() *provisioner.Options {
@ -101,7 +102,7 @@ func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...Pro
X509TemplateData: pc.X509TemplateData, X509TemplateData: pc.X509TemplateData,
SSHTemplate: pc.SSHTemplate, SSHTemplate: pc.SSHTemplate,
SSHTemplateData: pc.SSHTemplateData, SSHTemplateData: pc.SSHTemplateData,
Status: StatusActive, Status: status.Active,
} }
if err := db.CreateProvisioner(ctx, p); err != nil { if err := db.CreateProvisioner(ctx, p); err != nil {

View file

@ -183,7 +183,7 @@ func (c *Collection) Store(p Interface) error {
// Store provisioner always by name. // Store provisioner always by name.
if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded { if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded {
c.byID.Delete(p.GetID()) c.byID.Delete(p.GetID())
return errors.New("cannot add multiple provisioners with the same id") return errors.New("cannot add multiple provisioners with the same name")
} }
// Store provisioner in byKey if EncryptedKey is defined. // Store provisioner in byKey if EncryptedKey is defined.

View file

@ -0,0 +1,11 @@
package status
// Type is the type for status.
type Type string
var (
// Active active
Active = Type("active")
// Deleted deleted
Deleted = Type("deleted")
)

View file

@ -7,8 +7,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"strconv"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
mgmtAPI "github.com/smallstep/certificates/authority/mgmt/api" mgmtAPI "github.com/smallstep/certificates/authority/mgmt/api"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -88,6 +90,80 @@ retry:
return adm, nil return adm, nil
} }
// AdminOption is the type of options passed to the Provisioner method.
type AdminOption func(o *adminOptions) error
type adminOptions struct {
cursor string
limit int
}
func (o *adminOptions) apply(opts []AdminOption) (err error) {
for _, fn := range opts {
if err = fn(o); err != nil {
return
}
}
return
}
func (o *adminOptions) rawQuery() string {
v := url.Values{}
if len(o.cursor) > 0 {
v.Set("cursor", o.cursor)
}
if o.limit > 0 {
v.Set("limit", strconv.Itoa(o.limit))
}
return v.Encode()
}
// WithAdminCursor will request the admins starting with the given cursor.
func WithAdminCursor(cursor string) AdminOption {
return func(o *adminOptions) error {
o.cursor = cursor
return nil
}
}
// WithAdminLimit will request the given number of admins.
func WithAdminLimit(limit int) AdminOption {
return func(o *adminOptions) error {
o.limit = limit
return nil
}
}
// GetAdmins performs the GET /mgmt/admins request to the CA.
func (c *MgmtClient) GetAdmins(opts ...AdminOption) (*mgmtAPI.GetAdminsResponse, error) {
var retried bool
o := new(adminOptions)
if err := o.apply(opts); err != nil {
return nil, err
}
u := c.endpoint.ResolveReference(&url.URL{
Path: "/mgmt/admins",
RawQuery: o.rawQuery(),
})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var body = new(mgmtAPI.GetAdminsResponse)
if err := readJSON(resp.Body, body); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return body, nil
}
// CreateAdmin performs the POST /mgmt/admin request to the CA. // CreateAdmin performs the POST /mgmt/admin request to the CA.
func (c *MgmtClient) CreateAdmin(req *mgmtAPI.CreateAdminRequest) (*mgmt.Admin, error) { func (c *MgmtClient) CreateAdmin(req *mgmtAPI.CreateAdminRequest) (*mgmt.Admin, error) {
var retried bool var retried bool
@ -139,7 +215,7 @@ retry:
} }
// UpdateAdmin performs the PUT /mgmt/admin/{id} request to the CA. // UpdateAdmin performs the PUT /mgmt/admin/{id} request to the CA.
func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*mgmt.Admin, error) { func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*admin.Admin, error) {
var retried bool var retried bool
body, err := json.Marshal(uar) body, err := json.Marshal(uar)
if err != nil { if err != nil {
@ -162,36 +238,13 @@ retry:
} }
return nil, readMgmtError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var adm = new(mgmt.Admin) var adm = new(admin.Admin)
if err := readJSON(resp.Body, adm); err != nil { if err := readJSON(resp.Body, adm); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errors.Wrapf(err, "error reading %s", u)
} }
return adm, nil return adm, nil
} }
// GetAdmins performs the GET /mgmt/admins request to the CA.
func (c *MgmtClient) GetAdmins() ([]*mgmt.Admin, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/mgmt/admins"})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var admins = new([]*mgmt.Admin)
if err := readJSON(resp.Body, admins); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return *admins, nil
}
// GetProvisioner performs the GET /mgmt/provisioner/{id} request to the CA. // GetProvisioner performs the GET /mgmt/provisioner/{id} request to the CA.
func (c *MgmtClient) GetProvisioner(id string) (*mgmt.Provisioner, error) { func (c *MgmtClient) GetProvisioner(id string) (*mgmt.Provisioner, error) {
var retried bool var retried bool