Merge pull request #1674 from nspcc-dev/extensible_pool

Add pool for `Extensible` payloads
This commit is contained in:
Roman Khimov 2021-01-28 21:01:05 +03:00 committed by GitHub
commit a6921ceecb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 379 additions and 76 deletions

View file

@ -55,16 +55,12 @@ type Service interface {
OnPayload(p *npayload.Extensible) OnPayload(p *npayload.Extensible)
// OnTransaction is a callback to notify Service about new received transaction. // OnTransaction is a callback to notify Service about new received transaction.
OnTransaction(tx *transaction.Transaction) OnTransaction(tx *transaction.Transaction)
// GetPayload returns Payload with specified hash if it is present in the local cache.
GetPayload(h util.Uint256) *npayload.Extensible
} }
type service struct { type service struct {
Config Config
log *zap.Logger log *zap.Logger
// cache is a fifo cache which stores recent payloads.
cache *relayCache
// txx is a fifo cache which stores miner transactions. // txx is a fifo cache which stores miner transactions.
txx *relayCache txx *relayCache
dbft *dbft.DBFT dbft *dbft.DBFT
@ -124,7 +120,6 @@ func NewService(cfg Config) (Service, error) {
Config: cfg, Config: cfg,
log: cfg.Logger, log: cfg.Logger,
cache: newFIFOCache(cacheMaxCapacity),
txx: newFIFOCache(cacheMaxCapacity), txx: newFIFOCache(cacheMaxCapacity),
messages: make(chan Payload, 100), messages: make(chan Payload, 100),
@ -379,11 +374,6 @@ func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload {
// OnPayload handles Payload receive. // OnPayload handles Payload receive.
func (s *service) OnPayload(cp *npayload.Extensible) { func (s *service) OnPayload(cp *npayload.Extensible) {
log := s.log.With(zap.Stringer("hash", cp.Hash())) log := s.log.With(zap.Stringer("hash", cp.Hash()))
if s.cache.Has(cp.Hash()) {
log.Debug("payload is already in cache")
return
}
p := s.payloadFromExtensible(cp) p := s.payloadFromExtensible(cp)
p.decodeData() p.decodeData()
if !s.validatePayload(p) { if !s.validatePayload(p) {
@ -391,9 +381,6 @@ func (s *service) OnPayload(cp *npayload.Extensible) {
return return
} }
s.Config.Broadcast(cp)
s.cache.Add(cp)
if s.dbft == nil || !s.started.Load() { if s.dbft == nil || !s.started.Load() {
log.Debug("dbft is inactive or not started yet") log.Debug("dbft is inactive or not started yet")
return return
@ -416,25 +403,12 @@ func (s *service) OnTransaction(tx *transaction.Transaction) {
} }
} }
// GetPayload returns payload stored in cache.
func (s *service) GetPayload(h util.Uint256) *npayload.Extensible {
p := s.cache.Get(h)
if p == nil {
return (*npayload.Extensible)(nil)
}
cp := *p.(*npayload.Extensible)
return &cp
}
func (s *service) broadcast(p payload.ConsensusPayload) { func (s *service) broadcast(p payload.ConsensusPayload) {
if err := p.(*Payload).Sign(s.dbft.Priv.(*privateKey)); err != nil { if err := p.(*Payload).Sign(s.dbft.Priv.(*privateKey)); err != nil {
s.log.Warn("can't sign consensus payload", zap.Error(err)) s.log.Warn("can't sign consensus payload", zap.Error(err))
} }
ep := &p.(*Payload).Extensible ep := &p.(*Payload).Extensible
s.cache.Add(ep)
s.Config.Broadcast(ep) s.Config.Broadcast(ep)
} }

View file

@ -351,7 +351,6 @@ func TestService_OnPayload(t *testing.T) {
// sender is invalid // sender is invalid
srv.OnPayload(&p.Extensible) srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
require.Nil(t, srv.GetPayload(p.Hash()))
p = new(Payload) p = new(Payload)
p.SetValidatorIndex(1) p.SetValidatorIndex(1)
@ -360,11 +359,6 @@ func TestService_OnPayload(t *testing.T) {
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
srv.OnPayload(&p.Extensible) srv.OnPayload(&p.Extensible)
shouldReceive(t, srv.messages) shouldReceive(t, srv.messages)
require.Equal(t, &p.Extensible, srv.GetPayload(p.Hash()))
// payload has already been received
srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages)
srv.Chain.Close() srv.Chain.Close()
} }

View file

@ -131,6 +131,8 @@ type Blockchain struct {
contracts native.Contracts contracts native.Contracts
extensible atomic.Value
// Notification subsystem. // Notification subsystem.
events chan bcEvent events chan bcEvent
subCh chan interface{} subCh chan interface{}
@ -297,7 +299,7 @@ func (bc *Blockchain) init() error {
return fmt.Errorf("can't init cache for Management native contract: %w", err) return fmt.Errorf("can't init cache for Management native contract: %w", err)
} }
return nil return bc.updateExtensibleWhitelist(bHeight)
} }
// Run runs chain loop, it needs to be run as goroutine and executing it is // Run runs chain loop, it needs to be run as goroutine and executing it is
@ -759,6 +761,10 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
for _, f := range bc.postBlock { for _, f := range bc.postBlock {
f(bc, txpool, block) f(bc, txpool, block)
} }
if err := bc.updateExtensibleWhitelist(block.Index); err != nil {
bc.lock.Unlock()
return err
}
bc.lock.Unlock() bc.lock.Unlock()
updateBlockHeightMetric(block.Index) updateBlockHeightMetric(block.Index)
@ -771,6 +777,68 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
return nil return nil
} }
func (bc *Blockchain) updateExtensibleWhitelist(height uint32) error {
updateCommittee := native.ShouldUpdateCommittee(height, bc)
oracles, oh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, native.RoleOracle, height)
if err != nil {
return err
}
stateVals, sh, err := bc.contracts.Designate.GetDesignatedByRole(bc.dao, native.RoleStateValidator, height)
if err != nil {
return err
}
if bc.extensible.Load() != nil && !updateCommittee && oh != height && sh != height {
return nil
}
newList := []util.Uint160{bc.contracts.NEO.GetCommitteeAddress()}
nextVals := bc.contracts.NEO.GetNextBlockValidatorsInternal()
script, err := smartcontract.CreateDefaultMultiSigRedeemScript(nextVals)
if err != nil {
return err
}
newList = append(newList, hash.Hash160(script))
bc.updateExtensibleList(&newList, bc.contracts.NEO.GetNextBlockValidatorsInternal())
if len(oracles) > 0 {
h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, native.RoleOracle)
if err != nil {
return err
}
newList = append(newList, h)
bc.updateExtensibleList(&newList, oracles)
}
if len(stateVals) > 0 {
h, err := bc.contracts.Designate.GetLastDesignatedHash(bc.dao, native.RoleStateValidator)
if err != nil {
return err
}
newList = append(newList, h)
bc.updateExtensibleList(&newList, stateVals)
}
sort.Slice(newList, func(i, j int) bool {
return newList[i].Less(newList[j])
})
bc.extensible.Store(newList)
return nil
}
func (bc *Blockchain) updateExtensibleList(s *[]util.Uint160, pubs keys.PublicKeys) {
for _, pub := range pubs {
*s = append(*s, pub.GetScriptHash())
}
}
// IsExtensibleAllowed determines if script hash is allowed to send extensible payloads.
func (bc *Blockchain) IsExtensibleAllowed(u util.Uint160) bool {
us := bc.extensible.Load().([]util.Uint160)
n := sort.Search(len(us), func(i int) bool { return !us[i].Less(u) })
return n < len(us)
}
func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache *dao.Cached, trig trigger.Type) (*state.AppExecResult, error) { func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache *dao.Cached, trig trigger.Type) (*state.AppExecResult, error) {
systemInterop := bc.newInteropContext(trig, cache, block, nil) systemInterop := bc.newInteropContext(trig, cache, block, nil)
v := systemInterop.SpawnVM() v := systemInterop.SpawnVM()

View file

@ -41,6 +41,7 @@ type Blockchainer interface {
CurrentBlockHash() util.Uint256 CurrentBlockHash() util.Uint256
HasBlock(util.Uint256) bool HasBlock(util.Uint256) bool
HasTransaction(util.Uint256) bool HasTransaction(util.Uint256) bool
IsExtensibleAllowed(util.Uint160) bool
GetAppExecResults(util.Uint256, trigger.Type) ([]state.AppExecResult, error) GetAppExecResults(util.Uint256, trigger.Type) ([]state.AppExecResult, error)
GetNotaryDepositExpiration(acc util.Uint160) uint32 GetNotaryDepositExpiration(acc util.Uint160) uint32
GetNativeContractScriptHash(string) (util.Uint160, error) GetNativeContractScriptHash(string) (util.Uint160, error)

View file

@ -30,6 +30,8 @@ type Designate struct {
rolesChangedFlag atomic.Value rolesChangedFlag atomic.Value
oracles atomic.Value oracles atomic.Value
stateVals atomic.Value
notaries atomic.Value
// p2pSigExtensionsEnabled defines whether the P2P signature extensions logic is relevant. // p2pSigExtensionsEnabled defines whether the P2P signature extensions logic is relevant.
p2pSigExtensionsEnabled bool p2pSigExtensionsEnabled bool
@ -37,7 +39,7 @@ type Designate struct {
OracleService atomic.Value OracleService atomic.Value
} }
type oraclesData struct { type roleData struct {
nodes keys.PublicKeys nodes keys.PublicKeys
addr util.Uint160 addr util.Uint160
height uint32 height uint32
@ -109,20 +111,18 @@ func (s *Designate) PostPersist(ic *interop.Context) error {
return nil return nil
} }
nodeKeys, height, err := s.GetDesignatedByRole(ic.DAO, RoleOracle, math.MaxUint32) if err := s.updateCachedRoleData(&s.oracles, ic.DAO, RoleOracle); err != nil {
if err != nil {
return err return err
} }
if err := s.updateCachedRoleData(&s.stateVals, ic.DAO, RoleStateValidator); err != nil {
return err
}
if s.p2pSigExtensionsEnabled {
if err := s.updateCachedRoleData(&s.notaries, ic.DAO, RoleP2PNotary); err != nil {
return err
}
}
od := &oraclesData{
nodes: nodeKeys,
addr: oracleHashFromNodes(nodeKeys),
height: height,
}
s.oracles.Store(od)
if orc, _ := s.OracleService.Load().(services.Oracle); orc != nil {
orc.UpdateOracleNodes(od.nodes.Copy())
}
s.rolesChangedFlag.Store(false) s.rolesChangedFlag.Store(false)
return nil return nil
} }
@ -157,23 +157,64 @@ func (s *Designate) rolesChanged() bool {
return rc == nil || rc.(bool) return rc == nil || rc.(bool)
} }
func oracleHashFromNodes(nodes keys.PublicKeys) util.Uint160 { func (s *Designate) hashFromNodes(r Role, nodes keys.PublicKeys) util.Uint160 {
if len(nodes) == 0 { if len(nodes) == 0 {
return util.Uint160{} return util.Uint160{}
} }
script, _ := smartcontract.CreateMajorityMultiSigRedeemScript(nodes.Copy()) var script []byte
switch r {
case RoleOracle:
script, _ = smartcontract.CreateDefaultMultiSigRedeemScript(nodes.Copy())
case RoleP2PNotary:
script, _ = smartcontract.CreateMultiSigRedeemScript(1, nodes.Copy())
default:
script, _ = smartcontract.CreateMajorityMultiSigRedeemScript(nodes.Copy())
}
return hash.Hash160(script) return hash.Hash160(script)
} }
func (s *Designate) getLastDesignatedHash(d dao.DAO, r Role) (util.Uint160, error) { func (s *Designate) updateCachedRoleData(v *atomic.Value, d dao.DAO, r Role) error {
nodeKeys, height, err := s.GetDesignatedByRole(d, r, math.MaxUint32)
if err != nil {
return err
}
v.Store(&roleData{
nodes: nodeKeys,
addr: s.hashFromNodes(r, nodeKeys),
height: height,
})
if r == RoleOracle {
if orc, _ := s.OracleService.Load().(services.Oracle); orc != nil {
orc.UpdateOracleNodes(nodeKeys.Copy())
}
}
return nil
}
func (s *Designate) getCachedRoleData(r Role) *roleData {
var val interface{}
switch r {
case RoleOracle:
val = s.oracles.Load()
case RoleStateValidator:
val = s.stateVals.Load()
case RoleP2PNotary:
val = s.notaries.Load()
}
if val != nil {
return val.(*roleData)
}
return nil
}
// GetLastDesignatedHash returns last designated hash of a given role.
func (s *Designate) GetLastDesignatedHash(d dao.DAO, r Role) (util.Uint160, error) {
if !s.isValidRole(r) { if !s.isValidRole(r) {
return util.Uint160{}, ErrInvalidRole return util.Uint160{}, ErrInvalidRole
} }
if r == RoleOracle && !s.rolesChanged() { if !s.rolesChanged() {
odVal := s.oracles.Load() if val := s.getCachedRoleData(r); val != nil {
if odVal != nil { return val.addr, nil
od := odVal.(*oraclesData)
return od.addr, nil
} }
} }
nodes, _, err := s.GetDesignatedByRole(d, r, math.MaxUint32) nodes, _, err := s.GetDesignatedByRole(d, r, math.MaxUint32)
@ -181,7 +222,7 @@ func (s *Designate) getLastDesignatedHash(d dao.DAO, r Role) (util.Uint160, erro
return util.Uint160{}, err return util.Uint160{}, err
} }
// We only have hashing defined for oracles now. // We only have hashing defined for oracles now.
return oracleHashFromNodes(nodes), nil return s.hashFromNodes(r, nodes), nil
} }
// GetDesignatedByRole returns nodes for role r. // GetDesignatedByRole returns nodes for role r.
@ -189,13 +230,9 @@ func (s *Designate) GetDesignatedByRole(d dao.DAO, r Role, index uint32) (keys.P
if !s.isValidRole(r) { if !s.isValidRole(r) {
return nil, 0, ErrInvalidRole return nil, 0, ErrInvalidRole
} }
if r == RoleOracle && !s.rolesChanged() { if !s.rolesChanged() {
odVal := s.oracles.Load() if val := s.getCachedRoleData(r); val != nil && val.height <= index {
if odVal != nil { return val.nodes.Copy(), val.height, nil
od := odVal.(*oraclesData)
if od.height <= index {
return od.nodes, od.height, nil
}
} }
} }
kvs, err := d.GetStorageItemsWithPrefix(s.ContractID, []byte{byte(r)}) kvs, err := d.GetStorageItemsWithPrefix(s.ContractID, []byte{byte(r)})

View file

@ -377,7 +377,7 @@ func (o *Oracle) PutRequestInternal(id uint64, req *state.OracleRequest, d dao.D
// GetScriptHash returns script hash or oracle nodes. // GetScriptHash returns script hash or oracle nodes.
func (o *Oracle) GetScriptHash(d dao.DAO) (util.Uint160, error) { func (o *Oracle) GetScriptHash(d dao.DAO) (util.Uint160, error) {
return o.Desig.getLastDesignatedHash(d, RoleOracle) return o.Desig.GetLastDesignatedHash(d, RoleOracle)
} }
// GetOracleNodes returns public keys of oracle nodes. // GetOracleNodes returns public keys of oracle nodes.

View file

@ -93,6 +93,13 @@ func TestDesignate_DesignateAsRoleTx(t *testing.T) {
bc.getNodesByRole(t, false, native.RoleOracle, 100500, 0) bc.getNodesByRole(t, false, native.RoleOracle, 100500, 0)
bc.getNodesByRole(t, true, native.RoleOracle, 0, 0) // returns an empty list bc.getNodesByRole(t, true, native.RoleOracle, 0, 0) // returns an empty list
bc.getNodesByRole(t, true, native.RoleOracle, index, 1) // returns pubs bc.getNodesByRole(t, true, native.RoleOracle, index, 1) // returns pubs
priv1, err := keys.NewPrivateKey()
require.NoError(t, err)
pubs = keys.PublicKeys{priv1.PublicKey()}
bc.setNodesByRole(t, true, native.RoleStateValidator, pubs)
bc.getNodesByRole(t, true, native.RoleStateValidator, bc.BlockHeight()+1, 1)
} }
func TestDesignate_DesignateAsRole(t *testing.T) { func TestDesignate_DesignateAsRole(t *testing.T) {

View file

@ -0,0 +1,93 @@
package extpool
import (
"errors"
"sync"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
)
// Pool represents pool of extensible payloads.
type Pool struct {
lock sync.RWMutex
verified map[util.Uint256]*payload.Extensible
chain blockchainer.Blockchainer
}
// New returns new payload pool using provided chain.
func New(bc blockchainer.Blockchainer) *Pool {
return &Pool{
verified: make(map[util.Uint256]*payload.Extensible),
chain: bc,
}
}
var (
errDisallowedSender = errors.New("disallowed sender")
errInvalidHeight = errors.New("invalid height")
)
// Add adds extensible payload to the pool.
// First return value specifies if payload was new.
// Second one is nil if and only if payload is valid.
func (p *Pool) Add(e *payload.Extensible) (bool, error) {
if ok, err := p.verify(e); err != nil || !ok {
return ok, err
}
p.lock.Lock()
defer p.lock.Unlock()
h := e.Hash()
if _, ok := p.verified[h]; ok {
return false, nil
}
p.verified[h] = e
return true, nil
}
func (p *Pool) verify(e *payload.Extensible) (bool, error) {
if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil {
return false, err
}
h := p.chain.BlockHeight()
if h < e.ValidBlockStart || e.ValidBlockEnd <= h {
// We can receive consensus payload for the last or next block
// which leads to unwanted node disconnect.
if e.ValidBlockEnd == h {
return false, nil
}
return false, errInvalidHeight
}
if !p.chain.IsExtensibleAllowed(e.Sender) {
return false, errDisallowedSender
}
return true, nil
}
// Get returns payload by hash.
func (p *Pool) Get(h util.Uint256) *payload.Extensible {
p.lock.RLock()
defer p.lock.RUnlock()
return p.verified[h]
}
const extensibleVerifyMaxGAS = 2000000
// RemoveStale removes invalid payloads after block processing.
func (p *Pool) RemoveStale(index uint32) {
p.lock.Lock()
defer p.lock.Unlock()
for h, e := range p.verified {
if e.ValidBlockEnd <= index || !p.chain.IsExtensibleAllowed(e.Sender) {
delete(p.verified, h)
continue
}
if err := p.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil {
delete(p.verified, h)
}
}
}

View file

@ -0,0 +1,105 @@
package extpool
import (
"errors"
"testing"
"github.com/nspcc-dev/neo-go/pkg/core/blockchainer"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestAddGet(t *testing.T) {
bc := newTestChain()
bc.height = 10
p := New(bc)
t.Run("invalid witness", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x42}}
p.testAdd(t, false, errVerification, ep)
})
t.Run("disallowed sender", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100, Sender: util.Uint160{0x41}}
p.testAdd(t, false, errDisallowedSender, ep)
})
t.Run("bad height", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 9}
p.testAdd(t, false, errInvalidHeight, ep)
ep = &payload.Extensible{ValidBlockEnd: 10}
p.testAdd(t, false, nil, ep)
})
t.Run("good", func(t *testing.T) {
ep := &payload.Extensible{ValidBlockEnd: 100}
p.testAdd(t, true, nil, ep)
require.Equal(t, ep, p.Get(ep.Hash()))
p.testAdd(t, false, nil, ep)
})
}
func TestRemoveStale(t *testing.T) {
bc := newTestChain()
bc.height = 10
p := New(bc)
eps := []*payload.Extensible{
{ValidBlockEnd: 11}, // small height
{ValidBlockEnd: 12}, // good
{Sender: util.Uint160{0x11}, ValidBlockEnd: 12}, // invalid sender
{Sender: util.Uint160{0x12}, ValidBlockEnd: 12}, // invalid witness
}
for i := range eps {
p.testAdd(t, true, nil, eps[i])
}
bc.verifyWitness = func(u util.Uint160) bool { println("call"); return u[0] != 0x12 }
bc.isAllowed = func(u util.Uint160) bool { return u[0] != 0x11 }
p.RemoveStale(11)
require.Nil(t, p.Get(eps[0].Hash()))
require.Equal(t, eps[1], p.Get(eps[1].Hash()))
require.Nil(t, p.Get(eps[2].Hash()))
require.Nil(t, p.Get(eps[3].Hash()))
}
func (p *Pool) testAdd(t *testing.T, expectedOk bool, expectedErr error, ep *payload.Extensible) {
ok, err := p.Add(ep)
if expectedErr != nil {
require.True(t, errors.Is(err, expectedErr), "got: %v", err)
} else {
require.NoError(t, err)
}
require.Equal(t, expectedOk, ok)
}
type testChain struct {
blockchainer.Blockchainer
height uint32
verifyWitness func(util.Uint160) bool
isAllowed func(util.Uint160) bool
}
var errVerification = errors.New("verification failed")
func newTestChain() *testChain {
return &testChain{
verifyWitness: func(u util.Uint160) bool {
return u[0] != 0x42
},
isAllowed: func(u util.Uint160) bool {
return u[0] != 0x42 && u[0] != 0x41
},
}
}
func (c *testChain) VerifyWitness(u util.Uint160, _ crypto.Verifiable, _ *transaction.Witness, _ int64) error {
if !c.verifyWitness(u) {
return errVerification
}
return nil
}
func (c *testChain) IsExtensibleAllowed(u util.Uint160) bool {
return c.isAllowed(u)
}
func (c *testChain) BlockHeight() uint32 { return c.height }

