diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 1f6649194..a35230f02 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -55,16 +55,12 @@ type Service interface { OnPayload(p *npayload.Extensible) // OnTransaction is a callback to notify Service about new received transaction. OnTransaction(tx *transaction.Transaction) - // GetPayload returns Payload with specified hash if it is present in the local cache. - GetPayload(h util.Uint256) *npayload.Extensible } type service struct { Config log *zap.Logger - // cache is a fifo cache which stores recent payloads. - cache *relayCache // txx is a fifo cache which stores miner transactions. txx *relayCache dbft *dbft.DBFT @@ -124,7 +120,6 @@ func NewService(cfg Config) (Service, error) { Config: cfg, log: cfg.Logger, - cache: newFIFOCache(cacheMaxCapacity), txx: newFIFOCache(cacheMaxCapacity), messages: make(chan Payload, 100), @@ -379,11 +374,6 @@ func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload { // OnPayload handles Payload receive. func (s *service) OnPayload(cp *npayload.Extensible) { log := s.log.With(zap.Stringer("hash", cp.Hash())) - if s.cache.Has(cp.Hash()) { - log.Debug("payload is already in cache") - return - } - p := s.payloadFromExtensible(cp) p.decodeData() if !s.validatePayload(p) { @@ -391,9 +381,6 @@ func (s *service) OnPayload(cp *npayload.Extensible) { return } - s.Config.Broadcast(cp) - s.cache.Add(cp) - if s.dbft == nil || !s.started.Load() { log.Debug("dbft is inactive or not started yet") return @@ -416,25 +403,12 @@ func (s *service) OnTransaction(tx *transaction.Transaction) { } } -// GetPayload returns payload stored in cache. -func (s *service) GetPayload(h util.Uint256) *npayload.Extensible { - p := s.cache.Get(h) - if p == nil { - return (*npayload.Extensible)(nil) - } - - cp := *p.(*npayload.Extensible) - - return &cp -} - func (s *service) broadcast(p payload.ConsensusPayload) { if err := p.(*Payload).Sign(s.dbft.Priv.(*privateKey)); err != nil { s.log.Warn("can't sign consensus payload", zap.Error(err)) } ep := &p.(*Payload).Extensible - s.cache.Add(ep) s.Config.Broadcast(ep) } diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 3f8737dc0..541033b4a 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -351,7 +351,6 @@ func TestService_OnPayload(t *testing.T) { // sender is invalid srv.OnPayload(&p.Extensible) shouldNotReceive(t, srv.messages) - require.Nil(t, srv.GetPayload(p.Hash())) p = new(Payload) p.SetValidatorIndex(1) @@ -360,11 +359,6 @@ func TestService_OnPayload(t *testing.T) { require.NoError(t, p.Sign(priv)) srv.OnPayload(&p.Extensible) shouldReceive(t, srv.messages) - require.Equal(t, &p.Extensible, srv.GetPayload(p.Hash())) - - // payload has already been received - srv.OnPayload(&p.Extensible) - shouldNotReceive(t, srv.messages) srv.Chain.Close() } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index d179f8454..0b7f4e336 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -131,6 +131,8 @@ type Blockchain struct { contracts native.Contracts + extensible atomic.Value + // Notification subsystem. events chan bcEvent subCh chan interface{} @@ -297,7 +299,7 @@ func (bc *Blockchain) init() error { return fmt.Errorf("can't init cache for Management native contract: %w", err) } - return nil + return bc.updateExtensibleWhitelist(bHeight) } // Run runs chain loop, it needs to be run as goroutine and executing it is @@ -759,6 +761,10 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error for _, f := range bc.postBlock { f(bc, txpool, block) } + if err := bc.updateExtensibleWhitelist(block.Index); err != nil { + bc.lock.Unlock() + return err + } bc.lock.Unlock() updateBlockHeightMetric(block.Index) @@ -771,6 +777,68 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error return nil } +func (bc *Blockchain) updateExtensibleWhitelist(height uint32) error { + updateCommittee := native.ShouldUpdateCommittee(height, bc) + oracles, oh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, native.RoleOracle, height) + if err != nil { + return err + } + stateVals, sh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, native.RoleStateValidator, height) + if err != nil { + return err + } + + if bc.extensible.Load() != nil && !updateCommittee && oh != height && sh != height { + return nil + } + + newList := []util.Uint160{bc.contracts.NEO.GetCommitteeAddress()} + nextVals := bc.contracts.NEO.GetNextBlockValidatorsInternal() + script, err := smartcontract.CreateDefaultMultiSigRedeemScript(nextVals) + if err != nil { + return err + } + newList = append(newList, hash.Hash160(script)) + bc.updateExtensibleList(&newList, bc.contracts.NEO.GetNextBlockValidatorsInternal()) + + if len(oracles) > 0 { + h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, native.RoleOracle) + if err != nil { + return err + } + newList = append(newList, h) + bc.updateExtensibleList(&newList, oracles) + } + + if len(stateVals) > 0 { + h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, native.RoleStateValidator) + if err != nil { + return err + } + newList = append(newList, h) + bc.updateExtensibleList(&newList, stateVals) + } + + sort.Slice(newList, func(i, j int) bool { + return newList[i].Less(newList[j]) + }) + bc.extensible.Store(newList) + return nil +} + +func (bc *Blockchain) updateExtensibleList(s *[]util.Uint160, pubs keys.PublicKeys) { + for _, pub := range pubs { + *s = append(*s, pub.GetScriptHash()) + } +} + +// IsExtensibleAllowed determines if script hash is allowed to send extensible payloads. +func (bc *Blockchain) IsExtensibleAllowed(u util.Uint160) bool { + us := bc.extensible.Load().([]util.Uint160) + n := sort.Search(len(us), func(i int) bool { return !us[i].Less(u) }) + return n < len(us) +} + func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache *dao.Cached, trig trigger.Type) (*state.AppExecResult, error) { systemInterop := bc.newInteropContext(trig, cache, block, nil) v := systemInterop.SpawnVM() diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index b3b7cc8db..99a151c11 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -41,6 +41,7 @@ type Blockchainer interface { CurrentBlockHash() util.Uint256 HasBlock(util.Uint256) bool HasTransaction(util.Uint256) bool + IsExtensibleAllowed(util.Uint160) bool GetAppExecResults(util.Uint256, trigger.Type) ([]state.AppExecResult, error) GetNotaryDepositExpiration(acc util.Uint160) uint32 GetNativeContractScriptHash(string) (util.Uint160, error) diff --git a/pkg/core/native/designate.go b/pkg/core/native/designate.go index f4190a9de..d507fccec 100644 --- a/pkg/core/native/designate.go +++ b/pkg/core/native/designate.go @@ -207,7 +207,8 @@ func (s *Designate) getCachedRoleData(r Role) *roleData { return nil } -func (s *Designate) getLastDesignatedHash(d dao.DAO, r Role) (util.Uint160, error) { +// GetLastDesignatedHash returns last designated hash of a given role. +func (s *Designate) GetLastDesignatedHash(d dao.DAO, r Role) (util.Uint160, error) { if !s.isValidRole(r) { return util.Uint160{}, ErrInvalidRole } diff --git a/pkg/core/native/oracle.go b/pkg/core/native/oracle.go index 1c43623ec..ed775d7d4 100644 --- a/pkg/core/native/oracle.go +++ b/pkg/core/native/oracle.go @@ -377,7 +377,7 @@ func (o *Oracle) PutRequestInternal(id uint64, req *state.OracleRequest, d dao.D // GetScriptHash returns script hash or oracle nodes. func (o *Oracle) GetScriptHash(d dao.DAO) (util.Uint160, error) { - return o.Desig.getLastDesignatedHash(d, RoleOracle) + return o.Desig.GetLastDesignatedHash(d, RoleOracle) } // GetOracleNodes returns public keys of oracle nodes. diff --git a/pkg/network/extpool/pool.go b/pkg/network/extpool/pool.go new file mode 100644 index 000000000..ef9bba218 --- /dev/null +++ b/pkg/network/extpool/pool.go @@ -0,0 +1,93 @@ +package extpool + +import ( + "errors" + "sync" + + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// Pool represents pool of extensible payloads. +type Pool struct { + lock sync.RWMutex + verified map[util.Uint256]*payload.Extensible + chain blockchainer.Blockchainer +} + +// New returns new payload pool using provided chain. +func New(bc blockchainer.Blockchainer) *Pool { + return &Pool{ + verified: make(map[util.Uint256]*payload.Extensible), + chain: bc, + } +} + +var ( + errDisallowedSender = errors.New("disallowed sender") + errInvalidHeight = errors.New("invalid height") +) + +// Add adds extensible payload to the pool. +// First return value specifies if payload was new. +// Second one is nil if and only if payload is valid. +func (p *Pool) Add(e *payload.Extensible) (bool, error) { + if ok, err := p.verify(e); err != nil || !ok { + return ok, err + } + + p.lock.Lock() + defer p.lock.Unlock() + + h := e.Hash() + if _, ok := p.verified[h]; ok { + return false, nil + } + p.verified[h] = e + return true, nil +} + +func (p *Pool) verify(e *payload.Extensible) (bool, error) { + if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { + return false, err + } + h := p.chain.BlockHeight() + if h < e.ValidBlockStart || e.ValidBlockEnd <= h { + // We can receive consensus payload for the last or next block + // which leads to unwanted node disconnect. + if e.ValidBlockEnd == h { + return false, nil + } + return false, errInvalidHeight + } + if !p.chain.IsExtensibleAllowed(e.Sender) { + return false, errDisallowedSender + } + return true, nil +} + +// Get returns payload by hash. +func (p *Pool) Get(h util.Uint256) *payload.Extensible { + p.lock.RLock() + defer p.lock.RUnlock() + + return p.verified[h] +} + +const extensibleVerifyMaxGAS = 2000000 + +// RemoveStale removes invalid payloads after block processing. +func (p *Pool) RemoveStale(index uint32) { + p.lock.Lock() + defer p.lock.Unlock() + for h, e := range p.verified { + if e.ValidBlockEnd <= index || !p.chain.IsExtensibleAllowed(e.Sender) { + delete(p.verified, h) + continue + } + if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { + delete(p.verified, h) + } + } +} diff --git a/pkg/network/extpool/pool_test.go b/pkg/network/extpool/pool_test.go new file mode 100644 index 000000000..5757b5310 --- /dev/null +++ b/pkg/network/extpool/pool_test.go @@ -0,0 +1,105 @@ +package extpool + +import ( + "errors" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto" + "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestAddGet(t *testing.T) { + bc := newTestChain() + bc.height = 10 + + p := New(bc) + t.Run("invalid witness", func(t *testing.T) { + ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}} + p.testAdd(t, false, errVerification, ep) + }) + t.Run("disallowed sender", func(t *testing.T) { + ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}} + p.testAdd(t, false, errDisallowedSender, ep) + }) + t.Run("bad height", func(t *testing.T) { + ep := &payload.Extensible{ValidBlockEnd: 9} + p.testAdd(t, false, errInvalidHeight, ep) + + ep = &payload.Extensible{ValidBlockEnd: 10} + p.testAdd(t, false, nil, ep) + }) + t.Run("good", func(t *testing.T) { + ep := &payload.Extensible{ValidBlockEnd: 100} + p.testAdd(t, true, nil, ep) + require.Equal(t, ep, p.Get(ep.Hash())) + + p.testAdd(t, false, nil, ep) + }) +} + +func TestRemoveStale(t *testing.T) { + bc := newTestChain() + bc.height = 10 + + p := New(bc) + eps := []*payload.Extensible{ + {ValidBlockEnd: 11}, // small height + {ValidBlockEnd: 12}, // good + {Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender + {Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness + } + for i := range eps { + p.testAdd(t, true, nil, eps[i]) + } + bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 } + bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 } + p.RemoveStale(11) + require.Nil(t, p.Get(eps[0].Hash())) + require.Equal(t, eps[1], p.Get(eps[1].Hash())) + require.Nil(t, p.Get(eps[2].Hash())) + require.Nil(t, p.Get(eps[3].Hash())) +} + +func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) { + ok, err := p.Add(ep) + if expectedErr != nil { + require.True(t, errors.Is(err, expectedErr), "got: %v", err) + } else { + require.NoError(t, err) + } + require.Equal(t, expectedOk, ok) +} + +type testChain struct { + blockchainer.Blockchainer + height uint32 + verifyWitness func(util.Uint160) bool + isAllowed func(util.Uint160) bool +} + +var errVerification = errors.New("verification failed") + +func newTestChain() *testChain { + return &testChain{ + verifyWitness: func(u util.Uint160) bool { + return u[0] != 0x42 + }, + isAllowed: func(u util.Uint160) bool { + return u[0] != 0x42 && u[0] != 0x41 + }, + } +} +func (c *testChain) VerifyWitness(u util.Uint160, _ crypto.Verifiable, _ *transaction.Witness, _ int64) error { + if !c.verifyWitness(u) { + return errVerification + } + return nil +} +func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool { + return c.isAllowed(u) +} +func (c *testChain) BlockHeight() uint32 { return c.height } diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 9fa810334..473bb0e58 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -82,6 +82,9 @@ func (chain *testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transa func (chain *testChain) IsTxStillRelevant(t *transaction.Transaction, txpool *mempool.Pool, isPartialTx bool) bool { panic("TODO") } +func (*testChain) IsExtensibleAllowed(uint160 util.Uint160) bool { + return true +} func (chain *testChain) GetNotaryDepositExpiration(acc util.Uint160) uint32 { if chain.notaryDepositExpiration != 0 { diff --git a/pkg/network/server.go b/pkg/network/server.go index d6616a404..7e4a10946 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -19,6 +19,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/capability" + "github.com/nspcc-dev/neo-go/pkg/network/extpool" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/services/oracle" "github.com/nspcc-dev/neo-go/pkg/util" @@ -67,6 +68,7 @@ type ( bQueue *blockQueue consensus consensus.Service notaryRequestPool *mempool.Pool + extensiblePool *extpool.Pool NotaryFeer NotaryFeer lock sync.RWMutex @@ -127,6 +129,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai unregister: make(chan peerDrop), peers: make(map[Peer]bool), consensusStarted: atomic.NewBool(false), + extensiblePool: extpool.New(chain), log: log, transactions: make(chan *transaction.Transaction, 64), } @@ -574,7 +577,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { payload.TXType: s.chain.HasTransaction, payload.BlockType: s.chain.HasBlock, payload.ExtensibleType: func(h util.Uint256) bool { - cp := s.consensus.GetPayload(h) + cp := s.extensiblePool.Get(h) return cp != nil }, payload.P2PNotaryRequestType: func(h util.Uint256) bool { @@ -643,7 +646,7 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { notFound = append(notFound, hash) } case payload.ExtensibleType: - if cp := s.consensus.GetPayload(hash); cp != nil { + if cp := s.extensiblePool.Get(hash); cp != nil { msg = NewMessage(CMDExtensible, cp) } case payload.P2PNotaryRequestType: @@ -752,29 +755,28 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error return p.EnqueueP2PMessage(msg) } -const extensibleVerifyMaxGAS = 2000000 - // handleExtensibleCmd processes received extensible payload. func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { - if err := s.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { + ok, err := s.extensiblePool.Add(e) + if err != nil { return err } - h := s.chain.BlockHeight() - if h < e.ValidBlockStart || e.ValidBlockEnd <= h { - // We can receive consensus payload for the last or next block - // which leads to unwanted node disconnect. - if e.ValidBlockEnd == h { - return nil - } - return errors.New("invalid height") + if !ok { // payload is already in cache + return nil } - switch e.Category { case consensus.Category: s.consensus.OnPayload(e) default: return errors.New("invalid category") } + + msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()})) + if e.Category == consensus.Category { + s.broadcastHPMessage(msg) + } else { + s.broadcastMessage(msg) + } return nil } @@ -990,6 +992,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { } func (s *Server) handleNewPayload(p *payload.Extensible) { + _, err := s.extensiblePool.Add(p) + if err != nil { + s.log.Error("created payload is not valid", zap.Error(err)) + return + } + msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()})) // It's high priority because it directly affects consensus process, // even though it's just an inv. @@ -1100,6 +1108,7 @@ func (s *Server) relayBlocksLoop() { s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { return p.Handshaked() && p.LastBlockIndex() < b.Index }) + s.extensiblePool.RemoveStale(b.Index) } } } diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 365c4fe71..04cf91ca7 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -714,6 +714,18 @@ func TestInv(t *testing.T) { }) require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual) }) + t.Run("extensible", func(t *testing.T) { + ep := payload.NewExtensible(netmode.UnitTestNet) + s.chain.(*testChain).verifyWitnessF = func() error { return nil } + ep.ValidBlockEnd = s.chain.(*testChain).BlockHeight() + 1 + ok, err := s.extensiblePool.Add(ep) + require.NoError(t, err) + require.True(t, ok) + s.testHandleMessage(t, p, CMDInv, &payload.Inventory{ + Type: payload.ExtensibleType, + Hashes: []util.Uint256{ep.Hash()}, + }) + }) t.Run("p2pNotaryRequest", func(t *testing.T) { fallbackTx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123) fallbackTx.Signers = []transaction.Signer{{Account: random.Uint160()}, {Account: random.Uint160()}}