network: replace ConsensusType with ExtensibleType

This commit is contained in:
Evgeniy Stratonikov 2021-01-14 16:38:40 +03:00
parent b918ec3abc
commit 5d83c28bc9
9 changed files with 139 additions and 62 deletions

View file

@ -52,11 +52,11 @@ type Service interface {
Shutdown() Shutdown()
// OnPayload is a callback to notify Service about new received payload. // OnPayload is a callback to notify Service about new received payload.
OnPayload(p *Payload) OnPayload(p *npayload.Extensible)
// OnTransaction is a callback to notify Service about new received transaction. // OnTransaction is a callback to notify Service about new received transaction.
OnTransaction(tx *transaction.Transaction) OnTransaction(tx *transaction.Transaction)
// GetPayload returns Payload with specified hash if it is present in the local cache. // GetPayload returns Payload with specified hash if it is present in the local cache.
GetPayload(h util.Uint256) *Payload GetPayload(h util.Uint256) *npayload.Extensible
} }
type service struct { type service struct {
@ -98,7 +98,7 @@ type Config struct {
Logger *zap.Logger Logger *zap.Logger
// Broadcast is a callback which is called to notify server // Broadcast is a callback which is called to notify server
// about new consensus payload to sent. // about new consensus payload to sent.
Broadcast func(p *Payload) Broadcast func(p *npayload.Extensible)
// Chain is a core.Blockchainer instance. // Chain is a core.Blockchainer instance.
Chain blockchainer.Blockchainer Chain blockchainer.Blockchainer
// RequestTx is a callback to which will be called // RequestTx is a callback to which will be called
@ -367,13 +367,26 @@ func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, c
return -1, nil, nil return -1, nil, nil
} }
func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload {
return &Payload{
Extensible: *ep,
message: message{
stateRootEnabled: s.stateRootEnabled,
},
}
}
// OnPayload handles Payload receive. // OnPayload handles Payload receive.
func (s *service) OnPayload(cp *Payload) { func (s *service) OnPayload(cp *npayload.Extensible) {
log := s.log.With(zap.Stringer("hash", cp.Hash())) log := s.log.With(zap.Stringer("hash", cp.Hash()))
if s.cache.Has(cp.Hash()) { if s.cache.Has(cp.Hash()) {
log.Debug("payload is already in cache") log.Debug("payload is already in cache")
return return
} else if !s.validatePayload(cp) { }
p := s.payloadFromExtensible(cp)
p.decodeData()
if !s.validatePayload(p) {
log.Info("can't validate payload") log.Info("can't validate payload")
return return
} }
@ -387,14 +400,14 @@ func (s *service) OnPayload(cp *Payload) {
} }
// decode payload data into message // decode payload data into message
if cp.message.payload == nil { if p.message.payload == nil {
if err := cp.decodeData(); err != nil { if err := p.decodeData(); err != nil {
log.Info("can't decode payload data") log.Info("can't decode payload data")
return return
} }
} }
s.messages <- *cp s.messages <- *p
} }
func (s *service) OnTransaction(tx *transaction.Transaction) { func (s *service) OnTransaction(tx *transaction.Transaction) {
@ -404,13 +417,13 @@ func (s *service) OnTransaction(tx *transaction.Transaction) {
} }
// GetPayload returns payload stored in cache. // GetPayload returns payload stored in cache.
func (s *service) GetPayload(h util.Uint256) *Payload { func (s *service) GetPayload(h util.Uint256) *npayload.Extensible {
p := s.cache.Get(h) p := s.cache.Get(h)
if p == nil { if p == nil {
return (*Payload)(nil) return (*npayload.Extensible)(nil)
} }
cp := *p.(*Payload) cp := *p.(*npayload.Extensible)
return &cp return &cp
} }
@ -420,8 +433,9 @@ func (s *service) broadcast(p payload.ConsensusPayload) {
s.log.Warn("can't sign consensus payload", zap.Error(err)) s.log.Warn("can't sign consensus payload", zap.Error(err))
} }
s.cache.Add(p) ep := &p.(*Payload).Extensible
s.Config.Broadcast(p.(*Payload)) s.cache.Add(ep)
s.Config.Broadcast(ep)
} }
func (s *service) getTx(h util.Uint256) block.Transaction { func (s *service) getTx(h util.Uint256) block.Transaction {

View file

@ -21,6 +21,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -345,9 +346,10 @@ func TestService_OnPayload(t *testing.T) {
p := new(Payload) p := new(Payload)
p.SetValidatorIndex(1) p.SetValidatorIndex(1)
p.SetPayload(&prepareRequest{}) p.SetPayload(&prepareRequest{})
p.encodeData()
// sender is invalid // sender is invalid
srv.OnPayload(p) srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
require.Nil(t, srv.GetPayload(p.Hash())) require.Nil(t, srv.GetPayload(p.Hash()))
@ -356,12 +358,12 @@ func TestService_OnPayload(t *testing.T) {
p.Sender = priv.GetScriptHash() p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{}) p.SetPayload(&prepareRequest{})
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
srv.OnPayload(p) srv.OnPayload(&p.Extensible)
shouldReceive(t, srv.messages) shouldReceive(t, srv.messages)
require.Equal(t, p, srv.GetPayload(p.Hash())) require.Equal(t, &p.Extensible, srv.GetPayload(p.Hash()))
// payload has already been received // payload has already been received
srv.OnPayload(p) srv.OnPayload(&p.Extensible)
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
srv.Chain.Close() srv.Chain.Close()
} }
@ -477,7 +479,7 @@ func newTestService(t *testing.T) *service {
func newTestServiceWithChain(t *testing.T, bc *core.Blockchain) *service { func newTestServiceWithChain(t *testing.T, bc *core.Blockchain) *service {
srv, err := NewService(Config{ srv, err := NewService(Config{
Logger: zaptest.NewLogger(t), Logger: zaptest.NewLogger(t),
Broadcast: func(*Payload) {}, Broadcast: func(*npayload.Extensible) {},
Chain: bc, Chain: bc,
RequestTx: func(...util.Uint256) {}, RequestTx: func(...util.Uint256) {},
TimePerBlock: time.Duration(bc.GetConfig().SecondsPerBlock) * time.Second, TimePerBlock: time.Duration(bc.GetConfig().SecondsPerBlock) * time.Second,

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
@ -74,7 +73,7 @@ const (
CMDNotFound CommandType = 0x2a CMDNotFound CommandType = 0x2a
CMDTX = CommandType(payload.TXType) CMDTX = CommandType(payload.TXType)
CMDBlock = CommandType(payload.BlockType) CMDBlock = CommandType(payload.BlockType)
CMDConsensus = CommandType(payload.ConsensusType) CMDExtensible = CommandType(payload.ExtensibleType)
CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType)
CMDReject CommandType = 0x2f CMDReject CommandType = 0x2f
@ -147,8 +146,8 @@ func (m *Message) decodePayload() error {
p = &payload.AddressList{} p = &payload.AddressList{}
case CMDBlock: case CMDBlock:
p = block.New(m.Network, m.StateRootInHeader) p = block.New(m.Network, m.StateRootInHeader)
case CMDConsensus: case CMDExtensible:
p = consensus.NewPayload(m.Network, m.StateRootInHeader) p = payload.NewExtensible(m.Network)
case CMDP2PNotaryRequest: case CMDP2PNotaryRequest:
p = &payload.P2PNotaryRequest{Network: m.Network} p = &payload.P2PNotaryRequest{Network: m.Network}
case CMDGetBlocks: case CMDGetBlocks:

View file

@ -24,7 +24,8 @@ func _() {
_ = x[CMDNotFound-42] _ = x[CMDNotFound-42]
_ = x[CMDTX-43] _ = x[CMDTX-43]
_ = x[CMDBlock-44] _ = x[CMDBlock-44]
_ = x[CMDConsensus-45] _ = x[CMDExtensible-46]
_ = x[CMDP2PNotaryRequest-80]
_ = x[CMDReject-47] _ = x[CMDReject-47]
_ = x[CMDFilterLoad-48] _ = x[CMDFilterLoad-48]
_ = x[CMDFilterAdd-49] _ = x[CMDFilterAdd-49]
@ -39,10 +40,11 @@ const (
_CommandType_name_2 = "CMDPingCMDPong" _CommandType_name_2 = "CMDPingCMDPong"
_CommandType_name_3 = "CMDGetHeadersCMDHeaders" _CommandType_name_3 = "CMDGetHeadersCMDHeaders"
_CommandType_name_4 = "CMDGetBlocksCMDMempool" _CommandType_name_4 = "CMDGetBlocksCMDMempool"
_CommandType_name_5 = "CMDInvCMDGetDataCMDGetBlockByIndexCMDNotFoundCMDTXCMDBlockCMDConsensus" _CommandType_name_5 = "CMDInvCMDGetDataCMDGetBlockByIndexCMDNotFoundCMDTXCMDBlock"
_CommandType_name_6 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear"
_CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_7 = "CMDMerkleBlock"
_CommandType_name_8 = "CMDAlert" _CommandType_name_8 = "CMDAlert"
_CommandType_name_9 = "CMDP2PNotaryRequest"
) )
var ( var (
@ -51,8 +53,8 @@ var (
_CommandType_index_2 = [...]uint8{0, 7, 14} _CommandType_index_2 = [...]uint8{0, 7, 14}
_CommandType_index_3 = [...]uint8{0, 13, 23} _CommandType_index_3 = [...]uint8{0, 13, 23}
_CommandType_index_4 = [...]uint8{0, 12, 22} _CommandType_index_4 = [...]uint8{0, 12, 22}
_CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58, 70} _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58}
_CommandType_index_6 = [...]uint8{0, 9, 22, 34, 48} _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61}
) )
func (i CommandType) String() string { func (i CommandType) String() string {
@ -71,16 +73,18 @@ func (i CommandType) String() string {
case 36 <= i && i <= 37: case 36 <= i && i <= 37:
i -= 36 i -= 36
return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]] return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]]
case 39 <= i && i <= 45: case 39 <= i && i <= 44:
i -= 39 i -= 39
return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]] return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]]
case 47 <= i && i <= 50: case 46 <= i && i <= 50:
i -= 47 i -= 46
return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]] return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]]
case i == 56: case i == 56:
return _CommandType_name_7 return _CommandType_name_7
case i == 64: case i == 64:
return _CommandType_name_8 return _CommandType_name_8
case i == 80:
return _CommandType_name_9
default: default:
return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")"
} }

