Add GetID() and add authority to initial context
This commit is contained in:
parent
1e03bbb1af
commit
8942422973
2 changed files with 48 additions and 9 deletions
|
@ -250,6 +250,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
ctx := NewContext(context.Background(), a)
|
||||||
|
|
||||||
// Set password if they are not set.
|
// Set password if they are not set.
|
||||||
var configPassword []byte
|
var configPassword []byte
|
||||||
|
@ -285,7 +286,7 @@ func (a *Authority) init() error {
|
||||||
if a.config.KMS != nil {
|
if a.config.KMS != nil {
|
||||||
options = *a.config.KMS
|
options = *a.config.KMS
|
||||||
}
|
}
|
||||||
a.keyManager, err = kms.New(context.Background(), options)
|
a.keyManager, err = kms.New(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -315,7 +316,7 @@ func (a *Authority) init() error {
|
||||||
|
|
||||||
// Configure linked RA
|
// Configure linked RA
|
||||||
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
if linkedcaClient != nil && options.CertificateAuthority == "" {
|
||||||
conf, err := linkedcaClient.GetConfiguration(context.Background())
|
conf, err := linkedcaClient.GetConfiguration(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -349,7 +350,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.x509CAService, err = cas.New(context.Background(), options)
|
a.x509CAService, err = cas.New(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -536,7 +537,7 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
a.scepService, err = scep.NewService(context.Background(), options)
|
a.scepService, err = scep.NewService(ctx, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -558,19 +559,19 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
provs, err := a.adminDB.GetProvisioners(context.Background())
|
provs, err := a.adminDB.GetProvisioners(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
|
||||||
}
|
}
|
||||||
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
|
||||||
// Create First Provisioner
|
// Create First Provisioner
|
||||||
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password))
|
prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return admin.WrapErrorISE(err, "error creating first provisioner")
|
return admin.WrapErrorISE(err, "error creating first provisioner")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create first admin
|
// Create first admin
|
||||||
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{
|
if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
|
||||||
ProvisionerId: prov.Id,
|
ProvisionerId: prov.Id,
|
||||||
Subject: "step",
|
Subject: "step",
|
||||||
Type: linkedca.Admin_SUPER_ADMIN,
|
Type: linkedca.Admin_SUPER_ADMIN,
|
||||||
|
@ -581,12 +582,12 @@ func (a *Authority) init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load Provisioners and Admins
|
// Load Provisioners and Admins
|
||||||
if err := a.reloadAdminResources(context.Background()); err != nil {
|
if err := a.reloadAdminResources(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load x509 and SSH Policy Engines
|
// Load x509 and SSH Policy Engines
|
||||||
if err := a.reloadPolicyEngines(context.Background()); err != nil {
|
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -611,6 +612,15 @@ func (a *Authority) init() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetID returns the define authority id or a zero uuid.
|
||||||
|
func (a *Authority) GetID() string {
|
||||||
|
const zeroUUID = "00000000-0000-0000-0000-000000000000"
|
||||||
|
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return zeroUUID
|
||||||
|
}
|
||||||
|
|
||||||
// GetDatabase returns the authority database. If the configuration does not
|
// GetDatabase returns the authority database. If the configuration does not
|
||||||
// define a database, GetDatabase will return a db.SimpleDB instance.
|
// define a database, GetDatabase will return a db.SimpleDB instance.
|
||||||
func (a *Authority) GetDatabase() db.AuthDB {
|
func (a *Authority) GetDatabase() db.AuthDB {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
"github.com/smallstep/certificates/authority/config"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthority_GetID(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
authorityID string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
|
||||||
|
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
a := &Authority{
|
||||||
|
config: &config.Config{
|
||||||
|
AuthorityConfig: &config.AuthConfig{
|
||||||
|
AuthorityID: tt.fields.authorityID,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if got := a.GetID(); got != tt.want {
|
||||||
|
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue