diff --git a/api/errors.go b/api/errors.go index fa2d6a06..085d05cf 100644 --- a/api/errors.go +++ b/api/errors.go @@ -8,6 +8,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/acme" + "github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/logging" ) @@ -18,6 +19,9 @@ func WriteError(w http.ResponseWriter, err error) { case *acme.Error: acme.WriteError(w, k) return + case *mgmt.Error: + mgmt.WriteError(w, k) + return default: w.Header().Set("Content-Type", "application/json") } diff --git a/authority/admin.go b/authority/admin.go deleted file mode 100644 index 6c95de4f..00000000 --- a/authority/admin.go +++ /dev/null @@ -1,12 +0,0 @@ -package authority - -// Admin is the type definining Authority admins. Admins can update Authority -// configuration, provisioners, and even other admins. -type Admin struct { - ID string `json:"-"` - AuthorityID string `json:"-"` - Name string `json:"name"` - Provisioner string `json:"provisioner"` - IsSuperAdmin bool `json:"isSuperAdmin"` - IsDeleted bool `json:"isDeleted"` -} diff --git a/authority/admin/admin.go b/authority/admin/admin.go new file mode 100644 index 00000000..579538dc --- /dev/null +++ b/authority/admin/admin.go @@ -0,0 +1,15 @@ +package admin + +// Type specifies the type of administrator privileges the admin has. +type Type string + +// Admin type. +type Admin struct { + ID string `json:"id"` + AuthorityID string `json:"-"` + Subject string `json:"subject"` + ProvisionerName string `json:"provisionerName"` + ProvisionerType string `json:"provisionerType"` + ProvisionerID string `json:"provisionerID"` + Type Type `json:"type"` +} diff --git a/authority/admin/collection.go b/authority/admin/collection.go new file mode 100644 index 00000000..87dd63ce --- /dev/null +++ b/authority/admin/collection.go @@ -0,0 +1,173 @@ +package admin + +import ( + "crypto/sha1" + "sync" + + "github.com/pkg/errors" + "go.step.sm/crypto/jose" +) + +// DefaultProvisionersLimit is the default limit for listing provisioners. +const DefaultProvisionersLimit = 20 + +// DefaultProvisionersMax is the maximum limit for listing provisioners. +const DefaultProvisionersMax = 100 + +/* +type uidProvisioner struct { + provisioner Interface + uid string +} + +type provisionerSlice []uidProvisioner + +func (p provisionerSlice) Len() int { return len(p) } +func (p provisionerSlice) Less(i, j int) bool { return p[i].uid < p[j].uid } +func (p provisionerSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } +*/ + +// loadByTokenPayload is a payload used to extract the id used to load the +// provisioner. +type loadByTokenPayload struct { + jose.Claims + AuthorizedParty string `json:"azp"` // OIDC client id + TenantID string `json:"tid"` // Microsoft Azure tenant id +} + +// Collection is a memory map of admins. +type Collection struct { + byID *sync.Map + bySubProv *sync.Map + byProv *sync.Map + count int + countByProvisioner map[string]int +} + +// NewCollection initializes a collection of provisioners. The given list of +// audiences are the audiences used by the JWT provisioner. +func NewCollection() *Collection { + return &Collection{ + byID: new(sync.Map), + byProv: new(sync.Map), + bySubProv: new(sync.Map), + countByProvisioner: map[string]int{}, + } +} + +// LoadByID a admin by the ID. +func (c *Collection) LoadByID(id string) (*Admin, bool) { + return loadAdmin(c.byID, id) +} + +func subProvNameHash(sub, provName string) string { + subHash := sha1.Sum([]byte(sub)) + provNameHash := sha1.Sum([]byte(provName)) + _res := sha1.Sum(append(subHash[:], provNameHash[:]...)) + return string(_res[:]) +} + +// LoadBySubProv a admin by the subject and provisioner name. +func (c *Collection) LoadBySubProv(sub, provName string) (*Admin, bool) { + return loadAdmin(c.bySubProv, subProvNameHash(sub, provName)) +} + +// LoadByProvisioner a admin by the subject and provisioner name. +func (c *Collection) LoadByProvisioner(provName string) ([]*Admin, bool) { + a, ok := c.byProv.Load(provName) + if !ok { + return nil, false + } + admins, ok := a.([]*Admin) + if !ok { + return nil, false + } + return admins, true +} + +// Store adds an admin to the collection and enforces the uniqueness of +// admin IDs and amdin subject <-> provisioner name combos. +func (c *Collection) Store(adm *Admin) error { + provName := adm.ProvisionerName + // Store admin always in byID. ID must be unique. + if _, loaded := c.byID.LoadOrStore(adm.ID, adm); loaded { + return errors.New("cannot add multiple admins with the same id") + } + + // Store admin alwasy in bySubProv. Subject <-> ProvisionerName must be unique. + if _, loaded := c.bySubProv.LoadOrStore(subProvNameHash(adm.Subject, provName), adm); loaded { + c.byID.Delete(adm.ID) + return errors.New("cannot add multiple admins with the same subject and provisioner") + } + + if admins, ok := c.LoadByProvisioner(provName); ok { + c.byProv.Store(provName, append(admins, adm)) + c.countByProvisioner[provName]++ + } else { + c.byProv.Store(provName, []*Admin{adm}) + c.countByProvisioner[provName] = 1 + } + c.count++ + + return nil +} + +// Count returns the total number of admins. +func (c *Collection) Count() int { + return c.count +} + +// CountByProvisioner returns the total number of admins. +func (c *Collection) CountByProvisioner(provName string) int { + if cnt, ok := c.countByProvisioner[provName]; ok { + return cnt + } + return 0 +} + +/* +// Find implements pagination on a list of sorted provisioners. +func (c *Collection) Find(cursor string, limit int) (List, string) { + switch { + case limit <= 0: + limit = DefaultProvisionersLimit + case limit > DefaultProvisionersMax: + limit = DefaultProvisionersMax + } + + n := c.sorted.Len() + cursor = fmt.Sprintf("%040s", cursor) + i := sort.Search(n, func(i int) bool { return c.sorted[i].uid >= cursor }) + + slice := List{} + for ; i < n && len(slice) < limit; i++ { + slice = append(slice, c.sorted[i].provisioner) + } + + if i < n { + return slice, strings.TrimLeft(c.sorted[i].uid, "0") + } + return slice, "" +} +*/ + +func loadAdmin(m *sync.Map, key string) (*Admin, bool) { + a, ok := m.Load(key) + if !ok { + return nil, false + } + adm, ok := a.(*Admin) + if !ok { + return nil, false + } + return adm, true +} + +/* +// provisionerSum returns the SHA1 of the provisioners ID. From this we will +// create the unique and sorted id. +func provisionerSum(p Interface) []byte { + sum := sha1.Sum([]byte(p.GetID())) + return sum[:] +} +*/ diff --git a/authority/authority.go b/authority/authority.go index ee82eb13..2da5b341 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -13,6 +13,7 @@ import ( "github.com/smallstep/certificates/cas" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/mgmt" authMgmtNosql "github.com/smallstep/certificates/authority/mgmt/db/nosql" @@ -34,6 +35,7 @@ type Authority struct { mgmtDB mgmt.DB keyManager kms.KeyManager provisioners *provisioner.Collection + admins *admin.Collection db db.AuthDB templates *templates.Templates @@ -127,6 +129,61 @@ func NewEmbedded(opts ...Option) (*Authority, error) { return a, nil } +func (a *Authority) ReloadAuthConfig() error { + mgmtAuthConfig, err := a.mgmtDB.GetAuthConfig(context.Background(), mgmt.DefaultAuthorityID) + if err != nil { + return mgmt.WrapErrorISE(err, "error getting authConfig from db") + } + + a.config.AuthorityConfig, err = mgmtAuthConfig.ToCertificates() + if err != nil { + return mgmt.WrapErrorISE(err, "error converting mgmt authConfig to certificates authConfig") + } + + // Merge global and configuration claims + claimer, err := provisioner.NewClaimer(a.config.AuthorityConfig.Claims, config.GlobalProvisionerClaims) + if err != nil { + return err + } + // TODO: should we also be combining the ssh federated roots here? + // If we rotate ssh roots keys, sshpop provisioner will lose ability to + // validate old SSH certificates, unless they are added as federated certs. + sshKeys, err := a.GetSSHRoots(context.Background()) + if err != nil { + return err + } + // Initialize provisioners + audiences := a.config.GetAudiences() + a.provisioners = provisioner.NewCollection(audiences) + config := provisioner.Config{ + Claims: claimer.Claims(), + Audiences: audiences, + DB: a.db, + SSHKeys: &provisioner.SSHKeys{ + UserKeys: sshKeys.UserKeys, + HostKeys: sshKeys.HostKeys, + }, + GetIdentityFunc: a.getIdentityFunc, + } + // Store all the provisioners + for _, p := range a.config.AuthorityConfig.Provisioners { + if err := p.Init(config); err != nil { + return err + } + if err := a.provisioners.Store(p); err != nil { + return err + } + } + // Store all the admins + a.admins = admin.NewCollection() + for _, adm := range a.config.AuthorityConfig.Admins { + if err := a.admins.Store(adm); err != nil { + return err + } + } + return nil +} + // init performs validation and initializes the fields of an Authority struct. func (a *Authority) init() error { // Check if handler has already been validated/initialized. @@ -373,6 +430,13 @@ func (a *Authority) init() error { return err } } + // Store all the admins + a.admins = admin.NewCollection() + for _, adm := range a.config.AuthorityConfig.Admins { + if err := a.admins.Store(adm); err != nil { + return err + } + } // Configure templates, currently only ssh templates are supported. if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { @@ -406,6 +470,16 @@ func (a *Authority) GetMgmtDatabase() mgmt.DB { return a.mgmtDB } +// GetAdminCollection returns the admin collection. +func (a *Authority) GetAdminCollection() *admin.Collection { + return a.admins +} + +// GetProvisionerCollection returns the admin collection. +func (a *Authority) GetProvisionerCollection() *provisioner.Collection { + return a.provisioners +} + // Shutdown safely shuts down any clients, databases, etc. held by the Authority. func (a *Authority) Shutdown() error { if err := a.keyManager.Close(); err != nil { diff --git a/authority/config/config.go b/authority/config/config.go index 66b8bbe0..9fbf18e0 100644 --- a/authority/config/config.go +++ b/authority/config/config.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/provisioner" cas "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" @@ -95,6 +96,7 @@ type AuthConfig struct { *cas.Options AuthorityID string `json:"authorityID,omitempty"` Provisioners provisioner.List `json:"provisioners"` + Admins []*admin.Admin `json:"-"` Template *ASN1DN `json:"template,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` diff --git a/authority/mgmt/admin.go b/authority/mgmt/admin.go index 9ceabe93..f21abdce 100644 --- a/authority/mgmt/admin.go +++ b/authority/mgmt/admin.go @@ -1,27 +1,55 @@ package mgmt -import "context" +import ( + "context" + + "github.com/smallstep/certificates/authority/admin" +) + +// AdminType specifies the type of the admin. e.g. SUPER_ADMIN, REGULAR +type AdminType string + +var ( + // AdminTypeSuper superadmin + AdminTypeSuper = AdminType("SUPER_ADMIN") + // AdminTypeRegular regular + AdminTypeRegular = AdminType("REGULAR") +) // Admin type. type Admin struct { - ID string `json:"id"` - AuthorityID string `json:"-"` - ProvisionerID string `json:"provisionerID"` - Name string `json:"name"` - IsSuperAdmin bool `json:"isSuperAdmin"` - Status StatusType `json:"status"` + ID string `json:"id"` + AuthorityID string `json:"-"` + ProvisionerID string `json:"provisionerID"` + Subject string `json:"subject"` + ProvisionerName string `json:"provisionerName"` + ProvisionerType string `json:"provisionerType"` + Type AdminType `json:"type"` + Status StatusType `json:"status"` } // CreateAdmin builds and stores an admin type in the DB. -func CreateAdmin(ctx context.Context, db DB, name string, provID string, isSuperAdmin bool) (*Admin, error) { +func CreateAdmin(ctx context.Context, db DB, provName, sub string, typ AdminType) (*Admin, error) { adm := &Admin{ - Name: name, - ProvisionerID: provID, - IsSuperAdmin: isSuperAdmin, - Status: StatusActive, + Subject: sub, + ProvisionerName: provName, + Type: typ, + Status: StatusActive, } if err := db.CreateAdmin(ctx, adm); err != nil { return nil, WrapErrorISE(err, "error creating admin") } return adm, nil } + +// ToCertificates converts an Admin to the Admin type expected by the authority. +func (adm *Admin) ToCertificates() (*admin.Admin, error) { + return &admin.Admin{ + ID: adm.ID, + Subject: adm.Subject, + ProvisionerID: adm.ProvisionerID, + ProvisionerName: adm.ProvisionerName, + ProvisionerType: adm.ProvisionerType, + Type: admin.Type(adm.Type), + }, nil +} diff --git a/authority/mgmt/api/admin.go b/authority/mgmt/api/admin.go index ae60b75a..f997095a 100644 --- a/authority/mgmt/api/admin.go +++ b/authority/mgmt/api/admin.go @@ -1,31 +1,34 @@ package api import ( + "fmt" "net/http" "github.com/go-chi/chi" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/mgmt" ) // CreateAdminRequest represents the body for a CreateAdmin request. type CreateAdminRequest struct { - Name string `json:"name"` - ProvisionerID string `json:"provisionerID"` - IsSuperAdmin bool `json:"isSuperAdmin"` + Subject string `json:"subject"` + Provisioner string `json:"provisioner"` + Type mgmt.AdminType `json:"type"` } // Validate validates a new-admin request body. -func (car *CreateAdminRequest) Validate() error { +func (car *CreateAdminRequest) Validate(c *admin.Collection) error { + if _, ok := c.LoadBySubProv(car.Subject, car.Provisioner); ok { + return mgmt.NewError(mgmt.ErrorBadRequestType, + "admin with subject %s and provisioner name %s already exists", car.Subject, car.Provisioner) + } return nil } // UpdateAdminRequest represents the body for a UpdateAdmin request. type UpdateAdminRequest struct { - Name string `json:"name"` - ProvisionerID string `json:"provisionerID"` - IsSuperAdmin string `json:"isSuperAdmin"` - Status string `json:"status"` + Type mgmt.AdminType `json:"type"` } // Validate validates a new-admin request body. @@ -73,27 +76,37 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { return } - // TODO validate + if err := body.Validate(h.auth.GetAdminCollection()); err != nil { + api.WriteError(w, err) + return + } adm := &mgmt.Admin{ - ProvisionerID: body.ProvisionerID, - Name: body.Name, - IsSuperAdmin: body.IsSuperAdmin, - Status: mgmt.StatusActive, + ProvisionerName: body.Provisioner, + Subject: body.Subject, + Type: body.Type, + Status: mgmt.StatusActive, } if err := h.db.CreateAdmin(ctx, adm); err != nil { api.WriteError(w, mgmt.WrapErrorISE(err, "error creating admin")) return } api.JSON(w, adm) + if err := h.auth.ReloadAuthConfig(); err != nil { + fmt.Printf("err = %+v\n", err) + } } // DeleteAdmin deletes admin. func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - id := chi.URLParam(r, "id") + if h.auth.GetAdminCollection().Count() == 1 { + api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "cannot remove last admin")) + return + } + + ctx := r.Context() adm, err := h.db.GetAdmin(ctx, id) if err != nil { api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving admin %s", id)) @@ -105,6 +118,9 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { return } api.JSON(w, &DeleteResponse{Status: "ok"}) + if err := h.auth.ReloadAuthConfig(); err != nil { + fmt.Printf("err = %+v\n", err) + } } // UpdateAdmin updates an existing admin. @@ -127,22 +143,14 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { // TODO validate - if len(body.Name) > 0 { - adm.Name = body.Name - } - if len(body.Status) > 0 { - adm.Status = mgmt.StatusActive // FIXME - } - // Set IsSuperAdmin iff the string was set in the update request. - if len(body.IsSuperAdmin) > 0 { - adm.IsSuperAdmin = (body.IsSuperAdmin == "true") - } - if len(body.ProvisionerID) > 0 { - adm.ProvisionerID = body.ProvisionerID - } + adm.Type = body.Type + if err := h.db.UpdateAdmin(ctx, adm); err != nil { api.WriteError(w, mgmt.WrapErrorISE(err, "error updating admin %s", id)) return } api.JSON(w, adm) + if err := h.auth.ReloadAuthConfig(); err != nil { + fmt.Printf("err = %+v\n", err) + } } diff --git a/authority/mgmt/api/authConfig.go b/authority/mgmt/api/authConfig.go index 283a4b66..7d5f7afb 100644 --- a/authority/mgmt/api/authConfig.go +++ b/authority/mgmt/api/authConfig.go @@ -46,39 +46,6 @@ func (h *Handler) GetAuthConfig(w http.ResponseWriter, r *http.Request) { api.JSON(w, ac) } -// CreateAuthConfig creates a new admin. -func (h *Handler) CreateAuthConfig(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - var body CreateAuthConfigRequest - if err := api.ReadJSON(r.Body, &body); err != nil { - api.WriteError(w, err) - return - } - if err := body.Validate(); err != nil { - api.WriteError(w, err) - } - - ac := &mgmt.AuthConfig{ - Status: mgmt.StatusActive, - Backdate: "1m", - } - if body.ASN1DN != nil { - ac.ASN1DN = body.ASN1DN - } - if body.Claims != nil { - ac.Claims = body.Claims - } - if body.Backdate != "" { - ac.Backdate = body.Backdate - } - if err := h.db.CreateAuthConfig(ctx, ac); err != nil { - api.WriteError(w, err) - return - } - api.JSONStatus(w, ac, http.StatusCreated) -} - // UpdateAuthConfig updates an existing AuthConfig. func (h *Handler) UpdateAuthConfig(w http.ResponseWriter, r *http.Request) { /* diff --git a/authority/mgmt/api/handler.go b/authority/mgmt/api/handler.go index 778cdaea..cb52736f 100644 --- a/authority/mgmt/api/handler.go +++ b/authority/mgmt/api/handler.go @@ -4,6 +4,7 @@ import ( "time" "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/mgmt" ) @@ -19,32 +20,32 @@ var clock Clock // Handler is the ACME API request handler. type Handler struct { - db mgmt.DB + db mgmt.DB + auth *authority.Authority } // NewHandler returns a new Authority Config Handler. -func NewHandler(db mgmt.DB) api.RouterHandler { - return &Handler{db} +func NewHandler(db mgmt.DB, auth *authority.Authority) api.RouterHandler { + return &Handler{db, auth} } // Route traffic and implement the Router interface. func (h *Handler) Route(r api.Router) { // Provisioners - r.MethodFunc("GET", "/provisioner/{id}", h.GetProvisioner) + r.MethodFunc("GET", "/provisioner/{name}", h.GetProvisioner) r.MethodFunc("GET", "/provisioners", h.GetProvisioners) r.MethodFunc("POST", "/provisioner", h.CreateProvisioner) - r.MethodFunc("PUT", "/provisioner/{id}", h.UpdateProvisioner) - //r.MethodFunc("DELETE", "/provisioner/{id}", h.UpdateAdmin) + r.MethodFunc("PUT", "/provisioner/{name}", h.UpdateProvisioner) + r.MethodFunc("DELETE", "/provisioner/{name}", h.DeleteProvisioner) // Admins r.MethodFunc("GET", "/admin/{id}", h.GetAdmin) r.MethodFunc("GET", "/admins", h.GetAdmins) r.MethodFunc("POST", "/admin", h.CreateAdmin) - r.MethodFunc("PUT", "/admin/{id}", h.UpdateAdmin) + r.MethodFunc("PATCH", "/admin/{id}", h.UpdateAdmin) r.MethodFunc("DELETE", "/admin/{id}", h.DeleteAdmin) // AuthConfig r.MethodFunc("GET", "/authconfig/{id}", h.GetAuthConfig) - r.MethodFunc("POST", "/authconfig", h.CreateAuthConfig) r.MethodFunc("PUT", "/authconfig/{id}", h.UpdateAuthConfig) } diff --git a/authority/mgmt/api/provisioner.go b/authority/mgmt/api/provisioner.go index b6b4c1c7..8a8da08f 100644 --- a/authority/mgmt/api/provisioner.go +++ b/authority/mgmt/api/provisioner.go @@ -7,6 +7,7 @@ import ( "github.com/go-chi/chi" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/mgmt" + "github.com/smallstep/certificates/authority/provisioner" ) // CreateProvisionerRequest represents the body for a CreateProvisioner request. @@ -14,7 +15,7 @@ type CreateProvisionerRequest struct { Type string `json:"type"` Name string `json:"name"` Claims *mgmt.Claims `json:"claims"` - Details interface{} `json:"details"` + Details []byte `json:"details"` X509Template string `json:"x509Template"` X509TemplateData []byte `json:"x509TemplateData"` SSHTemplate string `json:"sshTemplate"` @@ -22,31 +23,39 @@ type CreateProvisionerRequest struct { } // Validate validates a new-provisioner request body. -func (car *CreateProvisionerRequest) Validate() error { +func (cpr *CreateProvisionerRequest) Validate(c *provisioner.Collection) error { + if _, ok := c.LoadByName(cpr.Name); ok { + return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", cpr.Name) + } return nil } // UpdateProvisionerRequest represents the body for a UpdateProvisioner request. type UpdateProvisionerRequest struct { + Type string `json:"type"` + Name string `json:"name"` Claims *mgmt.Claims `json:"claims"` - Details interface{} `json:"details"` + Details []byte `json:"details"` X509Template string `json:"x509Template"` X509TemplateData []byte `json:"x509TemplateData"` SSHTemplate string `json:"sshTemplate"` SSHTemplateData []byte `json:"sshTemplateData"` } -// Validate validates a new-provisioner request body. -func (uar *UpdateProvisionerRequest) Validate() error { +// Validate validates a update-provisioner request body. +func (upr *UpdateProvisionerRequest) Validate(c *provisioner.Collection) error { + if _, ok := c.LoadByName(upr.Name); ok { + return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", upr.Name) + } return nil } // GetProvisioner returns the requested provisioner, or an error. func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - id := chi.URLParam(r, "id") + name := chi.URLParam(r, "name") - prov, err := h.db.GetProvisioner(ctx, id) + prov, err := h.db.GetProvisionerByName(ctx, name) if err != nil { api.WriteError(w, err) return @@ -63,7 +72,6 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - fmt.Printf("provs = %+v\n", provs) api.JSON(w, provs) } @@ -76,15 +84,24 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { api.WriteError(w, err) return } - if err := body.Validate(); err != nil { + if err := body.Validate(h.auth.GetProvisionerCollection()); err != nil { api.WriteError(w, err) + return } + details, err := mgmt.UnmarshalProvisionerDetails(body.Details) + if err != nil { + api.WriteError(w, mgmt.WrapErrorISE(err, "error unmarshaling provisioner details")) + return + } + + claims := mgmt.NewDefaultClaims() + prov := &mgmt.Provisioner{ Type: body.Type, Name: body.Name, - Claims: body.Claims, - Details: body.Details, + Claims: claims, + Details: details, X509Template: body.X509Template, X509TemplateData: body.X509TemplateData, SSHTemplate: body.SSHTemplate, @@ -95,6 +112,58 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { return } api.JSONStatus(w, prov, http.StatusCreated) + + if err := h.auth.ReloadAuthConfig(); err != nil { + fmt.Printf("err = %+v\n", err) + } +} + +// DeleteProvisioner deletes a provisioner. +func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "name") + + c := h.auth.GetAdminCollection() + if c.Count() == c.CountByProvisioner(name) { + api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, + "cannot remove provisioner %s because no admins will remain", name)) + return + } + + ctx := r.Context() + prov, err := h.db.GetProvisionerByName(ctx, name) + if err != nil { + api.WriteError(w, mgmt.WrapErrorISE(err, "error retrieiving provisioner %s", name)) + return + } + fmt.Printf("prov = %+v\n", prov) + prov.Status = mgmt.StatusDeleted + if err := h.db.UpdateProvisioner(ctx, name, prov); err != nil { + api.WriteError(w, mgmt.WrapErrorISE(err, "error updating provisioner %s", name)) + return + } + + // Delete all admins associated with the provisioner. + admins, ok := c.LoadByProvisioner(name) + if ok { + for _, adm := range admins { + if err := h.db.UpdateAdmin(ctx, &mgmt.Admin{ + ID: adm.ID, + ProvisionerID: adm.ProvisionerID, + Subject: adm.Subject, + Type: mgmt.AdminType(adm.Type), + Status: mgmt.StatusDeleted, + }); err != nil { + api.WriteError(w, mgmt.WrapErrorISE(err, "error deleting admin %s, as part of provisioner %s deletion", adm.Subject, name)) + return + } + } + } + + api.JSON(w, &DeleteResponse{Status: "ok"}) + + if err := h.auth.ReloadAuthConfig(); err != nil { + fmt.Printf("err = %+v\n", err) + } } // UpdateProvisioner updates an existing prov. diff --git a/authority/mgmt/authConfig.go b/authority/mgmt/authConfig.go index 734bca50..14448af4 100644 --- a/authority/mgmt/authConfig.go +++ b/authority/mgmt/authConfig.go @@ -1,6 +1,7 @@ package mgmt import ( + "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" ) @@ -9,7 +10,7 @@ import ( type AuthConfig struct { //*cas.Options `json:"cas"` ID string `json:"id"` - ASN1DN *config.ASN1DN `json:"template,omitempty"` + ASN1DN *config.ASN1DN `json:"asn1dn,omitempty"` Provisioners []*Provisioner `json:"-"` Admins []*Admin `json:"-"` Claims *Claims `json:"claims,omitempty"` @@ -46,9 +47,18 @@ func (ac *AuthConfig) ToCertificates() (*config.AuthConfig, error) { } provs = append(provs, authProv) } + var admins []*admin.Admin + for _, adm := range ac.Admins { + authAdmin, err := adm.ToCertificates() + if err != nil { + return nil, err + } + admins = append(admins, authAdmin) + } return &config.AuthConfig{ AuthorityID: ac.ID, Provisioners: provs, + Admins: admins, Template: ac.ASN1DN, Claims: claims, DisableIssuedAtCheck: false, diff --git a/authority/mgmt/config.go b/authority/mgmt/config.go index b3ece47f..d569ad6f 100644 --- a/authority/mgmt/config.go +++ b/authority/mgmt/config.go @@ -2,7 +2,6 @@ package mgmt import ( "context" - "fmt" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/config" @@ -16,26 +15,15 @@ const ( ) // StatusType is the type for status. -type StatusType int +type StatusType string -const ( +var ( // StatusActive active - StatusActive StatusType = iota + StatusActive = StatusType("active") // StatusDeleted deleted - StatusDeleted + StatusDeleted = StatusType("deleted") ) -func (st StatusType) String() string { - switch st { - case StatusActive: - return "active" - case StatusDeleted: - return "deleted" - default: - return fmt.Sprintf("status %d not found", st) - } -} - // Claims encapsulates all x509 and ssh claims applied to the authority // configuration. E.g. maxTLSCertDuration, defaultSSHCertDuration, etc. type Claims struct { @@ -123,14 +111,18 @@ func CreateAuthority(ctx context.Context, db DB, options ...AuthorityOption) (*A return nil, WrapErrorISE(err, "error creating first provisioner") } - admin, err := CreateAdmin(ctx, db, "Change Me", prov.ID, true) - if err != nil { + adm := &Admin{ + ProvisionerID: prov.ID, + Subject: "Change Me", + Type: AdminTypeSuper, + } + if err := db.CreateAdmin(ctx, adm); err != nil { // TODO should we try to clean up? - return nil, WrapErrorISE(err, "error creating first provisioner") + return nil, WrapErrorISE(err, "error creating first admin") } ac.Provisioners = []*Provisioner{prov} - ac.Admins = []*Admin{admin} + ac.Admins = []*Admin{adm} return ac, nil } diff --git a/authority/mgmt/db.go b/authority/mgmt/db.go index 9cfab3e5..8546c4b4 100644 --- a/authority/mgmt/db.go +++ b/authority/mgmt/db.go @@ -14,8 +14,9 @@ var ErrNotFound = errors.New("not found") type DB interface { CreateProvisioner(ctx context.Context, prov *Provisioner) error GetProvisioner(ctx context.Context, id string) (*Provisioner, error) + GetProvisionerByName(ctx context.Context, name string) (*Provisioner, error) GetProvisioners(ctx context.Context) ([]*Provisioner, error) - UpdateProvisioner(ctx context.Context, prov *Provisioner) error + UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error CreateAdmin(ctx context.Context, admin *Admin) error GetAdmin(ctx context.Context, id string) (*Admin, error) @@ -30,10 +31,11 @@ type DB interface { // MockDB is an implementation of the DB interface that should only be used as // a mock in tests. type MockDB struct { - MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error - MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error) - MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error) - MockUpdateProvisioner func(ctx context.Context, prov *Provisioner) error + MockCreateProvisioner func(ctx context.Context, prov *Provisioner) error + MockGetProvisioner func(ctx context.Context, id string) (*Provisioner, error) + MockGetProvisionerByName func(ctx context.Context, name string) (*Provisioner, error) + MockGetProvisioners func(ctx context.Context) ([]*Provisioner, error) + MockUpdateProvisioner func(ctx context.Context, name string, prov *Provisioner) error MockCreateAdmin func(ctx context.Context, adm *Admin) error MockGetAdmin func(ctx context.Context, id string) (*Admin, error) @@ -68,6 +70,16 @@ func (m *MockDB) GetProvisioner(ctx context.Context, id string) (*Provisioner, e return m.MockRet1.(*Provisioner), m.MockError } +// GetProvisionerByName mock. +func (m *MockDB) GetProvisionerByName(ctx context.Context, id string) (*Provisioner, error) { + if m.MockGetProvisionerByName != nil { + return m.MockGetProvisionerByName(ctx, id) + } else if m.MockError != nil { + return nil, m.MockError + } + return m.MockRet1.(*Provisioner), m.MockError +} + // GetProvisioners mock func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) { if m.MockGetProvisioners != nil { @@ -79,9 +91,9 @@ func (m *MockDB) GetProvisioners(ctx context.Context) ([]*Provisioner, error) { } // UpdateProvisioner mock -func (m *MockDB) UpdateProvisioner(ctx context.Context, prov *Provisioner) error { +func (m *MockDB) UpdateProvisioner(ctx context.Context, name string, prov *Provisioner) error { if m.MockUpdateProvisioner != nil { - return m.MockUpdateProvisioner(ctx, prov) + return m.MockUpdateProvisioner(ctx, name, prov) } return m.MockError } diff --git a/authority/mgmt/db/nosql/admin.go b/authority/mgmt/db/nosql/admin.go index 70cb12d1..4e465489 100644 --- a/authority/mgmt/db/nosql/admin.go +++ b/authority/mgmt/db/nosql/admin.go @@ -12,13 +12,13 @@ import ( // dbAdmin is the database representation of the Admin type. type dbAdmin struct { - ID string `json:"id"` - AuthorityID string `json:"authorityID"` - ProvisionerID string `json:"provisionerID"` - Name string `json:"name"` - IsSuperAdmin bool `json:"isSuperAdmin"` - CreatedAt time.Time `json:"createdAt"` - DeletedAt time.Time `json:"deletedAt"` + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + ProvisionerID string `json:"provisionerID"` + Subject string `json:"subject"` + Type mgmt.AdminType `json:"type"` + CreatedAt time.Time `json:"createdAt"` + DeletedAt time.Time `json:"deletedAt"` } func (dbp *dbAdmin) clone() *dbAdmin { @@ -52,27 +52,6 @@ func (db *DB) getDBAdmin(ctx context.Context, id string) (*dbAdmin, error) { return dba, nil } -// GetAdmin retrieves and unmarshals a admin from the database. -func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) { - data, err := db.getDBAdminBytes(ctx, id) - if err != nil { - return nil, err - } - adm, err := unmarshalAdmin(data, id) - if err != nil { - return nil, err - } - if adm.Status == mgmt.StatusDeleted { - return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID) - } - if adm.AuthorityID != db.authorityID { - return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, - "admin %s is not owned by authority %s", adm.ID, db.authorityID) - } - - return adm, nil -} - func unmarshalDBAdmin(data []byte, id string) (*dbAdmin, error) { var dba = new(dbAdmin) if err := json.Unmarshal(data, dba); err != nil { @@ -90,8 +69,9 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) { ID: dba.ID, AuthorityID: dba.AuthorityID, ProvisionerID: dba.ProvisionerID, - Name: dba.Name, - IsSuperAdmin: dba.IsSuperAdmin, + Subject: dba.Subject, + Type: dba.Type, + Status: mgmt.StatusActive, } if !dba.DeletedAt.IsZero() { adm.Status = mgmt.StatusDeleted @@ -99,6 +79,33 @@ func unmarshalAdmin(data []byte, id string) (*mgmt.Admin, error) { return adm, nil } +// GetAdmin retrieves and unmarshals a admin from the database. +func (db *DB) GetAdmin(ctx context.Context, id string) (*mgmt.Admin, error) { + data, err := db.getDBAdminBytes(ctx, id) + if err != nil { + return nil, err + } + adm, err := unmarshalAdmin(data, id) + if err != nil { + return nil, err + } + if adm.Status == mgmt.StatusDeleted { + return nil, mgmt.NewError(mgmt.ErrorDeletedType, "admin %s is deleted", adm.ID) + } + if adm.AuthorityID != db.authorityID { + return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, + "admin %s is not owned by authority %s", adm.ID, db.authorityID) + } + prov, err := db.GetProvisioner(ctx, adm.ProvisionerID) + if err != nil { + return nil, err + } + adm.ProvisionerName = prov.Name + adm.ProvisionerType = prov.Type + + return adm, nil +} + // GetAdmins retrieves and unmarshals all active (not deleted) admins // from the database. // TODO should we be paginating? @@ -107,7 +114,10 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) { if err != nil { return nil, errors.Wrap(err, "error loading admins") } - var admins []*mgmt.Admin + var ( + provCache = map[string]*mgmt.Provisioner{} + admins []*mgmt.Admin + ) for _, entry := range dbEntries { adm, err := unmarshalAdmin(entry.Value, string(entry.Key)) if err != nil { @@ -119,6 +129,19 @@ func (db *DB) GetAdmins(ctx context.Context) ([]*mgmt.Admin, error) { if adm.AuthorityID != db.authorityID { continue } + var ( + prov *mgmt.Provisioner + ok bool + ) + if prov, ok = provCache[adm.ProvisionerID]; !ok { + prov, err = db.GetProvisioner(ctx, adm.ProvisionerID) + if err != nil { + return nil, err + } + provCache[adm.ProvisionerID] = prov + } + adm.ProvisionerName = prov.Name + adm.ProvisionerType = prov.Type admins = append(admins, adm) } return admins, nil @@ -129,16 +152,34 @@ func (db *DB) CreateAdmin(ctx context.Context, adm *mgmt.Admin) error { var err error adm.ID, err = randID() if err != nil { - return errors.Wrap(err, "error generating random id for admin") + return mgmt.WrapErrorISE(err, "error generating random id for admin") } adm.AuthorityID = db.authorityID + // If provisionerID is set, then use it, otherwise load the provisioner + // to get the name. + if adm.ProvisionerID == "" { + prov, err := db.GetProvisionerByName(ctx, adm.ProvisionerName) + if err != nil { + return err + } + adm.ProvisionerID = prov.ID + adm.ProvisionerType = prov.Type + } else { + prov, err := db.GetProvisioner(ctx, adm.ProvisionerID) + if err != nil { + return err + } + adm.ProvisionerName = prov.Name + adm.ProvisionerType = prov.Type + } + dba := &dbAdmin{ ID: adm.ID, AuthorityID: db.authorityID, ProvisionerID: adm.ProvisionerID, - Name: adm.Name, - IsSuperAdmin: adm.IsSuperAdmin, + Subject: adm.Subject, + Type: adm.Type, CreatedAt: clock.Now(), } @@ -158,8 +199,7 @@ func (db *DB) UpdateAdmin(ctx context.Context, adm *mgmt.Admin) error { if old.DeletedAt.IsZero() && adm.Status == mgmt.StatusDeleted { nu.DeletedAt = clock.Now() } - nu.ProvisionerID = adm.ProvisionerID - nu.IsSuperAdmin = adm.IsSuperAdmin + nu.Type = adm.Type return db.save(ctx, old.ID, nu, old, "admin", authorityAdminsTable) } diff --git a/authority/mgmt/db/nosql/authConfig.go b/authority/mgmt/db/nosql/authConfig.go index fe189ce3..b9108fdc 100644 --- a/authority/mgmt/db/nosql/authConfig.go +++ b/authority/mgmt/db/nosql/authConfig.go @@ -60,9 +60,14 @@ func (db *DB) GetAuthConfig(ctx context.Context, id string) (*mgmt.AuthConfig, e if err != nil { return nil, err } + admins, err := db.GetAdmins(ctx) + if err != nil { + return nil, err + } return &mgmt.AuthConfig{ ID: dba.ID, + Admins: admins, Provisioners: provs, ASN1DN: dba.ASN1DN, Backdate: dba.Backdate, diff --git a/authority/mgmt/db/nosql/nosql.go b/authority/mgmt/db/nosql/nosql.go index d71b2804..d97b7700 100644 --- a/authority/mgmt/db/nosql/nosql.go +++ b/authority/mgmt/db/nosql/nosql.go @@ -3,6 +3,7 @@ package nosql import ( "context" "encoding/json" + "fmt" "time" "github.com/pkg/errors" @@ -11,9 +12,10 @@ import ( ) var ( - authorityAdminsTable = []byte("authority_admins") - authorityConfigsTable = []byte("authority_configs") - authorityProvisionersTable = []byte("authority_provisioners") + authorityAdminsTable = []byte("authority_admins") + authorityConfigsTable = []byte("authority_configs") + authorityProvisionersTable = []byte("authority_provisioners") + authorityProvisionersNameIDIndexTable = []byte("authority_provisioners_name_id_index") ) // DB is a struct that implements the AcmeDB interface. @@ -24,7 +26,7 @@ type DB struct { // New configures and returns a new Authority DB backend implemented using a nosql DB. func New(db nosqlDB.DB, authorityID string) (*DB, error) { - tables := [][]byte{authorityAdminsTable, authorityConfigsTable, authorityProvisionersTable} + tables := [][]byte{authorityAdminsTable, authorityConfigsTable, authorityProvisionersTable, authorityProvisionersNameIDIndexTable} for _, b := range tables { if err := db.CreateTable(b); err != nil { return nil, errors.Wrapf(err, "error creating table %s", @@ -58,6 +60,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, old) } } + fmt.Printf("oldB = %+v\n", oldB) _, swapped, err := db.db.CmpAndSwap(table, []byte(id), oldB, newB) switch { @@ -73,7 +76,7 @@ func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface var idLen = 32 func randID() (val string, err error) { - val, err = randutil.Alphanumeric(idLen) + val, err = randutil.UUIDv4() if err != nil { return "", errors.Wrap(err, "error generating random alphanumeric ID") } diff --git a/authority/mgmt/db/nosql/provisioner.go b/authority/mgmt/db/nosql/provisioner.go index 6d9f74ab..d44c3be1 100644 --- a/authority/mgmt/db/nosql/provisioner.go +++ b/authority/mgmt/db/nosql/provisioner.go @@ -9,13 +9,15 @@ import ( "github.com/pkg/errors" "github.com/smallstep/certificates/authority/mgmt" "github.com/smallstep/nosql" + "github.com/smallstep/nosql/database" ) // dbProvisioner is the database representation of a Provisioner type. type dbProvisioner struct { - ID string `json:"id"` - AuthorityID string `json:"authorityID"` - Type string `json:"type"` + ID string `json:"id"` + AuthorityID string `json:"authorityID"` + Type string `json:"type"` + // Name is the key Name string `json:"name"` Claims *mgmt.Claims `json:"claims"` Details []byte `json:"details"` @@ -27,11 +29,30 @@ type dbProvisioner struct { DeletedAt time.Time `json:"deletedAt"` } +type provisionerNameID struct { + Name string `json:"name"` + ID string `json:"id"` +} + func (dbp *dbProvisioner) clone() *dbProvisioner { u := *dbp return &u } +func (db *DB) getProvisionerIDByName(ctx context.Context, name string) (string, error) { + data, err := db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name)) + if nosql.IsErrNotFound(err) { + return "", mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name) + } else if err != nil { + return "", mgmt.WrapErrorISE(err, "error loading provisioner %s", name) + } + ni := new(provisionerNameID) + if err := json.Unmarshal(data, ni); err != nil { + return "", mgmt.WrapErrorISE(err, "error unmarshaling provisionerNameID for provisioner %s", name) + } + return ni.ID, nil +} + func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) { data, err := db.db.Get(authorityProvisionersTable, []byte(id)) if nosql.IsErrNotFound(err) { @@ -51,6 +72,9 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, if err != nil { return nil, err } + if !dbp.DeletedAt.IsZero() { + return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", id) + } if dbp.AuthorityID != db.authorityID { return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, "provisioner %s is not owned by authority %s", dbp.ID, db.authorityID) @@ -58,6 +82,25 @@ func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, return dbp, nil } +func (db *DB) getDBProvisionerByName(ctx context.Context, name string) (*dbProvisioner, error) { + id, err := db.getProvisionerIDByName(ctx, name) + if err != nil { + return nil, err + } + dbp, err := db.getDBProvisioner(ctx, id) + if err != nil { + return nil, err + } + if !dbp.DeletedAt.IsZero() { + return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", name) + } + if dbp.AuthorityID != db.authorityID { + return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType, + "provisioner %s is not owned by authority %s", name, db.authorityID) + } + return dbp, nil +} + // GetProvisioner retrieves and unmarshals a provisioner from the database. func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) { data, err := db.getDBProvisionerBytes(ctx, id) @@ -79,29 +122,30 @@ func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, return prov, nil } -func unmarshalDBProvisioner(data []byte, id string) (*dbProvisioner, error) { +// GetProvisionerByName retrieves a provisioner from the database by name. +func (db *DB) GetProvisionerByName(ctx context.Context, name string) (*mgmt.Provisioner, error) { + p, err := db.getProvisionerIDByName(ctx, name) + if err != nil { + return nil, err + } + return db.GetProvisioner(ctx, id) +} + +func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) { var dbp = new(dbProvisioner) if err := json.Unmarshal(data, dbp); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", id) + return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", name) } return dbp, nil } -type detailsType struct { - Type mgmt.ProvisionerType -} - -func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) { - dbp, err := unmarshalDBProvisioner(data, id) +func unmarshalProvisioner(data []byte, name string) (*mgmt.Provisioner, error) { + dbp, err := unmarshalDBProvisioner(data, name) if err != nil { return nil, err } - dt := new(detailsType) - if err := json.Unmarshal(dbp.Details, dt); err != nil { - return nil, mgmt.WrapErrorISE(err, "error unmarshaling details to detailsType for provisioner %s", id) - } - details, err := unmarshalDetails(dt.Type, dbp.Details) + details, err := mgmt.UnmarshalProvisionerDetails(dbp.Details) if err != nil { return nil, err } @@ -113,6 +157,7 @@ func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) { Name: dbp.Name, Claims: dbp.Claims, Details: details, + Status: mgmt.StatusActive, X509Template: dbp.X509Template, X509TemplateData: dbp.X509TemplateData, SSHTemplate: dbp.SSHTemplate, @@ -129,7 +174,7 @@ func unmarshalProvisioner(data []byte, id string) (*mgmt.Provisioner, error) { func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error) { dbEntries, err := db.db.List(authorityProvisionersTable) if err != nil { - return nil, errors.Wrap(err, "error loading provisioners") + return nil, mgmt.WrapErrorISE(err, "error loading provisioners") } var provs []*mgmt.Provisioner for _, entry := range dbEntries { @@ -158,76 +203,169 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) err details, err := json.Marshal(prov.Details) if err != nil { - return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner") + return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name) } dbp := &dbProvisioner{ - ID: prov.ID, - AuthorityID: db.authorityID, - Type: prov.Type, - Name: prov.Name, - Claims: prov.Claims, - Details: details, - X509Template: prov.X509Template, - SSHTemplate: prov.SSHTemplate, - CreatedAt: clock.Now(), + ID: prov.ID, + AuthorityID: db.authorityID, + Type: prov.Type, + Name: prov.Name, + Claims: prov.Claims, + Details: details, + X509Template: prov.X509Template, + X509TemplateData: prov.X509TemplateData, + SSHTemplate: prov.SSHTemplate, + SSHTemplateData: prov.SSHTemplateData, + CreatedAt: clock.Now(), + } + dbpBytes, err := json.Marshal(dbp) + if err != nil { + return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name) + } + pni := &provisionerNameID{ + Name: prov.Name, + ID: prov.ID, + } + pniBytes, err := json.Marshal(pni) + if err != nil { + return mgmt.WrapErrorISE(err, "error marshaling provisionerNameIndex %s", prov.Name) } - return db.save(ctx, dbp.ID, dbp, nil, "provisioner", authorityProvisionersTable) + if err := db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: authorityProvisionersTable, + Key: []byte(dbp.ID), + Cmd: database.CmpAndSwap, + Value: dbpBytes, + CmpValue: nil, + }, + { + Bucket: authorityProvisionersNameIDIndexTable, + Key: []byte(dbp.Name), + Cmd: database.CmpAndSwap, + Value: pniBytes, + CmpValue: nil, + }, + }, + }); err != nil { + return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name) + } + + return nil } // UpdateProvisioner saves an updated provisioner to the database. -func (db *DB) UpdateProvisioner(ctx context.Context, prov *mgmt.Provisioner) error { - old, err := db.getDBProvisioner(ctx, prov.ID) +func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Provisioner) error { + id, err := db.getProvisionerIDByName(ctx, name) if err != nil { return err } + prov.ID = id + oldBytes, err := db.getDBProvisionerBytes(ctx, id) + if err != nil { + return err + } + fmt.Printf("oldBytes = %+v\n", oldBytes) + old, err := unmarshalDBProvisioner(oldBytes, id) + if err != nil { + return err + } + fmt.Printf("old = %+v\n", old) nu := old.clone() + nu.Type = prov.Type + nu.Name = prov.Name + nu.Claims = prov.Claims + nu.Details, err = json.Marshal(prov.Details) + if err != nil { + return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", name) + } + nu.X509Template = prov.X509Template + nu.X509TemplateData = prov.X509TemplateData + nu.SSHTemplateData = prov.SSHTemplateData + + var txs = []*database.TxEntry{} // If the provisioner was active but is now deleted ... if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted { nu.DeletedAt = clock.Now() + txs = append(txs, &database.TxEntry{ + Bucket: authorityProvisionersNameIDIndexTable, + Key: []byte(name), + Cmd: database.Delete, + }) } - nu.Claims = prov.Claims - nu.X509Template = prov.X509Template - nu.SSHTemplate = prov.SSHTemplate - nu.Details, err = json.Marshal(prov.Details) + if prov.Name != name { + // If the new name does not match the old name then: + // 1) check that the new name is not already taken + // 2) delete the old name-id index resource + // 3) create a new name-id index resource + // 4) update the provisioner resource + nuBytes, err := json.Marshal(nu) + if err != nil { + return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name) + } + pni := &provisionerNameID{ + Name: prov.Name, + ID: prov.ID, + } + pniBytes, err := json.Marshal(pni) + if err != nil { + return mgmt.WrapErrorISE(err, "error marshaling provisionerNameID for provisioner %s", prov.Name) + } + + _, err = db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name)) + if err == nil { + return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", prov.Name) + } else if !nosql.IsErrNotFound(err) { + return mgmt.WrapErrorISE(err, "error loading provisionerNameID %s", prov.Name) + } + err = db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: authorityProvisionersNameIDIndexTable, + Key: []byte(name), + Cmd: database.Delete, + }, + { + Bucket: authorityProvisionersNameIDIndexTable, + Key: []byte(prov.Name), + Cmd: database.CmpAndSwap, + Value: pniBytes, + CmpValue: nil, + }, + { + Bucket: authorityProvisionersTable, + Key: []byte(nu.ID), + Cmd: database.CmpAndSwap, + Value: nuBytes, + CmpValue: oldBytes, + }, + }, + }) + } else { + err = db.db.Update(&database.Tx{ + Operations: []*database.TxEntry{ + { + Bucket: authorityProvisionersNameIDIndexTable, + Key: []byte(name), + Cmd: database.Delete, + }, + { + Bucket: authorityProvisionersTable, + Key: []byte(nu.ID), + Cmd: database.CmpAndSwap, + Value: nuBytes, + CmpValue: oldBytes, + }, + }, + }) + } if err != nil { - return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner") + return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name) } - - return db.save(ctx, old.ID, nu, old, "provisioner", authorityProvisionersTable) -} - -func unmarshalDetails(typ mgmt.ProvisionerType, data []byte) (mgmt.ProvisionerDetails, error) { - var v mgmt.ProvisionerDetails - switch typ { - case mgmt.ProvisionerTypeJWK: - v = new(mgmt.ProvisionerDetailsJWK) - case mgmt.ProvisionerTypeOIDC: - v = new(mgmt.ProvisionerDetailsOIDC) - case mgmt.ProvisionerTypeGCP: - v = new(mgmt.ProvisionerDetailsGCP) - case mgmt.ProvisionerTypeAWS: - v = new(mgmt.ProvisionerDetailsAWS) - case mgmt.ProvisionerTypeAZURE: - v = new(mgmt.ProvisionerDetailsAzure) - case mgmt.ProvisionerTypeACME: - v = new(mgmt.ProvisionerDetailsACME) - case mgmt.ProvisionerTypeX5C: - v = new(mgmt.ProvisionerDetailsX5C) - case mgmt.ProvisionerTypeK8SSA: - v = new(mgmt.ProvisionerDetailsK8SSA) - case mgmt.ProvisionerTypeSSHPOP: - v = new(mgmt.ProvisionerDetailsSSHPOP) - default: - return nil, fmt.Errorf("unsupported provisioner type %s", typ) - } - - if err := json.Unmarshal(data, v); err != nil { - return nil, err - } - return v, nil + return nil } diff --git a/authority/mgmt/errors.go b/authority/mgmt/errors.go index f0f90400..63c0652e 100644 --- a/authority/mgmt/errors.go +++ b/authority/mgmt/errors.go @@ -88,12 +88,11 @@ var ( // Error represents an ACME type Error struct { - Type string `json:"type"` - Detail string `json:"detail"` - Subproblems []interface{} `json:"subproblems,omitempty"` - Identifier interface{} `json:"identifier,omitempty"` - Err error `json:"-"` - Status int `json:"-"` + Type string `json:"type"` + Detail string `json:"detail"` + Message string `json:"message"` + Err error `json:"-"` + Status int `json:"-"` } // IsType returns true if the error type matches the input type. @@ -160,7 +159,7 @@ func (e *Error) StatusCode() int { // Error allows AError to implement the error interface. func (e *Error) Error() string { - return e.Detail + return e.Err.Error() } // Cause returns the internal error and implements the Causer interface. @@ -182,9 +181,10 @@ func (e *Error) ToLog() (interface{}, error) { // WriteError writes to w a JSON representation of the given error. func WriteError(w http.ResponseWriter, err *Error) { - w.Header().Set("Content-Type", "application/problem+json") + w.Header().Set("Content-Type", "application/json") w.WriteHeader(err.StatusCode()) + err.Message = err.Err.Error() // Write errors in the response writer if rl, ok := w.(logging.ResponseLogger); ok { rl.WithFields(map[string]interface{}{ @@ -199,6 +199,7 @@ func WriteError(w http.ResponseWriter, err *Error) { } } + fmt.Printf("err = %+v\n", err) if err := json.NewEncoder(w).Encode(err); err != nil { log.Println(err) } diff --git a/authority/mgmt/provisioner.go b/authority/mgmt/provisioner.go index 961907f8..fb39f6bc 100644 --- a/authority/mgmt/provisioner.go +++ b/authority/mgmt/provisioner.go @@ -59,8 +59,8 @@ func WithPassword(pass string) func(*ProvisionerCtx) { // Provisioner type. type Provisioner struct { - ID string `json:"id"` - AuthorityID string `json:"authorityID"` + ID string `json:"-"` + AuthorityID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Claims *Claims `json:"claims"` @@ -87,8 +87,7 @@ func (p *Provisioner) GetOptions() *provisioner.Options { func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...ProvisionerOption) (*Provisioner, error) { pc := NewProvisionerCtx(opts...) - - details, err := createJWKDetails(pc) + details, err := NewProvisionerDetails(ProvisionerType(typ), pc) if err != nil { return nil, err } @@ -180,6 +179,27 @@ func (*ProvisionerDetailsK8SSA) isProvisionerDetails() {} func (*ProvisionerDetailsSSHPOP) isProvisionerDetails() {} +func NewProvisionerDetails(typ ProvisionerType, pc *ProvisionerCtx) (ProvisionerDetails, error) { + switch typ { + case ProvisionerTypeJWK: + return createJWKDetails(pc) + /* + case ProvisionerTypeOIDC: + return createOIDCDetails(pc) + case ProvisionerTypeACME: + return createACMEDetails(pc) + case ProvisionerTypeK8SSA: + return createK8SSADetails(pc) + case ProvisionerTypeSSHPOP: + return createSSHPOPDetails(pc) + case ProvisionerTypeX5C: + return createSSHPOPDetails(pc) + */ + default: + return nil, NewErrorISE("unsupported provisioner type %s", typ) + } +} + func createJWKDetails(pc *ProvisionerCtx) (*ProvisionerDetailsJWK, error) { var err error @@ -231,6 +251,7 @@ func (p *Provisioner) ToCertificates() (provisioner.Interface, error) { return nil, err } return &provisioner.JWK{ + ID: p.ID, Type: p.Type, Name: p.Name, Key: jwk, @@ -386,3 +407,43 @@ func (c *Claims) ToCertificates() (*provisioner.Claims, error) { EnableSSHCA: &c.SSH.Enabled, }, nil } + +type detailsType struct { + Type ProvisionerType +} + +func UnmarshalProvisionerDetails(data []byte) (ProvisionerDetails, error) { + dt := new(detailsType) + if err := json.Unmarshal(data, dt); err != nil { + return nil, WrapErrorISE(err, "error unmarshaling provisioner details") + } + + var v ProvisionerDetails + switch dt.Type { + case ProvisionerTypeJWK: + v = new(ProvisionerDetailsJWK) + case ProvisionerTypeOIDC: + v = new(ProvisionerDetailsOIDC) + case ProvisionerTypeGCP: + v = new(ProvisionerDetailsGCP) + case ProvisionerTypeAWS: + v = new(ProvisionerDetailsAWS) + case ProvisionerTypeAZURE: + v = new(ProvisionerDetailsAzure) + case ProvisionerTypeACME: + v = new(ProvisionerDetailsACME) + case ProvisionerTypeX5C: + v = new(ProvisionerDetailsX5C) + case ProvisionerTypeK8SSA: + v = new(ProvisionerDetailsK8SSA) + case ProvisionerTypeSSHPOP: + v = new(ProvisionerDetailsSSHPOP) + default: + return nil, fmt.Errorf("unsupported provisioner type %s", dt.Type) + } + + if err := json.Unmarshal(data, v); err != nil { + return nil, err + } + return v, nil +} diff --git a/authority/provisioner/collection.go b/authority/provisioner/collection.go index 13b7be4d..b6ff8b9e 100644 --- a/authority/provisioner/collection.go +++ b/authority/provisioner/collection.go @@ -45,6 +45,7 @@ type loadByTokenPayload struct { type Collection struct { byID *sync.Map byKey *sync.Map + byName *sync.Map sorted provisionerSlice audiences Audiences } @@ -55,6 +56,7 @@ func NewCollection(audiences Audiences) *Collection { return &Collection{ byID: new(sync.Map), byKey: new(sync.Map), + byName: new(sync.Map), audiences: audiences, } } @@ -64,6 +66,11 @@ func (c *Collection) Load(id string) (Interface, bool) { return loadProvisioner(c.byID, id) } +// LoadByName a provisioner by name. +func (c *Collection) LoadByName(name string) (Interface, bool) { + return loadProvisioner(c.byName, name) +} + // LoadByToken parses the token claims and loads the provisioner associated. func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims) (Interface, bool) { var audiences []string @@ -173,6 +180,11 @@ func (c *Collection) Store(p Interface) error { if _, loaded := c.byID.LoadOrStore(p.GetID(), p); loaded { return errors.New("cannot add multiple provisioners with the same id") } + // Store provisioner always by name. + if _, loaded := c.byName.LoadOrStore(p.GetName(), p); loaded { + c.byID.Delete(p.GetID()) + return errors.New("cannot add multiple provisioners with the same id") + } // Store provisioner in byKey if EncryptedKey is defined. if kid, _, ok := p.GetEncryptedKey(); ok { diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index d6a97e2b..a2d3e0b1 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -28,6 +28,7 @@ type stepPayload struct { // signature requests. type JWK struct { *base + ID string `json:"-"` Type string `json:"type"` Name string `json:"name"` Key *jose.JSONWebKey `json:"key"` @@ -41,6 +42,9 @@ type JWK struct { // GetID returns the provisioner unique identifier. The name and credential id // should uniquely identify any JWK provisioner. func (p *JWK) GetID() string { + if p.ID != "" { + return p.ID + } return p.Name + ":" + p.Key.KeyID } diff --git a/ca/ca.go b/ca/ca.go index ec10f74c..ecbc55e2 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -172,10 +172,9 @@ func (ca *CA) Init(config *config.Config) (*CA, error) { }) // MGMT Router - mgmtDB := auth.GetMgmtDatabase() if mgmtDB != nil { - mgmtHandler := mgmtAPI.NewHandler(mgmtDB) + mgmtHandler := mgmtAPI.NewHandler(mgmtDB, auth) mux.Route("/mgmt", func(r chi.Router) { mgmtHandler.Route(r) }) diff --git a/ca/mgmtClient.go b/ca/mgmtClient.go index 47316ad6..96c33f03 100644 --- a/ca/mgmtClient.go +++ b/ca/mgmtClient.go @@ -3,6 +3,7 @@ package ca import ( "bytes" "encoding/json" + "io" "net/http" "net/url" "path" @@ -78,7 +79,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readMgmtError(resp.Body) } var adm = new(mgmt.Admin) if err := readJSON(resp.Body, adm); err != nil { @@ -105,7 +106,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readMgmtError(resp.Body) } var adm = new(mgmt.Admin) if err := readJSON(resp.Body, adm); err != nil { @@ -132,7 +133,7 @@ retry: retried = true goto retry } - return readError(resp.Body) + return readMgmtError(resp.Body) } return nil } @@ -145,7 +146,7 @@ func (c *MgmtClient) UpdateAdmin(id string, uar *mgmtAPI.UpdateAdminRequest) (*m return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/admin", id)}) - req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body)) + req, err := http.NewRequest("PATCH", u.String(), bytes.NewReader(body)) if err != nil { return nil, errors.Wrapf(err, "create PUT %s request failed", u) } @@ -159,7 +160,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readMgmtError(resp.Body) } var adm = new(mgmt.Admin) if err := readJSON(resp.Body, adm); err != nil { @@ -182,7 +183,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readMgmtError(resp.Body) } var admins = new([]*mgmt.Admin) if err := readJSON(resp.Body, admins); err != nil { @@ -191,6 +192,29 @@ retry: return *admins, nil } +// GetProvisioner performs the GET /mgmt/provisioner/{id} request to the CA. +func (c *MgmtClient) GetProvisioner(id string) (*mgmt.Provisioner, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", id)}) +retry: + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readMgmtError(resp.Body) + } + var prov = new(mgmt.Provisioner) + if err := readJSON(resp.Body, prov); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return prov, nil +} + // GetProvisioners performs the GET /mgmt/provisioners request to the CA. func (c *MgmtClient) GetProvisioners() ([]*mgmt.Provisioner, error) { var retried bool @@ -205,7 +229,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readMgmtError(resp.Body) } var provs = new([]*mgmt.Provisioner) if err := readJSON(resp.Body, provs); err != nil { @@ -213,3 +237,116 @@ retry: } return *provs, nil } + +// RemoveProvisioner performs the DELETE /mgmt/provisioner/{name} request to the CA. +func (c *MgmtClient) RemoveProvisioner(name string) error { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", name)}) + req, err := http.NewRequest("DELETE", u.String(), nil) + if err != nil { + return errors.Wrapf(err, "create DELETE %s request failed", u) + } +retry: + resp, err := c.client.Do(req) + if err != nil { + return errors.Wrapf(err, "client DELETE %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return readMgmtError(resp.Body) + } + return nil +} + +// CreateProvisioner performs the POST /mgmt/provisioner request to the CA. +func (c *MgmtClient) CreateProvisioner(req *mgmtAPI.CreateProvisionerRequest) (*mgmt.Provisioner, error) { + var retried bool + body, err := json.Marshal(req) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: "/mgmt/provisioner"}) +retry: + resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readMgmtError(resp.Body) + } + var prov = new(mgmt.Provisioner) + if err := readJSON(resp.Body, prov); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return prov, nil +} + +// UpdateProvisioner performs the PUT /mgmt/provisioner/{id} request to the CA. +func (c *MgmtClient) UpdateProvisioner(id string, upr *mgmtAPI.UpdateProvisionerRequest) (*mgmt.Provisioner, error) { + var retried bool + body, err := json.Marshal(upr) + if err != nil { + return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") + } + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/provisioner", id)}) + req, err := http.NewRequest("PUT", u.String(), bytes.NewReader(body)) + if err != nil { + return nil, errors.Wrapf(err, "create PUT %s request failed", u) + } +retry: + resp, err := c.client.Do(req) + if err != nil { + return nil, errors.Wrapf(err, "client PUT %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readMgmtError(resp.Body) + } + var prov = new(mgmt.Provisioner) + if err := readJSON(resp.Body, prov); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return prov, nil +} + +// GetAuthConfig performs the GET /mgmt/authconfig/{id} request to the CA. +func (c *MgmtClient) GetAuthConfig(id string) (*mgmt.AuthConfig, error) { + var retried bool + u := c.endpoint.ResolveReference(&url.URL{Path: path.Join("/mgmt/authconfig", id)}) +retry: + resp, err := c.client.Get(u.String()) + if err != nil { + return nil, errors.Wrapf(err, "client GET %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readMgmtError(resp.Body) + } + var ac = new(mgmt.AuthConfig) + if err := readJSON(resp.Body, ac); err != nil { + return nil, errors.Wrapf(err, "error reading %s", u) + } + return ac, nil +} + +func readMgmtError(r io.ReadCloser) error { + defer r.Close() + mgmtErr := new(mgmt.Error) + if err := json.NewDecoder(r).Decode(mgmtErr); err != nil { + return err + } + return errors.New(mgmtErr.Message) +}