mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-26 09:42:22 +00:00
network: do not duplicate MPT nodes in GetMPTNodes response
Also tests are added.
This commit is contained in:
parent
51c8c0d82b
commit
0fa48691f7
3 changed files with 161 additions and 8 deletions
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||||
|
uatomic "go.uber.org/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
||||||
|
@ -45,9 +46,11 @@ type FakeChain struct {
|
||||||
|
|
||||||
// FakeStateSync implements StateSync interface.
|
// FakeStateSync implements StateSync interface.
|
||||||
type FakeStateSync struct {
|
type FakeStateSync struct {
|
||||||
IsActiveFlag bool
|
IsActiveFlag uatomic.Bool
|
||||||
IsInitializedFlag bool
|
IsInitializedFlag uatomic.Bool
|
||||||
InitFunc func(h uint32) error
|
InitFunc func(h uint32) error
|
||||||
|
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||||
|
AddMPTNodesFunc func(nodes [][]byte) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFakeChain returns new FakeChain structure.
|
// NewFakeChain returns new FakeChain structure.
|
||||||
|
@ -461,7 +464,10 @@ func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMPTNodes implements StateSync interface.
|
// AddMPTNodes implements StateSync interface.
|
||||||
func (s *FakeStateSync) AddMPTNodes([][]byte) error {
|
func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
|
||||||
|
if s.AddMPTNodesFunc != nil {
|
||||||
|
return s.AddMPTNodesFunc(nodes)
|
||||||
|
}
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -471,11 +477,11 @@ func (s *FakeStateSync) BlockHeight() uint32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsActive implements StateSync interface.
|
// IsActive implements StateSync interface.
|
||||||
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag }
|
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
|
||||||
|
|
||||||
// IsInitialized implements StateSync interface.
|
// IsInitialized implements StateSync interface.
|
||||||
func (s *FakeStateSync) IsInitialized() bool {
|
func (s *FakeStateSync) IsInitialized() bool {
|
||||||
return s.IsInitializedFlag
|
return s.IsInitializedFlag.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init implements StateSync interface.
|
// Init implements StateSync interface.
|
||||||
|
@ -496,6 +502,9 @@ func (s *FakeStateSync) NeedMPTNodes() bool {
|
||||||
|
|
||||||
// Traverse implements StateSync interface.
|
// Traverse implements StateSync interface.
|
||||||
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
if s.TraverseFunc != nil {
|
||||||
|
return s.TraverseFunc(root, process)
|
||||||
|
}
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -827,18 +827,23 @@ func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
||||||
}
|
}
|
||||||
resp := payload.MPTData{}
|
resp := payload.MPTData{}
|
||||||
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||||
|
added := make(map[util.Uint256]struct{})
|
||||||
for _, h := range inv.Hashes {
|
for _, h := range inv.Hashes {
|
||||||
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
err := s.stateSync.Traverse(h,
|
err := s.stateSync.Traverse(h,
|
||||||
func(_ mpt.Node, node []byte) bool {
|
func(n mpt.Node, node []byte) bool {
|
||||||
|
if _, ok := added[n.Hash()]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
l := len(node)
|
l := len(node)
|
||||||
size := l + io.GetVarSize(l)
|
size := l + io.GetVarSize(l)
|
||||||
if size > capLeft {
|
if size > capLeft {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
resp.Nodes = append(resp.Nodes, node)
|
resp.Nodes = append(resp.Nodes, node)
|
||||||
|
added[n.Hash()] = struct{}{}
|
||||||
capLeft -= size
|
capLeft -= size
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core"
|
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
|
@ -737,6 +738,139 @@ func TestInv(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleGetMPTData(t *testing.T) {
|
||||||
|
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||||
|
Hashes: []util.Uint256{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
var recvResponse atomic.Bool
|
||||||
|
r1 := random.Uint256()
|
||||||
|
r2 := random.Uint256()
|
||||||
|
r3 := random.Uint256()
|
||||||
|
node := []byte{1, 2, 3}
|
||||||
|
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||||
|
if !(root.Equals(r1) || root.Equals(r2)) {
|
||||||
|
t.Fatal("unexpected root")
|
||||||
|
}
|
||||||
|
require.False(t, process(mpt.NewHashNode(r3), node))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
found := &payload.MPTData{
|
||||||
|
Nodes: [][]byte{node}, // no duplicates expected
|
||||||
|
}
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||||
|
switch msg.Command {
|
||||||
|
case CMDMPTData:
|
||||||
|
require.Equal(t, found, msg.Payload)
|
||||||
|
recvResponse.Store(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hs := []util.Uint256{r1, r2}
|
||||||
|
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
|
||||||
|
|
||||||
|
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleMPTData(t *testing.T) {
|
||||||
|
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: [][]byte{{1, 2, 3}},
|
||||||
|
})
|
||||||
|
require.Error(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("good", func(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
|
||||||
|
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||||
|
s.stateSync = &fakechain.FakeStateSync{
|
||||||
|
AddMPTNodesFunc: func(nodes [][]byte) error {
|
||||||
|
require.Equal(t, expected, nodes)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||||
|
Nodes: expected,
|
||||||
|
})
|
||||||
|
require.NoError(t, s.handleMessage(p, msg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestMPTNodes(t *testing.T) {
|
||||||
|
s := startTestServer(t)
|
||||||
|
|
||||||
|
var actual []util.Uint256
|
||||||
|
p := newLocalPeer(t, s)
|
||||||
|
p.handshaked = true
|
||||||
|
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||||
|
if msg.Command == CMDGetMPTData {
|
||||||
|
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.register <- p
|
||||||
|
s.register <- p // ensure previous send was handled
|
||||||
|
|
||||||
|
t.Run("no hashes, no message", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, nil))
|
||||||
|
require.Nil(t, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, small", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := []util.Uint256{random.Uint256(), random.Uint256()}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, exactly one chunk", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
|
||||||
|
for i := range expected {
|
||||||
|
expected[i] = random.Uint256()
|
||||||
|
}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
})
|
||||||
|
t.Run("good, too large chunk", func(t *testing.T) {
|
||||||
|
actual = nil
|
||||||
|
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
|
||||||
|
for i := range expected {
|
||||||
|
expected[i] = random.Uint256()
|
||||||
|
}
|
||||||
|
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||||
|
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRequestTx(t *testing.T) {
|
func TestRequestTx(t *testing.T) {
|
||||||
s := startTestServer(t)
|
s := startTestServer(t)
|
||||||
|
|
||||||
|
@ -912,7 +1046,10 @@ func TestTryInitStateSync(t *testing.T) {
|
||||||
|
|
||||||
t.Run("module already initialized", func(t *testing.T) {
|
t.Run("module already initialized", func(t *testing.T) {
|
||||||
s := startTestServer(t)
|
s := startTestServer(t)
|
||||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true}
|
ss := &fakechain.FakeStateSync{}
|
||||||
|
ss.IsActiveFlag.Store(true)
|
||||||
|
ss.IsInitializedFlag.Store(true)
|
||||||
|
s.stateSync = ss
|
||||||
s.tryInitStateSync()
|
s.tryInitStateSync()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -930,12 +1067,14 @@ func TestTryInitStateSync(t *testing.T) {
|
||||||
s.peers[p] = true
|
s.peers[p] = true
|
||||||
var expectedH uint32 = 8 // median peer
|
var expectedH uint32 = 8 // median peer
|
||||||
|
|
||||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error {
|
ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
|
||||||
if h != expectedH {
|
if h != expectedH {
|
||||||
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}}
|
}}
|
||||||
|
ss.IsActiveFlag.Store(true)
|
||||||
|
s.stateSync = ss
|
||||||
s.tryInitStateSync()
|
s.tryInitStateSync()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue