This commit is contained in:
max furman 2021-05-17 21:07:25 -07:00
parent 4d48072746
commit 5d09d04d14
24 changed files with 1005 additions and 262 deletions

View file

@ -8,6 +8,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
@ -18,6 +19,9 @@ func WriteError(w http.ResponseWriter, err error) {
case *acme.Error: case *acme.Error:
acme.WriteError(w, k) acme.WriteError(w, k)
return return
case *mgmt.Error:
mgmt.WriteError(w, k)
return
default: default:
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
} }

View file

@ -1,12 +0,0 @@
package authority
// Admin is the type definining Authority admins. Admins can update Authority
// configuration, provisioners, and even other admins.
type Admin struct {
ID string `json:"-"`
AuthorityID string `json:"-"`
Name string `json:"name"`
Provisioner string `json:"provisioner"`
IsSuperAdmin bool `json:"isSuperAdmin"`
IsDeleted bool `json:"isDeleted"`
}

15
authority/admin/admin.go Normal file
View file

@ -0,0 +1,15 @@
package admin
// Type specifies the type of administrator privileges the admin has.
type Type string
// Admin type.
type Admin struct {
ID string `json:"id"`
AuthorityID string `json:"-"`
Subject string `json:"subject"`
ProvisionerName string `json:"provisionerName"`
ProvisionerType string `json:"provisionerType"`
ProvisionerID string `json:"provisionerID"`
Type Type `json:"type"`
}

View file

@ -0,0 +1,173 @@
package admin
import (
"crypto/sha1"
"sync"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
)
// DefaultProvisionersLimit is the default limit for listing provisioners.
const DefaultProvisionersLimit = 20
// DefaultProvisionersMax is the maximum limit for listing provisioners.
const DefaultProvisionersMax = 100
/*
type uidProvisioner struct {
provisioner Interface
uid string
}
type provisionerSlice []uidProvisioner
func (p provisionerSlice) Len() int { return len(p) }
func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid }
func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
*/
// loadByTokenPayload is a payload used to extract the id used to load the
// provisioner.
type loadByTokenPayload struct {
jose.Claims
AuthorizedParty string `json:"azp"` // OIDC client id
TenantID string `json:"tid"` // Microsoft Azure tenant id
}
// Collection is a memory map of admins.
type Collection struct {
byID *sync.Map
bySubProv *sync.Map
byProv *sync.Map
count int
countByProvisioner map[string]int
}
// NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner.
func NewCollection() *Collection {
return &Collection{
byID: new(sync.Map),
byProv: new(sync.Map),
bySubProv: new(sync.Map),
countByProvisioner: map[string]int{},
}
}
// LoadByID a admin by the ID.
func (c *Collection) LoadByID(id string) (*Admin, bool) {
return loadAdmin(c.byID, id)
}
func subProvNameHash(sub, provName string) string {
subHash := sha1.Sum([]byte(sub))
provNameHash := sha1.Sum([]byte(provName))
_res := sha1.Sum(append(subHash[:], provNameHash[:]...))
return string(_res[:])
}
// LoadBySubProv a admin by the subject and provisioner name.
func (c *Collection) LoadBySubProv(sub, provName string) (*Admin, bool) {
return loadAdmin(c.bySubProv, subProvNameHash(sub, provName))
}
// LoadByProvisioner a admin by the subject and provisioner name.
func (c *Collection) LoadByProvisioner(provName string) ([]*Admin, bool) {
a, ok := c.byProv.Load(provName)
if !ok {
return nil, false
}
admins, ok := a.([]*Admin)
if !ok {
return nil, false
}
return admins, true
}
// Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos.
func (c *Collection) Store(adm *Admin) error {
provName := adm.ProvisionerName
// Store admin always in byID. ID must be unique.
if _, loaded := c.byID.LoadOrStore(adm.ID, adm); loaded {
return errors.New("cannot add multiple admins with the same id")
}
// Store admin alwasy in bySubProv. Subject <-> ProvisionerName must be unique.
if _, loaded := c.bySubProv.LoadOrStore(subProvNameHash(adm.Subject, provName), adm); loaded {
c.byID.Delete(adm.ID)
return errors.New("cannot add multiple admins with the same subject and provisioner")
}
if admins, ok := c.LoadByProvisioner(provName); ok {
c.byProv.Store(provName, append(admins, adm))
c.countByProvisioner[provName]++
} else {
c.byProv.Store(provName, []*Admin{adm})
c.countByProvisioner[provName] = 1
}
c.count++
return nil
}
// Count returns the total number of admins.
func (c *Collection) Count() int {
return c.count
}
// CountByProvisioner returns the total number of admins.
func (c *Collection) CountByProvisioner(provName string) int {
if cnt, ok := c.countByProvisioner[provName]; ok {
return cnt
}
return 0
}
/*
// Find implements pagination on a list of sorted provisioners.
func (c *Collection) Find(cursor string, limit int) (List, string) {
switch {
case limit <= 0:
limit = DefaultProvisionersLimit
case limit > DefaultProvisionersMax:
limit = DefaultProvisionersMax
}
n := c.sorted.Len()
cursor = fmt.Sprintf("%040s", cursor)
i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor })
slice := List{}
for ; i < n && len(slice) < limit; i++ {
slice = append(slice, c.sorted[i].provisioner)
}
if i < n {
return slice, strings.TrimLeft(c.sorted[i].uid, "0")
}
return slice, ""
}
*/
func loadAdmin(m *sync.Map, key string) (*Admin, bool) {
a, ok := m.Load(key)
if !ok {
return nil, false
}
adm, ok := a.(*Admin)
if !ok {
return nil, false
}
return adm, true
}
/*
// provisionerSum returns the SHA1 of the provisioners ID. From this we will
// create the unique and sorted id.
func provisionerSum(p Interface) []byte {
sum := sha1.Sum([]byte(p.GetID()))
return sum[:]
}
*/

View file

