This commit is contained in:
max furman 2021-05-21 13:31:41 -07:00
parent d8d5d7332b
commit 9bfb1c2e7b
5 changed files with 142 additions and 99 deletions

View file

@ -32,24 +32,24 @@ func (p adminSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
// Collection is a memory map of admins. // Collection is a memory map of admins.
type Collection struct { type Collection struct {
byID *sync.Map byID *sync.Map
bySubProv *sync.Map bySubProv *sync.Map
byProv *sync.Map byProv *sync.Map
sorted adminSlice sorted adminSlice
provisioners *provisioner.Collection provisioners *provisioner.Collection
count int superCount int
countByProvisioner map[string]int superCountByProvisioner map[string]int
} }
// NewCollection initializes a collection of provisioners. The given list of // NewCollection initializes a collection of provisioners. The given list of
// audiences are the audiences used by the JWT provisioner. // audiences are the audiences used by the JWT provisioner.
func NewCollection(provisioners *provisioner.Collection) *Collection { func NewCollection(provisioners *provisioner.Collection) *Collection {
return &Collection{ return &Collection{
byID: new(sync.Map), byID: new(sync.Map),
byProv: new(sync.Map), byProv: new(sync.Map),
bySubProv: new(sync.Map), bySubProv: new(sync.Map),
countByProvisioner: map[string]int{}, superCountByProvisioner: map[string]int{},
provisioners: provisioners, provisioners: provisioners,
} }
} }
@ -106,12 +106,12 @@ func (c *Collection) Store(adm *Admin) error {
if admins, ok := c.LoadByProvisioner(provName); ok { if admins, ok := c.LoadByProvisioner(provName); ok {
c.byProv.Store(provName, append(admins, adm)) c.byProv.Store(provName, append(admins, adm))
c.countByProvisioner[provName]++ c.superCountByProvisioner[provName]++
} else { } else {
c.byProv.Store(provName, []*Admin{adm}) c.byProv.Store(provName, []*Admin{adm})
c.countByProvisioner[provName] = 1 c.superCountByProvisioner[provName] = 1
} }
c.count++ c.superCount++
// Store sorted admins. // Store sorted admins.
// Use the first 4 bytes (32bit) of the sum to insert the order // Use the first 4 bytes (32bit) of the sum to insert the order
@ -131,14 +131,14 @@ func (c *Collection) Store(adm *Admin) error {
return nil return nil
} }
// Count returns the total number of admins. // SuperCount returns the total number of admins.
func (c *Collection) Count() int { func (c *Collection) SuperCount() int {
return c.count return c.superCount
} }
// CountByProvisioner returns the total number of admins. // SuperCountByProvisioner returns the total number of admins.
func (c *Collection) CountByProvisioner(provName string) int { func (c *Collection) SuperCountByProvisioner(provName string) int {
if cnt, ok := c.countByProvisioner[provName]; ok { if cnt, ok := c.superCountByProvisioner[provName]; ok {
return cnt return cnt
} }
return 0 return 0

View file

@ -120,8 +120,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if h.auth.GetAdminCollection().Count() == 1 { if h.auth.GetAdminCollection().SuperCount() == 1 {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "cannot remove last admin")) api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType, "cannot remove the last super admin"))
return return
} }

View file

@ -101,34 +101,17 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
var body CreateProvisionerRequest var prov = new(mgmt.Provisioner)
if err := api.ReadJSON(r.Body, &body); err != nil { if err := api.ReadJSON(r.Body, prov); err != nil {
api.WriteError(w, err)
return
}
if err := body.Validate(h.auth.GetProvisionerCollection()); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
details, err := mgmt.UnmarshalProvisionerDetails(body.Details) // TODO: validate
if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error unmarshaling provisioner details"))
return
}
claims := mgmt.NewDefaultClaims() // TODO: fix this
prov.Claims = mgmt.NewDefaultClaims()
prov := &mgmt.Provisioner{
Type: body.Type,
Name: body.Name,
Claims: claims,
Details: details,
X509Template: body.X509Template,
X509TemplateData: body.X509TemplateData,
SSHTemplate: body.SSHTemplate,
SSHTemplateData: body.SSHTemplateData,
}
if err := h.db.CreateProvisioner(ctx, prov); err != nil { if err := h.db.CreateProvisioner(ctx, prov); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
@ -144,21 +127,20 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
c := h.auth.GetAdminCollection()
fmt.Printf("c.Count() = %+v\n", c.Count())
fmt.Printf("c.CountByProvisioner() = %+v\n", c.CountByProvisioner(name))
if c.Count() == c.CountByProvisioner(name) {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType,
"cannot remove provisioner %s because no admins will remain", name))
return
}
ctx := r.Context()
p, ok := h.auth.GetProvisionerCollection().LoadByName(name) p, ok := h.auth.GetProvisionerCollection().LoadByName(name)
if !ok { if !ok {
api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name)) api.WriteError(w, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name))
return return
} }
c := h.auth.GetAdminCollection()
if c.SuperCount() == c.SuperCountByProvisioner(name) {
api.WriteError(w, mgmt.NewError(mgmt.ErrorBadRequestType,
"cannot remove provisioner %s because no super admins will remain", name))
return
}
ctx := r.Context()
prov, err := h.db.GetProvisioner(ctx, p.GetID()) prov, err := h.db.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
api.WriteError(w, mgmt.WrapErrorISE(err, "error loading provisioner %s from db", name)) api.WriteError(w, mgmt.WrapErrorISE(err, "error loading provisioner %s from db", name))

