cli: add db dump/restore commands

Dump given number of blocks (from the given offset) to file and restore them
from it. Fixes #436.
This commit is contained in:
Roman Khimov 2019-10-21 08:41:05 +03:00
parent b533dfceba
commit c885a69edb

View file

@ -9,6 +9,7 @@ import (
"github.com/CityOfZion/neo-go/config"
"github.com/CityOfZion/neo-go/pkg/core"
"github.com/CityOfZion/neo-go/pkg/core/storage"
"github.com/CityOfZion/neo-go/pkg/io"
"github.com/CityOfZion/neo-go/pkg/network"
"github.com/CityOfZion/neo-go/pkg/rpc"
"github.com/pkg/errors"
@ -18,18 +19,63 @@ import (
// NewCommands returns 'node' command.
func NewCommands() []cli.Command {
return []cli.Command{{
Name: "node",
Usage: "start a NEO node",
Action: startServer,
Flags: []cli.Flag{
cli.StringFlag{Name: "config-path"},
cli.BoolFlag{Name: "privnet, p"},
cli.BoolFlag{Name: "mainnet, m"},
cli.BoolFlag{Name: "testnet, t"},
cli.BoolFlag{Name: "debug, d"},
var cfgFlags = []cli.Flag{
cli.StringFlag{Name: "config-path"},
cli.BoolFlag{Name: "privnet, p"},
cli.BoolFlag{Name: "mainnet, m"},
cli.BoolFlag{Name: "testnet, t"},
cli.BoolFlag{Name: "debug, d"},
}
var cfgWithCountFlags = make([]cli.Flag, len(cfgFlags))
copy(cfgWithCountFlags, cfgFlags)
cfgWithCountFlags = append(cfgWithCountFlags,
cli.UintFlag{
Name: "count, c",
Usage: "number of blocks to be processed (default or 0: all chain)",
},
}}
cli.UintFlag{
Name: "skip, s",
Usage: "number of blocks to skip (default: 0)",
},
)
var cfgCountOutFlags = make([]cli.Flag, len(cfgWithCountFlags))
copy(cfgCountOutFlags, cfgWithCountFlags)
cfgCountOutFlags = append(cfgCountOutFlags, cli.StringFlag{
Name: "out, o",
Usage: "Output file (stdout if not given)",
})
var cfgCountInFlags = make([]cli.Flag, len(cfgWithCountFlags))
copy(cfgCountInFlags, cfgWithCountFlags)
cfgCountInFlags = append(cfgCountInFlags, cli.StringFlag{
Name: "in, i",
Usage: "Input file (stdin if not given)",
})
return []cli.Command{
{
Name: "node",
Usage: "start a NEO node",
Action: startServer,
Flags: cfgFlags,
},
{
Name: "db",
Usage: "database manipulations",
Subcommands: []cli.Command{
{
Name: "dump",
Usage: "dump blocks (starting with block #1) to the file",
Action: dumpDB,
Flags: cfgCountOutFlags,
},
{
Name: "restore",
Usage: "restore blocks from the file",
Action: restoreDB,
Flags: cfgCountInFlags,
},
},
},
}
}
func newGraceContext() context.Context {
@ -43,26 +89,154 @@ func newGraceContext() context.Context {
return ctx
}
func startServer(ctx *cli.Context) error {
net := config.ModePrivNet
// getConfigFromContext looks at path and mode flags in the given config and
// returns appropriate config.
func getConfigFromContext(ctx *cli.Context) (config.Config, error) {
var net = config.ModePrivNet
if ctx.Bool("testnet") {
net = config.ModeTestNet
}
if ctx.Bool("mainnet") {
net = config.ModeMainNet
}
grace, cancel := context.WithCancel(newGraceContext())
defer cancel()
configPath := "./config"
if argCp := ctx.String("config-path"); argCp != "" {
configPath = argCp
}
cfg, err := config.Load(configPath, net)
return config.Load(configPath, net)
}
// handleLoggingParams enables debugging output is that's requested by the user.
func handleLoggingParams(ctx *cli.Context) {
if ctx.Bool("debug") {
log.SetLevel(log.DebugLevel)
}
}
func getCountAndSkipFromContext(ctx *cli.Context) (uint32, uint32) {
count := uint32(ctx.Uint("count"))
skip := uint32(ctx.Uint("skip"))
return count, skip
}
func dumpDB(ctx *cli.Context) error {
cfg, err := getConfigFromContext(ctx)
if err != nil {
return cli.NewExitError(err, 1)
}
handleLoggingParams(ctx)
count, skip := getCountAndSkipFromContext(ctx)
var outStream = os.Stdout
if out := ctx.String("out"); out != "" {
outStream, err = os.Create(out)
if err != nil {
return cli.NewExitError(err, 1)
}
}
defer outStream.Close()
writer := io.NewBinWriterFromIO(outStream)
grace, cancel := context.WithCancel(newGraceContext())
defer cancel()
chain, err := initBlockChain(cfg)
if err != nil {
return cli.NewExitError(err, 1)
}
go chain.Run(grace)
chainHeight := chain.BlockHeight()
if skip+count > chainHeight {
return cli.NewExitError(fmt.Errorf("chain is not that high (%d) to dump %d blocks starting from %d", chainHeight, count, skip), 1)
}
if count == 0 {
count = chainHeight - skip
}
writer.WriteLE(count)
for i := skip + 1; i <= count; i++ {
bh := chain.GetHeaderHash(int(i))
b, err := chain.GetBlock(bh)
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to get block %d: %s", i, err), 1)
}
b.EncodeBinary(writer)
if writer.Err != nil {
return cli.NewExitError(err, 1)
}
}
return nil
}
func restoreDB(ctx *cli.Context) error {
cfg, err := getConfigFromContext(ctx)
if err != nil {
return err
}
handleLoggingParams(ctx)
count, skip := getCountAndSkipFromContext(ctx)
var inStream = os.Stdin
if in := ctx.String("in"); in != "" {
inStream, err = os.Open(in)
if err != nil {
return cli.NewExitError(err, 1)
}
}
defer inStream.Close()
reader := io.NewBinReaderFromIO(inStream)
grace, cancel := context.WithCancel(newGraceContext())
defer cancel()
chain, err := initBlockChain(cfg)
if err != nil {
return err
}
go chain.Run(grace)
var allBlocks uint32
reader.ReadLE(&allBlocks)
if reader.Err != nil {
return cli.NewExitError(err, 1)
}
if skip+count > allBlocks {
return cli.NewExitError(fmt.Errorf("input file has only %d blocks, can't read %d starting from %d", allBlocks, count, skip), 1)
}
if count == 0 {
count = allBlocks
}
i := uint32(0)
for ; i < skip; i++ {
b := &core.Block{}
b.DecodeBinary(reader)
if reader.Err != nil {
return cli.NewExitError(err, 1)
}
}
for ; i < count; i++ {
b := &core.Block{}
b.DecodeBinary(reader)
if reader.Err != nil {
return cli.NewExitError(err, 1)
}
err := chain.AddBlock(b)
if err != nil {
return cli.NewExitError(fmt.Errorf("failed to add block %d: %s", i, err), 1)
}
}
return nil
}
func startServer(ctx *cli.Context) error {
cfg, err := getConfigFromContext(ctx)
if err != nil {
return err
}
handleLoggingParams(ctx)
grace, cancel := context.WithCancel(newGraceContext())
defer cancel()
serverConfig := network.NewServerConfig(cfg)
@ -71,10 +245,6 @@ func startServer(ctx *cli.Context) error {
return err
}
if ctx.Bool("debug") {
log.SetLevel(log.DebugLevel)
}
server := network.NewServer(serverConfig, chain)
rpcServer := rpc.NewServer(chain, cfg.ApplicationConfiguration.RPCPort, server)
errChan := make(chan error)