network: unplug stateroot service from the Server

Notice that it makes the node accept Extensible payloads with any category
which is the same way C# node works. We're trusting Extensible senders,
improper payloads are harmless until they DoS the network, but we have some
protections against that too (and spamming with proper category doesn't differ
a lot).
This commit is contained in:
Roman Khimov 2022-01-12 21:09:37 +03:00
parent 0ad3ea5944
commit 66aafd868b
6 changed files with 46 additions and 55 deletions

View file

@ -18,6 +18,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/metrics" "github.com/nspcc-dev/neo-go/pkg/network/metrics"
"github.com/nspcc-dev/neo-go/pkg/rpc/server" "github.com/nspcc-dev/neo-go/pkg/rpc/server"
"github.com/nspcc-dev/neo-go/pkg/services/oracle" "github.com/nspcc-dev/neo-go/pkg/services/oracle"
"github.com/nspcc-dev/neo-go/pkg/services/stateroot"
"github.com/urfave/cli" "github.com/urfave/cli"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore" "go.uber.org/zap/zapcore"
@ -361,6 +362,12 @@ func startServer(ctx *cli.Context) error {
if err != nil { if err != nil {
return cli.NewExitError(fmt.Errorf("failed to create network server: %w", err), 1) return cli.NewExitError(fmt.Errorf("failed to create network server: %w", err), 1)
} }
sr, err := stateroot.New(serverConfig.StateRootCfg, log, chain, serv.BroadcastExtensible)
if err != nil {
return cli.NewExitError(fmt.Errorf("can't initialize StateRoot service: %w", err), 1)
}
serv.AddExtensibleService(sr, stateroot.Category, sr.OnPayload)
oracleSrv, err := mkOracle(serverConfig, chain, serv, log) oracleSrv, err := mkOracle(serverConfig, chain, serv, log)
if err != nil { if err != nil {
return err return err

View file

@ -52,7 +52,7 @@ type Service interface {
Shutdown() Shutdown()
// OnPayload is a callback to notify Service about new received payload. // OnPayload is a callback to notify Service about new received payload.
OnPayload(p *npayload.Extensible) OnPayload(p *npayload.Extensible) error
// OnTransaction is a callback to notify Service about new received transaction. // OnTransaction is a callback to notify Service about new received transaction.
OnTransaction(tx *transaction.Transaction) OnTransaction(tx *transaction.Transaction)
} }
@ -365,26 +365,27 @@ func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload {
} }
// OnPayload handles Payload receive. // OnPayload handles Payload receive.
func (s *service) OnPayload(cp *npayload.Extensible) { func (s *service) OnPayload(cp *npayload.Extensible) error {
log := s.log.With(zap.Stringer("hash", cp.Hash())) log := s.log.With(zap.Stringer("hash", cp.Hash()))
p := s.payloadFromExtensible(cp) p := s.payloadFromExtensible(cp)
// decode payload data into message // decode payload data into message
if err := p.decodeData(); err != nil { if err := p.decodeData(); err != nil {
log.Info("can't decode payload data", zap.Error(err)) log.Info("can't decode payload data", zap.Error(err))
return return nil
} }
if !s.validatePayload(p) { if !s.validatePayload(p) {
log.Info("can't validate payload") log.Info("can't validate payload")
return return nil
} }
if s.dbft == nil || !s.started.Load() { if s.dbft == nil || !s.started.Load() {
log.Debug("dbft is inactive or not started yet") log.Debug("dbft is inactive or not started yet")
return return nil
} }
s.messages <- *p s.messages <- *p
return nil
} }
func (s *service) OnTransaction(tx *transaction.Transaction) { func (s *service) OnTransaction(tx *transaction.Transaction) {

View file

@ -351,7 +351,7 @@ func TestService_OnPayload(t *testing.T) {
p.encodeData() p.encodeData()
// sender is invalid // sender is invalid
srv.OnPayload(&p.Extensible) require.NoError(t, srv.OnPayload(&p.Extensible))
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
p = new(Payload) p = new(Payload)
@ -359,7 +359,7 @@ func TestService_OnPayload(t *testing.T) {
p.Sender = priv.GetScriptHash() p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{}) p.SetPayload(&prepareRequest{})
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
srv.OnPayload(&p.Extensible) require.NoError(t, srv.OnPayload(&p.Extensible))
shouldReceive(t, srv.messages) shouldReceive(t, srv.messages)
} }

View file

@ -25,7 +25,6 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/extpool"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/services/notary" "github.com/nspcc-dev/neo-go/pkg/services/notary"
"github.com/nspcc-dev/neo-go/pkg/services/stateroot"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
@ -83,6 +82,7 @@ type (
notaryFeer NotaryFeer notaryFeer NotaryFeer
notaryModule *notary.Notary notaryModule *notary.Notary
services []Service services []Service
extensHandlers map[string]func(*payload.Extensible) error
txInLock sync.Mutex txInLock sync.Mutex
txInMap map[util.Uint256]struct{} txInMap map[util.Uint256]struct{}
@ -103,7 +103,6 @@ type (
syncReached *atomic.Bool syncReached *atomic.Bool
stateRoot stateroot.Service
stateSync blockchainer.StateSync stateSync blockchainer.StateSync
log *zap.Logger log *zap.Logger
@ -159,6 +158,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
extensiblePool: extpool.New(chain, config.ExtensiblePoolSize), extensiblePool: extpool.New(chain, config.ExtensiblePoolSize),
log: log, log: log,
transactions: make(chan *transaction.Transaction, 64), transactions: make(chan *transaction.Transaction, 64),
extensHandlers: make(map[string]func(*payload.Extensible) error),
} }
if chain.P2PSigExtensionsEnabled() { if chain.P2PSigExtensionsEnabled() {
s.notaryFeer = NewNotaryFeer(chain) s.notaryFeer = NewNotaryFeer(chain)
@ -194,17 +194,6 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
s.tryStartServices() s.tryStartServices()
}) })
if config.StateRootCfg.Enabled && chain.GetConfig().StateRootInHeader {
return nil, errors.New("`StateRootInHeader` should be disabled when state service is enabled")
}
sr, err := stateroot.New(config.StateRootCfg, s.log, chain, s.handleNewPayload)
if err != nil {
return nil, fmt.Errorf("can't initialize StateRoot service: %w", err)
}
s.stateRoot = sr
s.services = append(s.services, sr)
sSync := chain.GetStateSyncModule() sSync := chain.GetStateSyncModule()
s.stateSync = sSync s.stateSync = sSync
s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil) s.bSyncQueue = newBlockQueue(maxBlockBatch, sSync, log, nil)
@ -212,7 +201,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
if config.Wallet != nil { if config.Wallet != nil {
srv, err := newConsensus(consensus.Config{ srv, err := newConsensus(consensus.Config{
Logger: log, Logger: log,
Broadcast: s.handleNewPayload, Broadcast: s.BroadcastExtensible,
Chain: chain, Chain: chain,
ProtocolConfiguration: chain.GetConfig(), ProtocolConfiguration: chain.GetConfig(),
RequestTx: s.requestTx, RequestTx: s.requestTx,
@ -225,7 +214,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
} }
s.consensus = srv s.consensus = srv
s.services = append(s.services, srv) s.AddExtensibleService(srv, consensus.Category, srv.OnPayload)
} }
if s.MinPeers < 0 { if s.MinPeers < 0 {
@ -306,9 +295,10 @@ func (s *Server) AddService(svc Service) {
s.services = append(s.services, svc) s.services = append(s.services, svc)
} }
// GetStateRoot returns state root service instance. // AddExtensibleService register a service that handles extensible payload of some kind.
func (s *Server) GetStateRoot() stateroot.Service { func (s *Server) AddExtensibleService(svc Service, category string, handler func(*payload.Extensible) error) {
return s.stateRoot s.extensHandlers[category] = handler
s.AddService(svc)
} }
// UnconnectedPeers returns a list of peers that are in the discovery peer list // UnconnectedPeers returns a list of peers that are in the discovery peer list
@ -946,27 +936,26 @@ func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if !ok { // payload is already in cache if !ok { // payload is already in cache
return nil return nil
} }
switch e.Category { handler := s.extensHandlers[e.Category]
case consensus.Category: if handler != nil {
if s.consensus != nil { err = handler(e)
s.consensus.OnPayload(e)
}
case stateroot.Category:
err := s.stateRoot.OnPayload(e)
if err != nil { if err != nil {
return err return err
} }
default: }
return errors.New("invalid category") s.advertiseExtensible(e)
return nil
} }
func (s *Server) advertiseExtensible(e *payload.Extensible) {
msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()}))
if e.Category == consensus.Category { if e.Category == consensus.Category {
// It's high priority because it directly affects consensus process,
// even though it's just an inv.
s.broadcastHPMessage(msg) s.broadcastHPMessage(msg)
} else { } else {
s.broadcastMessage(msg) s.broadcastMessage(msg)
} }
return nil
} }
// handleTxCmd processes received transaction. // handleTxCmd processes received transaction.
@ -1253,22 +1242,17 @@ func (s *Server) tryInitStateSync() {
} }
} }
} }
func (s *Server) handleNewPayload(p *payload.Extensible) {
// BroadcastExtensible add locally-generated Extensible payload to the pool
// and advertises it to peers.
func (s *Server) BroadcastExtensible(p *payload.Extensible) {
_, err := s.extensiblePool.Add(p) _, err := s.extensiblePool.Add(p)
if err != nil { if err != nil {
s.log.Error("created payload is not valid", zap.Error(err)) s.log.Error("created payload is not valid", zap.Error(err))
return return
} }
msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()})) s.advertiseExtensible(p)
switch p.Category {
case consensus.Category:
// It's high priority because it directly affects consensus process,
// even though it's just an inv.
s.broadcastHPMessage(msg)
default:
s.broadcastMessage(msg)
}
} }
func (s *Server) requestTx(hashes ...util.Uint256) { func (s *Server) requestTx(hashes ...util.Uint256) {

View file

@ -43,7 +43,10 @@ func newFakeConsensus(c consensus.Config) (consensus.Service, error) {
} }
func (f *fakeConsensus) Start() { f.started.Store(true) } func (f *fakeConsensus) Start() { f.started.Store(true) }
func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) } func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) }
func (f *fakeConsensus) OnPayload(p *payload.Extensible) { f.payloads = append(f.payloads, p) } func (f *fakeConsensus) OnPayload(p *payload.Extensible) error {
f.payloads = append(f.payloads, p)
return nil
}
func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) } func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) }
func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
@ -455,13 +458,6 @@ func TestConsensus(t *testing.T) {
msg := newConsensusMessage(s.chain.BlockHeight()+1, s.chain.BlockHeight()+2) msg := newConsensusMessage(s.chain.BlockHeight()+1, s.chain.BlockHeight()+2)
require.Error(t, s.handleMessage(p, msg)) require.Error(t, s.handleMessage(p, msg))
}) })
t.Run("invalid category", func(t *testing.T) {
pl := payload.NewExtensible()
pl.Category = "invalid"
pl.ValidBlockEnd = s.chain.BlockHeight() + 1
msg := NewMessage(CMDExtensible, pl)
require.Error(t, s.handleMessage(p, msg))
})
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {

View file

@ -77,6 +77,9 @@ func New(cfg config.StateRoot, log *zap.Logger, bc blockchainer.Blockchainer, cb
s.MainCfg = cfg s.MainCfg = cfg
if cfg.Enabled { if cfg.Enabled {
if bcConf.StateRootInHeader {
return nil, errors.New("`StateRootInHeader` should be disabled when state service is enabled")
}
var err error var err error
w := cfg.UnlockWallet w := cfg.UnlockWallet
if s.wallet, err = wallet.NewWalletFromFile(w.Path); err != nil { if s.wallet, err = wallet.NewWalletFromFile(w.Path); err != nil {