diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index 426972c6b..6d67f0a6c 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -55,6 +55,15 @@ type FakeStateSync struct { // NewFakeChain returns new FakeChain structure. func NewFakeChain() *FakeChain { + return NewFakeChainWithCustomCfg(nil) +} + +// NewFakeChainWithCustomCfg returns new FakeChain structure with specified protocol configuration. +func NewFakeChainWithCustomCfg(protocolCfg func(c *config.ProtocolConfiguration)) *FakeChain { + cfg := config.ProtocolConfiguration{Magic: netmode.UnitTestNet, P2PNotaryRequestPayloadPoolSize: 10} + if protocolCfg != nil { + protocolCfg(&cfg) + } return &FakeChain{ Pool: mempool.New(10, 0, false), PoolTxF: func(*transaction.Transaction) error { return nil }, @@ -62,7 +71,7 @@ func NewFakeChain() *FakeChain { blocks: make(map[util.Uint256]*block.Block), hdrHashes: make(map[uint32]util.Uint256), txs: make(map[util.Uint256]*transaction.Transaction), - ProtocolConfiguration: config.ProtocolConfiguration{Magic: netmode.UnitTestNet, P2PNotaryRequestPayloadPoolSize: 10}, + ProtocolConfiguration: cfg, } } diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 4a1779a78..23c399c91 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/nspcc-dev/neo-go/internal/fakechain" + "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" @@ -187,7 +188,11 @@ func (p *localPeer) CanProcessAddr() bool { } func newTestServer(t *testing.T, serverConfig ServerConfig) *Server { - s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChain(), zaptest.NewLogger(t), + return newTestServerWithCustomCfg(t, serverConfig, nil) +} + +func newTestServerWithCustomCfg(t *testing.T, serverConfig ServerConfig, protocolCfg func(*config.ProtocolConfiguration)) *Server { + s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChainWithCustomCfg(protocolCfg), zaptest.NewLogger(t), newFakeTransp, newFakeConsensus, newTestDiscovery) require.NoError(t, err) t.Cleanup(s.discovery.Close) diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 69dc0d86b..e4c109c9f 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -379,8 +379,14 @@ func (s *Server) testHandleMessage(t *testing.T, p Peer, cmd CommandType, pl pay return s } -func startTestServer(t *testing.T) *Server { - s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) +func startTestServer(t *testing.T, protocolCfg ...func(*config.ProtocolConfiguration)) *Server { + var s *Server + srvCfg := ServerConfig{Port: 0, UserAgent: "/test/"} + if protocolCfg != nil { + s = newTestServerWithCustomCfg(t, srvCfg, protocolCfg[0]) + } else { + s = newTestServer(t, srvCfg) + } ch := startWithChannel(s) t.Cleanup(func() { s.Shutdown() @@ -750,9 +756,10 @@ func TestHandleGetMPTData(t *testing.T) { }) t.Run("KeepOnlyLatestState on", func(t *testing.T) { - s := startTestServer(t) - s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true - s.chain.(*fakechain.FakeChain).KeepOnlyLatestState = true + s := startTestServer(t, func(c *config.ProtocolConfiguration) { + c.P2PStateExchangeExtensions = true + c.KeepOnlyLatestState = true + }) p := newLocalPeer(t, s) p.handshaked = true msg := NewMessage(CMDGetMPTData, &payload.MPTInventory{ @@ -762,8 +769,9 @@ func TestHandleGetMPTData(t *testing.T) { }) t.Run("good", func(t *testing.T) { - s := startTestServer(t) - s.chain.(*fakechain.FakeChain).P2PStateExchangeExtensions = true + s := startTestServer(t, func(c *config.ProtocolConfiguration) { + c.P2PStateExchangeExtensions = true + }) var recvResponse atomic.Bool r1 := random.Uint256() r2 := random.Uint256() @@ -1059,14 +1067,15 @@ func TestTryInitStateSync(t *testing.T) { p := newLocalPeer(t, s) p.handshaked = true p.lastBlockIndex = h - s.peers[p] = true + s.register <- p } p := newLocalPeer(t, s) p.handshaked = false // one disconnected peer to check it won't be taken into attention p.lastBlockIndex = 5 - s.peers[p] = true - var expectedH uint32 = 8 // median peer + s.register <- p + require.Eventually(t, func() bool { return 7 == s.PeerCount() }, time.Second, time.Millisecond*10) + var expectedH uint32 = 8 // median peer ss := &fakechain.FakeStateSync{InitFunc: func(h uint32) error { if h != expectedH { return fmt.Errorf("invalid height: expected %d, got %d", expectedH, h)