network: do not duplicate MPT nodes in GetMPTNodes response

Also tests are added.
This commit is contained in:
Anna Shaleva 2021-09-06 15:16:47 +03:00
parent 51c8c0d82b
commit 0fa48691f7
3 changed files with 161 additions and 8 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
uatomic "go.uber.org/atomic"
) )
// FakeChain implements Blockchainer interface, but does not provide real functionality. // FakeChain implements Blockchainer interface, but does not provide real functionality.
@ -45,9 +46,11 @@ type FakeChain struct {
// FakeStateSync implements StateSync interface. // FakeStateSync implements StateSync interface.
type FakeStateSync struct { type FakeStateSync struct {
IsActiveFlag bool IsActiveFlag uatomic.Bool
IsInitializedFlag bool IsInitializedFlag uatomic.Bool
InitFunc func(h uint32) error InitFunc func(h uint32) error
TraverseFunc func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error
AddMPTNodesFunc func(nodes [][]byte) error
} }
// NewFakeChain returns new FakeChain structure. // NewFakeChain returns new FakeChain structure.
@ -461,7 +464,10 @@ func (s *FakeStateSync) AddHeaders(...*block.Header) error {
} }
// AddMPTNodes implements StateSync interface. // AddMPTNodes implements StateSync interface.
func (s *FakeStateSync) AddMPTNodes([][]byte) error { func (s *FakeStateSync) AddMPTNodes(nodes [][]byte) error {
if s.AddMPTNodesFunc != nil {
return s.AddMPTNodesFunc(nodes)
}
panic("TODO") panic("TODO")
} }
@ -471,11 +477,11 @@ func (s *FakeStateSync) BlockHeight() uint32 {
} }
// IsActive implements StateSync interface. // IsActive implements StateSync interface.
func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag } func (s *FakeStateSync) IsActive() bool { return s.IsActiveFlag.Load() }
// IsInitialized implements StateSync interface. // IsInitialized implements StateSync interface.
func (s *FakeStateSync) IsInitialized() bool { func (s *FakeStateSync) IsInitialized() bool {
return s.IsInitializedFlag return s.IsInitializedFlag.Load()
} }
// Init implements StateSync interface. // Init implements StateSync interface.
@ -496,6 +502,9 @@ func (s *FakeStateSync) NeedMPTNodes() bool {
// Traverse implements StateSync interface. // Traverse implements StateSync interface.
func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error { func (s *FakeStateSync) Traverse(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
if s.TraverseFunc != nil {
return s.TraverseFunc(root, process)
}
panic("TODO") panic("TODO")
} }

View file

@ -827,18 +827,23 @@ func (s *Server) handleGetMPTDataCmd(p Peer, inv *payload.MPTInventory) error {
} }
resp := payload.MPTData{} resp := payload.MPTData{}
capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes))) capLeft := payload.MaxSize - 8 // max(io.GetVarSize(len(resp.Nodes)))
added := make(map[util.Uint256]struct{})
for _, h := range inv.Hashes { for _, h := range inv.Hashes {
if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type if capLeft <= 2 { // at least 1 byte for len(nodeBytes) and 1 byte for node type
break break
} }
err := s.stateSync.Traverse(h, err := s.stateSync.Traverse(h,
func(_ mpt.Node, node []byte) bool { func(n mpt.Node, node []byte) bool {
if _, ok := added[n.Hash()]; ok {
return false
}
l := len(node) l := len(node)
size := l + io.GetVarSize(l) size := l + io.GetVarSize(l)
if size > capLeft { if size > capLeft {
return true return true
} }
resp.Nodes = append(resp.Nodes, node) resp.Nodes = append(resp.Nodes, node)
added[n.Hash()] = struct{}{}
capLeft -= size capLeft -= size
return false return false
}) })

View file

@ -17,6 +17,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/mpt"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
@ -737,6 +738,139 @@ func TestInv(t *testing.T) {
}) })
} }
func TestHandleGetMPTData(t *testing.T) {
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
s := startTestServer(t)
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("KeepOnlyLatestState on", func(t *testing.T) {
s := startTestServer(t)
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{
Hashes: []util.Uint256{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
var recvResponse atomic.Bool
r1 := random.Uint256()
r2 := random.Uint256()
r3 := random.Uint256()
node := []byte{1, 2, 3}
s.stateSync.(*fakechain.FakeStateSync).TraverseFunc = func(root util.Uint256, process func(node mpt.Node, nodeBytes []byte) bool) error {
if !(root.Equals(r1) || root.Equals(r2)) {
t.Fatal("unexpected root")
}
require.False(t, process(mpt.NewHashNode(r3), node))
return nil
}
found := &payload.MPTData{
Nodes: [][]byte{node}, // no duplicates expected
}
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
switch msg.Command {
case CMDMPTData:
require.Equal(t, found, msg.Payload)
recvResponse.Store(true)
}
}
hs := []util.Uint256{r1, r2}
s.testHandleMessage(t, p, CMDGetMPTData, payload.NewMPTInventory(hs))
require.Eventually(t, recvResponse.Load, time.Second, time.Millisecond)
})
}
func TestHandleMPTData(t *testing.T) {
t.Run("P2PStateExchange extensions off", func(t *testing.T) {
s := startTestServer(t)
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDMPTData, &payload.MPTData{
Nodes: [][]byte{{1, 2, 3}},
})
require.Error(t, s.handleMessage(p, msg))
})
t.Run("good", func(t *testing.T) {
s := startTestServer(t)
expected := [][]byte{{1, 2, 3}, {2, 3, 4}}
s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true
s.stateSync = &fakechain.FakeStateSync{
AddMPTNodesFunc: func(nodes [][]byte) error {
require.Equal(t, expected, nodes)
return nil
},
}
p := newLocalPeer(t, s)
p.handshaked = true
msg := NewMessage(CMDMPTData, &payload.MPTData{
Nodes: expected,
})
require.NoError(t, s.handleMessage(p, msg))
})
}
func TestRequestMPTNodes(t *testing.T) {
s := startTestServer(t)
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetMPTData {
actual = append(actual, msg.Payload.(*payload.MPTInventory).Hashes...)
}
}
s.register <- p
s.register <- p // ensure previous send was handled
t.Run("no hashes, no message", func(t *testing.T) {
actual = nil
require.NoError(t, s.requestMPTNodes(p, nil))
require.Nil(t, actual)
})
t.Run("good, small", func(t *testing.T) {
actual = nil
expected := []util.Uint256{random.Uint256(), random.Uint256()}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected, actual)
})
t.Run("good, exactly one chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxMPTHashesCount)
for i := range expected {
expected[i] = random.Uint256()
}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected, actual)
})
t.Run("good, too large chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxMPTHashesCount+1)
for i := range expected {
expected[i] = random.Uint256()
}
require.NoError(t, s.requestMPTNodes(p, expected))
require.Equal(t, expected[:payload.MaxMPTHashesCount], actual)
})
}
func TestRequestTx(t *testing.T) { func TestRequestTx(t *testing.T) {
s := startTestServer(t) s := startTestServer(t)
@ -912,7 +1046,10 @@ func TestTryInitStateSync(t *testing.T) {
t.Run("module already initialized", func(t *testing.T) { t.Run("module already initialized", func(t *testing.T) {
s := startTestServer(t) s := startTestServer(t)
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: true} ss := &fakechain.FakeStateSync{}
ss.IsActiveFlag.Store(true)
ss.IsInitializedFlag.Store(true)
s.stateSync = ss
s.tryInitStateSync() s.tryInitStateSync()
}) })
@ -930,12 +1067,14 @@ func TestTryInitStateSync(t *testing.T) {
s.peers[p] = true s.peers[p] = true
var expectedH uint32 = 8 // median peer var expectedH uint32 = 8 // median peer
s.stateSync = &fakechain.FakeStateSync{IsActiveFlag: true, IsInitializedFlag: false, InitFunc: func(h uint32) error { ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error {
if h != expectedH { if h != expectedH {
return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h) return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)
} }
return nil return nil
}} }}
ss.IsActiveFlag.Store(true)
s.stateSync = ss
s.tryInitStateSync() s.tryInitStateSync()
}) })
} }