diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 49b876c06..426972c6b 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -22,6 +22,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + uatomic "go.uber.org/atomic" ) // FakeChain implements Blockchainer interface, but does not provide real functionality. @@ -45,9 +46,11 @@ type FakeChain struct { // FakeStateSync implements StateSync interface. type FakeStateSync struct { - IsActiveFlag bool - IsInitializedFlag bool + IsActiveFlag uatomic.Bool + IsInitializedFlag uatomic.Bool InitFunc func(h uint32) error + TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error + AddMPTNodesFunc func(nodes [][]byte) error } // NewFakeChain returns new FakeChain structure. @@ -461,7 +464,10 @@ func (s *FakeStateSync) AddHeaders(...*block.Header) error { } // AddMPTNodes implements StateSync interface. -func (s *FakeStateSync) AddMPTNodes([][]byte) error { +func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error { + if s.AddMPTNodesFunc != nil { + return s.AddMPTNodesFunc(nodes) + } panic("TODO") } @@ -471,11 +477,11 @@ func (s *FakeStateSync) BlockHeight() uint32 { } // IsActive implements StateSync interface. -func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag } +func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() } // IsInitialized implements StateSync interface. func (s *FakeStateSync) IsInitialized() bool { - return s.IsInitializedFlag + return s.IsInitializedFlag.Load() } // Init implements StateSync interface. @@ -496,6 +502,9 @@ func (s *FakeStateSync) NeedMPTNodes() bool { // Traverse implements StateSync interface. func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if s.TraverseFunc != nil { + return s.TraverseFunc(root, process) + } panic("TODO") } diff --git a/pkg/network/server.go b/pkg/network/server.go index bf57ade5f..16a6a46f5 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -827,18 +827,23 @@ func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error { } resp := payload.MPTData{} capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) + added := make(map[util.Uint256]struct{}) for _, h := range inv.Hashes { if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type break } err := s.stateSync.Traverse(h, - func(_ mpt.Node, node []byte) bool { + func(n mpt.Node, node []byte) bool { + if _, ok := added[n.Hash()]; ok { + return false + } l := len(node) size := l + io.GetVarSize(l) if size > capLeft { return true } resp.Nodes = append(resp.Nodes, node) + added[n.Hash()] = struct{}{} capLeft -= size return false }) diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 4b1de5b64..69dc0d86b 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -17,6 +17,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/interop" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -737,6 +738,139 @@ func TestInv(t *testing.T) { }) } +func TestHandleGetMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("KeepOnlyLatestState on", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ + Hashes: []util.Uint256{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + var recvResponse atomic.Bool + r1 := random.Uint256() + r2 := random.Uint256() + r3 := random.Uint256() + node := []byte{1, 2, 3} + s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { + if !(root.Equals(r1) || root.Equals(r2)) { + t.Fatal("unexpected root") + } + require.False(t, process(mpt.NewHashNode(r3), node)) + return nil + } + found := &payload.MPTData{ + Nodes: [][]byte{node}, // no duplicates expected + } + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + switch msg.Command { + case CMDMPTData: + require.Equal(t, found, msg.Payload) + recvResponse.Store(true) + } + } + hs := []util.Uint256{r1, r2} + s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs)) + + require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond) + }) +} + +func TestHandleMPTData(t *testing.T) { + t.Run("P2PStateExchange extensions off", func(t *testing.T) { + s := startTestServer(t) + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: [][]byte{{1, 2, 3}}, + }) + require.Error(t, s.handleMessage(p, msg)) + }) + + t.Run("good", func(t *testing.T) { + s := startTestServer(t) + expected := [][]byte{{1, 2, 3}, {2, 3, 4}} + s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s.stateSync = &fakechain.FakeStateSync{ + AddMPTNodesFunc: func(nodes [][]byte) error { + require.Equal(t, expected, nodes) + return nil + }, + } + + p := newLocalPeer(t, s) + p.handshaked = true + msg := NewMessage(CMDMPTData, &payload.MPTData{ + Nodes: expected, + }) + require.NoError(t, s.handleMessage(p, msg)) + }) +} + +func TestRequestMPTNodes(t *testing.T) { + s := startTestServer(t) + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDGetMPTData { + actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...) + } + } + s.register <- p + s.register <- p // ensure previous send was handled + + t.Run("no hashes, no message", func(t *testing.T) { + actual = nil + require.NoError(t, s.requestMPTNodes(p, nil)) + require.Nil(t, actual) + }) + t.Run("good, small", func(t *testing.T) { + actual = nil + expected := []util.Uint256{random.Uint256(), random.Uint256()} + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, exactly one chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected, actual) + }) + t.Run("good, too large chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxMPTHashesCount+1) + for i := range expected { + expected[i] = random.Uint256() + } + require.NoError(t, s.requestMPTNodes(p, expected)) + require.Equal(t, expected[:payload.MaxMPTHashesCount], actual) + }) +} + func TestRequestTx(t *testing.T) { s := startTestServer(t) @@ -912,7 +1046,10 @@ func TestTryInitStateSync(t *testing.T) { t.Run("module already initialized", func(t *testing.T) { s := startTestServer(t) - s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true} + ss := &fakechain.FakeStateSync{} + ss.IsActiveFlag.Store(true) + ss.IsInitializedFlag.Store(true) + s.stateSync = ss s.tryInitStateSync() }) @@ -930,12 +1067,14 @@ func TestTryInitStateSync(t *testing.T) { s.peers[p] = true var expectedH uint32 = 8 // median peer - s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error { + ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error { if h != expectedH { return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) } return nil }} + ss.IsActiveFlag.Store(true) + s.stateSync = ss s.tryInitStateSync() }) }