diff --git a/authority/authority.go b/authority/authority.go index cdf2c8bf..c184c6e9 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -250,6 +250,7 @@ func (a *Authority) init() error { } var err error + ctx := NewContext(context.Background(), a) // Set password if they are not set. var configPassword []byte @@ -285,7 +286,7 @@ func (a *Authority) init() error { if a.config.KMS != nil { options = *a.config.KMS } - a.keyManager, err = kms.New(context.Background(), options) + a.keyManager, err = kms.New(ctx, options) if err != nil { return err } @@ -315,7 +316,7 @@ func (a *Authority) init() error { // Configure linked RA if linkedcaClient != nil && options.CertificateAuthority == "" { - conf, err := linkedcaClient.GetConfiguration(context.Background()) + conf, err := linkedcaClient.GetConfiguration(ctx) if err != nil { return err } @@ -349,7 +350,7 @@ func (a *Authority) init() error { } } - a.x509CAService, err = cas.New(context.Background(), options) + a.x509CAService, err = cas.New(ctx, options) if err != nil { return err } @@ -536,7 +537,7 @@ func (a *Authority) init() error { } } - a.scepService, err = scep.NewService(context.Background(), options) + a.scepService, err = scep.NewService(ctx, options) if err != nil { return err } @@ -558,19 +559,19 @@ func (a *Authority) init() error { } } - provs, err := a.adminDB.GetProvisioners(context.Background()) + provs, err := a.adminDB.GetProvisioners(ctx) if err != nil { return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") } if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { // Create First Provisioner - prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) + prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password)) if err != nil { return admin.WrapErrorISE(err, "error creating first provisioner") } // Create first admin - if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ + if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{ ProvisionerId: prov.Id, Subject: "step", Type: linkedca.Admin_SUPER_ADMIN, @@ -581,12 +582,12 @@ func (a *Authority) init() error { } // Load Provisioners and Admins - if err := a.reloadAdminResources(context.Background()); err != nil { + if err := a.reloadAdminResources(ctx); err != nil { return err } // Load x509 and SSH Policy Engines - if err := a.reloadPolicyEngines(context.Background()); err != nil { + if err := a.reloadPolicyEngines(ctx); err != nil { return err } @@ -611,6 +612,15 @@ func (a *Authority) init() error { return nil } +// GetID returns the define authority id or a zero uuid. +func (a *Authority) GetID() string { + const zeroUUID = "00000000-0000-0000-0000-000000000000" + if id := a.config.AuthorityConfig.AuthorityID; id != "" { + return id + } + return zeroUUID +} + // GetDatabase returns the authority database. If the configuration does not // define a database, GetDatabase will return a db.SimpleDB instance. func (a *Authority) GetDatabase() db.AuthDB { diff --git a/authority/authority_test.go b/authority/authority_test.go index 1f63333d..9f35f23e 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "go.step.sm/crypto/jose" @@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) { }) } } + +func TestAuthority_GetID(t *testing.T) { + type fields struct { + authorityID string + } + tests := []struct { + name string + fields fields + want string + }{ + {"ok", fields{""}, "00000000-0000-0000-0000-000000000000"}, + {"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + AuthorityID: tt.fields.authorityID, + }, + }, + } + if got := a.GetID(); got != tt.want { + t.Errorf("Authority.GetID() = %v, want %v", got, tt.want) + } + }) + } +}