certificates/authority/mgmt/db/nosql/provisioner.go
max furman 5d09d04d14 wip
2021-05-19 15:20:16 -07:00

371 lines
11 KiB
Go

package nosql
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/mgmt"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
// dbProvisioner is the database representation of a Provisioner type.
type dbProvisioner struct {
ID string `json:"id"`
AuthorityID string `json:"authorityID"`
Type string `json:"type"`
// Name is the key
Name string `json:"name"`
Claims *mgmt.Claims `json:"claims"`
Details []byte `json:"details"`
X509Template string `json:"x509Template"`
X509TemplateData []byte `json:"x509TemplateData"`
SSHTemplate string `json:"sshTemplate"`
SSHTemplateData []byte `json:"sshTemplateData"`
CreatedAt time.Time `json:"createdAt"`
DeletedAt time.Time `json:"deletedAt"`
}
type provisionerNameID struct {
Name string `json:"name"`
ID string `json:"id"`
}
func (dbp *dbProvisioner) clone() *dbProvisioner {
u := *dbp
return &u
}
func (db *DB) getProvisionerIDByName(ctx context.Context, name string) (string, error) {
data, err := db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if nosql.IsErrNotFound(err) {
return "", mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", name)
} else if err != nil {
return "", mgmt.WrapErrorISE(err, "error loading provisioner %s", name)
}
ni := new(provisionerNameID)
if err := json.Unmarshal(data, ni); err != nil {
return "", mgmt.WrapErrorISE(err, "error unmarshaling provisionerNameID for provisioner %s", name)
}
return ni.ID, nil
}
func (db *DB) getDBProvisionerBytes(ctx context.Context, id string) ([]byte, error) {
data, err := db.db.Get(authorityProvisionersTable, []byte(id))
if nosql.IsErrNotFound(err) {
return nil, mgmt.NewError(mgmt.ErrorNotFoundType, "provisioner %s not found", id)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading provisioner %s", id)
}
return data, nil
}
func (db *DB) getDBProvisioner(ctx context.Context, id string) (*dbProvisioner, error) {
data, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return nil, err
}
dbp, err := unmarshalDBProvisioner(data, id)
if err != nil {
return nil, err
}
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", id)
}
if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", dbp.ID, db.authorityID)
}
return dbp, nil
}
func (db *DB) getDBProvisionerByName(ctx context.Context, name string) (*dbProvisioner, error) {
id, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
dbp, err := db.getDBProvisioner(ctx, id)
if err != nil {
return nil, err
}
if !dbp.DeletedAt.IsZero() {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", name)
}
if dbp.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", name, db.authorityID)
}
return dbp, nil
}
// GetProvisioner retrieves and unmarshals a provisioner from the database.
func (db *DB) GetProvisioner(ctx context.Context, id string) (*mgmt.Provisioner, error) {
data, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return nil, err
}
prov, err := unmarshalProvisioner(data, id)
if err != nil {
return nil, err
}
if prov.Status == mgmt.StatusDeleted {
return nil, mgmt.NewError(mgmt.ErrorDeletedType, "provisioner %s is deleted", prov.ID)
}
if prov.AuthorityID != db.authorityID {
return nil, mgmt.NewError(mgmt.ErrorAuthorityMismatchType,
"provisioner %s is not owned by authority %s", prov.ID, db.authorityID)
}
return prov, nil
}
// GetProvisionerByName retrieves a provisioner from the database by name.
func (db *DB) GetProvisionerByName(ctx context.Context, name string) (*mgmt.Provisioner, error) {
p, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return nil, err
}
return db.GetProvisioner(ctx, id)
}
func unmarshalDBProvisioner(data []byte, name string) (*dbProvisioner, error) {
var dbp = new(dbProvisioner)
if err := json.Unmarshal(data, dbp); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling provisioner %s into dbProvisioner", name)
}
return dbp, nil
}
func unmarshalProvisioner(data []byte, name string) (*mgmt.Provisioner, error) {
dbp, err := unmarshalDBProvisioner(data, name)
if err != nil {
return nil, err
}
details, err := mgmt.UnmarshalProvisionerDetails(dbp.Details)
if err != nil {
return nil, err
}
prov := &mgmt.Provisioner{
ID: dbp.ID,
AuthorityID: dbp.AuthorityID,
Type: dbp.Type,
Name: dbp.Name,
Claims: dbp.Claims,
Details: details,
Status: mgmt.StatusActive,
X509Template: dbp.X509Template,
X509TemplateData: dbp.X509TemplateData,
SSHTemplate: dbp.SSHTemplate,
SSHTemplateData: dbp.SSHTemplateData,
}
if !dbp.DeletedAt.IsZero() {
prov.Status = mgmt.StatusDeleted
}
return prov, nil
}
// GetProvisioners retrieves and unmarshals all active (not deleted) provisioners
// from the database.
func (db *DB) GetProvisioners(ctx context.Context) ([]*mgmt.Provisioner, error) {
dbEntries, err := db.db.List(authorityProvisionersTable)
if err != nil {
return nil, mgmt.WrapErrorISE(err, "error loading provisioners")
}
var provs []*mgmt.Provisioner
for _, entry := range dbEntries {
prov, err := unmarshalProvisioner(entry.Value, string(entry.Key))
if err != nil {
return nil, err
}
if prov.Status == mgmt.StatusDeleted {
continue
}
if prov.AuthorityID != db.authorityID {
continue
}
provs = append(provs, prov)
}
return provs, nil
}
// CreateProvisioner stores a new provisioner to the database.
func (db *DB) CreateProvisioner(ctx context.Context, prov *mgmt.Provisioner) error {
var err error
prov.ID, err = randID()
if err != nil {
return errors.Wrap(err, "error generating random id for provisioner")
}
details, err := json.Marshal(prov.Details)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when creating provisioner %s", prov.Name)
}
dbp := &dbProvisioner{
ID: prov.ID,
AuthorityID: db.authorityID,
Type: prov.Type,
Name: prov.Name,
Claims: prov.Claims,
Details: details,
X509Template: prov.X509Template,
X509TemplateData: prov.X509TemplateData,
SSHTemplate: prov.SSHTemplate,
SSHTemplateData: prov.SSHTemplateData,
CreatedAt: clock.Now(),
}
dbpBytes, err := json.Marshal(dbp)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameIndex %s", prov.Name)
}
if err := db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersTable,
Key: []byte(dbp.ID),
Cmd: database.CmpAndSwap,
Value: dbpBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(dbp.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
},
}); err != nil {
return mgmt.WrapErrorISE(err, "error creating provisioner %s", prov.Name)
}
return nil
}
// UpdateProvisioner saves an updated provisioner to the database.
func (db *DB) UpdateProvisioner(ctx context.Context, name string, prov *mgmt.Provisioner) error {
id, err := db.getProvisionerIDByName(ctx, name)
if err != nil {
return err
}
prov.ID = id
oldBytes, err := db.getDBProvisionerBytes(ctx, id)
if err != nil {
return err
}
fmt.Printf("oldBytes = %+v\n", oldBytes)
old, err := unmarshalDBProvisioner(oldBytes, id)
if err != nil {
return err
}
fmt.Printf("old = %+v\n", old)
nu := old.clone()
nu.Type = prov.Type
nu.Name = prov.Name
nu.Claims = prov.Claims
nu.Details, err = json.Marshal(prov.Details)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling details when updating provisioner %s", name)
}
nu.X509Template = prov.X509Template
nu.X509TemplateData = prov.X509TemplateData
nu.SSHTemplateData = prov.SSHTemplateData
var txs = []*database.TxEntry{}
// If the provisioner was active but is now deleted ...
if old.DeletedAt.IsZero() && prov.Status == mgmt.StatusDeleted {
nu.DeletedAt = clock.Now()
txs = append(txs, &database.TxEntry{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
})
}
if prov.Name != name {
// If the new name does not match the old name then:
// 1) check that the new name is not already taken
// 2) delete the old name-id index resource
// 3) create a new name-id index resource
// 4) update the provisioner resource
nuBytes, err := json.Marshal(nu)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling dbProvisioner %s", prov.Name)
}
pni := &provisionerNameID{
Name: prov.Name,
ID: prov.ID,
}
pniBytes, err := json.Marshal(pni)
if err != nil {
return mgmt.WrapErrorISE(err, "error marshaling provisionerNameID for provisioner %s", prov.Name)
}
_, err = db.db.Get(authorityProvisionersNameIDIndexTable, []byte(name))
if err == nil {
return mgmt.NewError(mgmt.ErrorBadRequestType, "provisioner with name %s already exists", prov.Name)
} else if !nosql.IsErrNotFound(err) {
return mgmt.WrapErrorISE(err, "error loading provisionerNameID %s", prov.Name)
}
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(prov.Name),
Cmd: database.CmpAndSwap,
Value: pniBytes,
CmpValue: nil,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
} else {
err = db.db.Update(&database.Tx{
Operations: []*database.TxEntry{
{
Bucket: authorityProvisionersNameIDIndexTable,
Key: []byte(name),
Cmd: database.Delete,
},
{
Bucket: authorityProvisionersTable,
Key: []byte(nu.ID),
Cmd: database.CmpAndSwap,
Value: nuBytes,
CmpValue: oldBytes,
},
},
})
}
if err != nil {
return mgmt.WrapErrorISE(err, "error updating provisioner %s", prov.Name)
}
return nil
}