View file

@ -82,6 +82,9 @@ func (chain *testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transa
func (chain *testChain) IsTxStillRelevant(t *transaction.Transaction, txpool *mempool.Pool, isPartialTx bool) bool { func (chain *testChain) IsTxStillRelevant(t *transaction.Transaction, txpool *mempool.Pool, isPartialTx bool) bool {
panic("TODO") panic("TODO")
} }
func (*testChain) IsExtensibleAllowed(uint160 util.Uint160) bool {
return true
}
func (chain *testChain) GetNotaryDepositExpiration(acc util.Uint160) uint32 { func (chain *testChain) GetNotaryDepositExpiration(acc util.Uint160) uint32 {
if chain.notaryDepositExpiration != 0 { if chain.notaryDepositExpiration != 0 {

View file

@ -19,6 +19,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/extpool"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/services/oracle" "github.com/nspcc-dev/neo-go/pkg/services/oracle"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -67,6 +68,7 @@ type (
bQueue *blockQueue bQueue *blockQueue
consensus consensus.Service consensus consensus.Service
notaryRequestPool *mempool.Pool notaryRequestPool *mempool.Pool
extensiblePool *extpool.Pool
NotaryFeer NotaryFeer NotaryFeer NotaryFeer
lock sync.RWMutex lock sync.RWMutex
@ -127,6 +129,7 @@ func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchai
unregister: make(chan peerDrop), unregister: make(chan peerDrop),
peers: make(map[Peer]bool), peers: make(map[Peer]bool),
consensusStarted: atomic.NewBool(false), consensusStarted: atomic.NewBool(false),
extensiblePool: extpool.New(chain),
log: log, log: log,
transactions: make(chan *transaction.Transaction, 64), transactions: make(chan *transaction.Transaction, 64),
} }
@ -574,7 +577,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
payload.TXType: s.chain.HasTransaction, payload.TXType: s.chain.HasTransaction,
payload.BlockType: s.chain.HasBlock, payload.BlockType: s.chain.HasBlock,
payload.ExtensibleType: func(h util.Uint256) bool { payload.ExtensibleType: func(h util.Uint256) bool {
cp := s.consensus.GetPayload(h) cp := s.extensiblePool.Get(h)
return cp != nil return cp != nil
}, },
payload.P2PNotaryRequestType: func(h util.Uint256) bool { payload.P2PNotaryRequestType: func(h util.Uint256) bool {
@ -643,7 +646,7 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
notFound = append(notFound, hash) notFound = append(notFound, hash)
} }
case payload.ExtensibleType: case payload.ExtensibleType:
if cp := s.consensus.GetPayload(hash); cp != nil { if cp := s.extensiblePool.Get(hash); cp != nil {
msg = NewMessage(CMDExtensible, cp) msg = NewMessage(CMDExtensible, cp)
} }
case payload.P2PNotaryRequestType: case payload.P2PNotaryRequestType:
@ -752,29 +755,28 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
return p.EnqueueP2PMessage(msg) return p.EnqueueP2PMessage(msg)
} }
const extensibleVerifyMaxGAS = 2000000
// handleExtensibleCmd processes received extensible payload. // handleExtensibleCmd processes received extensible payload.
func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if err := s.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { ok, err := s.extensiblePool.Add(e)
if err != nil {
return err return err
} }
h := s.chain.BlockHeight() if !ok { // payload is already in cache
if h < e.ValidBlockStart || e.ValidBlockEnd <= h { return nil
// We can receive consensus payload for the last or next block
// which leads to unwanted node disconnect.
if e.ValidBlockEnd == h {
return nil
}
return errors.New("invalid height")
} }
switch e.Category { switch e.Category {
case consensus.Category: case consensus.Category:
s.consensus.OnPayload(e) s.consensus.OnPayload(e)
default: default:
return errors.New("invalid category") return errors.New("invalid category")
} }
msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{e.Hash()}))
if e.Category == consensus.Category {
s.broadcastHPMessage(msg)
} else {
s.broadcastMessage(msg)
}
return nil return nil
} }
@ -990,6 +992,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
} }
func (s *Server) handleNewPayload(p *payload.Extensible) { func (s *Server) handleNewPayload(p *payload.Extensible) {
_, err := s.extensiblePool.Add(p)
if err != nil {
s.log.Error("created payload is not valid", zap.Error(err))
return
}
msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()}))
// It's high priority because it directly affects consensus process, // It's high priority because it directly affects consensus process,
// even though it's just an inv. // even though it's just an inv.
@ -1100,6 +1108,7 @@ func (s *Server) relayBlocksLoop() {
s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool {
return p.Handshaked() && p.LastBlockIndex() < b.Index return p.Handshaked() && p.LastBlockIndex() < b.Index
}) })
s.extensiblePool.RemoveStale(b.Index)
} }
} }
} }

View file

@ -714,6 +714,18 @@ func TestInv(t *testing.T) {
}) })
require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual) require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual)
}) })
t.Run("extensible", func(t *testing.T) {
ep := payload.NewExtensible(netmode.UnitTestNet)
s.chain.(*testChain).verifyWitnessF = func() error { return nil }
ep.ValidBlockEnd = s.chain.(*testChain).BlockHeight() + 1
ok, err := s.extensiblePool.Add(ep)
require.NoError(t, err)
require.True(t, ok)
s.testHandleMessage(t, p, CMDInv, &payload.Inventory{
Type: payload.ExtensibleType,
Hashes: []util.Uint256{ep.Hash()},
})
})
t.Run("p2pNotaryRequest", func(t *testing.T) { t.Run("p2pNotaryRequest", func(t *testing.T) {
fallbackTx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123) fallbackTx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123)
fallbackTx.Signers = []transaction.Signer{{Account: random.Uint160()}, {Account: random.Uint160()}} fallbackTx.Signers = []transaction.Signer{{Account: random.Uint160()}, {Account: random.Uint160()}}