diff --git a/cli/server/server.go b/cli/server/server.go index 62d45f6d1..bf51603d6 100644 --- a/cli/server/server.go +++ b/cli/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "os" "os/signal" @@ -30,6 +31,17 @@ func NewCommand() cli.Command { } } +func newGraceContext() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt) + go func() { + <-stop + cancel() + }() + return ctx +} + func startServer(ctx *cli.Context) error { net := config.ModePrivNet if ctx.Bool("testnet") { @@ -39,6 +51,9 @@ func startServer(ctx *cli.Context) error { net = config.ModeMainNet } + grace, cancel := context.WithCancel(newGraceContext()) + defer cancel() + configPath := "../config" if argCp := ctx.String("config-path"); argCp != "" { configPath = argCp @@ -48,11 +63,8 @@ func startServer(ctx *cli.Context) error { return cli.NewExitError(err, 1) } - interruptChan := make(chan os.Signal, 1) - signal.Notify(interruptChan, os.Interrupt) - serverConfig := network.NewServerConfig(cfg) - chain, err := core.NewBlockchainLevelDB(cfg) + chain, err := core.NewBlockchainLevelDB(grace, cfg) if err != nil { err = fmt.Errorf("could not initialize blockchain: %s", err) return cli.NewExitError(err, 1) @@ -79,9 +91,9 @@ Main: select { case err := <-errChan: shutdownErr = errors.Wrap(err, "Error encountered by server") - interruptChan <- os.Kill + cancel() - case <-interruptChan: + case <-grace.Done(): server.Shutdown() if serverErr := rpcServer.Shutdown(); serverErr != nil { shutdownErr = errors.Wrap(serverErr, "Error encountered whilst shutting down server") diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 7ed7493f2..9324ce46b 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -2,6 +2,7 @@ package core import ( "bytes" + "context" "encoding/binary" "fmt" "sync/atomic" @@ -60,7 +61,7 @@ type headersOpFunc func(headerList *HeaderHashList) // NewBlockchain return a new blockchain object the will use the // given Store as its underlying storage. -func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockchain, error) { +func NewBlockchain(ctx context.Context, s storage.Store, cfg config.ProtocolConfiguration) (*Blockchain, error) { bc := &Blockchain{ config: cfg, Store: s, @@ -69,7 +70,7 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockcha blockCache: NewCache(), verifyBlocks: false, } - go bc.run() + go bc.run(ctx) if err := bc.init(); err != nil { return nil, err @@ -79,8 +80,9 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockcha } // GetBlockchainLevelDB returns blockchain based on configuration -func NewBlockchainLevelDB(cfg config.Config) (*Blockchain, error) { +func NewBlockchainLevelDB(ctx context.Context, cfg config.Config) (*Blockchain, error) { store, err := storage.NewLevelDBStore( + ctx, cfg.ApplicationConfiguration.DataDirectoryPath, nil, ) @@ -88,7 +90,7 @@ func NewBlockchainLevelDB(cfg config.Config) (*Blockchain, error) { return nil, err } - return NewBlockchain(store, cfg.ProtocolConfiguration) + return NewBlockchain(ctx, store, cfg.ProtocolConfiguration) } func (bc *Blockchain) init() error { @@ -165,15 +167,18 @@ func (bc *Blockchain) init() error { return nil } -func (bc *Blockchain) run() { +func (bc *Blockchain) run(ctx context.Context) { persistTimer := time.NewTimer(persistInterval) + defer persistTimer.Stop() for { select { + case <-ctx.Done(): + return case op := <-bc.headersOp: op(bc.headerList) bc.headersOpDone <- struct{}{} case <-persistTimer.C: - go bc.persist() + go bc.persist(ctx) persistTimer.Reset(persistInterval) } } @@ -395,7 +400,7 @@ func (bc *Blockchain) persistBlock(block *Block) error { return nil } -func (bc *Blockchain) persist() (err error) { +func (bc *Blockchain) persist(ctx context.Context) (err error) { var ( start = time.Now() persisted = 0 @@ -422,7 +427,13 @@ func (bc *Blockchain) persist() (err error) { } } } - <-bc.headersOpDone + + select { + case <-ctx.Done(): + return + case <-bc.headersOpDone: + // + } if persisted > 0 { log.WithFields(log.Fields{ diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 8d4842de1..033769ef5 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -1,6 +1,7 @@ package core import ( + "context" "testing" "github.com/CityOfZion/neo-go/config" @@ -54,7 +55,7 @@ func TestAddBlock(t *testing.T) { t.Log(bc.blockCache) - if err := bc.persist(); err != nil { + if err := bc.persist(context.Background()); err != nil { t.Fatal(err) } @@ -118,7 +119,7 @@ func TestHasBlock(t *testing.T) { t.Fatal(err) } } - assert.Nil(t, bc.persist()) + assert.Nil(t, bc.persist(context.Background())) for i := 0; i < len(blocks); i++ { assert.True(t, bc.HasBlock(blocks[i].Hash())) @@ -148,7 +149,7 @@ func newTestChain(t *testing.T) *Blockchain { if err != nil { t.Fatal(err) } - chain, err := NewBlockchain(storage.NewMemoryStore(), cfg.ProtocolConfiguration) + chain, err := NewBlockchain(context.Background(), storage.NewMemoryStore(), cfg.ProtocolConfiguration) if err != nil { t.Fatal(err) } diff --git a/pkg/core/storage/leveldb_store.go b/pkg/core/storage/leveldb_store.go index e05ea9881..c78343a33 100644 --- a/pkg/core/storage/leveldb_store.go +++ b/pkg/core/storage/leveldb_store.go @@ -1,6 +1,8 @@ package storage import ( + "context" + "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/opt" "github.com/syndtr/goleveldb/leveldb/util" @@ -15,11 +17,18 @@ type LevelDBStore struct { // NewLevelDBStore return a new LevelDBStore object that will // initialize the database found at the given path. -func NewLevelDBStore(path string, opts *opt.Options) (*LevelDBStore, error) { +func NewLevelDBStore(ctx context.Context, path string, opts *opt.Options) (*LevelDBStore, error) { db, err := leveldb.OpenFile(path, opts) if err != nil { return nil, err } + + // graceful shutdown + go func() { + <-ctx.Done() + db.Close() + }() + return &LevelDBStore{ path: path, db: db, diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index 5a739e8ad..c9b0d374b 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -2,6 +2,7 @@ package rpc import ( "bytes" + "context" "fmt" "io/ioutil" "net/http" @@ -25,7 +26,7 @@ func TestHandler(t *testing.T) { t.Fatal("could not create levelDB chain", err) } - chain, err := core.NewBlockchainLevelDB(cfg) + chain, err := core.NewBlockchainLevelDB(context.Background(), cfg) if err != nil { t.Fatal("could not create levelDB chain", err) }