network: do not duplicate MPT nodes in GetMPTNodes response
Also tests are added.
This commit is contained in:
parent
51c8c0d82b
commit
0fa48691f7
3 changed files with 161 additions and 8 deletions
|
@ -22,6 +22,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
uatomic "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// FakeChain implements Blockchainer interface, but does not provide real functionality.
|
||||
|
@ -45,9 +46,11 @@ type FakeChain struct {
|
|||
|
||||
// FakeStateSync implements StateSync interface.
|
||||
type FakeStateSync struct {
|
||||
IsActiveFlag bool
|
||||
IsInitializedFlag bool
|
||||
IsActiveFlag uatomic.Bool
|
||||
IsInitializedFlag uatomic.Bool
|
||||
InitFunc func(h uint32) error
|
||||
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
|
||||
AddMPTNodesFunc func(nodes [][]byte) error
|
||||
}
|
||||
|
||||
// NewFakeChain returns new FakeChain structure.
|
||||
|
@ -461,7 +464,10 @@ func (s *FakeStateSync) AddHeaders(...*block.Header) error {
|
|||
}
|
||||
|
||||
// AddMPTNodes implements StateSync interface.
|
||||
func (s *FakeStateSync) AddMPTNodes([][]byte) error {
|
||||
func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
|
||||
if s.AddMPTNodesFunc != nil {
|
||||
return s.AddMPTNodesFunc(nodes)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
|
@ -471,11 +477,11 @@ func (s *FakeStateSync) BlockHeight() uint32 {
|
|||
}
|
||||
|
||||
// IsActive implements StateSync interface.
|
||||
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag }
|
||||
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
|
||||
|
||||
// IsInitialized implements StateSync interface.
|
||||
func (s *FakeStateSync) IsInitialized() bool {
|
||||
return s.IsInitializedFlag
|
||||
return s.IsInitializedFlag.Load()
|
||||
}
|
||||
|
||||
// Init implements StateSync interface.
|
||||
|
@ -496,6 +502,9 @@ func (s *FakeStateSync) NeedMPTNodes() bool {
|
|||
|
||||
// Traverse implements StateSync interface.
|
||||
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
if s.TraverseFunc != nil {
|
||||
return s.TraverseFunc(root, process)
|
||||
}
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
|
|
|
@ -827,18 +827,23 @@ func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
|
|||
}
|
||||
resp := payload.MPTData{}
|
||||
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
|
||||
added := make(map[util.Uint256]struct{})
|
||||
for _, h := range inv.Hashes {
|
||||
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
|
||||
break
|
||||
}
|
||||
err := s.stateSync.Traverse(h,
|
||||
func(_ mpt.Node, node []byte) bool {
|
||||
func(n mpt.Node, node []byte) bool {
|
||||
if _, ok := added[n.Hash()]; ok {
|
||||
return false
|
||||
}
|
||||
l := len(node)
|
||||
size := l + io.GetVarSize(l)
|
||||
if size > capLeft {
|
||||
return true
|
||||
}
|
||||
resp.Nodes = append(resp.Nodes, node)
|
||||
added[n.Hash()] = struct{}{}
|
||||
capLeft -= size
|
||||
return false
|
||||
})
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/nspcc-dev/neo-go/pkg/core"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
|
@ -737,6 +738,139 @@ func TestInv(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestHandleGetMPTData(t *testing.T) {
|
||||
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
|
||||
Hashes: []util.Uint256{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
var recvResponse atomic.Bool
|
||||
r1 := random.Uint256()
|
||||
r2 := random.Uint256()
|
||||
r3 := random.Uint256()
|
||||
node := []byte{1, 2, 3}
|
||||
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
|
||||
if !(root.Equals(r1) || root.Equals(r2)) {
|
||||
t.Fatal("unexpected root")
|
||||
}
|
||||
require.False(t, process(mpt.NewHashNode(r3), node))
|
||||
return nil
|
||||
}
|
||||
found := &payload.MPTData{
|
||||
Nodes: [][]byte{node}, // no duplicates expected
|
||||
}
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||
switch msg.Command {
|
||||
case CMDMPTData:
|
||||
require.Equal(t, found, msg.Payload)
|
||||
recvResponse.Store(true)
|
||||
}
|
||||
}
|
||||
hs := []util.Uint256{r1, r2}
|
||||
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
|
||||
|
||||
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleMPTData(t *testing.T) {
|
||||
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||
Nodes: [][]byte{{1, 2, 3}},
|
||||
})
|
||||
require.Error(t, s.handleMessage(p, msg))
|
||||
})
|
||||
|
||||
t.Run("good", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
|
||||
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
|
||||
s.stateSync = &fakechain.FakeStateSync{
|
||||
AddMPTNodesFunc: func(nodes [][]byte) error {
|
||||
require.Equal(t, expected, nodes)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
msg := NewMessage(CMDMPTData, &payload.MPTData{
|
||||
Nodes: expected,
|
||||
})
|
||||
require.NoError(t, s.handleMessage(p, msg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestMPTNodes(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
|
||||
var actual []util.Uint256
|
||||
p := newLocalPeer(t, s)
|
||||
p.handshaked = true
|
||||
p.messageHandler = func(t *testing.T, msg *Message) {
|
||||
if msg.Command == CMDGetMPTData {
|
||||
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
|
||||
}
|
||||
}
|
||||
s.register <- p
|
||||
s.register <- p // ensure previous send was handled
|
||||
|
||||
t.Run("no hashes, no message", func(t *testing.T) {
|
||||
actual = nil
|
||||
require.NoError(t, s.requestMPTNodes(p, nil))
|
||||
require.Nil(t, actual)
|
||||
})
|
||||
t.Run("good, small", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := []util.Uint256{random.Uint256(), random.Uint256()}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected, actual)
|
||||
})
|
||||
t.Run("good, exactly one chunk", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
|
||||
for i := range expected {
|
||||
expected[i] = random.Uint256()
|
||||
}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected, actual)
|
||||
})
|
||||
t.Run("good, too large chunk", func(t *testing.T) {
|
||||
actual = nil
|
||||
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
|
||||
for i := range expected {
|
||||
expected[i] = random.Uint256()
|
||||
}
|
||||
require.NoError(t, s.requestMPTNodes(p, expected))
|
||||
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestTx(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
|
||||
|
@ -912,7 +1046,10 @@ func TestTryInitStateSync(t *testing.T) {
|
|||
|
||||
t.Run("module already initialized", func(t *testing.T) {
|
||||
s := startTestServer(t)
|
||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true}
|
||||
ss := &fakechain.FakeStateSync{}
|
||||
ss.IsActiveFlag.Store(true)
|
||||
ss.IsInitializedFlag.Store(true)
|
||||
s.stateSync = ss
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
|
||||
|
@ -930,12 +1067,14 @@ func TestTryInitStateSync(t *testing.T) {
|
|||
s.peers[p] = true
|
||||
var expectedH uint32 = 8 // median peer
|
||||
|
||||
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error {
|
||||
ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
|
||||
if h != expectedH {
|
||||
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
|
||||
}
|
||||
return nil
|
||||
}}
|
||||
ss.IsActiveFlag.Store(true)
|
||||
s.stateSync = ss
|
||||
s.tryInitStateSync()
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue