diff --git a/pkg/network/blockqueue_test.go b/pkg/network/blockqueue_test.go index 85382c526..b796cf3f6 100644 --- a/pkg/network/blockqueue_test.go +++ b/pkg/network/blockqueue_test.go @@ -10,7 +10,7 @@ import ( ) func TestBlockQueue(t *testing.T) { - chain := &testChain{} + chain := newTestChain() // notice, it's not yet running bq := newBlockQueue(0, chain, zaptest.NewLogger(t), nil) blocks := make([]*block.Block, 11) diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index 93781c045..c9e8dc49c 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -67,6 +67,10 @@ func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *Defa return d } +func newDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) Discoverer { + return NewDefaultDiscovery(addrs, dt, ts) +} + // BackFill implements the Discoverer interface and will backfill the // the pool with the given addresses. func (d *DefaultDiscovery) BackFill(addrs ...string) { diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 8b829ddba..dc69d3d7b 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -2,6 +2,7 @@ package network import ( "errors" + "net" "sort" "sync/atomic" "testing" @@ -10,11 +11,19 @@ import ( "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + atomic2 "go.uber.org/atomic" ) type fakeTransp struct { retFalse int32 + started atomic2.Bool + closed atomic2.Bool dialCh chan string + addr string +} + +func newFakeTransp(s *Server) Transporter { + return &fakeTransp{} } func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error { @@ -26,14 +35,23 @@ func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error { return nil } func (ft *fakeTransp) Accept() { + if ft.started.Load() { + panic("started twice") + } + ft.addr = net.JoinHostPort("0.0.0.0", "42") + ft.started.Store(true) } func (ft *fakeTransp) Proto() string { return "" } func (ft *fakeTransp) Address() string { - return "" + return ft.addr } func (ft *fakeTransp) Close() { + if ft.closed.Load() { + panic("closed twice") + } + ft.closed.Store(true) } func TestDefaultDiscoverer(t *testing.T) { ts := &fakeTransp{} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 33aa4f2a6..894a5f03a 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -1,14 +1,17 @@ package network import ( + "errors" + "fmt" "math/big" - "math/rand" "net" - "strconv" + "sync" "sync/atomic" "testing" + "time" "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/state" @@ -21,45 +24,76 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" + "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) type testChain struct { + config.ProtocolConfiguration + *mempool.Pool + blocksCh []chan<- *block.Block blockheight uint32 + poolTx func(*transaction.Transaction) error + blocks map[util.Uint256]*block.Block + hdrHashes map[uint32]util.Uint256 + txs map[util.Uint256]*transaction.Transaction } -func (chain testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction { +func newTestChain() *testChain { + return &testChain{ + Pool: mempool.New(10), + poolTx: func(*transaction.Transaction) error { return nil }, + blocks: make(map[util.Uint256]*block.Block), + hdrHashes: make(map[uint32]util.Uint256), + txs: make(map[util.Uint256]*transaction.Transaction), + } +} + +func (chain *testChain) putBlock(b *block.Block) { + chain.blocks[b.Hash()] = b + chain.hdrHashes[b.Index] = b.Hash() + atomic.StoreUint32(&chain.blockheight, b.Index) +} +func (chain *testChain) putHeader(b *block.Block) { + chain.hdrHashes[b.Index] = b.Hash() +} + +func (chain *testChain) putTx(tx *transaction.Transaction) { + chain.txs[tx.Hash()] = tx +} + +func (chain *testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction { panic("TODO") } -func (chain testChain) GetConfig() config.ProtocolConfiguration { - panic("TODO") +func (chain *testChain) GetConfig() config.ProtocolConfiguration { + return chain.ProtocolConfiguration } -func (chain testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) { +func (chain *testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) { panic("TODO") } -func (chain testChain) FeePerByte() int64 { +func (chain *testChain) FeePerByte() int64 { panic("TODO") } -func (chain testChain) P2PSigExtensionsEnabled() bool { +func (chain *testChain) P2PSigExtensionsEnabled() bool { return false } -func (chain testChain) GetMaxBlockSystemFee() int64 { +func (chain *testChain) GetMaxBlockSystemFee() int64 { panic("TODO") } -func (chain testChain) GetMaxBlockSize() uint32 { +func (chain *testChain) GetMaxBlockSize() uint32 { panic("TODO") } -func (chain testChain) AddHeaders(...*block.Header) error { +func (chain *testChain) AddHeaders(...*block.Header) error { panic("TODO") } func (chain *testChain) AddBlock(block *block.Block) error { - if block.Index == chain.blockheight+1 { - atomic.StoreUint32(&chain.blockheight, block.Index) + if block.Index == atomic.LoadUint32(&chain.blockheight)+1 { + chain.putBlock(block) } return nil } @@ -72,148 +106,200 @@ func (chain *testChain) BlockHeight() uint32 { func (chain *testChain) Close() { panic("TODO") } -func (chain testChain) HeaderHeight() uint32 { - return 0 +func (chain *testChain) HeaderHeight() uint32 { + return atomic.LoadUint32(&chain.blockheight) } -func (chain testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) { +func (chain *testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) { panic("TODO") } -func (chain testChain) GetBlock(hash util.Uint256) (*block.Block, error) { +func (chain *testChain) GetBlock(hash util.Uint256) (*block.Block, error) { + if b, ok := chain.blocks[hash]; ok { + return b, nil + } + return nil, errors.New("not found") +} +func (chain *testChain) GetCommittee() (keys.PublicKeys, error) { panic("TODO") } -func (chain testChain) GetCommittee() (keys.PublicKeys, error) { +func (chain *testChain) GetContractState(hash util.Uint160) *state.Contract { panic("TODO") } -func (chain testChain) GetContractState(hash util.Uint160) *state.Contract { +func (chain *testChain) GetContractScriptHash(id int32) (util.Uint160, error) { panic("TODO") } -func (chain testChain) GetContractScriptHash(id int32) (util.Uint160, error) { +func (chain *testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) { panic("TODO") } -func (chain testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) { +func (chain *testChain) GetHeaderHash(n int) util.Uint256 { + return chain.hdrHashes[uint32(n)] +} +func (chain *testChain) GetHeader(hash util.Uint256) (*block.Header, error) { + b, err := chain.GetBlock(hash) + if err != nil { + return nil, err + } + return b.Header(), nil +} + +func (chain *testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) { panic("TODO") } -func (chain testChain) GetHeaderHash(int) util.Uint256 { +func (chain *testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error { + panic("TODO") +} +func (chain *testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances { + panic("TODO") +} +func (chain *testChain) GetValidators() ([]*keys.PublicKey, error) { + panic("TODO") +} +func (chain *testChain) GetStandByCommittee() keys.PublicKeys { + panic("TODO") +} +func (chain *testChain) GetStandByValidators() keys.PublicKeys { + panic("TODO") +} +func (chain *testChain) GetEnrollments() ([]state.Validator, error) { + panic("TODO") +} +func (chain *testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) { + panic("TODO") +} +func (chain *testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} +func (chain *testChain) GetStorageItem(id int32, key []byte) *state.StorageItem { + panic("TODO") +} +func (chain *testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM { + panic("TODO") +} +func (chain *testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) { + panic("TODO") +} +func (chain *testChain) CurrentHeaderHash() util.Uint256 { return util.Uint256{} } -func (chain testChain) GetHeader(hash util.Uint256) (*block.Header, error) { - panic("TODO") -} - -func (chain testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) { - panic("TODO") -} -func (chain testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error { - panic("TODO") -} -func (chain testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances { - panic("TODO") -} -func (chain testChain) GetValidators() ([]*keys.PublicKey, error) { - panic("TODO") -} -func (chain testChain) GetStandByCommittee() keys.PublicKeys { - panic("TODO") -} -func (chain testChain) GetStandByValidators() keys.PublicKeys { - panic("TODO") -} -func (chain testChain) GetEnrollments() ([]state.Validator, error) { - panic("TODO") -} -func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) { - panic("TODO") -} -func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { - panic("TODO") -} -func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem { - panic("TODO") -} -func (chain testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM { - panic("TODO") -} -func (chain testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) { - panic("TODO") -} -func (chain testChain) CurrentHeaderHash() util.Uint256 { +func (chain *testChain) CurrentBlockHash() util.Uint256 { return util.Uint256{} } -func (chain testChain) CurrentBlockHash() util.Uint256 { - return util.Uint256{} +func (chain *testChain) HasBlock(h util.Uint256) bool { + _, ok := chain.blocks[h] + return ok } -func (chain testChain) HasBlock(util.Uint256) bool { - return false +func (chain *testChain) HasTransaction(h util.Uint256) bool { + _, ok := chain.txs[h] + return ok } -func (chain testChain) HasTransaction(util.Uint256) bool { - return false +func (chain *testChain) GetTransaction(h util.Uint256) (*transaction.Transaction, uint32, error) { + if tx, ok := chain.txs[h]; ok { + return tx, 1, nil + } + return nil, 0, errors.New("not found") } -func (chain testChain) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) { + +func (chain *testChain) GetMemPool() *mempool.Pool { + return chain.Pool +} + +func (chain *testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) { panic("TODO") } -func (chain testChain) GetMemPool() *mempool.Pool { +func (chain *testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { panic("TODO") } -func (chain testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) { +func (chain *testChain) PoolTx(tx *transaction.Transaction, _ ...*mempool.Pool) error { + return chain.poolTx(tx) +} + +func (chain *testChain) SubscribeForBlocks(ch chan<- *block.Block) { + chain.blocksCh = append(chain.blocksCh, ch) +} +func (chain *testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) { + panic("TODO") +} +func (chain *testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) { + panic("TODO") +} +func (chain *testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) { panic("TODO") } -func (chain testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { +func (chain *testChain) VerifyTx(*transaction.Transaction) error { + panic("TODO") +} +func (*testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error { panic("TODO") } -func (chain testChain) PoolTx(*transaction.Transaction, ...*mempool.Pool) error { +func (chain *testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) { + for i, c := range chain.blocksCh { + if c == ch { + if i < len(chain.blocksCh) { + copy(chain.blocksCh[i:], chain.blocksCh[i+1:]) + } + chain.blocksCh = chain.blocksCh[:len(chain.blocksCh)] + } + } +} +func (chain *testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) { + panic("TODO") +} +func (chain *testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) { + panic("TODO") +} +func (chain *testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { panic("TODO") } -func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) { - panic("TODO") -} -func (chain testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) { - panic("TODO") -} -func (chain testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) { - panic("TODO") -} -func (chain testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) { - panic("TODO") +type testDiscovery struct { + sync.Mutex + bad []string + good []string + connected []string + unregistered []string + backfill []string } -func (chain testChain) VerifyTx(*transaction.Transaction) error { - panic("TODO") -} -func (testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error { - panic("TODO") -} +func newTestDiscovery([]string, time.Duration, Transporter) Discoverer { return new(testDiscovery) } -func (chain testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) { - panic("TODO") +func (d *testDiscovery) BackFill(addrs ...string) { + d.Lock() + defer d.Unlock() + d.backfill = append(d.backfill, addrs...) } -func (chain testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) { - panic("TODO") +func (d *testDiscovery) Close() {} +func (d *testDiscovery) PoolCount() int { return 0 } +func (d *testDiscovery) RegisterBadAddr(addr string) { + d.Lock() + defer d.Unlock() + d.bad = append(d.bad, addr) } -func (chain testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) { - panic("TODO") +func (d *testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} +func (d *testDiscovery) RegisterConnectedAddr(addr string) { + d.Lock() + defer d.Unlock() + d.connected = append(d.connected, addr) } -func (chain testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { - panic("TODO") +func (d *testDiscovery) UnregisterConnectedAddr(addr string) { + d.Lock() + defer d.Unlock() + d.unregistered = append(d.unregistered, addr) } - -type testDiscovery struct{} - -func (d testDiscovery) BackFill(addrs ...string) {} -func (d testDiscovery) Close() {} -func (d testDiscovery) PoolCount() int { return 0 } -func (d testDiscovery) RegisterBadAddr(string) {} -func (d testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} -func (d testDiscovery) RegisterConnectedAddr(string) {} -func (d testDiscovery) UnregisterConnectedAddr(string) {} -func (d testDiscovery) UnconnectedPeers() []string { return []string{} } -func (d testDiscovery) RequestRemote(n int) {} -func (d testDiscovery) BadPeers() []string { return []string{} } -func (d testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} } +func (d *testDiscovery) UnconnectedPeers() []string { + d.Lock() + defer d.Unlock() + return d.unregistered +} +func (d *testDiscovery) RequestRemote(n int) {} +func (d *testDiscovery) BadPeers() []string { + d.Lock() + defer d.Unlock() + return d.bad +} +func (d *testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} } var defaultMessageHandler = func(t *testing.T, msg *Message) {} @@ -228,6 +314,7 @@ type localPeer struct { messageHandler func(t *testing.T, msg *Message) pingSent int getAddrSent int + droppedWith atomic.Value } func newLocalPeer(t *testing.T, s *Server) *localPeer { @@ -246,8 +333,14 @@ func (p *localPeer) RemoteAddr() net.Addr { func (p *localPeer) PeerAddr() net.Addr { return &p.netaddr } -func (p *localPeer) StartProtocol() {} -func (p *localPeer) Disconnect(err error) {} +func (p *localPeer) StartProtocol() {} +func (p *localPeer) Disconnect(err error) { + if p.droppedWith.Load() == nil { + p.droppedWith.Store(err) + } + fmt.Println("peer dropped:", err) + p.server.unregister <- peerDrop{p, err} +} func (p *localPeer) EnqueueMessage(msg *Message) error { b, err := msg.Bytes() @@ -266,7 +359,7 @@ func (p *localPeer) EnqueueP2PPacket(m []byte) error { return p.EnqueueHPPacket(m) } func (p *localPeer) EnqueueHPPacket(m []byte) error { - msg := &Message{} + msg := &Message{Network: netmode.UnitTestNet} r := io.NewBinReaderFromBuf(m) err := msg.Decode(r) if err == nil { @@ -333,17 +426,8 @@ func (p *localPeer) CanProcessAddr() bool { } func newTestServer(t *testing.T, serverConfig ServerConfig) *Server { - s := &Server{ - ServerConfig: serverConfig, - chain: &testChain{}, - discovery: testDiscovery{}, - id: rand.Uint32(), - quit: make(chan struct{}), - register: make(chan Peer), - unregister: make(chan peerDrop), - peers: make(map[Peer]bool), - log: zaptest.NewLogger(t), - } - s.transport = NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log) + s, err := newServerFromConstructors(serverConfig, newTestChain(), zaptest.NewLogger(t), + newFakeTransp, newFakeConsensus, newTestDiscovery) + require.NoError(t, err) return s } diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index 58ffdb3ef..df2010b88 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -128,10 +128,10 @@ func TestEncodeDecodeAddr(t *testing.T) { func TestEncodeDecodeBlock(t *testing.T) { t.Run("good", func(t *testing.T) { - testEncodeDecode(t, CMDBlock, newDummyBlock(1)) + testEncodeDecode(t, CMDBlock, newDummyBlock(12, 1)) }) t.Run("invalid state root enabled setting", func(t *testing.T) { - expected := NewMessage(CMDBlock, newDummyBlock(1)) + expected := NewMessage(CMDBlock, newDummyBlock(31, 1)) expected.Network = netmode.UnitTestNet data, err := testserdes.Encode(expected) require.NoError(t, err) @@ -270,7 +270,7 @@ func TestInvalidMessages(t *testing.T) { require.NoError(t, err) }) t.Run("trimmed payload", func(t *testing.T) { - m := NewMessage(CMDBlock, newDummyBlock(0)) + m := NewMessage(CMDBlock, newDummyBlock(1, 0)) data, err := testserdes.Encode(m) require.NoError(t, err) data = data[:len(data)-1] @@ -288,8 +288,9 @@ func (f failSer) EncodeBinary(r *io.BinWriter) { func (failSer) DecodeBinary(w *io.BinReader) {} -func newDummyBlock(txCount int) *block.Block { +func newDummyBlock(height uint32, txCount int) *block.Block { b := block.New(netmode.UnitTestNet, false) + b.Index = height b.PrevHash = random.Uint256() b.Timestamp = rand.Uint64() b.Script.InvocationScript = random.Bytes(2) @@ -303,7 +304,7 @@ func newDummyBlock(txCount int) *block.Block { } func newDummyTx() *transaction.Transaction { - tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), int64(rand.Uint64()>>1)) + tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123) tx.Signers = []transaction.Signer{{Account: random.Uint160()}} tx.Size() tx.Hash() diff --git a/pkg/network/server.go b/pkg/network/server.go index ae77a016b..7fb73769b 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -96,6 +96,16 @@ func randomID() uint32 { // NewServer returns a new Server, initialized with the given configuration. func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger) (*Server, error) { + return newServerFromConstructors(config, chain, log, func(s *Server) Transporter { + return NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log) + }, consensus.NewService, newDefaultDiscovery) +} + +func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger, + newTransport func(*Server) Transporter, + newConsensus func(consensus.Config) (consensus.Service, error), + newDiscovery func([]string, time.Duration, Transporter) Discoverer, +) (*Server, error) { if log == nil { return nil, errors.New("logger is a required parameter") } @@ -120,7 +130,7 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo } }) - srv, err := consensus.NewService(consensus.Config{ + srv, err := newConsensus(consensus.Config{ Logger: log, Broadcast: s.handleNewPayload, Chain: chain, @@ -156,8 +166,8 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo s.AttemptConnPeers = defaultAttemptConnPeers } - s.transport = NewTCPTransport(s, net.JoinHostPort(config.Address, strconv.Itoa(int(config.Port))), s.log) - s.discovery = NewDefaultDiscovery( + s.transport = newTransport(s) + s.discovery = newDiscovery( s.Seeds, s.DialTimeout, s.transport, diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 4f2e56c33..de4321fd9 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -1,17 +1,186 @@ package network import ( + "errors" + "math/big" "net" "strconv" + atomic2 "sync/atomic" "testing" "time" + "github.com/nspcc-dev/neo-go/internal/random" + "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/config/netmode" + "github.com/nspcc-dev/neo-go/pkg/consensus" + "github.com/nspcc-dev/neo-go/pkg/core" + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/atomic" + "go.uber.org/zap/zaptest" ) +type fakeConsensus struct { + started atomic.Bool + stopped atomic.Bool + payloads []*consensus.Payload + txs []*transaction.Transaction +} + +var _ consensus.Service = (*fakeConsensus)(nil) + +func newFakeConsensus(c consensus.Config) (consensus.Service, error) { + return new(fakeConsensus), nil +} +func (f *fakeConsensus) Start() { f.started.Store(true) } +func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) } +func (f *fakeConsensus) OnPayload(p *consensus.Payload) { f.payloads = append(f.payloads, p) } +func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) } +func (f *fakeConsensus) GetPayload(h util.Uint256) *consensus.Payload { panic("implement me") } + +func TestNewServer(t *testing.T) { + bc := &testChain{} + s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery) + require.Error(t, err) + + t.Run("set defaults", func(t *testing.T) { + s = newTestServer(t, ServerConfig{MinPeers: -1}) + defer s.discovery.Close() + + require.True(t, s.ID() != 0) + require.Equal(t, defaultMinPeers, s.ServerConfig.MinPeers) + require.Equal(t, defaultMaxPeers, s.ServerConfig.MaxPeers) + require.Equal(t, defaultAttemptConnPeers, s.ServerConfig.AttemptConnPeers) + }) + t.Run("don't defaults", func(t *testing.T) { + cfg := ServerConfig{ + MinPeers: 1, + MaxPeers: 2, + AttemptConnPeers: 3, + } + s = newTestServer(t, cfg) + defer s.discovery.Close() + + require.True(t, s.ID() != 0) + require.Equal(t, 1, s.ServerConfig.MinPeers) + require.Equal(t, 2, s.ServerConfig.MaxPeers) + require.Equal(t, 3, s.ServerConfig.AttemptConnPeers) + }) + t.Run("consensus error is not dropped", func(t *testing.T) { + errConsensus := errors.New("can't create consensus") + _, err = newServerFromConstructors(ServerConfig{MinPeers: -1}, bc, zaptest.NewLogger(t), newFakeTransp, + func(consensus.Config) (consensus.Service, error) { return nil, errConsensus }, + newTestDiscovery) + require.True(t, errors.Is(err, errConsensus), "got: %#v", err) + }) +} + +func startWithChannel(s *Server) chan error { + ch := make(chan error) + go func() { + s.Start(ch) + close(ch) + }() + return ch +} + +func TestServerStartAndShutdown(t *testing.T) { + t.Run("no consensus", func(t *testing.T) { + s := newTestServer(t, ServerConfig{}) + + ch := startWithChannel(s) + p := newLocalPeer(t, s) + s.register <- p + require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10) + + assert.True(t, s.transport.(*fakeTransp).started.Load()) + assert.False(t, s.consensus.(*fakeConsensus).started.Load()) + + s.Shutdown() + <-ch + + require.True(t, s.transport.(*fakeTransp).closed.Load()) + require.False(t, s.consensus.(*fakeConsensus).stopped.Load()) + err, ok := p.droppedWith.Load().(error) + require.True(t, ok) + require.True(t, errors.Is(err, errServerShutdown)) + }) + t.Run("with consensus", func(t *testing.T) { + s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)}) + + ch := startWithChannel(s) + p := newLocalPeer(t, s) + s.register <- p + + assert.True(t, s.consensus.(*fakeConsensus).started.Load()) + + s.Shutdown() + <-ch + + require.True(t, s.consensus.(*fakeConsensus).stopped.Load()) + }) +} + +func TestServerRegisterPeer(t *testing.T) { + const peerCount = 3 + + s := newTestServer(t, ServerConfig{MaxPeers: 2}) + ps := make([]*localPeer, peerCount) + for i := range ps { + ps[i] = newLocalPeer(t, s) + ps[i].netaddr.Port = i + 1 + } + + ch := startWithChannel(s) + defer func() { + s.Shutdown() + <-ch + }() + + s.register <- ps[0] + require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10) + + s.register <- ps[1] + require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10) + + require.Equal(t, 0, len(s.discovery.UnconnectedPeers())) + s.register <- ps[2] + require.Eventually(t, func() bool { return len(s.discovery.UnconnectedPeers()) > 0 }, time.Second, time.Millisecond*100) + + index := -1 + addrs := s.discovery.UnconnectedPeers() + for _, addr := range addrs { + for j := range ps { + if ps[j].PeerAddr().String() == addr { + index = j + break + } + } + } + require.True(t, index >= 0) + err, ok := ps[index].droppedWith.Load().(error) + require.True(t, ok) + require.True(t, errors.Is(err, errMaxPeers)) + + index = (index + 1) % peerCount + s.unregister <- peerDrop{ps[index], errIdenticalID} + require.Eventually(t, func() bool { + bad := s.BadPeers() + for i := range bad { + if bad[i] == ps[index].PeerAddr().String() { + return true + } + } + return false + }, time.Second, time.Millisecond*50) + +} + func TestGetBlocksByIndex(t *testing.T) { s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) ps := make([]*localPeer, 10) @@ -72,8 +241,7 @@ func TestSendVersion(t *testing.T) { p = newLocalPeer(t, s) ) // we need to set listener at least to handle dynamic port correctly - go s.transport.Accept() - require.Eventually(t, func() bool { return s.transport.Address() != "" }, time.Second, 10*time.Millisecond) + s.transport.Accept() p.messageHandler = func(t *testing.T, msg *Message) { // listener is already set, so Address() gives us proper address with port _, p, err := net.SplitHostPort(s.transport.Address()) @@ -136,7 +304,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { // invalid version and disconnects the peer. func TestServerNotSendsVerack(t *testing.T) { var ( - s = newTestServer(t, ServerConfig{Net: 56753}) + s = newTestServer(t, ServerConfig{MaxPeers: 10, Net: 56753}) p = newLocalPeer(t, s) p2 = newLocalPeer(t, s) ) @@ -196,3 +364,402 @@ func TestServerNotSendsVerack(t *testing.T) { assert.NotNil(t, err) require.Equal(t, errAlreadyConnected, err) } + +func (s *Server) testHandleMessage(t *testing.T, p Peer, cmd CommandType, pl payload.Payload) *Server { + if p == nil { + p = newLocalPeer(t, s) + p.(*localPeer).handshaked = true + } + msg := NewMessage(cmd, pl) + msg.Network = netmode.UnitTestNet + require.NoError(t, s.handleMessage(p, msg)) + return s +} + +func startTestServer(t *testing.T) (*Server, func()) { + s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) + ch := startWithChannel(s) + return s, func() { + s.Shutdown() + <-ch + } +} + +func TestBlock(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + atomic2.StoreUint32(&s.chain.(*testChain).blockheight, 12344) + require.Equal(t, uint32(12344), s.chain.BlockHeight()) + + b := block.New(netmode.UnitTestNet, false) + b.Index = 12345 + s.testHandleMessage(t, nil, CMDBlock, b) + require.Eventually(t, func() bool { return s.chain.BlockHeight() == 12345 }, time.Second, time.Millisecond*500) +} + +func TestConsensus(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + pl := consensus.NewPayload(netmode.UnitTestNet, false) + s.testHandleMessage(t, nil, CMDConsensus, pl) + require.Contains(t, s.consensus.(*fakeConsensus).payloads, pl) +} + +func TestTransaction(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + t.Run("good", func(t *testing.T) { + tx := newDummyTx() + p := newLocalPeer(t, s) + p.isFullNode = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDInv { + inv := msg.Payload.(*payload.Inventory) + require.Equal(t, payload.TXType, inv.Type) + require.Equal(t, []util.Uint256{tx.Hash()}, inv.Hashes) + } + } + s.register <- p + + s.testHandleMessage(t, nil, CMDTX, tx) + require.Contains(t, s.consensus.(*fakeConsensus).txs, tx) + }) + t.Run("bad", func(t *testing.T) { + tx := newDummyTx() + s.chain.(*testChain).poolTx = func(*transaction.Transaction) error { return core.ErrInsufficientFunds } + s.testHandleMessage(t, nil, CMDTX, tx) + for _, ftx := range s.consensus.(*fakeConsensus).txs { + require.NotEqual(t, ftx, tx) + } + }) +} + +func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType, hs, notFound []util.Uint256, found payload.Payload) { + var recvResponse atomic.Bool + var recvNotFound atomic.Bool + + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + switch msg.Command { + case CMDTX, CMDBlock, CMDConsensus: + require.Equal(t, found, msg.Payload) + recvResponse.Store(true) + case CMDNotFound: + require.Equal(t, notFound, msg.Payload.(*payload.Inventory).Hashes) + recvNotFound.Store(true) + } + } + + s.testHandleMessage(t, p, CMDGetData, payload.NewInventory(invType, hs)) + + require.Eventually(t, func() bool { return recvResponse.Load() }, time.Second, time.Millisecond) + require.Eventually(t, func() bool { return recvNotFound.Load() }, time.Second, time.Millisecond) +} + +func TestGetData(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + t.Run("block", func(t *testing.T) { + b := newDummyBlock(2, 0) + hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()} + s.chain.(*testChain).putBlock(b) + notFound := []util.Uint256{hs[0], hs[2]} + s.testHandleGetData(t, payload.BlockType, hs, notFound, b) + }) + t.Run("transaction", func(t *testing.T) { + tx := newDummyTx() + hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()} + s.chain.(*testChain).putTx(tx) + notFound := []util.Uint256{hs[0], hs[2]} + s.testHandleGetData(t, payload.TXType, hs, notFound, tx) + }) +} + +func initGetBlocksTest(t *testing.T) (*Server, func(), []*block.Block) { + s, shutdown := startTestServer(t) + + var blocks []*block.Block + for i := uint32(12); i <= 15; i++ { + b := newDummyBlock(i, 3) + s.chain.(*testChain).putBlock(b) + blocks = append(blocks, b) + } + return s, shutdown, blocks +} + +func TestGetBlocks(t *testing.T) { + s, shutdown, blocks := initGetBlocksTest(t) + defer shutdown() + + expected := make([]util.Uint256, len(blocks)) + for i := range blocks { + expected[i] = blocks[i].Hash() + } + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDInv { + actual = msg.Payload.(*payload.Inventory).Hashes + } + } + + t.Run("2", func(t *testing.T) { + s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: 2}) + require.Equal(t, expected[1:3], actual) + }) + t.Run("-1", func(t *testing.T) { + s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: -1}) + require.Equal(t, expected[1:], actual) + }) + t.Run("invalid start", func(t *testing.T) { + msg := NewMessage(CMDGetBlocks, &payload.GetBlocks{HashStart: util.Uint256{}, Count: -1}) + msg.Network = netmode.UnitTestNet + require.Error(t, s.handleMessage(p, msg)) + }) +} + +func TestGetBlockByIndex(t *testing.T) { + s, shutdown, blocks := initGetBlocksTest(t) + defer shutdown() + + var expected []*block.Block + var actual []*block.Block + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDBlock { + actual = append(actual, msg.Payload.(*block.Block)) + if len(actual) == len(expected) { + require.Equal(t, expected, actual) + } + } + } + + t.Run("2", func(t *testing.T) { + actual = nil + expected = blocks[:2] + s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2}) + }) + t.Run("-1", func(t *testing.T) { + actual = nil + expected = blocks + s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1}) + }) + t.Run("-1, last header", func(t *testing.T) { + s.chain.(*testChain).putHeader(newDummyBlock(16, 2)) + actual = nil + expected = blocks + s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1}) + }) +} + +func TestGetHeaders(t *testing.T) { + s, shutdown, blocks := initGetBlocksTest(t) + defer shutdown() + + expected := make([]*block.Header, len(blocks)) + for i := range blocks { + expected[i] = blocks[i].Header() + } + + var actual *payload.Headers + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDHeaders { + actual = msg.Payload.(*payload.Headers) + } + } + + t.Run("2", func(t *testing.T) { + actual = nil + s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2}) + require.Equal(t, expected[:2], actual.Hdrs) + }) + t.Run("more, than we have", func(t *testing.T) { + actual = nil + s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 10}) + require.Equal(t, expected, actual.Hdrs) + }) + t.Run("-1", func(t *testing.T) { + actual = nil + s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1}) + require.Equal(t, expected, actual.Hdrs) + }) + t.Run("no headers", func(t *testing.T) { + actual = nil + s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: 123, Count: -1}) + require.Nil(t, actual) + }) +} + +func TestInv(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDGetData { + actual = msg.Payload.(*payload.Inventory).Hashes + } + } + + t.Run("blocks", func(t *testing.T) { + b := newDummyBlock(10, 3) + s.chain.(*testChain).putBlock(b) + hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()} + s.testHandleMessage(t, p, CMDInv, &payload.Inventory{ + Type: payload.BlockType, + Hashes: hs, + }) + require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual) + }) + t.Run("transaction", func(t *testing.T) { + tx := newDummyTx() + s.chain.(*testChain).putTx(tx) + hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()} + s.testHandleMessage(t, p, CMDInv, &payload.Inventory{ + Type: payload.TXType, + Hashes: hs, + }) + require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual) + }) +} + +func TestRequestTx(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDGetData { + actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...) + } + } + s.register <- p + s.register <- p // ensure previous send was handled + + t.Run("no hashes, no message", func(t *testing.T) { + actual = nil + s.requestTx() + require.Nil(t, actual) + }) + t.Run("good, small", func(t *testing.T) { + actual = nil + expected := []util.Uint256{random.Uint256(), random.Uint256()} + s.requestTx(expected...) + require.Equal(t, expected, actual) + }) + t.Run("good, exactly one chunk", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxHashesCount) + for i := range expected { + expected[i] = random.Uint256() + } + s.requestTx(expected...) + require.Equal(t, expected, actual) + }) + t.Run("good, multiple chunks", func(t *testing.T) { + actual = nil + expected := make([]util.Uint256, payload.MaxHashesCount*2+payload.MaxHashesCount/2) + for i := range expected { + expected[i] = random.Uint256() + } + s.requestTx(expected...) + require.Equal(t, expected, actual) + }) +} + +func TestAddrs(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + ips := make([][16]byte, 4) + copy(ips[0][:], net.IPv4(1, 2, 3, 4)) + copy(ips[1][:], net.IPv4(7, 8, 9, 0)) + for i := range ips[2] { + ips[2][i] = byte(i) + } + + p := newLocalPeer(t, s) + p.handshaked = true + p.getAddrSent = 1 + pl := &payload.AddressList{ + Addrs: []*payload.AddressAndTime{ + { + IP: ips[0], + Capabilities: capability.Capabilities{{ + Type: capability.TCPServer, + Data: &capability.Server{Port: 12}, + }}, + }, + { + IP: ips[1], + Capabilities: capability.Capabilities{}, + }, + { + IP: ips[2], + Capabilities: capability.Capabilities{{ + Type: capability.TCPServer, + Data: &capability.Server{Port: 42}, + }}, + }, + }, + } + s.testHandleMessage(t, p, CMDAddr, pl) + + addrs := s.discovery.(*testDiscovery).backfill + require.Equal(t, 2, len(addrs)) + require.Equal(t, "1.2.3.4:12", addrs[0]) + require.Equal(t, net.JoinHostPort(net.IP(ips[2][:]).String(), "42"), addrs[1]) + + t.Run("CMDAddr not requested", func(t *testing.T) { + msg := NewMessage(CMDAddr, pl) + msg.Network = netmode.UnitTestNet + require.Error(t, s.handleMessage(p, msg)) + }) +} + +type feerStub struct { + blockHeight uint32 +} + +func (f feerStub) FeePerByte() int64 { return 1 } +func (f feerStub) GetUtilityTokenBalance(util.Uint160) *big.Int { return big.NewInt(100000000) } +func (f feerStub) BlockHeight() uint32 { return f.blockHeight } +func (f feerStub) P2PSigExtensionsEnabled() bool { return false } + +func TestMemPool(t *testing.T) { + s, shutdown := startTestServer(t) + defer shutdown() + + var actual []util.Uint256 + p := newLocalPeer(t, s) + p.handshaked = true + p.messageHandler = func(t *testing.T, msg *Message) { + if msg.Command == CMDInv { + actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...) + } + } + + bc := s.chain.(*testChain) + expected := make([]util.Uint256, 4) + for i := range expected { + tx := newDummyTx() + require.NoError(t, bc.Pool.Add(tx, &feerStub{blockHeight: 10})) + expected[i] = tx.Hash() + } + + s.testHandleMessage(t, p, CMDMempool, payload.NullPayload{}) + require.ElementsMatch(t, expected, actual) +}