diff --git a/authority/authority.go b/authority/authority.go index 5497dc2d..33340029 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -29,8 +29,19 @@ type Authority struct { initOnce bool } +// Option sets options to the Authority. +type Option func(*Authority) + +// WithDatabase sets an already initialized authority database to a new +// authority. This option is intended to be use on graceful reloads. +func WithDatabase(db db.AuthDB) Option { + return func(a *Authority) { + a.db = db + } +} + // New creates and initiates a new Authority type. -func New(config *Config) (*Authority, error) { +func New(config *Config, opts ...Option) (*Authority, error) { err := config.Validate() if err != nil { return nil, err @@ -41,6 +52,9 @@ func New(config *Config) (*Authority, error) { certificates: new(sync.Map), provisioners: provisioner.NewCollection(config.getAudiences()), } + for _, opt := range opts { + opt(a) + } if err := a.init(); err != nil { return nil, err } @@ -55,11 +69,12 @@ func (a *Authority) init() error { } var err error - - // Initialize step-ca Database. + // Initialize step-ca Database if it's not already initialized with WithDB. // If a.config.DB is nil then a simple, barebones in memory DB will be used. - if a.db, err = db.New(a.config.DB); err != nil { - return err + if a.db == nil { + if a.db, err = db.New(a.config.DB); err != nil { + return err + } } // Load the root certificates and add them to the certificate store @@ -118,6 +133,12 @@ func (a *Authority) init() error { return nil } +// GetDatabase returns the authority database. If the configuration does not +// define a database, GetDatabase will return a db.SimpleDB instance. +func (a *Authority) GetDatabase() db.AuthDB { + return a.db +} + // Shutdown safely shuts down any clients, databases, etc. held by the Authority. func (a *Authority) Shutdown() error { return a.db.Shutdown() diff --git a/authority/authority_test.go b/authority/authority_test.go index 8f180457..30ee3121 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -3,11 +3,13 @@ package authority import ( "crypto/sha256" "encoding/hex" + "reflect" "testing" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" stepJOSE "github.com/smallstep/cli/jose" ) @@ -139,3 +141,25 @@ func TestAuthorityNew(t *testing.T) { }) } } + +func TestAuthority_GetDatabase(t *testing.T) { + auth := testAuthority(t) + authWithDatabase, err := New(auth.config, WithDatabase(auth.db)) + assert.FatalError(t, err) + + tests := []struct { + name string + auth *Authority + want db.AuthDB + }{ + {"ok", auth, auth.db}, + {"ok WithDatabase", authWithDatabase, auth.db}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.auth.GetDatabase(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetDatabase() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 800c44d2..09565b1d 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -525,7 +525,10 @@ func doReload(ca *CA) error { return errors.Wrap(err, "error reloading ca") } - newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile)) + newCA, err := New(config, + WithPassword(ca.opts.password), + WithConfigFile(ca.opts.configFile), + WithDatabase(ca.auth.GetDatabase())) if err != nil { return errors.Wrap(err, "error reloading ca") } diff --git a/ca/ca.go b/ca/ca.go index d14c496d..06f36975 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -5,11 +5,13 @@ import ( "crypto/x509" "log" "net/http" + "reflect" "github.com/go-chi/chi" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/server" @@ -18,6 +20,7 @@ import ( type options struct { configFile string password []byte + database db.AuthDB } func (o *options) apply(opts []Option) { @@ -45,6 +48,13 @@ func WithPassword(password []byte) Option { } } +// WithDatabase sets the given authority database to the CA options. +func WithDatabase(db db.AuthDB) Option { + return func(o *options) { + o.database = db + } +} + // CA is the type used to build the complete certificate authority. It builds // the HTTP server, set ups the middlewares and the HTTP handlers. type CA struct { @@ -71,7 +81,12 @@ func (ca *CA) Init(config *authority.Config) (*CA, error) { ca.config.Password = string(ca.opts.password) } - auth, err := authority.New(config) + var opts []authority.Option + if ca.opts.database != nil { + opts = append(opts, authority.WithDatabase(ca.opts.database)) + } + + auth, err := authority.New(config, opts...) if err != nil { return nil, err } @@ -132,60 +147,34 @@ func (ca *CA) Stop() error { // Reload reloads the configuration of the CA and calls to the server Reload // method. func (ca *CA) Reload() error { - var hasDB bool - if ca.config.DB != nil { - hasDB = true - } - if ca.opts.configFile == "" { - return errors.New("error reloading ca: configuration file is not set") - } - config, err := authority.LoadConfiguration(ca.opts.configFile) if err != nil { return errors.Wrap(err, "error reloading ca configuration") } - logShutDown := func(ss ...string) { - for _, s := range ss { - log.Println(s) - } - log.Println("Continuing to serve requests may result in inconsistent state. Shutting Down ...") - } logContinue := func(reason string) { log.Println(reason) log.Println("Continuing to run with the original configuration.") log.Println("You can force a restart by sending a SIGTERM signal and then restarting the step-ca.") } - // Shut down the old authority (shut down the database). If New or Reload - // fails then the CA will continue to run but the database will have been - // shutdown, which will cause errors. - if err := ca.auth.Shutdown(); err != nil { - if hasDB { - logShutDown("Attempt to shut down the ca.Authority has failed.") - return ca.Stop() - } - logContinue("Reload failed because the ca.Authority could not be shut down.") - return err + // Do not allow reload if the database configuration has changed. + if !reflect.DeepEqual(ca.config.DB, config.DB) { + logContinue("Reload failed because the database configuration has changed.") + return errors.New("error reloading ca: database configuration cannot change") } - newCA, err := New(config, WithPassword(ca.opts.password), WithConfigFile(ca.opts.configFile)) + + newCA, err := New(config, + WithPassword(ca.opts.password), + WithConfigFile(ca.opts.configFile), + WithDatabase(ca.auth.GetDatabase()), + ) if err != nil { - if hasDB { - logShutDown("Attempt to initialize a CA with the new configuration has failed.", - "The database has already been shutdown.") - return ca.Stop() - } - logContinue("Reload failed because the CA with new configuration could " + - "not be initialized.") + logContinue("Reload failed because the CA with new configuration could not be initialized.") return errors.Wrap(err, "error reloading ca") } if err = ca.srv.Reload(newCA.srv); err != nil { - if hasDB { - logShutDown("Attempt to replace the old CA server has failed.", - "The database has already been shutdown.") - return ca.Stop() - } logContinue("Reload failed because server could not be replaced.") return errors.Wrap(err, "error reloading server") }