View file

@ -12,15 +12,6 @@ import (
type ProvisionerOption func(*ProvisionerCtx) type ProvisionerOption func(*ProvisionerCtx)
type ProvisionerCtx struct {
JWK *jose.JSONWebKey
JWE *jose.JSONWebEncryption
X509Template, SSHTemplate string
X509TemplateData, SSHTemplateData []byte
Claims *Claims
Password string
}
type ProvisionerType string type ProvisionerType string
var ( var (
@ -35,29 +26,6 @@ var (
ProvisionerTypeX5C = ProvisionerType("X5C") ProvisionerTypeX5C = ProvisionerType("X5C")
) )
func NewProvisionerCtx(opts ...ProvisionerOption) *ProvisionerCtx {
pc := &ProvisionerCtx{
Claims: NewDefaultClaims(),
}
for _, o := range opts {
o(pc)
}
return pc
}
func WithJWK(jwk *jose.JSONWebKey, jwe *jose.JSONWebEncryption) func(*ProvisionerCtx) {
return func(ctx *ProvisionerCtx) {
ctx.JWK = jwk
ctx.JWE = jwe
}
}
func WithPassword(pass string) func(*ProvisionerCtx) {
return func(ctx *ProvisionerCtx) {
ctx.Password = pass
}
}
type unmarshalProvisioner struct { type unmarshalProvisioner struct {
ID string `json:"-"` ID string `json:"-"`
AuthorityID string `json:"-"` AuthorityID string `json:"-"`
@ -157,6 +125,38 @@ func CreateProvisioner(ctx context.Context, db DB, typ, name string, opts ...Pro
return p, nil return p, nil
} }
type ProvisionerCtx struct {
JWK *jose.JSONWebKey
JWE *jose.JSONWebEncryption
X509Template, SSHTemplate string
X509TemplateData, SSHTemplateData []byte
Claims *Claims
Password string
}
func NewProvisionerCtx(opts ...ProvisionerOption) *ProvisionerCtx {
pc := &ProvisionerCtx{
Claims: NewDefaultClaims(),
}
for _, o := range opts {
o(pc)
}
return pc
}
func WithJWK(jwk *jose.JSONWebKey, jwe *jose.JSONWebEncryption) func(*ProvisionerCtx) {
return func(ctx *ProvisionerCtx) {
ctx.JWK = jwk
ctx.JWE = jwe
}
}
func WithPassword(pass string) func(*ProvisionerCtx) {
return func(ctx *ProvisionerCtx) {
ctx.Password = pass
}
}
// ProvisionerDetails is the interface implemented by all provisioner details // ProvisionerDetails is the interface implemented by all provisioner details
// attributes. // attributes.
type ProvisionerDetails interface { type ProvisionerDetails interface {
@ -172,37 +172,61 @@ type ProvisionerDetailsJWK struct {
// ProvisionerDetailsOIDC represents the values required by a OIDC provisioner. // ProvisionerDetailsOIDC represents the values required by a OIDC provisioner.
type ProvisionerDetailsOIDC struct { type ProvisionerDetailsOIDC struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
ConfigurationEndpoint string `json:"configurationEndpoint"`
Admins []string `json:"admins"`
Domains []string `json:"domains"`
Groups []string `json:"groups"`
ListenAddress string `json:"listenAddress"`
TenantID string `json:"tenantID"`
} }
// ProvisionerDetailsGCP represents the values required by a GCP provisioner. // ProvisionerDetailsGCP represents the values required by a GCP provisioner.
type ProvisionerDetailsGCP struct { type ProvisionerDetailsGCP struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
ServiceAccounts []string `json:"serviceAccounts"`
ProjectIDs []string `json:"projectIDs"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
InstanceAge string `json:"instanceAge"`
} }
// ProvisionerDetailsAWS represents the values required by a AWS provisioner. // ProvisionerDetailsAWS represents the values required by a AWS provisioner.
type ProvisionerDetailsAWS struct { type ProvisionerDetailsAWS struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
Accounts []string `json:"accounts"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
InstanceAge string `json:"instanceAge"`
} }
// ProvisionerDetailsAzure represents the values required by a Azure provisioner. // ProvisionerDetailsAzure represents the values required by a Azure provisioner.
type ProvisionerDetailsAzure struct { type ProvisionerDetailsAzure struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
ResourceGroups []string `json:"resourceGroups"`
Audience string `json:"audience"`
DisableCustomSANs bool `json:"disableCustomSANs"`
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
} }
// ProvisionerDetailsACME represents the values required by a ACME provisioner. // ProvisionerDetailsACME represents the values required by a ACME provisioner.
type ProvisionerDetailsACME struct { type ProvisionerDetailsACME struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
ForceCN bool `json:"forceCN"`
} }
// ProvisionerDetailsX5C represents the values required by a X5C provisioner. // ProvisionerDetailsX5C represents the values required by a X5C provisioner.
type ProvisionerDetailsX5C struct { type ProvisionerDetailsX5C struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
Roots []byte `json:"roots"`
} }
// ProvisionerDetailsK8SSA represents the values required by a K8SSA provisioner. // ProvisionerDetailsK8SSA represents the values required by a K8SSA provisioner.
type ProvisionerDetailsK8SSA struct { type ProvisionerDetailsK8SSA struct {
Type ProvisionerType `json:"type"` Type ProvisionerType `json:"type"`
PublicKeys []byte `json:"publicKeys"`
} }
// ProvisionerDetailsSSHPOP represents the values required by a SSHPOP provisioner. // ProvisionerDetailsSSHPOP represents the values required by a SSHPOP provisioner.
@ -285,6 +309,42 @@ func createJWKDetails(pc *ProvisionerCtx) (*ProvisionerDetailsJWK, error) {
}, nil }, nil
} }
func createACMEDetails(pc *ProvisionerCtx) (*ProvisionerDetailsJWK, error) {
var err error
if pc.JWK != nil && pc.JWE == nil {
return nil, NewErrorISE("JWE is required with JWK for createJWKProvisioner")
}
if pc.JWE != nil && pc.JWK == nil {
return nil, NewErrorISE("JWK is required with JWE for createJWKProvisioner")
}
if pc.JWK == nil && pc.JWE == nil {
// Create a new JWK w/ encrypted private key.
if pc.Password == "" {
return nil, NewErrorISE("password is required to provisioner with new keys")
}
pc.JWK, pc.JWE, err = jose.GenerateDefaultKeyPair([]byte(pc.Password))
if err != nil {
return nil, WrapErrorISE(err, "error generating JWK key pair")
}
}
jwkPubBytes, err := pc.JWK.MarshalJSON()
if err != nil {
return nil, WrapErrorISE(err, "error marshaling JWK")
}
jwePrivStr, err := pc.JWE.CompactSerialize()
if err != nil {
return nil, WrapErrorISE(err, "error serializing JWE")
}
return &ProvisionerDetailsJWK{
Type: ProvisionerTypeJWK,
PublicKey: jwkPubBytes,
PrivateKey: jwePrivStr,
}, nil
}
// ToCertificates converts the landlord provisioner type to the open source // ToCertificates converts the landlord provisioner type to the open source
// provisioner type. // provisioner type.
func (p *Provisioner) ToCertificates() (provisioner.Interface, error) { func (p *Provisioner) ToCertificates() (provisioner.Interface, error) {
@ -461,6 +521,7 @@ type detailsType struct {
Type ProvisionerType Type ProvisionerType
} }
// UnmarshalProvisionerDetails unmarshals bytes into the proper details type.
func UnmarshalProvisionerDetails(data json.RawMessage) (ProvisionerDetails, error) { func UnmarshalProvisionerDetails(data json.RawMessage) (ProvisionerDetails, error) {
dt := new(detailsType) dt := new(detailsType)
if err := json.Unmarshal(data, dt); err != nil { if err := json.Unmarshal(data, dt); err != nil {

View file

@ -315,9 +315,9 @@ retry:
} }
// CreateProvisioner performs the POST /admin/provisioners request to the CA. // CreateProvisioner performs the POST /admin/provisioners request to the CA.
func (c *AdminClient) CreateProvisioner(req *mgmtAPI.CreateProvisionerRequest) (*mgmt.Provisioner, error) { func (c *AdminClient) CreateProvisioner(prov *mgmt.Provisioner) (*mgmt.Provisioner, error) {
var retried bool var retried bool
body, err := json.Marshal(req) body, err := json.Marshal(prov)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request") return nil, errs.Wrap(http.StatusInternalServerError, err, "error marshaling request")
} }
@ -334,11 +334,11 @@ retry:
} }
return nil, readAdminError(resp.Body) return nil, readAdminError(resp.Body)
} }
var prov = new(mgmt.Provisioner) var nuProv = new(mgmt.Provisioner)
if err := readJSON(resp.Body, prov); err != nil { if err := readJSON(resp.Body, nuProv); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errors.Wrapf(err, "error reading %s", u)
} }
return prov, nil return nuProv, nil
} }
// UpdateProvisioner performs the PUT /admin/provisioners/{id} request to the CA. // UpdateProvisioner performs the PUT /admin/provisioners/{id} request to the CA.