@ -13,6 +13,7 @@ import (
"github.com/smallstep/certificates/cas" "github.com/smallstep/certificates/cas"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
authMgmtNosql "github.com/smallstep/certificates/authority/mgmt/db/nosql" authMgmtNosql "github.com/smallstep/certificates/authority/mgmt/db/nosql"
@ -34,6 +35,7 @@ type Authority struct {
mgmtDB mgmt.DB mgmtDB mgmt.DB
keyManager kms.KeyManager keyManager kms.KeyManager
provisioners *provisioner.Collection provisioners *provisioner.Collection
admins *admin.Collection
db db.AuthDB db db.AuthDB
templates *templates.Templates templates *templates.Templates
@ -127,6 +129,61 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
return a, nil return a, nil
} }
func (a *Authority) ReloadAuthConfig() error {
mgmtAuthConfig, err := a.mgmtDB.GetAuthConfig(context.Background(), mgmt.DefaultAuthorityID)
if err != nil {
return mgmt.WrapErrorISE(err, "error getting authConfig from db")
}
a.config.AuthorityConfig, err = mgmtAuthConfig.ToCertificates()
if err != nil {
return mgmt.WrapErrorISE(err, "error converting mgmt authConfig to certificates authConfig")
}
// Merge global and configuration claims
claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims)
if err != nil {
return err
}
// TODO: should we also be combining the ssh federated roots here?
// If we rotate ssh roots keys, sshpop provisioner will lose ability to
// validate old SSH certificates, unless they are added as federated certs.
sshKeys, err := a.GetSSHRoots(context.Background())
if err != nil {
return err
}
// Initialize provisioners
audiences := a.config.GetAudiences()
a.provisioners = provisioner.NewCollection(audiences)
config := provisioner.Config{
Claims: claimer.Claims(),
Audiences: audiences,
DB: a.db,
SSHKeys: &provisioner.SSHKeys{
UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys,
},
GetIdentityFunc: a.getIdentityFunc,
}
// Store all the provisioners
for _, p := range a.config.AuthorityConfig.Provisioners {
if err := p.Init(config); err != nil {
return err
}
if err := a.provisioners.Store(p); err != nil {
return err
}
}
// Store all the admins
a.admins = admin.NewCollection()
for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil {
return err
}
}
return nil
}
// init performs validation and initializes the fields of an Authority struct. // init performs validation and initializes the fields of an Authority struct.
func (a *Authority) init() error { func (a *Authority) init() error {
// Check if handler has already been validated/initialized. // Check if handler has already been validated/initialized.
@ -373,6 +430,13 @@ func (a *Authority) init() error {
return err return err
} }
} }
// Store all the admins
a.admins = admin.NewCollection()
for _, adm := range a.config.AuthorityConfig.Admins {
if err := a.admins.Store(adm); err != nil {
return err
}
}
// Configure templates, currently only ssh templates are supported. // Configure templates, currently only ssh templates are supported.
if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil {
@ -406,6 +470,16 @@ func (a *Authority) GetMgmtDatabase() mgmt.DB {
return a.mgmtDB return a.mgmtDB
} }
// GetAdminCollection returns the admin collection.
func (a *Authority) GetAdminCollection() *admin.Collection {
return a.admins
}
// GetProvisionerCollection returns the admin collection.
func (a *Authority) GetProvisionerCollection() *provisioner.Collection {
return a.provisioners
}
// Shutdown safely shuts down any clients, databases, etc. held by the Authority. // Shutdown safely shuts down any clients, databases, etc. held by the Authority.
func (a *Authority) Shutdown() error { func (a *Authority) Shutdown() error {
if err := a.keyManager.Close(); err != nil { if err := a.keyManager.Close(); err != nil {

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
cas "github.com/smallstep/certificates/cas/apiv1" cas "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -95,6 +96,7 @@ type AuthConfig struct {
*cas.Options *cas.Options
AuthorityID string `json:"authorityID,omitempty"` AuthorityID string `json:"authorityID,omitempty"`
Provisioners provisioner.List `json:"provisioners"` Provisioners provisioner.List `json:"provisioners"`
Admins []*admin.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"` Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`

View file

@ -1,23 +1,39 @@
package mgmt package mgmt
import "context" import (
"context"
"github.com/smallstep/certificates/authority/admin"
)
// AdminType specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR
type AdminType string
var (
// AdminTypeSuper superadmin
AdminTypeSuper = AdminType("SUPER_ADMIN")
// AdminTypeRegular regular
AdminTypeRegular = AdminType("REGULAR")
)
// Admin type. // Admin type.
type Admin struct { type Admin struct {
ID string `json:"id"` ID string `json:"id"`
AuthorityID string `json:"-"` AuthorityID string `json:"-"`
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Name string `json:"name"` Subject string `json:"subject"`
IsSuperAdmin bool `json:"isSuperAdmin"` ProvisionerName string `json:"provisionerName"`
ProvisionerType string `json:"provisionerType"`
Type AdminType `json:"type"`
Status StatusType `json:"status"` Status StatusType `json:"status"`
} }
// CreateAdmin builds and stores an admin type in the DB. // CreateAdmin builds and stores an admin type in the DB.
func CreateAdmin(ctx context.Context, db DB, name string, provID string, isSuperAdmin bool) (*Admin, error) { func CreateAdmin(ctx context.Context, db DB, provName, sub string, typ AdminType) (*Admin, error) {
adm := &Admin{ adm := &Admin{
Name: name, Subject: sub,
ProvisionerID: provID, ProvisionerName: provName,
IsSuperAdmin: isSuperAdmin, Type: typ,
Status: StatusActive, Status: StatusActive,
} }
if err := db.CreateAdmin(ctx, adm); err != nil { if err := db.CreateAdmin(ctx, adm); err != nil {
@ -25,3 +41,15 @@ func CreateAdmin(ctx context.Context, db DB, name string, provID string, isSuper
} }
return adm, nil return adm, nil
} }
// ToCertificates converts an Admin to the Admin type expected by the authority.
func (adm *Admin) ToCertificates() (*admin.Admin, error) {
return &admin.Admin{
ID: adm.ID,
Subject: adm.Subject,
ProvisionerID: adm.ProvisionerID,
ProvisionerName: adm.ProvisionerName,
ProvisionerType: adm.ProvisionerType,
Type: admin.Type(adm.Type),
}, nil
}

View file

@ -1,31 +1,34 @@
package api package api
import ( import (
"fmt"
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
) )
// CreateAdminRequest represents the body for a CreateAdmin request. // CreateAdminRequest represents the body for a CreateAdmin request.
type CreateAdminRequest struct { type CreateAdminRequest struct {
Name string `json:"name"` Subject string `json:"subject"`
ProvisionerID string `json:"provisionerID"` Provisioner string `json:"provisioner"`
IsSuperAdmin bool `json:"isSuperAdmin"` Type mgmt.AdminType `json:"type"`
} }
// Validate validates a new-admin request body. // Validate validates a new-admin request body.
func (car *CreateAdminRequest) Validate() error { func (car *CreateAdminRequest) Validate(c *admin.Collection) error {
if _, ok := c.LoadBySubProv(car.Subject, car.Provisioner); ok {
return mgmt.NewError(mgmt.ErrorBadRequestType,
"admin with subject %s and provisioner name %s already exists", car.Subject, car.Provisioner)
}
return nil return nil
} }
// UpdateAdminRequest represents the body for a UpdateAdmin request. // UpdateAdminRequest represents the body for a UpdateAdmin request.
type UpdateAdminRequest struct { type UpdateAdminRequest struct {
Name string `json:"name"` Type mgmt.AdminType `json:"type"`
ProvisionerID string `json:"provisionerID"`
IsSuperAdmin string `json:"isSuperAdmin"`
Status string `json:"status"`
} }
// Validate validates a new-admin request body. // Validate validates a new-admin request body.
@ -73,12 +76,15 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO validate if err := body.Validate(h.auth.GetAdminCollection()); err != nil {
api.WriteError(w, err)
return
}
adm := &mgmt.Admin{ adm := &mgmt.Admin{
ProvisionerID: body.ProvisionerID, ProvisionerName: body.Provisioner,
Name: body.Name, Subject: body.Subject,
IsSuperAdmin: body.IsSuperAdmin, Type: body.Type,
Status: mgmt.StatusActive, Status: mgmt.StatusActive,
} }
if err := h.db.CreateAdmin(ctx, adm); err != nil { if err := h.db.CreateAdmin(ctx, adm); err != nil {
@ -86,14 +92,21 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
api.JSON(w, adm) api.JSON(w, adm)
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
}
} }
// DeleteAdmin deletes admin. // DeleteAdmin deletes admin.
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if h.auth.GetAdminCollection().Count() == 1 {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "cannot remove last admin"))
return
}
ctx := r.Context()
adm, err := h.db.GetAdmin(ctx, id) adm, err := h.db.GetAdmin(ctx, id)
if err != nil { if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id)) api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id))
@ -105,6 +118,9 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
api.JSON(w, &DeleteResponse{Status: "ok"}) api.JSON(w, &DeleteResponse{Status: "ok"})
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
}
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
@ -127,22 +143,14 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
// TODO validate // TODO validate
if len(body.Name) > 0 { adm.Type = body.Type
adm.Name = body.Name
}
if len(body.Status) > 0 {
adm.Status = mgmt.StatusActive // FIXME
}
// Set IsSuperAdmin iff the string was set in the update request.
if len(body.IsSuperAdmin) > 0 {
adm.IsSuperAdmin = (body.IsSuperAdmin == "true")
}
if len(body.ProvisionerID) > 0 {
adm.ProvisionerID = body.ProvisionerID
}
if err := h.db.UpdateAdmin(ctx, adm); err != nil { if err := h.db.UpdateAdmin(ctx, adm); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id)) api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id))
return return
} }
api.JSON(w, adm) api.JSON(w, adm)
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
}
} }

View file

@ -46,39 +46,6 @@ func (h *Handler) GetAuthConfig(w http.ResponseWriter, r *http.Request) {
api.JSON(w, ac) api.JSON(w, ac)
} }
// CreateAuthConfig creates a new admin.
func (h *Handler) CreateAuthConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var body CreateAuthConfigRequest
if err := api.ReadJSON(r.Body, &body); err != nil {
api.WriteError(w, err)
return
}
if err := body.Validate(); err != nil {
api.WriteError(w, err)
}
ac := &mgmt.AuthConfig{
Status: mgmt.StatusActive,
Backdate: "1m",
}
if body.ASN1DN != nil {
ac.ASN1DN = body.ASN1DN
}
if body.Claims != nil {
ac.Claims = body.Claims
}
if body.Backdate != "" {
ac.Backdate = body.Backdate
}
if err := h.db.CreateAuthConfig(ctx, ac); err != nil {
api.WriteError(w, err)
return
}
api.JSONStatus(w, ac, http.StatusCreated)
}
// UpdateAuthConfig updates an existing AuthConfig. // UpdateAuthConfig updates an existing AuthConfig.
func (h *Handler) UpdateAuthConfig(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAuthConfig(w http.ResponseWriter, r *http.Request) {
/* /*

View file

@ -4,6 +4,7 @@ import (
"time" "time"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
) )
@ -20,31 +21,31 @@ var clock Clock
// Handler is the ACME API request handler. // Handler is the ACME API request handler.
type Handler struct { type Handler struct {
db mgmt.DB db mgmt.DB
auth *authority.Authority
} }
// NewHandler returns a new Authority Config Handler. // NewHandler returns a new Authority Config Handler.
func NewHandler(db mgmt.DB) api.RouterHandler { func NewHandler(db mgmt.DB, auth *authority.Authority) api.RouterHandler {
return &Handler{db} return &Handler{db, auth}
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
// Provisioners // Provisioners
r.MethodFunc("GET", "/provisioner/{id}", h.GetProvisioner) r.MethodFunc("GET", "/provisioner/{name}", h.GetProvisioner)
r.MethodFunc("GET", "/provisioners", h.GetProvisioners) r.MethodFunc("GET", "/provisioners", h.GetProvisioners)
r.MethodFunc("POST", "/provisioner", h.CreateProvisioner) r.MethodFunc("POST", "/provisioner", h.CreateProvisioner)
r.MethodFunc("PUT", "/provisioner/{id}", h.UpdateProvisioner) r.MethodFunc("PUT", "/provisioner/{name}", h.UpdateProvisioner)
//r.MethodFunc("DELETE", "/provisioner/{id}", h.UpdateAdmin) r.MethodFunc("DELETE", "/provisioner/{name}", h.DeleteProvisioner)
// Admins // Admins
r.MethodFunc("GET", "/admin/{id}", h.GetAdmin) r.MethodFunc("GET", "/admin/{id}", h.GetAdmin)
r.MethodFunc("GET", "/admins", h.GetAdmins) r.MethodFunc("GET", "/admins", h.GetAdmins)
r.MethodFunc("POST", "/admin", h.CreateAdmin) r.MethodFunc("POST", "/admin", h.CreateAdmin)
r.MethodFunc("PUT", "/admin/{id}", h.UpdateAdmin) r.MethodFunc("PATCH", "/admin/{id}", h.UpdateAdmin)
r.MethodFunc("DELETE", "/admin/{id}", h.DeleteAdmin) r.MethodFunc("DELETE", "/admin/{id}", h.DeleteAdmin)
// AuthConfig // AuthConfig
r.MethodFunc("GET", "/authconfig/{id}", h.GetAuthConfig) r.MethodFunc("GET", "/authconfig/{id}", h.GetAuthConfig)
r.MethodFunc("POST", "/authconfig", h.CreateAuthConfig)
r.MethodFunc("PUT", "/authconfig/{id}", h.UpdateAuthConfig) r.MethodFunc("PUT", "/authconfig/{id}", h.UpdateAuthConfig)
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/certificates/authority/provisioner"
) )
// CreateProvisionerRequest represents the body for a CreateProvisioner request. // CreateProvisionerRequest represents the body for a CreateProvisioner request.
@ -14,7 +15,7 @@ type CreateProvisionerRequest struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *mgmt.Claims `json:"claims"` Claims *mgmt.Claims `json:"claims"`
Details interface{} `json:"details"` Details []byte `json:"details"`
X509Template string `json:"x509Template"` X509Template string `json:"x509Template"`
X509TemplateData []byte `json:"x509TemplateData"` X509TemplateData []byte `json:"x509TemplateData"`
SSHTemplate string `json:"sshTemplate"` SSHTemplate string `json:"sshTemplate"`
@ -22,31 +23,39 @@ type CreateProvisionerRequest struct {
} }
// Validate validates a new-provisioner request body. // Validate validates a new-provisioner request body.
func (car *CreateProvisionerRequest) Validate() error { func (cpr *CreateProvisionerRequest) Validate(c *provisioner.Collection) error {
if _, ok := c.LoadByName(cpr.Name); ok {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", cpr.Name)
}
return nil return nil
} }
// UpdateProvisionerRequest represents the body for a UpdateProvisioner request. // UpdateProvisionerRequest represents the body for a UpdateProvisioner request.
type UpdateProvisionerRequest struct { type UpdateProvisionerRequest struct {
Type string `json:"type"`
Name string `json:"name"`
Claims *mgmt.Claims `json:"claims"` Claims *mgmt.Claims `json:"claims"`
Details interface{} `json:"details"` Details []byte `json:"details"`
X509Template string `json:"x509Template"` X509Template string `json:"x509Template"`
X509TemplateData []byte `json:"x509TemplateData"` X509TemplateData []byte `json:"x509TemplateData"`
SSHTemplate string `json:"sshTemplate"` SSHTemplate string `json:"sshTemplate"`
SSHTemplateData []byte `json:"sshTemplateData"` SSHTemplateData []byte `json:"sshTemplateData"`
} }
// Validate validates a new-provisioner request body. // Validate validates a update-provisioner request body.
func (uar *UpdateProvisionerRequest) Validate() error { func (upr *UpdateProvisionerRequest) Validate(c *provisioner.Collection) error {
if _, ok := c.LoadByName(upr.Name); ok {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", upr.Name)
}
return nil return nil
} }
// GetProvisioner returns the requested provisioner, or an error. // GetProvisioner returns the requested provisioner, or an error.
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
id := chi.URLParam(r, "id") name := chi.URLParam(r, "name")
prov, err := h.db.GetProvisioner(ctx, id) prov, err := h.db.GetProvisionerByName(ctx, name)
if err != nil { if err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -63,7 +72,6 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
fmt.Printf("provs = %+v\n", provs)
api.JSON(w, provs) api.JSON(w, provs)
} }
@ -76,15 +84,24 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(h.auth.GetProvisionerCollection()); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return
} }
details, err := mgmt.UnmarshalProvisionerDetails(body.Details)
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error unmarshaling provisioner details"))
return
}
claims := mgmt.NewDefaultClaims()
prov := &mgmt.Provisioner{ prov := &mgmt.Provisioner{
Type: body.Type, Type: body.Type,
Name: body.Name, Name: body.Name,
Claims: body.Claims, Claims: claims,
Details: body.Details, Details: details,
X509Template: body.X509Template, X509Template: body.X509Template,
X509TemplateData: body.X509TemplateData, X509TemplateData: body.X509TemplateData,
SSHTemplate: body.SSHTemplate, SSHTemplate: body.SSHTemplate,
@ -95,6 +112,58 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
return return
} }
api.JSONStatus(w, prov, http.StatusCreated) api.JSONStatus(w, prov, http.StatusCreated)
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
}
}
// DeleteProvisioner deletes a provisioner.
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
c := h.auth.GetAdminCollection()
if c.Count() == c.CountByProvisioner(name) {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType,
"cannot remove provisioner %s because no admins will remain", name))
return
}
ctx := r.Context()
prov, err := h.db.GetProvisionerByName(ctx, name)
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving provisioner %s", name))
return
}
fmt.Printf("prov = %+v\n", prov)
prov.Status = mgmt.StatusDeleted
if err := h.db.UpdateProvisioner(ctx, name, prov); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error updating provisioner %s", name))
return
}
// Delete all admins associated with the provisioner.
admins, ok := c.LoadByProvisioner(name)
if ok {
for _, adm := range admins {
if err := h.db.UpdateAdmin(ctx, &mgmt.Admin{
ID: adm.ID,
ProvisionerID: adm.ProvisionerID,
Subject: adm.Subject,
Type: mgmt.AdminType(adm.Type),
Status: mgmt.StatusDeleted,
}); err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, name))
return
}
}
}
api.JSON(w, &DeleteResponse{Status: "ok"})
if err := h.auth.ReloadAuthConfig(); err != nil {
fmt.Printf("err = %+v\n", err)
}
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.

View file

@ -1,6 +1,7 @@
package mgmt package mgmt
import ( import (
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -9,7 +10,7 @@ import (
type AuthConfig struct { type AuthConfig struct {
//*cas.Options `json:"cas"` //*cas.Options `json:"cas"`
ID string `json:"id"` ID string `json:"id"`
ASN1DN *config.ASN1DN `json:"template,omitempty"` ASN1DN *config.ASN1DN `json:"asn1dn,omitempty"`
Provisioners []*Provisioner `json:"-"` Provisioners []*Provisioner `json:"-"`
Admins []*Admin `json:"-"` Admins []*Admin `json:"-"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
@ -46,9 +47,18 @@ func (ac *AuthConfig) ToCertificates() (*config.AuthConfig, error) {
} }
provs = append(provs, authProv) provs = append(provs, authProv)
} }
var admins []*admin.Admin
for _, adm := range ac.Admins {
authAdmin, err := adm.ToCertificates()
if err != nil {
return nil, err
}
admins = append(admins, authAdmin)
}
return &config.AuthConfig{ return &config.AuthConfig{
AuthorityID: ac.ID, AuthorityID: ac.ID,
Provisioners: provs, Provisioners: provs,
Admins: admins,
Template: ac.ASN1DN, Template: ac.ASN1DN,
Claims: claims, Claims: claims,
DisableIssuedAtCheck: false, DisableIssuedAtCheck: false,

View file

@ -2,7 +2,6 @@ package mgmt
import ( import (
"context" "context"
"fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
@ -16,26 +15,15 @@ const (
) )
// StatusType is the type for status. // StatusType is the type for status.
type StatusType int type StatusType string
const ( var (
// StatusActive active // StatusActive active
StatusActive StatusType = iota StatusActive = StatusType("active")
// StatusDeleted deleted // StatusDeleted deleted
StatusDeleted StatusDeleted = StatusType("deleted")
) )
func (st StatusType) String() string {
switch st {
case StatusActive:
return "active"
case StatusDeleted:
return "deleted"
default:
return fmt.Sprintf("status %d not found", st)
}
}
// Claims encapsulates all x509 and ssh claims applied to the authority // Claims encapsulates all x509 and ssh claims applied to the authority
// configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc. // configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc.
type Claims struct { type Claims struct {
@ -123,14 +111,18 @@ func CreateAuthority(ctx context.Context, db DB, options ...AuthorityOption) (*A
return nil, WrapErrorISE(err, "error creating first provisioner") return nil, WrapErrorISE(err, "error creating first provisioner")
} }
admin, err := CreateAdmin(ctx, db, "Change Me", prov.ID, true) adm := &Admin{
if err != nil { ProvisionerID: prov.ID,
Subject: "Change Me",
Type: AdminTypeSuper,
}
if err := db.CreateAdmin(ctx, adm); err != nil {
// TODO should we try to clean up? // TODO should we try to clean up?
return nil, WrapErrorISE(err, "error creating first provisioner") return nil, WrapErrorISE(err, "error creating first admin")
} }
ac.Provisioners = []*Provisioner{prov} ac.Provisioners = []*Provisioner{prov}
ac.Admins = []*Admin{admin} ac.Admins = []*Admin{adm}
return ac, nil return ac, nil
} }

View file

@ -14,8 +14,9 @@ var ErrNotFound = errors.New("not found")
type DB interface { type DB interface {
CreateProvisioner(ctx context.Context, prov *Provisioner) error CreateProvisioner(ctx context.Context, prov *Provisioner) error
GetProvisioner(ctx context.Context, id string) (*Provisioner, error) GetProvisioner(ctx context.Context, id string) (*Provisioner, error)
GetProvisionerByName(ctx context.Context, name string) (*Provisioner, error)
GetProvisioners(ctx context.Context) ([]*Provisioner, error) GetProvisioners(ctx context.Context) ([]*Provisioner, error)
UpdateProvisioner(ctx context.Context, prov *Provisioner) error UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error
CreateAdmin(ctx context.Context, admin *Admin) error CreateAdmin(ctx context.Context, admin *Admin) error
GetAdmin(ctx context.Context, id string) (*Admin, error) GetAdmin(ctx context.Context, id string) (*Admin, error)
@ -32,8 +33,9 @@ type DB interface {
type MockDB struct { type MockDB struct {
MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error
MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error) MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error)
MockGetProvisionerByName func(ctx context.Context, name string) (*Provisioner, error)
MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error) MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error)
MockUpdateProvisioner func(ctx context.Context, prov *Provisioner) error MockUpdateProvisioner func(ctx context.Context, name string, prov *Provisioner) error
MockCreateAdmin func(ctx context.Context, adm *Admin) error MockCreateAdmin func(ctx context.Context, adm *Admin) error
MockGetAdmin func(ctx context.Context, id string) (*Admin, error) MockGetAdmin func(ctx context.Context, id string) (*Admin, error)
@ -68,6 +70,16 @@ func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*Provisioner, e
return m.MockRet1.(*Provisioner), m.MockError return m.MockRet1.(*Provisioner), m.MockError
} }
// GetProvisionerByName mock.
func (m *MockDB) GetProvisionerByName(ctx context.Context, id string) (*Provisioner, error) {
if m.MockGetProvisionerByName != nil {
return m.MockGetProvisionerByName(ctx, id)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*Provisioner), m.MockError
}
// GetProvisioners mock // GetProvisioners mock
func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) { func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
if m.MockGetProvisioners != nil { if m.MockGetProvisioners != nil {
@ -79,9 +91,9 @@ func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) {
} }
// UpdateProvisioner mock // UpdateProvisioner mock
func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *Provisioner) error { func (m *MockDB) UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error {
if m.MockUpdateProvisioner != nil { if m.MockUpdateProvisioner != nil {
return m.MockUpdateProvisioner(ctx, prov) return m.MockUpdateProvisioner(ctx, name, prov)
} }
return m.MockError return m.MockError
} }

View file

@ -15,8 +15,8 @@ type dbAdmin struct {
ID string `json:"id"` ID string `json:"id"`
AuthorityID string `json:"authorityID"` AuthorityID string `json:"authorityID"`
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Name string `json:"name"` Subject string `json:"subject"`
IsSuperAdmin bool `json:"isSuperAdmin"` Type mgmt.AdminType `json:"type"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
DeletedAt time.Time `json:"deletedAt"` DeletedAt time.Time `json:"deletedAt"`
} }
@ -52,27 +52,6 @@ func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) {
return dba, nil return dba, nil
} }
// GetAdmin retrieves and unmarshals a admin from the database.
func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) {
data, err := db.getDBAdminBytes(ctx, id)
if err != nil {
return nil, err
}
adm, err := unmarshalAdmin(data, id)
if err != nil {
return nil, err
}
if adm.Status == mgmt.StatusDeleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID)
}
if adm.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"admin %s is not owned by authority %s", adm.ID, db.authorityID)
}
return adm, nil
}
func unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) { func unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) {
var dba = new(dbAdmin) var dba = new(dbAdmin)
if err := json.Unmarshal(data, dba); err != nil { if err := json.Unmarshal(data, dba); err != nil {
@ -90,8 +69,9 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) {
ID: dba.ID, ID: dba.ID,
AuthorityID: dba.AuthorityID, AuthorityID: dba.AuthorityID,
ProvisionerID: dba.ProvisionerID, ProvisionerID: dba.ProvisionerID,
Name: dba.Name, Subject: dba.Subject,
IsSuperAdmin: dba.IsSuperAdmin, Type: dba.Type,
Status: mgmt.StatusActive,
} }
if !dba.DeletedAt.IsZero() { if !dba.DeletedAt.IsZero() {
adm.Status = mgmt.StatusDeleted adm.Status = mgmt.StatusDeleted
@ -99,6 +79,33 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) {
return adm, nil return adm, nil
} }
// GetAdmin retrieves and unmarshals a admin from the database.
func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) {
data, err := db.getDBAdminBytes(ctx, id)
if err != nil {
return nil, err
}
adm, err := unmarshalAdmin(data, id)
if err != nil {
return nil, err
}
if adm.Status == mgmt.StatusDeleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID)
}
if adm.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"admin %s is not owned by authority %s", adm.ID, db.authorityID)
}
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
return adm, nil
}
// GetAdmins retrieves and unmarshals all active (not deleted) admins // GetAdmins retrieves and unmarshals all active (not deleted) admins
// from the database. // from the database.
// TODO should we be paginating? // TODO should we be paginating?
@ -107,7 +114,10 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading admins") return nil, errors.Wrap(err, "error loading admins")
} }
var admins []*mgmt.Admin var (
provCache = map[string]*mgmt.Provisioner{}
admins []*mgmt.Admin
)
for _, entry := range dbEntries { for _, entry := range dbEntries {
adm, err := unmarshalAdmin(entry.Value, string(entry.Key)) adm, err := unmarshalAdmin(entry.Value, string(entry.Key))
if err != nil { if err != nil {
@ -119,6 +129,19 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) {
if adm.AuthorityID != db.authorityID { if adm.AuthorityID != db.authorityID {
continue continue
} }
var (
prov *mgmt.Provisioner
ok bool
)
if prov, ok = provCache[adm.ProvisionerID]; !ok {
prov, err = db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return nil, err
}
provCache[adm.ProvisionerID] = prov
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
admins = append(admins, adm) admins = append(admins, adm)
} }
return admins, nil return admins, nil
@ -129,16 +152,34 @@ func (db *DB) CreateAdmin(ctx context.Context, adm *mgmt.Admin) error {
var err error var err error
adm.ID, err = randID() adm.ID, err = randID()
if err != nil { if err != nil {
return errors.Wrap(err, "error generating random id for admin") return mgmt.WrapErrorISE(err, "error generating random id for admin")
} }
adm.AuthorityID = db.authorityID adm.AuthorityID = db.authorityID
// If provisionerID is set, then use it, otherwise load the provisioner
// to get the name.
if adm.ProvisionerID == "" {
prov, err := db.GetProvisionerByName(ctx, adm.ProvisionerName)
if err != nil {
return err
}
adm.ProvisionerID = prov.ID
adm.ProvisionerType = prov.Type
} else {
prov, err := db.GetProvisioner(ctx, adm.ProvisionerID)
if err != nil {
return err
}
adm.ProvisionerName = prov.Name
adm.ProvisionerType = prov.Type
}
dba := &dbAdmin{ dba := &dbAdmin{
ID: adm.ID, ID: adm.ID,
AuthorityID: db.authorityID, AuthorityID: db.authorityID,
ProvisionerID: adm.ProvisionerID, ProvisionerID: adm.ProvisionerID,
Name: adm.Name, Subject: adm.Subject,
IsSuperAdmin: adm.IsSuperAdmin, Type: adm.Type,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
@ -158,8 +199,7 @@ func (db *DB) UpdateAdmin(ctx context.Context, adm *mgmt.Admin) error {
if old.DeletedAt.IsZero() && adm.Status == mgmt.StatusDeleted { if old.DeletedAt.IsZero() && adm.Status == mgmt.StatusDeleted {
nu.DeletedAt = clock.Now() nu.DeletedAt = clock.Now()
} }
nu.ProvisionerID = adm.ProvisionerID nu.Type = adm.Type
nu.IsSuperAdmin = adm.IsSuperAdmin
return db.save(ctx, old.ID, nu, old, "admin", authorityAdminsTable) return db.save(ctx, old.ID, nu, old, "admin", authorityAdminsTable)
} }

View file

@ -60,9 +60,14 @@ func (db *DB) GetAuthConfig(ctx context.Context, id string) (*mgmt.AuthConfig, e
if err != nil { if err != nil {
return nil, err return nil, err
} }
admins, err := db.GetAdmins(ctx)
if err != nil {
return nil, err
}
return &mgmt.AuthConfig{ return &mgmt.AuthConfig{
ID: dba.ID, ID: dba.ID,
Admins: admins,
Provisioners: provs, Provisioners: provs,
ASN1DN: dba.ASN1DN, ASN1DN: dba.ASN1DN,
Backdate: dba.Backdate, Backdate: dba.Backdate,

View file

@ -3,6 +3,7 @@ package nosql
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -14,6 +15,7 @@ var (
authorityAdminsTable = []byte("authority_admins") authorityAdminsTable = []byte("authority_admins")
authorityConfigsTable = []byte("authority_configs") authorityConfigsTable = []byte("authority_configs")
authorityProvisionersTable = []byte("authority_provisioners") authorityProvisionersTable = []byte("authority_provisioners")
authorityProvisionersNameIDIndexTable = []byte("authority_provisioners_name_id_index")
) )
// DB is a struct that implements the AcmeDB interface. // DB is a struct that implements the AcmeDB interface.
@ -24,7 +26,7 @@ type DB struct {
// New configures and returns a new Authority DB backend implemented using a nosql DB. // New configures and returns a new Authority DB backend implemented using a nosql DB.
func New(db nosqlDB.DB, authorityID string) (*DB, error) { func New(db nosqlDB.DB, authorityID string) (*DB, error) {
tables := [][]byte{authorityAdminsTable, authorityConfigsTable, authorityProvisionersTable} tables := [][]byte{authorityAdminsTable, authorityConfigsTable, authorityProvisionersTable, authorityProvisionersNameIDIndexTable}
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s", return nil, errors.Wrapf(err, "error creating table %s",
@ -58,6 +60,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old)
} }
} }
fmt.Printf("oldB = %+v\n", oldB)
_, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB)
switch { switch {
@ -73,7 +76,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface
var idLen = 32 var idLen = 32
func randID() (val string, err error) { func randID() (val string, err error) {
val, err = randutil.Alphanumeric(idLen) val, err = randutil.UUIDv4()
if err != nil { if err != nil {
return "", errors.Wrap(err, "error generating random alphanumeric ID") return "", errors.Wrap(err, "error generating random alphanumeric ID")
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
) )
// dbProvisioner is the database representation of a Provisioner type. // dbProvisioner is the database representation of a Provisioner type.
@ -16,6 +17,7 @@ type dbProvisioner struct {
ID string `json:"id"` ID string `json:"id"`
AuthorityID string `json:"authorityID"` AuthorityID string `json:"authorityID"`
Type string `json:"type"` Type string `json:"type"`
// Name is the key
Name string `json:"name"` Name string `json:"name"`
Claims *mgmt.Claims `json:"claims"` Claims *mgmt.Claims `json:"claims"`
Details []byte `json:"details"` Details []byte `json:"details"`
@ -27,11 +29,30 @@ type dbProvisioner struct {
DeletedAt time.Time `json:"deletedAt"` DeletedAt time.Time `json:"deletedAt"`
} }
type provisionerNameID struct {
Name string `json:"name"`
ID string `json:"id"`
}
func (dbp *dbProvisioner) clone() *dbProvisioner { func (dbp *dbProvisioner) clone() *dbProvisioner {
u := *dbp u := *dbp
return &u return &u
} }
func (db *DB) getProvisionerIDByName(ctx context.Context, name string) (string, error) {
data, err := db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if nosql.IsErrNotFound(err) {
return "", mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name)
} else if err != nil {
return "", mgmt.WrapErrorISE(err, "error loading provisioner %s", name)
}
ni := new(provisionerNameID)
if err := json.Unmarshal(data, ni); err != nil {
return "", mgmt.WrapErrorISE(err, "error unmarshaling provisionerNameID for provisioner %s", name)
}
return ni.ID, nil
}
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) { func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(authorityProvisionersTable, []byte(id)) data, err := db.db.Get(authorityProvisionersTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
@ -51,6 +72,9 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", id)
}
if dbp.AuthorityID != db.authorityID { if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", dbp.ID, db.authorityID) "provisioner %s is not owned by authority %s", dbp.ID, db.authorityID)
@ -58,6 +82,25 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner,
return dbp, nil return dbp, nil
} }
func (db *DB) getDBProvisionerByName(ctx context.Context, name string) (*dbProvisioner, error) {
id, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
dbp, err := db.getDBProvisioner(ctx, id)
if err != nil {
return nil, err
}
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", name)
}
if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", name, db.authorityID)
}
return dbp, nil
}
// GetProvisioner retrieves and unmarshals a provisioner from the database. // GetProvisioner retrieves and unmarshals a provisioner from the database.
func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) { func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) {
data, err := db.getDBProvisionerBytes(ctx, id) data, err := db.getDBProvisionerBytes(ctx, id)
@ -79,29 +122,30 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner,
return prov, nil return prov, nil
} }
func unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) { // GetProvisionerByName retrieves a provisioner from the database by name.
func (db *DB) GetProvisionerByName(ctx context.Context, name string) (*mgmt.Provisioner, error) {
p, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
return db.GetProvisioner(ctx, id)
}
func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) {
var dbp = new(dbProvisioner) var dbp = new(dbProvisioner)
if err := json.Unmarshal(data, dbp); err != nil { if err := json.Unmarshal(data, dbp); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id) return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", name)
} }
return dbp, nil return dbp, nil
} }
type detailsType struct { func unmarshalProvisioner(data []byte, name string) (*mgmt.Provisioner, error) {
Type mgmt.ProvisionerType dbp, err := unmarshalDBProvisioner(data, name)
}
func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) {
dbp, err := unmarshalDBProvisioner(data, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dt := new(detailsType) details, err := mgmt.UnmarshalProvisionerDetails(dbp.Details)
if err := json.Unmarshal(dbp.Details, dt); err != nil {
return nil, mgmt.WrapErrorISE(err, "error unmarshaling details to detailsType for provisioner %s", id)
}
details, err := unmarshalDetails(dt.Type, dbp.Details)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,6 +157,7 @@ func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) {
Name: dbp.Name, Name: dbp.Name,
Claims: dbp.Claims, Claims: dbp.Claims,
Details: details, Details: details,
Status: mgmt.StatusActive,
X509Template: dbp.X509Template, X509Template: dbp.X509Template,
X509TemplateData: dbp.X509TemplateData, X509TemplateData: dbp.X509TemplateData,
SSHTemplate: dbp.SSHTemplate, SSHTemplate: dbp.SSHTemplate,
@ -129,7 +174,7 @@ func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) {
func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error) { func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error) {
dbEntries, err := db.db.List(authorityProvisionersTable) dbEntries, err := db.db.List(authorityProvisionersTable)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error loading provisioners") return nil, mgmt.WrapErrorISE(err, "error loading provisioners")
} }
var provs []*mgmt.Provisioner var provs []*mgmt.Provisioner
for _, entry := range dbEntries { for _, entry := range dbEntries {
@ -158,7 +203,7 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
details, err := json.Marshal(prov.Details) details, err := json.Marshal(prov.Details)
if err != nil { if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner") return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name)
} }
dbp := &dbProvisioner{ dbp := &dbProvisioner{
@ -169,65 +214,158 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err
Claims: prov.Claims, Claims: prov.Claims,
Details: details, Details: details,
X509Template: prov.X509Template, X509Template: prov.X509Template,
X509TemplateData: prov.X509TemplateData,
SSHTemplate: prov.SSHTemplate, SSHTemplate: prov.SSHTemplate,
SSHTemplateData: prov.SSHTemplateData,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
dbpBytes, err := json.Marshal(dbp)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameIndex %s", prov.Name)
}
return db.save(ctx, dbp.ID, dbp, nil, "provisioner", authorityProvisionersTable) if err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersTable,
Key: []byte(dbp.ID),
Cmd: database.CmpAndSwap,
Value: dbpBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(dbp.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
},
}); err != nil {
return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
}
return nil
} }
// UpdateProvisioner saves an updated provisioner to the database. // UpdateProvisioner saves an updated provisioner to the database.
func (db *DB) UpdateProvisioner(ctx context.Context, prov *mgmt.Provisioner) error { func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Provisioner) error {
old, err := db.getDBProvisioner(ctx, prov.ID) id, err := db.getProvisionerIDByName(ctx, name)
if err != nil { if err != nil {
return err return err
} }
prov.ID = id
oldBytes, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return err
}
fmt.Printf("oldBytes = %+v\n", oldBytes)
old, err := unmarshalDBProvisioner(oldBytes, id)
if err != nil {
return err
}
fmt.Printf("old = %+v\n", old)
nu := old.clone() nu := old.clone()
nu.Type = prov.Type
nu.Name = prov.Name
nu.Claims = prov.Claims
nu.Details, err = json.Marshal(prov.Details)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", name)
}
nu.X509Template = prov.X509Template
nu.X509TemplateData = prov.X509TemplateData
nu.SSHTemplateData = prov.SSHTemplateData
var txs = []*database.TxEntry{}
// If the provisioner was active but is now deleted ... // If the provisioner was active but is now deleted ...
if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted { if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted {
nu.DeletedAt = clock.Now() nu.DeletedAt = clock.Now()
txs = append(txs, &database.TxEntry{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
})
} }
nu.Claims = prov.Claims
nu.X509Template = prov.X509Template
nu.SSHTemplate = prov.SSHTemplate
nu.Details, err = json.Marshal(prov.Details) if prov.Name != name {
// If the new name does not match the old name then:
// 1) check that the new name is not already taken
// 2) delete the old name-id index resource
// 3) create a new name-id index resource
// 4) update the provisioner resource
nuBytes, err := json.Marshal(nu)
if err != nil { if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner") return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameID for provisioner %s", prov.Name)
} }
return db.save(ctx, old.ID, nu, old, "provisioner", authorityProvisionersTable) _, err = db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if err == nil {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", prov.Name)
} else if !nosql.IsErrNotFound(err) {
return mgmt.WrapErrorISE(err, "error loading provisionerNameID %s", prov.Name)
} }
err = db.db.Update(&database.Tx{
func unmarshalDetails(typ mgmt.ProvisionerType, data []byte) (mgmt.ProvisionerDetails, error) { Operations: []*database.TxEntry{
var v mgmt.ProvisionerDetails {
switch typ { Bucket: authorityProvisionersNameIDIndexTable,
case mgmt.ProvisionerTypeJWK: Key: []byte(name),
v = new(mgmt.ProvisionerDetailsJWK) Cmd: database.Delete,
case mgmt.ProvisionerTypeOIDC: },
v = new(mgmt.ProvisionerDetailsOIDC) {
case mgmt.ProvisionerTypeGCP: Bucket: authorityProvisionersNameIDIndexTable,
v = new(mgmt.ProvisionerDetailsGCP) Key: []byte(prov.Name),
case mgmt.ProvisionerTypeAWS: Cmd: database.CmpAndSwap,
v = new(mgmt.ProvisionerDetailsAWS) Value: pniBytes,
case mgmt.ProvisionerTypeAZURE: CmpValue: nil,
v = new(mgmt.ProvisionerDetailsAzure) },
case mgmt.ProvisionerTypeACME: {
v = new(mgmt.ProvisionerDetailsACME) Bucket: authorityProvisionersTable,
case mgmt.ProvisionerTypeX5C: Key: []byte(nu.ID),
v = new(mgmt.ProvisionerDetailsX5C) Cmd: database.CmpAndSwap,
case mgmt.ProvisionerTypeK8SSA: Value: nuBytes,
v = new(mgmt.ProvisionerDetailsK8SSA) CmpValue: oldBytes,
case mgmt.ProvisionerTypeSSHPOP: },
v = new(mgmt.ProvisionerDetailsSSHPOP) },
default: })
return nil, fmt.Errorf("unsupported provisioner type %s", typ) } else {
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
} }
if err != nil {
if err := json.Unmarshal(data, v); err != nil { return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name)
return nil, err
} }
return v, nil return nil
} }

View file

@ -90,8 +90,7 @@ var (
type Error struct { type Error struct {
Type string `json:"type"` Type string `json:"type"`
Detail string `json:"detail"` Detail string `json:"detail"`
Subproblems []interface{} `json:"subproblems,omitempty"` Message string `json:"message"`
Identifier interface{} `json:"identifier,omitempty"`
Err error `json:"-"` Err error `json:"-"`
Status int `json:"-"` Status int `json:"-"`
} }
@ -160,7 +159,7 @@ func (e *Error) StatusCode() int {
// Error allows AError to implement the error interface. // Error allows AError to implement the error interface.
func (e *Error) Error() string { func (e *Error) Error() string {
return e.Detail return e.Err.Error()
} }
// Cause returns the internal error and implements the Causer interface. // Cause returns the internal error and implements the Causer interface.
@ -182,9 +181,10 @@ func (e *Error) ToLog() (interface{}, error) {
// WriteError writes to w a JSON representation of the given error. // WriteError writes to w a JSON representation of the given error.
func WriteError(w http.ResponseWriter, err *Error) { func WriteError(w http.ResponseWriter, err *Error) {
w.Header().Set("Content-Type", "application/problem+json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(err.StatusCode()) w.WriteHeader(err.StatusCode())
err.Message = err.Err.Error()
// Write errors in the response writer // Write errors in the response writer
if rl, ok := w.(logging.ResponseLogger); ok { if rl, ok := w.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{ rl.WithFields(map[string]interface{}{
@ -199,6 +199,7 @@ func WriteError(w http.ResponseWriter, err *Error) {
} }
} }
fmt.Printf("err = %+v\n", err)
if err := json.NewEncoder(w).Encode(err); err != nil { if err := json.NewEncoder(w).Encode(err); err != nil {
log.Println(err) log.Println(err)
} }

View file

@ -59,8 +59,8 @@ func WithPassword(pass string) func(*ProvisionerCtx) {
// Provisioner type. // Provisioner type.
type Provisioner struct { type Provisioner struct {
ID string `json:"id"` ID string `json:"-"`
AuthorityID string `json:"authorityID"` AuthorityID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *Claims `json:"claims"` Claims *Claims `json:"claims"`
@ -87,8 +87,7 @@ func (p *Provisioner) GetOptions() *provisioner.Options {
func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...ProvisionerOption) (*Provisioner, error) { func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...ProvisionerOption) (*Provisioner, error) {
pc := NewProvisionerCtx(opts...) pc := NewProvisionerCtx(opts...)
details, err := NewProvisionerDetails(ProvisionerType(typ), pc)
details, err := createJWKDetails(pc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,6 +179,27 @@ func (*ProvisionerDetailsK8SSA) isProvisionerDetails() {}
func (*ProvisionerDetailsSSHPOP) isProvisionerDetails() {} func (*ProvisionerDetailsSSHPOP) isProvisionerDetails() {}
func NewProvisionerDetails(typ ProvisionerType, pc *ProvisionerCtx) (ProvisionerDetails, error) {
switch typ {
case ProvisionerTypeJWK:
return createJWKDetails(pc)
/*
case ProvisionerTypeOIDC:
return createOIDCDetails(pc)
case ProvisionerTypeACME:
return createACMEDetails(pc)
case ProvisionerTypeK8SSA:
return createK8SSADetails(pc)
case ProvisionerTypeSSHPOP:
return createSSHPOPDetails(pc)
case ProvisionerTypeX5C:
return createSSHPOPDetails(pc)
*/
default:
return nil, NewErrorISE("unsupported provisioner type %s", typ)
}
}
func createJWKDetails(pc *ProvisionerCtx) (*ProvisionerDetailsJWK, error) { func createJWKDetails(pc *ProvisionerCtx) (*ProvisionerDetailsJWK, error) {
var err error var err error
@ -231,6 +251,7 @@ func (p *Provisioner) ToCertificates() (provisioner.Interface, error) {
return nil, err return nil, err
} }
return &provisioner.JWK{ return &provisioner.JWK{
ID: p.ID,
Type: p.Type, Type: p.Type,
Name: p.Name, Name: p.Name,
Key: jwk, Key: jwk,
@ -386,3 +407,43 @@ func (c *Claims) ToCertificates() (*provisioner.Claims, error) {
EnableSSHCA: &c.SSH.Enabled, EnableSSHCA: &c.SSH.Enabled,
}, nil }, nil
} }
type detailsType struct {
Type ProvisionerType
}
func UnmarshalProvisionerDetails(data []byte) (ProvisionerDetails, error) {
dt := new(detailsType)
if err := json.Unmarshal(data, dt); err != nil {
return nil, WrapErrorISE(err, "error unmarshaling provisioner details")
}
var v ProvisionerDetails
switch dt.Type {
case ProvisionerTypeJWK:
v = new(ProvisionerDetailsJWK)
case ProvisionerTypeOIDC:
v = new(ProvisionerDetailsOIDC)
case ProvisionerTypeGCP:
v = new(ProvisionerDetailsGCP)
case ProvisionerTypeAWS:
v = new(ProvisionerDetailsAWS)
case ProvisionerTypeAZURE:
v = new(ProvisionerDetailsAzure)
case ProvisionerTypeACME:
v = new(ProvisionerDetailsACME)
case ProvisionerTypeX5C:
v = new(ProvisionerDetailsX5C)
case ProvisionerTypeK8SSA:
v = new(ProvisionerDetailsK8SSA)
case ProvisionerTypeSSHPOP:
v = new(ProvisionerDetailsSSHPOP)
default:
return nil, fmt.Errorf("unsupported provisioner type %s", dt.Type)
}
if err := json.Unmarshal(data, v); err != nil {
return nil, err
}
return v, nil
}

View file

@ -45,6 +45,7 @@ type loadByTokenPayload struct {
type Collection struct { type Collection struct {
byID *sync.Map byID *sync.Map
byKey *sync.Map byKey *sync.Map
byName *sync.Map
sorted provisionerSlice sorted provisionerSlice
audiences Audiences audiences Audiences
} }
@ -55,6 +56,7 @@ func NewCollection(audiences Audiences) *Collection {
return &Collection{ return &Collection{
byID: new(sync.Map), byID: new(sync.Map),
byKey: new(sync.Map), byKey: new(sync.Map),
byName: new(sync.Map),
audiences: audiences, audiences: audiences,
} }
} }
@ -64,6 +66,11 @@ func (c *Collection) Load(id string) (Interface, bool) {
return loadProvisioner(c.byID, id) return loadProvisioner(c.byID, id)
} }
// LoadByName a provisioner by name.
func (c *Collection) LoadByName(name string) (Interface, bool) {
return loadProvisioner(c.byName, name)
}
// LoadByToken parses the token claims and loads the provisioner associated. // LoadByToken parses the token claims and loads the provisioner associated.
func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) {
var audiences []string var audiences []string
@ -173,6 +180,11 @@ func (c *Collection) Store(p Interface) error {
if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded { if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded {
return errors.New("cannot add multiple provisioners with the same id") return errors.New("cannot add multiple provisioners with the same id")
} }
// Store provisioner always by name.
if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded {
c.byID.Delete(p.GetID())
return errors.New("cannot add multiple provisioners with the same id")
}
// Store provisioner in byKey if EncryptedKey is defined. // Store provisioner in byKey if EncryptedKey is defined.
if kid, _, ok := p.GetEncryptedKey(); ok { if kid, _, ok := p.GetEncryptedKey(); ok {

View file

@ -28,6 +28,7 @@ type stepPayload struct {
// signature requests. // signature requests.
type JWK struct { type JWK struct {
*base *base
ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Key *jose.JSONWebKey `json:"key"` Key *jose.JSONWebKey `json:"key"`
@ -41,6 +42,9 @@ type JWK struct {
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
// should uniquely identify any JWK provisioner. // should uniquely identify any JWK provisioner.
func (p *JWK) GetID() string { func (p *JWK) GetID() string {
if p.ID != "" {
return p.ID
}
return p.Name + ":" + p.Key.KeyID return p.Name + ":" + p.Key.KeyID
} }

View file

@ -172,10 +172,9 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
}) })
// MGMT Router // MGMT Router
mgmtDB := auth.GetMgmtDatabase() mgmtDB := auth.GetMgmtDatabase()
if mgmtDB != nil { if mgmtDB != nil {
mgmtHandler := mgmtAPI.NewHandler(mgmtDB) mgmtHandler := mgmtAPI.NewHandler(mgmtDB, auth)
mux.Route("/mgmt", func(r chi.Router) { mux.Route("/mgmt", func(r chi.Router) {
mgmtHandler.Route(r) mgmtHandler.Route(r)
}) })

View file

@ -3,6 +3,7 @@ package ca
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"io"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -78,7 +79,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var adm = new(mgmt.Admin) var adm = new(mgmt.Admin)
if err := readJSON(resp.Body, adm); err != nil { if err := readJSON(resp.Body, adm); err != nil {
@ -105,7 +106,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var adm = new(mgmt.Admin) var adm = new(mgmt.Admin)
if err := readJSON(resp.Body, adm); err != nil { if err := readJSON(resp.Body, adm); err != nil {
@ -132,7 +133,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return readError(resp.Body) return readMgmtError(resp.Body)
} }
return nil return nil
} }
@ -145,7 +146,7 @@ func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*m
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
} }
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/admin", id)}) u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/admin", id)})
req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body)) req, err := http.NewRequest("PATCH", u.String(), bytes.NewReader(body))
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -159,7 +160,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var adm = new(mgmt.Admin) var adm = new(mgmt.Admin)
if err := readJSON(resp.Body, adm); err != nil { if err := readJSON(resp.Body, adm); err != nil {
@ -182,7 +183,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var admins = new([]*mgmt.Admin) var admins = new([]*mgmt.Admin)
if err := readJSON(resp.Body, admins); err != nil { if err := readJSON(resp.Body, admins); err != nil {
@ -191,6 +192,29 @@ retry:
return *admins, nil return *admins, nil
} }
// GetProvisioner performs the GET /mgmt/provisioner/{id} request to the CA.
func (c *MgmtClient) GetProvisioner(id string) (*mgmt.Provisioner, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", id)})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var prov = new(mgmt.Provisioner)
if err := readJSON(resp.Body, prov); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return prov, nil
}
// GetProvisioners performs the GET /mgmt/provisioners request to the CA. // GetProvisioners performs the GET /mgmt/provisioners request to the CA.
func (c *MgmtClient) GetProvisioners() ([]*mgmt.Provisioner, error) { func (c *MgmtClient) GetProvisioners() ([]*mgmt.Provisioner, error) {
var retried bool var retried bool
@ -205,7 +229,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body) return nil, readMgmtError(resp.Body)
} }
var provs = new([]*mgmt.Provisioner) var provs = new([]*mgmt.Provisioner)
if err := readJSON(resp.Body, provs); err != nil { if err := readJSON(resp.Body, provs); err != nil {
@ -213,3 +237,116 @@ retry:
} }
return *provs, nil return *provs, nil
} }
// RemoveProvisioner performs the DELETE /mgmt/provisioner/{name} request to the CA.
func (c *MgmtClient) RemoveProvisioner(name string) error {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", name)})
req, err := http.NewRequest("DELETE", u.String(), nil)
if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u)
}
retry:
resp, err := c.client.Do(req)
if err != nil {
return errors.Wrapf(err, "client DELETE %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return readMgmtError(resp.Body)
}
return nil
}
// CreateProvisioner performs the POST /mgmt/provisioner request to the CA.
func (c *MgmtClient) CreateProvisioner(req *mgmtAPI.CreateProvisionerRequest) (*mgmt.Provisioner, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/mgmt/provisioner"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var prov = new(mgmt.Provisioner)
if err := readJSON(resp.Body, prov); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return prov, nil
}
// UpdateProvisioner performs the PUT /mgmt/provisioner/{id} request to the CA.
func (c *MgmtClient) UpdateProvisioner(id string, upr *mgmtAPI.UpdateProvisionerRequest) (*mgmt.Provisioner, error) {
var retried bool
body, err := json.Marshal(upr)
if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", id)})
req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body))
if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u)
}
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "client PUT %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var prov = new(mgmt.Provisioner)
if err := readJSON(resp.Body, prov); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return prov, nil
}
// GetAuthConfig performs the GET /mgmt/authconfig/{id} request to the CA.
func (c *MgmtClient) GetAuthConfig(id string) (*mgmt.AuthConfig, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/authconfig", id)})
retry:
resp, err := c.client.Get(u.String())
if err != nil {
return nil, errors.Wrapf(err, "client GET %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readMgmtError(resp.Body)
}
var ac = new(mgmt.AuthConfig)
if err := readJSON(resp.Body, ac); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u)
}
return ac, nil
}
func readMgmtError(r io.ReadCloser) error {
defer r.Close()
mgmtErr := new(mgmt.Error)
if err := json.NewDecoder(r).Decode(mgmtErr); err != nil {
return err
}
return errors.New(mgmtErr.Message)
}