View file

@ -96,7 +96,7 @@ func TestEncodeDecodePing(t *testing.T) {
} }
func TestEncodeDecodeInventory(t *testing.T) { func TestEncodeDecodeInventory(t *testing.T) {
testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{{1, 2, 3}}))
} }
func TestEncodeDecodeAddr(t *testing.T) { func TestEncodeDecodeAddr(t *testing.T) {

View file

@ -18,8 +18,8 @@ func (i InventoryType) String() string {
return "TX" return "TX"
case BlockType: case BlockType:
return "block" return "block"
case ConsensusType: case ExtensibleType:
return "consensus" return "extensible"
case P2PNotaryRequestType: case P2PNotaryRequestType:
return "p2pNotaryRequest" return "p2pNotaryRequest"
default: default:
@ -29,14 +29,14 @@ func (i InventoryType) String() string {
// Valid returns true if the inventory (type) is known. // Valid returns true if the inventory (type) is known.
func (i InventoryType) Valid(p2pSigExtensionsEnabled bool) bool { func (i InventoryType) Valid(p2pSigExtensionsEnabled bool) bool {
return i == BlockType || i == TXType || i == ConsensusType || (p2pSigExtensionsEnabled && i == P2PNotaryRequestType) return i == BlockType || i == TXType || i == ExtensibleType || (p2pSigExtensionsEnabled && i == P2PNotaryRequestType)
} }
// List of valid InventoryTypes. // List of valid InventoryTypes.
const ( const (
TXType InventoryType = 0x2b TXType InventoryType = 0x2b
BlockType InventoryType = 0x2c BlockType InventoryType = 0x2c
ConsensusType InventoryType = 0x2d ExtensibleType InventoryType = 0x2e
P2PNotaryRequestType InventoryType = 0x50 P2PNotaryRequestType InventoryType = 0x50
) )

View file

@ -35,8 +35,8 @@ func TestValid(t *testing.T) {
require.True(t, TXType.Valid(true)) require.True(t, TXType.Valid(true))
require.True(t, BlockType.Valid(false)) require.True(t, BlockType.Valid(false))
require.True(t, BlockType.Valid(true)) require.True(t, BlockType.Valid(true))
require.True(t, ConsensusType.Valid(false)) require.True(t, ExtensibleType.Valid(false))
require.True(t, ConsensusType.Valid(true)) require.True(t, ExtensibleType.Valid(true))
require.False(t, P2PNotaryRequestType.Valid(false)) require.False(t, P2PNotaryRequestType.Valid(false))
require.True(t, P2PNotaryRequestType.Valid(true)) require.True(t, P2PNotaryRequestType.Valid(true))
require.False(t, InventoryType(0xFF).Valid(false)) require.False(t, InventoryType(0xFF).Valid(false))
@ -46,7 +46,7 @@ func TestValid(t *testing.T) {
func TestString(t *testing.T) { func TestString(t *testing.T) {
require.Equal(t, "TX", TXType.String()) require.Equal(t, "TX", TXType.String())
require.Equal(t, "block", BlockType.String()) require.Equal(t, "block", BlockType.String())
require.Equal(t, "consensus", ConsensusType.String()) require.Equal(t, "extensible", ExtensibleType.String())
require.Equal(t, "p2pNotaryRequest", P2PNotaryRequestType.String()) require.Equal(t, "p2pNotaryRequest", P2PNotaryRequestType.String())
require.True(t, strings.Contains(InventoryType(0xFF).String(), "unknown")) require.True(t, strings.Contains(InventoryType(0xFF).String(), "unknown"))
} }

View file

@ -536,7 +536,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
var typExists = map[payload.InventoryType]func(util.Uint256) bool{ var typExists = map[payload.InventoryType]func(util.Uint256) bool{
payload.TXType: s.chain.HasTransaction, payload.TXType: s.chain.HasTransaction,
payload.BlockType: s.chain.HasBlock, payload.BlockType: s.chain.HasBlock,
payload.ConsensusType: func(h util.Uint256) bool { payload.ExtensibleType: func(h util.Uint256) bool {
cp := s.consensus.GetPayload(h) cp := s.consensus.GetPayload(h)
return cp != nil return cp != nil
}, },
@ -557,7 +557,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error {
if err != nil { if err != nil {
return err return err
} }
if inv.Type == payload.ConsensusType { if inv.Type == payload.ExtensibleType {
return p.EnqueueHPPacket(true, pkt) return p.EnqueueHPPacket(true, pkt)
} }
return p.EnqueueP2PPacket(pkt) return p.EnqueueP2PPacket(pkt)
@ -605,9 +605,9 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
} else { } else {
notFound = append(notFound, hash) notFound = append(notFound, hash)
} }
case payload.ConsensusType: case payload.ExtensibleType:
if cp := s.consensus.GetPayload(hash); cp != nil { if cp := s.consensus.GetPayload(hash); cp != nil {
msg = NewMessage(CMDConsensus, cp) msg = NewMessage(CMDExtensible, cp)
} }
case payload.P2PNotaryRequestType: case payload.P2PNotaryRequestType:
if nrp, ok := s.notaryRequestPool.TryGetData(hash); ok { // already have checked P2PSigExtEnabled if nrp, ok := s.notaryRequestPool.TryGetData(hash); ok { // already have checked P2PSigExtEnabled
@ -619,7 +619,7 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error {
if msg != nil { if msg != nil {
pkt, err := msg.Bytes() pkt, err := msg.Bytes()
if err == nil { if err == nil {
if inv.Type == payload.ConsensusType { if inv.Type == payload.ExtensibleType {
err = p.EnqueueHPPacket(true, pkt) err = p.EnqueueHPPacket(true, pkt)
} else { } else {
err = p.EnqueueP2PPacket(pkt) err = p.EnqueueP2PPacket(pkt)
@ -715,10 +715,29 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error
return p.EnqueueP2PMessage(msg) return p.EnqueueP2PMessage(msg)
} }
// handleConsensusCmd processes received consensus payload. const extensibleVerifyMaxGAS = 2000000
// It never returns an error.
func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { // handleExtensibleCmd processes received extensible payload.
s.consensus.OnPayload(cp) func (s *Server) handleExtensibleCmd(e *payload.Extensible) error {
if err := s.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil {
return err
}
h := s.chain.BlockHeight()
if h < e.ValidBlockStart || e.ValidBlockEnd <= h {
// We can receive consensus payload for the last or next block
// which leads to unwanted node disconnect.
if e.ValidBlockEnd == h {
return nil
}
return errors.New("invalid height")
}
switch e.Category {
case consensus.Category:
s.consensus.OnPayload(e)
default:
return errors.New("invalid category")
}
return nil return nil
} }
@ -895,9 +914,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
case CMDBlock: case CMDBlock:
block := msg.Payload.(*block.Block) block := msg.Payload.(*block.Block)
return s.handleBlockCmd(peer, block) return s.handleBlockCmd(peer, block)
case CMDConsensus: case CMDExtensible:
cp := msg.Payload.(*consensus.Payload) cp := msg.Payload.(*payload.Extensible)
return s.handleConsensusCmd(cp) return s.handleExtensibleCmd(cp)
case CMDTX: case CMDTX:
tx := msg.Payload.(*transaction.Transaction) tx := msg.Payload.(*transaction.Transaction)
return s.handleTxCmd(tx) return s.handleTxCmd(tx)
@ -933,8 +952,8 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error {
return nil return nil
} }
func (s *Server) handleNewPayload(p *consensus.Payload) { func (s *Server) handleNewPayload(p *payload.Extensible) {
msg := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()}))
// It's high priority because it directly affects consensus process, // It's high priority because it directly affects consensus process,
// even though it's just an inv. // even though it's just an inv.
s.broadcastHPMessage(msg) s.broadcastHPMessage(msg)

View file

@ -31,7 +31,7 @@ import (
type fakeConsensus struct { type fakeConsensus struct {
started atomic.Bool started atomic.Bool
stopped atomic.Bool stopped atomic.Bool
payloads []*consensus.Payload payloads []*payload.Extensible
txs []*transaction.Transaction txs []*transaction.Transaction
} }
@ -42,9 +42,9 @@ func newFakeConsensus(c consensus.Config) (consensus.Service, error) {
} }
func (f *fakeConsensus) Start() { f.started.Store(true) } func (f *fakeConsensus) Start() { f.started.Store(true) }
func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) } func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) }
func (f *fakeConsensus) OnPayload(p *consensus.Payload) { f.payloads = append(f.payloads, p) } func (f *fakeConsensus) OnPayload(p *payload.Extensible) { f.payloads = append(f.payloads, p) }
func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) } func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) }
func (f *fakeConsensus) GetPayload(h util.Uint256) *consensus.Payload { panic("implement me") } func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") }
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
bc := &testChain{} bc := &testChain{}
@ -405,9 +405,48 @@ func TestConsensus(t *testing.T) {
s, shutdown := startTestServer(t) s, shutdown := startTestServer(t)
defer shutdown() defer shutdown()
pl := consensus.NewPayload(netmode.UnitTestNet, false) atomic2.StoreUint32(&s.chain.(*testChain).blockheight, 4)
s.testHandleMessage(t, nil, CMDConsensus, pl) p := newLocalPeer(t, s)
require.Contains(t, s.consensus.(*fakeConsensus).payloads, pl) p.handshaked = true
newConsensusMessage := func(start, end uint32) *Message {
pl := payload.NewExtensible(netmode.UnitTestNet)
pl.Category = consensus.Category
pl.ValidBlockStart = start
pl.ValidBlockEnd = end
return NewMessage(CMDExtensible, pl)
}
s.chain.(*testChain).verifyWitnessF = func() error { return errors.New("invalid") }
msg := newConsensusMessage(0, s.chain.BlockHeight()+1)
require.Error(t, s.handleMessage(p, msg))
s.chain.(*testChain).verifyWitnessF = func() error { return nil }
require.NoError(t, s.handleMessage(p, msg))
require.Contains(t, s.consensus.(*fakeConsensus).payloads, msg.Payload.(*payload.Extensible))
t.Run("small ValidUntilBlockEnd", func(t *testing.T) {
t.Run("current height", func(t *testing.T) {
msg := newConsensusMessage(0, s.chain.BlockHeight())
require.NoError(t, s.handleMessage(p, msg))
require.NotContains(t, s.consensus.(*fakeConsensus).payloads, msg.Payload.(*payload.Extensible))
})
t.Run("invalid", func(t *testing.T) {
msg := newConsensusMessage(0, s.chain.BlockHeight()-1)
require.Error(t, s.handleMessage(p, msg))
})
})
t.Run("big ValidUntiLBlockStart", func(t *testing.T) {
msg := newConsensusMessage(s.chain.BlockHeight()+1, s.chain.BlockHeight()+2)
require.Error(t, s.handleMessage(p, msg))
})
t.Run("invalid category", func(t *testing.T) {
pl := payload.NewExtensible(netmode.UnitTestNet)
pl.Category = "invalid"
pl.ValidBlockEnd = s.chain.BlockHeight() + 1
msg := NewMessage(CMDExtensible, pl)
require.Error(t, s.handleMessage(p, msg))
})
} }
func TestTransaction(t *testing.T) { func TestTransaction(t *testing.T) {
@ -448,7 +487,7 @@ func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType,
p.handshaked = true p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
switch msg.Command { switch msg.Command {
case CMDTX, CMDBlock, CMDConsensus, CMDP2PNotaryRequest: case CMDTX, CMDBlock, CMDExtensible, CMDP2PNotaryRequest:
require.Equal(t, found, msg.Payload) require.Equal(t, found, msg.Payload)
recvResponse.Store(true) recvResponse.Store(true)
case CMDNotFound: case CMDNotFound: