diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 0dcc819a3..1f6649194 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -52,11 +52,11 @@ type Service interface { Shutdown() // OnPayload is a callback to notify Service about new received payload. - OnPayload(p *Payload) + OnPayload(p *npayload.Extensible) // OnTransaction is a callback to notify Service about new received transaction. OnTransaction(tx *transaction.Transaction) // GetPayload returns Payload with specified hash if it is present in the local cache. - GetPayload(h util.Uint256) *Payload + GetPayload(h util.Uint256) *npayload.Extensible } type service struct { @@ -98,7 +98,7 @@ type Config struct { Logger *zap.Logger // Broadcast is a callback which is called to notify server // about new consensus payload to sent. - Broadcast func(p *Payload) + Broadcast func(p *npayload.Extensible) // Chain is a core.Blockchainer instance. Chain blockchainer.Blockchainer // RequestTx is a callback to which will be called @@ -367,13 +367,26 @@ func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, c return -1, nil, nil } +func (s *service) payloadFromExtensible(ep *npayload.Extensible) *Payload { + return &Payload{ + Extensible: *ep, + message: message{ + stateRootEnabled: s.stateRootEnabled, + }, + } +} + // OnPayload handles Payload receive. -func (s *service) OnPayload(cp *Payload) { +func (s *service) OnPayload(cp *npayload.Extensible) { log := s.log.With(zap.Stringer("hash", cp.Hash())) if s.cache.Has(cp.Hash()) { log.Debug("payload is already in cache") return - } else if !s.validatePayload(cp) { + } + + p := s.payloadFromExtensible(cp) + p.decodeData() + if !s.validatePayload(p) { log.Info("can't validate payload") return } @@ -387,14 +400,14 @@ func (s *service) OnPayload(cp *Payload) { } // decode payload data into message - if cp.message.payload == nil { - if err := cp.decodeData(); err != nil { + if p.message.payload == nil { + if err := p.decodeData(); err != nil { log.Info("can't decode payload data") return } } - s.messages <- *cp + s.messages <- *p } func (s *service) OnTransaction(tx *transaction.Transaction) { @@ -404,13 +417,13 @@ func (s *service) OnTransaction(tx *transaction.Transaction) { } // GetPayload returns payload stored in cache. -func (s *service) GetPayload(h util.Uint256) *Payload { +func (s *service) GetPayload(h util.Uint256) *npayload.Extensible { p := s.cache.Get(h) if p == nil { - return (*Payload)(nil) + return (*npayload.Extensible)(nil) } - cp := *p.(*Payload) + cp := *p.(*npayload.Extensible) return &cp } @@ -420,8 +433,9 @@ func (s *service) broadcast(p payload.ConsensusPayload) { s.log.Warn("can't sign consensus payload", zap.Error(err)) } - s.cache.Add(p) - s.Config.Broadcast(p.(*Payload)) + ep := &p.(*Payload).Extensible + s.cache.Add(ep) + s.Config.Broadcast(ep) } func (s *service) getTx(h util.Uint256) block.Transaction { diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index 7e88d80fa..3f8737dc0 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -21,6 +21,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/io" + npayload "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/callflag" "github.com/nspcc-dev/neo-go/pkg/util" @@ -345,9 +346,10 @@ func TestService_OnPayload(t *testing.T) { p := new(Payload) p.SetValidatorIndex(1) p.SetPayload(&prepareRequest{}) + p.encodeData() // sender is invalid - srv.OnPayload(p) + srv.OnPayload(&p.Extensible) shouldNotReceive(t, srv.messages) require.Nil(t, srv.GetPayload(p.Hash())) @@ -356,12 +358,12 @@ func TestService_OnPayload(t *testing.T) { p.Sender = priv.GetScriptHash() p.SetPayload(&prepareRequest{}) require.NoError(t, p.Sign(priv)) - srv.OnPayload(p) + srv.OnPayload(&p.Extensible) shouldReceive(t, srv.messages) - require.Equal(t, p, srv.GetPayload(p.Hash())) + require.Equal(t, &p.Extensible, srv.GetPayload(p.Hash())) // payload has already been received - srv.OnPayload(p) + srv.OnPayload(&p.Extensible) shouldNotReceive(t, srv.messages) srv.Chain.Close() } @@ -477,7 +479,7 @@ func newTestService(t *testing.T) *service { func newTestServiceWithChain(t *testing.T, bc *core.Blockchain) *service { srv, err := NewService(Config{ Logger: zaptest.NewLogger(t), - Broadcast: func(*Payload) {}, + Broadcast: func(*npayload.Extensible) {}, Chain: bc, RequestTx: func(...util.Uint256) {}, TimePerBlock: time.Duration(bc.GetConfig().SecondsPerBlock) * time.Second, diff --git a/pkg/network/message.go b/pkg/network/message.go index c8bf606ae..9c64cc802 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/nspcc-dev/neo-go/pkg/config/netmode" - "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/io" @@ -74,7 +73,7 @@ const ( CMDNotFound CommandType = 0x2a CMDTX = CommandType(payload.TXType) CMDBlock = CommandType(payload.BlockType) - CMDConsensus = CommandType(payload.ConsensusType) + CMDExtensible = CommandType(payload.ExtensibleType) CMDP2PNotaryRequest = CommandType(payload.P2PNotaryRequestType) CMDReject CommandType = 0x2f @@ -147,8 +146,8 @@ func (m *Message) decodePayload() error { p = &payload.AddressList{} case CMDBlock: p = block.New(m.Network, m.StateRootInHeader) - case CMDConsensus: - p = consensus.NewPayload(m.Network, m.StateRootInHeader) + case CMDExtensible: + p = payload.NewExtensible(m.Network) case CMDP2PNotaryRequest: p = &payload.P2PNotaryRequest{Network: m.Network} case CMDGetBlocks: diff --git a/pkg/network/message_string.go b/pkg/network/message_string.go index 233c9084b..7da007079 100644 --- a/pkg/network/message_string.go +++ b/pkg/network/message_string.go @@ -24,7 +24,8 @@ func _() { _ = x[CMDNotFound-42] _ = x[CMDTX-43] _ = x[CMDBlock-44] - _ = x[CMDConsensus-45] + _ = x[CMDExtensible-46] + _ = x[CMDP2PNotaryRequest-80] _ = x[CMDReject-47] _ = x[CMDFilterLoad-48] _ = x[CMDFilterAdd-49] @@ -39,10 +40,11 @@ const ( _CommandType_name_2 = "CMDPingCMDPong" _CommandType_name_3 = "CMDGetHeadersCMDHeaders" _CommandType_name_4 = "CMDGetBlocksCMDMempool" - _CommandType_name_5 = "CMDInvCMDGetDataCMDGetBlockByIndexCMDNotFoundCMDTXCMDBlockCMDConsensus" - _CommandType_name_6 = "CMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" + _CommandType_name_5 = "CMDInvCMDGetDataCMDGetBlockByIndexCMDNotFoundCMDTXCMDBlock" + _CommandType_name_6 = "CMDExtensibleCMDRejectCMDFilterLoadCMDFilterAddCMDFilterClear" _CommandType_name_7 = "CMDMerkleBlock" _CommandType_name_8 = "CMDAlert" + _CommandType_name_9 = "CMDP2PNotaryRequest" ) var ( @@ -51,8 +53,8 @@ var ( _CommandType_index_2 = [...]uint8{0, 7, 14} _CommandType_index_3 = [...]uint8{0, 13, 23} _CommandType_index_4 = [...]uint8{0, 12, 22} - _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58, 70} - _CommandType_index_6 = [...]uint8{0, 9, 22, 34, 48} + _CommandType_index_5 = [...]uint8{0, 6, 16, 34, 45, 50, 58} + _CommandType_index_6 = [...]uint8{0, 13, 22, 35, 47, 61} ) func (i CommandType) String() string { @@ -71,16 +73,18 @@ func (i CommandType) String() string { case 36 <= i && i <= 37: i -= 36 return _CommandType_name_4[_CommandType_index_4[i]:_CommandType_index_4[i+1]] - case 39 <= i && i <= 45: + case 39 <= i && i <= 44: i -= 39 return _CommandType_name_5[_CommandType_index_5[i]:_CommandType_index_5[i+1]] - case 47 <= i && i <= 50: - i -= 47 + case 46 <= i && i <= 50: + i -= 46 return _CommandType_name_6[_CommandType_index_6[i]:_CommandType_index_6[i+1]] case i == 56: return _CommandType_name_7 case i == 64: return _CommandType_name_8 + case i == 80: + return _CommandType_name_9 default: return "CommandType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index df2010b88..697c7c7da 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -96,7 +96,7 @@ func TestEncodeDecodePing(t *testing.T) { } func TestEncodeDecodeInventory(t *testing.T) { - testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) + testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{{1, 2, 3}})) } func TestEncodeDecodeAddr(t *testing.T) { diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index 8727e2929..563ccdde1 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -18,8 +18,8 @@ func (i InventoryType) String() string { return "TX" case BlockType: return "block" - case ConsensusType: - return "consensus" + case ExtensibleType: + return "extensible" case P2PNotaryRequestType: return "p2pNotaryRequest" default: @@ -29,14 +29,14 @@ func (i InventoryType) String() string { // Valid returns true if the inventory (type) is known. func (i InventoryType) Valid(p2pSigExtensionsEnabled bool) bool { - return i == BlockType || i == TXType || i == ConsensusType || (p2pSigExtensionsEnabled && i == P2PNotaryRequestType) + return i == BlockType || i == TXType || i == ExtensibleType || (p2pSigExtensionsEnabled && i == P2PNotaryRequestType) } // List of valid InventoryTypes. const ( TXType InventoryType = 0x2b BlockType InventoryType = 0x2c - ConsensusType InventoryType = 0x2d + ExtensibleType InventoryType = 0x2e P2PNotaryRequestType InventoryType = 0x50 ) diff --git a/pkg/network/payload/inventory_test.go b/pkg/network/payload/inventory_test.go index 7d3920f9e..3b684367b 100644 --- a/pkg/network/payload/inventory_test.go +++ b/pkg/network/payload/inventory_test.go @@ -35,8 +35,8 @@ func TestValid(t *testing.T) { require.True(t, TXType.Valid(true)) require.True(t, BlockType.Valid(false)) require.True(t, BlockType.Valid(true)) - require.True(t, ConsensusType.Valid(false)) - require.True(t, ConsensusType.Valid(true)) + require.True(t, ExtensibleType.Valid(false)) + require.True(t, ExtensibleType.Valid(true)) require.False(t, P2PNotaryRequestType.Valid(false)) require.True(t, P2PNotaryRequestType.Valid(true)) require.False(t, InventoryType(0xFF).Valid(false)) @@ -46,7 +46,7 @@ func TestValid(t *testing.T) { func TestString(t *testing.T) { require.Equal(t, "TX", TXType.String()) require.Equal(t, "block", BlockType.String()) - require.Equal(t, "consensus", ConsensusType.String()) + require.Equal(t, "extensible", ExtensibleType.String()) require.Equal(t, "p2pNotaryRequest", P2PNotaryRequestType.String()) require.True(t, strings.Contains(InventoryType(0xFF).String(), "unknown")) } diff --git a/pkg/network/server.go b/pkg/network/server.go index b8befac11..8b5c1ea89 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -536,7 +536,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { var typExists = map[payload.InventoryType]func(util.Uint256) bool{ payload.TXType: s.chain.HasTransaction, payload.BlockType: s.chain.HasBlock, - payload.ConsensusType: func(h util.Uint256) bool { + payload.ExtensibleType: func(h util.Uint256) bool { cp := s.consensus.GetPayload(h) return cp != nil }, @@ -557,7 +557,7 @@ func (s *Server) handleInvCmd(p Peer, inv *payload.Inventory) error { if err != nil { return err } - if inv.Type == payload.ConsensusType { + if inv.Type == payload.ExtensibleType { return p.EnqueueHPPacket(true, pkt) } return p.EnqueueP2PPacket(pkt) @@ -605,9 +605,9 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { } else { notFound = append(notFound, hash) } - case payload.ConsensusType: + case payload.ExtensibleType: if cp := s.consensus.GetPayload(hash); cp != nil { - msg = NewMessage(CMDConsensus, cp) + msg = NewMessage(CMDExtensible, cp) } case payload.P2PNotaryRequestType: if nrp, ok := s.notaryRequestPool.TryGetData(hash); ok { // already have checked P2PSigExtEnabled @@ -619,7 +619,7 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { if msg != nil { pkt, err := msg.Bytes() if err == nil { - if inv.Type == payload.ConsensusType { + if inv.Type == payload.ExtensibleType { err = p.EnqueueHPPacket(true, pkt) } else { err = p.EnqueueP2PPacket(pkt) @@ -715,10 +715,29 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlockByIndex) error return p.EnqueueP2PMessage(msg) } -// handleConsensusCmd processes received consensus payload. -// It never returns an error. -func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { - s.consensus.OnPayload(cp) +const extensibleVerifyMaxGAS = 2000000 + +// handleExtensibleCmd processes received extensible payload. +func (s *Server) handleExtensibleCmd(e *payload.Extensible) error { + if err := s.chain.VerifyWitness(e.Sender, e, &e.Witness, extensibleVerifyMaxGAS); err != nil { + return err + } + h := s.chain.BlockHeight() + if h < e.ValidBlockStart || e.ValidBlockEnd <= h { + // We can receive consensus payload for the last or next block + // which leads to unwanted node disconnect. + if e.ValidBlockEnd == h { + return nil + } + return errors.New("invalid height") + } + + switch e.Category { + case consensus.Category: + s.consensus.OnPayload(e) + default: + return errors.New("invalid category") + } return nil } @@ -895,9 +914,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDBlock: block := msg.Payload.(*block.Block) return s.handleBlockCmd(peer, block) - case CMDConsensus: - cp := msg.Payload.(*consensus.Payload) - return s.handleConsensusCmd(cp) + case CMDExtensible: + cp := msg.Payload.(*payload.Extensible) + return s.handleExtensibleCmd(cp) case CMDTX: tx := msg.Payload.(*transaction.Transaction) return s.handleTxCmd(tx) @@ -933,8 +952,8 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return nil } -func (s *Server) handleNewPayload(p *consensus.Payload) { - msg := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{p.Hash()})) +func (s *Server) handleNewPayload(p *payload.Extensible) { + msg := NewMessage(CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{p.Hash()})) // It's high priority because it directly affects consensus process, // even though it's just an inv. s.broadcastHPMessage(msg) diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 9e457f1a2..365c4fe71 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -31,7 +31,7 @@ import ( type fakeConsensus struct { started atomic.Bool stopped atomic.Bool - payloads []*consensus.Payload + payloads []*payload.Extensible txs []*transaction.Transaction } @@ -40,11 +40,11 @@ var _ consensus.Service = (*fakeConsensus)(nil) func newFakeConsensus(c consensus.Config) (consensus.Service, error) { return new(fakeConsensus), nil } -func (f *fakeConsensus) Start() { f.started.Store(true) } -func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) } -func (f *fakeConsensus) OnPayload(p *consensus.Payload) { f.payloads = append(f.payloads, p) } -func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) } -func (f *fakeConsensus) GetPayload(h util.Uint256) *consensus.Payload { panic("implement me") } +func (f *fakeConsensus) Start() { f.started.Store(true) } +func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) } +func (f *fakeConsensus) OnPayload(p *payload.Extensible) { f.payloads = append(f.payloads, p) } +func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) } +func (f *fakeConsensus) GetPayload(h util.Uint256) *payload.Extensible { panic("implement me") } func TestNewServer(t *testing.T) { bc := &testChain{} @@ -405,9 +405,48 @@ func TestConsensus(t *testing.T) { s, shutdown := startTestServer(t) defer shutdown() - pl := consensus.NewPayload(netmode.UnitTestNet, false) - s.testHandleMessage(t, nil, CMDConsensus, pl) - require.Contains(t, s.consensus.(*fakeConsensus).payloads, pl) + atomic2.StoreUint32(&s.chain.(*testChain).blockheight, 4) + p := newLocalPeer(t, s) + p.handshaked = true + + newConsensusMessage := func(start, end uint32) *Message { + pl := payload.NewExtensible(netmode.UnitTestNet) + pl.Category = consensus.Category + pl.ValidBlockStart = start + pl.ValidBlockEnd = end + return NewMessage(CMDExtensible, pl) + } + + s.chain.(*testChain).verifyWitnessF = func() error { return errors.New("invalid") } + msg := newConsensusMessage(0, s.chain.BlockHeight()+1) + require.Error(t, s.handleMessage(p, msg)) + + s.chain.(*testChain).verifyWitnessF = func() error { return nil } + require.NoError(t, s.handleMessage(p, msg)) + require.Contains(t, s.consensus.(*fakeConsensus).payloads, msg.Payload.(*payload.Extensible)) + + t.Run("small ValidUntilBlockEnd", func(t *testing.T) { + t.Run("current height", func(t *testing.T) { + msg := newConsensusMessage(0, s.chain.BlockHeight()) + require.NoError(t, s.handleMessage(p, msg)) + require.NotContains(t, s.consensus.(*fakeConsensus).payloads, msg.Payload.(*payload.Extensible)) + }) + t.Run("invalid", func(t *testing.T) { + msg := newConsensusMessage(0, s.chain.BlockHeight()-1) + require.Error(t, s.handleMessage(p, msg)) + }) + }) + t.Run("big ValidUntiLBlockStart", func(t *testing.T) { + msg := newConsensusMessage(s.chain.BlockHeight()+1, s.chain.BlockHeight()+2) + require.Error(t, s.handleMessage(p, msg)) + }) + t.Run("invalid category", func(t *testing.T) { + pl := payload.NewExtensible(netmode.UnitTestNet) + pl.Category = "invalid" + pl.ValidBlockEnd = s.chain.BlockHeight() + 1 + msg := NewMessage(CMDExtensible, pl) + require.Error(t, s.handleMessage(p, msg)) + }) } func TestTransaction(t *testing.T) { @@ -448,7 +487,7 @@ func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType, p.handshaked = true p.messageHandler = func(t *testing.T, msg *Message) { switch msg.Command { - case CMDTX, CMDBlock, CMDConsensus, CMDP2PNotaryRequest: + case CMDTX, CMDBlock, CMDExtensible, CMDP2PNotaryRequest: require.Equal(t, found, msg.Payload) recvResponse.Store(true) case CMDNotFound: