Merge pull request #1595 from nspcc-dev/tests/network

network: add tests for most of the commands
This commit is contained in:
Roman Khimov 2020-12-09 15:31:57 +03:00 committed by GitHub
commit 3d0ed6eac3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 831 additions and 144 deletions

View file

@ -10,7 +10,7 @@ import (
)
func TestBlockQueue(t *testing.T) {
chain := &testChain{}
chain := newTestChain()
// notice, it's not yet running
bq := newBlockQueue(0, chain, zaptest.NewLogger(t), nil)
blocks := make([]*block.Block, 11)

View file

@ -67,6 +67,10 @@ func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *Defa
return d
}
func newDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) Discoverer {
return NewDefaultDiscovery(addrs, dt, ts)
}
// BackFill implements the Discoverer interface and will backfill the
// the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) {

View file

@ -2,6 +2,7 @@ package network
import (
"errors"
"net"
"sort"
"sync/atomic"
"testing"
@ -10,11 +11,19 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
atomic2 "go.uber.org/atomic"
)
type fakeTransp struct {
retFalse int32
started atomic2.Bool
closed atomic2.Bool
dialCh chan string
addr string
}
func newFakeTransp(s *Server) Transporter {
return &fakeTransp{}
}
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
@ -26,14 +35,23 @@ func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
return nil
}
func (ft *fakeTransp) Accept() {
if ft.started.Load() {
panic("started twice")
}
ft.addr = net.JoinHostPort("0.0.0.0", "42")
ft.started.Store(true)
}
func (ft *fakeTransp) Proto() string {
return ""
}
func (ft *fakeTransp) Address() string {
return ""
return ft.addr
}
func (ft *fakeTransp) Close() {
if ft.closed.Load() {
panic("closed twice")
}
ft.closed.Store(true)
}
func TestDefaultDiscoverer(t *testing.T) {
ts := &fakeTransp{}

View file

@ -1,14 +1,17 @@
package network
import (
"errors"
"fmt"
"math/big"
"math/rand"
"net"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/state"
@ -21,45 +24,76 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
)
type testChain struct {
config.ProtocolConfiguration
*mempool.Pool
blocksCh []chan<- *block.Block
blockheight uint32
poolTx func(*transaction.Transaction) error
blocks map[util.Uint256]*block.Block
hdrHashes map[uint32]util.Uint256
txs map[util.Uint256]*transaction.Transaction
}
func (chain testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction {
func newTestChain() *testChain {
return &testChain{
Pool: mempool.New(10),
poolTx: func(*transaction.Transaction) error { return nil },
blocks: make(map[util.Uint256]*block.Block),
hdrHashes: make(map[uint32]util.Uint256),
txs: make(map[util.Uint256]*transaction.Transaction),
}
}
func (chain *testChain) putBlock(b *block.Block) {
chain.blocks[b.Hash()] = b
chain.hdrHashes[b.Index] = b.Hash()
atomic.StoreUint32(&chain.blockheight, b.Index)
}
func (chain *testChain) putHeader(b *block.Block) {
chain.hdrHashes[b.Index] = b.Hash()
}
func (chain *testChain) putTx(tx *transaction.Transaction) {
chain.txs[tx.Hash()] = tx
}
func (chain *testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction {
panic("TODO")
}
func (chain testChain) GetConfig() config.ProtocolConfiguration {
panic("TODO")
func (chain *testChain) GetConfig() config.ProtocolConfiguration {
return chain.ProtocolConfiguration
}
func (chain testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) {
func (chain *testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) {
panic("TODO")
}
func (chain testChain) FeePerByte() int64 {
func (chain *testChain) FeePerByte() int64 {
panic("TODO")
}
func (chain testChain) P2PSigExtensionsEnabled() bool {
func (chain *testChain) P2PSigExtensionsEnabled() bool {
return false
}
func (chain testChain) GetMaxBlockSystemFee() int64 {
func (chain *testChain) GetMaxBlockSystemFee() int64 {
panic("TODO")
}
func (chain testChain) GetMaxBlockSize() uint32 {
func (chain *testChain) GetMaxBlockSize() uint32 {
panic("TODO")
}
func (chain testChain) AddHeaders(...*block.Header) error {
func (chain *testChain) AddHeaders(...*block.Header) error {
panic("TODO")
}
func (chain *testChain) AddBlock(block *block.Block) error {
if block.Index == chain.blockheight+1 {
atomic.StoreUint32(&chain.blockheight, block.Index)
if block.Index == atomic.LoadUint32(&chain.blockheight)+1 {
chain.putBlock(block)
}
return nil
}
@ -72,148 +106,200 @@ func (chain *testChain) BlockHeight() uint32 {
func (chain *testChain) Close() {
panic("TODO")
}
func (chain testChain) HeaderHeight() uint32 {
return 0
func (chain *testChain) HeaderHeight() uint32 {
return atomic.LoadUint32(&chain.blockheight)
}
func (chain testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) {
func (chain *testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) {
panic("TODO")
}
func (chain testChain) GetBlock(hash util.Uint256) (*block.Block, error) {
func (chain *testChain) GetBlock(hash util.Uint256) (*block.Block, error) {
if b, ok := chain.blocks[hash]; ok {
return b, nil
}
return nil, errors.New("not found")
}
func (chain *testChain) GetCommittee() (keys.PublicKeys, error) {
panic("TODO")
}
func (chain testChain) GetCommittee() (keys.PublicKeys, error) {
func (chain *testChain) GetContractState(hash util.Uint160) *state.Contract {
panic("TODO")
}
func (chain testChain) GetContractState(hash util.Uint160) *state.Contract {
func (chain *testChain) GetContractScriptHash(id int32) (util.Uint160, error) {
panic("TODO")
}
func (chain testChain) GetContractScriptHash(id int32) (util.Uint160, error) {
func (chain *testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) {
panic("TODO")
}
func (chain testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) {
func (chain *testChain) GetHeaderHash(n int) util.Uint256 {
return chain.hdrHashes[uint32(n)]
}
func (chain *testChain) GetHeader(hash util.Uint256) (*block.Header, error) {
b, err := chain.GetBlock(hash)
if err != nil {
return nil, err
}
return b.Header(), nil
}
func (chain *testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain testChain) GetHeaderHash(int) util.Uint256 {
func (chain *testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error {
panic("TODO")
}
func (chain *testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances {
panic("TODO")
}
func (chain *testChain) GetValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain *testChain) GetStandByCommittee() keys.PublicKeys {
panic("TODO")
}
func (chain *testChain) GetStandByValidators() keys.PublicKeys {
panic("TODO")
}
func (chain *testChain) GetEnrollments() ([]state.Validator, error) {
panic("TODO")
}
func (chain *testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
panic("TODO")
}
func (chain *testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
panic("TODO")
}
func (chain *testChain) GetStorageItem(id int32, key []byte) *state.StorageItem {
panic("TODO")
}
func (chain *testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM {
panic("TODO")
}
func (chain *testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) {
panic("TODO")
}
func (chain *testChain) CurrentHeaderHash() util.Uint256 {
return util.Uint256{}
}
func (chain testChain) GetHeader(hash util.Uint256) (*block.Header, error) {
panic("TODO")
}
func (chain testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error {
panic("TODO")
}
func (chain testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances {
panic("TODO")
}
func (chain testChain) GetValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain testChain) GetStandByCommittee() keys.PublicKeys {
panic("TODO")
}
func (chain testChain) GetStandByValidators() keys.PublicKeys {
panic("TODO")
}
func (chain testChain) GetEnrollments() ([]state.Validator, error) {
panic("TODO")
}
func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
panic("TODO")
}
func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
panic("TODO")
}
func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem {
panic("TODO")
}
func (chain testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM {
panic("TODO")
}
func (chain testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) {
panic("TODO")
}
func (chain testChain) CurrentHeaderHash() util.Uint256 {
func (chain *testChain) CurrentBlockHash() util.Uint256 {
return util.Uint256{}
}
func (chain testChain) CurrentBlockHash() util.Uint256 {
return util.Uint256{}
func (chain *testChain) HasBlock(h util.Uint256) bool {
_, ok := chain.blocks[h]
return ok
}
func (chain testChain) HasBlock(util.Uint256) bool {
return false
func (chain *testChain) HasTransaction(h util.Uint256) bool {
_, ok := chain.txs[h]
return ok
}
func (chain testChain) HasTransaction(util.Uint256) bool {
return false
func (chain *testChain) GetTransaction(h util.Uint256) (*transaction.Transaction, uint32, error) {
if tx, ok := chain.txs[h]; ok {
return tx, 1, nil
}
func (chain testChain) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) {
return nil, 0, errors.New("not found")
}
func (chain *testChain) GetMemPool() *mempool.Pool {
return chain.Pool
}
func (chain *testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) {
panic("TODO")
}
func (chain testChain) GetMemPool() *mempool.Pool {
func (chain *testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int {
panic("TODO")
}
func (chain testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) {
func (chain *testChain) PoolTx(tx *transaction.Transaction, _ ...*mempool.Pool) error {
return chain.poolTx(tx)
}
func (chain *testChain) SubscribeForBlocks(ch chan<- *block.Block) {
chain.blocksCh = append(chain.blocksCh, ch)
}
func (chain *testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
}
func (chain *testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
}
func (chain *testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
}
func (chain testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int {
func (chain *testChain) VerifyTx(*transaction.Transaction) error {
panic("TODO")
}
func (*testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error {
panic("TODO")
}
func (chain testChain) PoolTx(*transaction.Transaction, ...*mempool.Pool) error {
func (chain *testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) {
for i, c := range chain.blocksCh {
if c == ch {
if i < len(chain.blocksCh) {
copy(chain.blocksCh[i:], chain.blocksCh[i+1:])
}
chain.blocksCh = chain.blocksCh[:len(chain.blocksCh)]
}
}
}
func (chain *testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
}
func (chain *testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
}
func (chain *testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
}
func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) {
panic("TODO")
}
func (chain testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
}
func (chain testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
}
func (chain testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
type testDiscovery struct {
sync.Mutex
bad []string
good []string
connected []string
unregistered []string
backfill []string
}
func (chain testChain) VerifyTx(*transaction.Transaction) error {
panic("TODO")
}
func (testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error {
panic("TODO")
}
func newTestDiscovery([]string, time.Duration, Transporter) Discoverer { return new(testDiscovery) }
func (chain testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) {
panic("TODO")
func (d *testDiscovery) BackFill(addrs ...string) {
d.Lock()
defer d.Unlock()
d.backfill = append(d.backfill, addrs...)
}
func (chain testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
func (d *testDiscovery) Close() {}
func (d *testDiscovery) PoolCount() int { return 0 }
func (d *testDiscovery) RegisterBadAddr(addr string) {
d.Lock()
defer d.Unlock()
d.bad = append(d.bad, addr)
}
func (chain testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
func (d *testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {}
func (d *testDiscovery) RegisterConnectedAddr(addr string) {
d.Lock()
defer d.Unlock()
d.connected = append(d.connected, addr)
}
func (chain testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
func (d *testDiscovery) UnregisterConnectedAddr(addr string) {
d.Lock()
defer d.Unlock()
d.unregistered = append(d.unregistered, addr)
}
type testDiscovery struct{}
func (d testDiscovery) BackFill(addrs ...string) {}
func (d testDiscovery) Close() {}
func (d testDiscovery) PoolCount() int { return 0 }
func (d testDiscovery) RegisterBadAddr(string) {}
func (d testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {}
func (d testDiscovery) RegisterConnectedAddr(string) {}
func (d testDiscovery) UnregisterConnectedAddr(string) {}
func (d testDiscovery) UnconnectedPeers() []string { return []string{} }
func (d testDiscovery) RequestRemote(n int) {}
func (d testDiscovery) BadPeers() []string { return []string{} }
func (d testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} }
func (d *testDiscovery) UnconnectedPeers() []string {
d.Lock()
defer d.Unlock()
return d.unregistered
}
func (d *testDiscovery) RequestRemote(n int) {}
func (d *testDiscovery) BadPeers() []string {
d.Lock()
defer d.Unlock()
return d.bad
}
func (d *testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} }
var defaultMessageHandler = func(t *testing.T, msg *Message) {}
@ -228,6 +314,7 @@ type localPeer struct {
messageHandler func(t *testing.T, msg *Message)
pingSent int
getAddrSent int
droppedWith atomic.Value
}
func newLocalPeer(t *testing.T, s *Server) *localPeer {
@ -247,7 +334,13 @@ func (p *localPeer) PeerAddr() net.Addr {
return &p.netaddr
}
func (p *localPeer) StartProtocol() {}
func (p *localPeer) Disconnect(err error) {}
func (p *localPeer) Disconnect(err error) {
if p.droppedWith.Load() == nil {
p.droppedWith.Store(err)
}
fmt.Println("peer dropped:", err)
p.server.unregister <- peerDrop{p, err}
}
func (p *localPeer) EnqueueMessage(msg *Message) error {
b, err := msg.Bytes()
@ -266,7 +359,7 @@ func (p *localPeer) EnqueueP2PPacket(m []byte) error {
return p.EnqueueHPPacket(m)
}
func (p *localPeer) EnqueueHPPacket(m []byte) error {
msg := &Message{}
msg := &Message{Network: netmode.UnitTestNet}
r := io.NewBinReaderFromBuf(m)
err := msg.Decode(r)
if err == nil {
@ -333,17 +426,8 @@ func (p *localPeer) CanProcessAddr() bool {
}
func newTestServer(t *testing.T, serverConfig ServerConfig) *Server {
s := &Server{
ServerConfig: serverConfig,
chain: &testChain{},
discovery: testDiscovery{},
id: rand.Uint32(),
quit: make(chan struct{}),
register: make(chan Peer),
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
log: zaptest.NewLogger(t),
}
s.transport = NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log)
s, err := newServerFromConstructors(serverConfig, newTestChain(), zaptest.NewLogger(t),
newFakeTransp, newFakeConsensus, newTestDiscovery)
require.NoError(t, err)
return s
}

View file

@ -128,10 +128,10 @@ func TestEncodeDecodeAddr(t *testing.T) {
func TestEncodeDecodeBlock(t *testing.T) {
t.Run("good", func(t *testing.T) {
testEncodeDecode(t, CMDBlock, newDummyBlock(1))
testEncodeDecode(t, CMDBlock, newDummyBlock(12, 1))
})
t.Run("invalid state root enabled setting", func(t *testing.T) {
expected := NewMessage(CMDBlock, newDummyBlock(1))
expected := NewMessage(CMDBlock, newDummyBlock(31, 1))
expected.Network = netmode.UnitTestNet
data, err := testserdes.Encode(expected)
require.NoError(t, err)
@ -270,7 +270,7 @@ func TestInvalidMessages(t *testing.T) {
require.NoError(t, err)
})
t.Run("trimmed payload", func(t *testing.T) {
m := NewMessage(CMDBlock, newDummyBlock(0))
m := NewMessage(CMDBlock, newDummyBlock(1, 0))
data, err := testserdes.Encode(m)
require.NoError(t, err)
data = data[:len(data)-1]
@ -288,8 +288,9 @@ func (f failSer) EncodeBinary(r *io.BinWriter) {
func (failSer) DecodeBinary(w *io.BinReader) {}
func newDummyBlock(txCount int) *block.Block {
func newDummyBlock(height uint32, txCount int) *block.Block {
b := block.New(netmode.UnitTestNet, false)
b.Index = height
b.PrevHash = random.Uint256()
b.Timestamp = rand.Uint64()
b.Script.InvocationScript = random.Bytes(2)
@ -303,7 +304,7 @@ func newDummyBlock(txCount int) *block.Block {
}
func newDummyTx() *transaction.Transaction {
tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), int64(rand.Uint64()>>1))
tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123)
tx.Signers = []transaction.Signer{{Account: random.Uint160()}}
tx.Size()
tx.Hash()

View file

@ -96,6 +96,16 @@ func randomID() uint32 {
// NewServer returns a new Server, initialized with the given configuration.
func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger) (*Server, error) {
return newServerFromConstructors(config, chain, log, func(s *Server) Transporter {
return NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log)
}, consensus.NewService, newDefaultDiscovery)
}
func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger,
newTransport func(*Server) Transporter,
newConsensus func(consensus.Config) (consensus.Service, error),
newDiscovery func([]string, time.Duration, Transporter) Discoverer,
) (*Server, error) {
if log == nil {
return nil, errors.New("logger is a required parameter")
}
@ -120,7 +130,7 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
}
})
srv, err := consensus.NewService(consensus.Config{
srv, err := newConsensus(consensus.Config{
Logger: log,
Broadcast: s.handleNewPayload,
Chain: chain,
@ -156,8 +166,8 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
s.AttemptConnPeers = defaultAttemptConnPeers
}
s.transport = NewTCPTransport(s, net.JoinHostPort(config.Address, strconv.Itoa(int(config.Port))), s.log)
s.discovery = NewDefaultDiscovery(
s.transport = newTransport(s)
s.discovery = newDiscovery(
s.Seeds,
s.DialTimeout,
s.transport,
@ -595,7 +605,7 @@ func (s *Server) handleGetBlocksCmd(p Peer, gb *payload.GetBlocks) error {
return err
}
blockHashes := make([]util.Uint256, 0)
for i := start.Index + 1; i < start.Index+uint32(count); i++ {
for i := start.Index + 1; i <= start.Index+uint32(count); i++ {
hash := s.chain.GetHeaderHash(int(i))
if hash.Equals(util.Uint256{}) {
break
@ -839,12 +849,15 @@ func (s *Server) requestTx(hashes ...util.Uint256) {
return
}
for i := 0; i < len(hashes)/payload.MaxHashesCount; i++ {
for i := 0; i <= len(hashes)/payload.MaxHashesCount; i++ {
start := i * payload.MaxHashesCount
stop := (i + 1) * payload.MaxHashesCount
if stop < len(hashes) {
if stop > len(hashes) {
stop = len(hashes)
}
if start == stop {
break
}
msg := NewMessage(CMDGetData, payload.NewInventory(payload.TXType, hashes[start:stop]))
// It's high priority because it directly affects consensus process,
// even though it's getdata.

View file

@ -1,17 +1,186 @@
package network
import (
"errors"
"math/big"
"net"
"strconv"
atomic2 "sync/atomic"
"testing"
"time"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/zap/zaptest"
)
type fakeConsensus struct {
started atomic.Bool
stopped atomic.Bool
payloads []*consensus.Payload
txs []*transaction.Transaction
}
var _ consensus.Service = (*fakeConsensus)(nil)
func newFakeConsensus(c consensus.Config) (consensus.Service, error) {
return new(fakeConsensus), nil
}
func (f *fakeConsensus) Start() { f.started.Store(true) }
func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) }
func (f *fakeConsensus) OnPayload(p *consensus.Payload) { f.payloads = append(f.payloads, p) }
func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) }
func (f *fakeConsensus) GetPayload(h util.Uint256) *consensus.Payload { panic("implement me") }
func TestNewServer(t *testing.T) {
bc := &testChain{}
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
require.Error(t, err)
t.Run("set defaults", func(t *testing.T) {
s = newTestServer(t, ServerConfig{MinPeers: -1})
defer s.discovery.Close()
require.True(t, s.ID() != 0)
require.Equal(t, defaultMinPeers, s.ServerConfig.MinPeers)
require.Equal(t, defaultMaxPeers, s.ServerConfig.MaxPeers)
require.Equal(t, defaultAttemptConnPeers, s.ServerConfig.AttemptConnPeers)
})
t.Run("don't defaults", func(t *testing.T) {
cfg := ServerConfig{
MinPeers: 1,
MaxPeers: 2,
AttemptConnPeers: 3,
}
s = newTestServer(t, cfg)
defer s.discovery.Close()
require.True(t, s.ID() != 0)
require.Equal(t, 1, s.ServerConfig.MinPeers)
require.Equal(t, 2, s.ServerConfig.MaxPeers)
require.Equal(t, 3, s.ServerConfig.AttemptConnPeers)
})
t.Run("consensus error is not dropped", func(t *testing.T) {
errConsensus := errors.New("can't create consensus")
_, err = newServerFromConstructors(ServerConfig{MinPeers: -1}, bc, zaptest.NewLogger(t), newFakeTransp,
func(consensus.Config) (consensus.Service, error) { return nil, errConsensus },
newTestDiscovery)
require.True(t, errors.Is(err, errConsensus), "got: %#v", err)
})
}
func startWithChannel(s *Server) chan error {
ch := make(chan error)
go func() {
s.Start(ch)
close(ch)
}()
return ch
}
func TestServerStartAndShutdown(t *testing.T) {
t.Run("no consensus", func(t *testing.T) {
s := newTestServer(t, ServerConfig{})
ch := startWithChannel(s)
p := newLocalPeer(t, s)
s.register <- p
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
assert.True(t, s.transport.(*fakeTransp).started.Load())
assert.False(t, s.consensus.(*fakeConsensus).started.Load())
s.Shutdown()
<-ch
require.True(t, s.transport.(*fakeTransp).closed.Load())
require.False(t, s.consensus.(*fakeConsensus).stopped.Load())
err, ok := p.droppedWith.Load().(error)
require.True(t, ok)
require.True(t, errors.Is(err, errServerShutdown))
})
t.Run("with consensus", func(t *testing.T) {
s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)})
ch := startWithChannel(s)
p := newLocalPeer(t, s)
s.register <- p
assert.True(t, s.consensus.(*fakeConsensus).started.Load())
s.Shutdown()
<-ch
require.True(t, s.consensus.(*fakeConsensus).stopped.Load())
})
}
func TestServerRegisterPeer(t *testing.T) {
const peerCount = 3
s := newTestServer(t, ServerConfig{MaxPeers: 2})
ps := make([]*localPeer, peerCount)
for i := range ps {
ps[i] = newLocalPeer(t, s)
ps[i].netaddr.Port = i + 1
}
ch := startWithChannel(s)
defer func() {
s.Shutdown()
<-ch
}()
s.register <- ps[0]
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
s.register <- ps[1]
require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10)
require.Equal(t, 0, len(s.discovery.UnconnectedPeers()))
s.register <- ps[2]
require.Eventually(t, func() bool { return len(s.discovery.UnconnectedPeers()) > 0 }, time.Second, time.Millisecond*100)
index := -1
addrs := s.discovery.UnconnectedPeers()
for _, addr := range addrs {
for j := range ps {
if ps[j].PeerAddr().String() == addr {
index = j
break
}
}
}
require.True(t, index >= 0)
err, ok := ps[index].droppedWith.Load().(error)
require.True(t, ok)
require.True(t, errors.Is(err, errMaxPeers))
index = (index + 1) % peerCount
s.unregister <- peerDrop{ps[index], errIdenticalID}
require.Eventually(t, func() bool {
bad := s.BadPeers()
for i := range bad {
if bad[i] == ps[index].PeerAddr().String() {
return true
}
}
return false
}, time.Second, time.Millisecond*50)
}
func TestGetBlocksByIndex(t *testing.T) {
s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"})
ps := make([]*localPeer, 10)
@ -72,8 +241,7 @@ func TestSendVersion(t *testing.T) {
p = newLocalPeer(t, s)
)
// we need to set listener at least to handle dynamic port correctly
go s.transport.Accept()
require.Eventually(t, func() bool { return s.transport.Address() != "" }, time.Second, 10*time.Millisecond)
s.transport.Accept()
p.messageHandler = func(t *testing.T, msg *Message) {
// listener is already set, so Address() gives us proper address with port
_, p, err := net.SplitHostPort(s.transport.Address())
@ -136,7 +304,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
// invalid version and disconnects the peer.
func TestServerNotSendsVerack(t *testing.T) {
var (
s = newTestServer(t, ServerConfig{Net: 56753})
s = newTestServer(t, ServerConfig{MaxPeers: 10, Net: 56753})
p = newLocalPeer(t, s)
p2 = newLocalPeer(t, s)
)
@ -196,3 +364,402 @@ func TestServerNotSendsVerack(t *testing.T) {
assert.NotNil(t, err)
require.Equal(t, errAlreadyConnected, err)
}
func (s *Server) testHandleMessage(t *testing.T, p Peer, cmd CommandType, pl payload.Payload) *Server {
if p == nil {
p = newLocalPeer(t, s)
p.(*localPeer).handshaked = true
}
msg := NewMessage(cmd, pl)
msg.Network = netmode.UnitTestNet
require.NoError(t, s.handleMessage(p, msg))
return s
}
func startTestServer(t *testing.T) (*Server, func()) {
s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"})
ch := startWithChannel(s)
return s, func() {
s.Shutdown()
<-ch
}
}
func TestBlock(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
atomic2.StoreUint32(&s.chain.(*testChain).blockheight, 12344)
require.Equal(t, uint32(12344), s.chain.BlockHeight())
b := block.New(netmode.UnitTestNet, false)
b.Index = 12345
s.testHandleMessage(t, nil, CMDBlock, b)
require.Eventually(t, func() bool { return s.chain.BlockHeight() == 12345 }, time.Second, time.Millisecond*500)
}
func TestConsensus(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
pl := consensus.NewPayload(netmode.UnitTestNet, false)
s.testHandleMessage(t, nil, CMDConsensus, pl)
require.Contains(t, s.consensus.(*fakeConsensus).payloads, pl)
}
func TestTransaction(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
t.Run("good", func(t *testing.T) {
tx := newDummyTx()
p := newLocalPeer(t, s)
p.isFullNode = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
inv := msg.Payload.(*payload.Inventory)
require.Equal(t, payload.TXType, inv.Type)
require.Equal(t, []util.Uint256{tx.Hash()}, inv.Hashes)
}
}
s.register <- p
s.testHandleMessage(t, nil, CMDTX, tx)
require.Contains(t, s.consensus.(*fakeConsensus).txs, tx)
})
t.Run("bad", func(t *testing.T) {
tx := newDummyTx()
s.chain.(*testChain).poolTx = func(*transaction.Transaction) error { return core.ErrInsufficientFunds }
s.testHandleMessage(t, nil, CMDTX, tx)
for _, ftx := range s.consensus.(*fakeConsensus).txs {
require.NotEqual(t, ftx, tx)
}
})
}
func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType, hs, notFound []util.Uint256, found payload.Payload) {
var recvResponse atomic.Bool
var recvNotFound atomic.Bool
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
switch msg.Command {
case CMDTX, CMDBlock, CMDConsensus:
require.Equal(t, found, msg.Payload)
recvResponse.Store(true)
case CMDNotFound:
require.Equal(t, notFound, msg.Payload.(*payload.Inventory).Hashes)
recvNotFound.Store(true)
}
}
s.testHandleMessage(t, p, CMDGetData, payload.NewInventory(invType, hs))
require.Eventually(t, func() bool { return recvResponse.Load() }, time.Second, time.Millisecond)
require.Eventually(t, func() bool { return recvNotFound.Load() }, time.Second, time.Millisecond)
}
func TestGetData(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
t.Run("block", func(t *testing.T) {
b := newDummyBlock(2, 0)
hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()}
s.chain.(*testChain).putBlock(b)
notFound := []util.Uint256{hs[0], hs[2]}
s.testHandleGetData(t, payload.BlockType, hs, notFound, b)
})
t.Run("transaction", func(t *testing.T) {
tx := newDummyTx()
hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()}
s.chain.(*testChain).putTx(tx)
notFound := []util.Uint256{hs[0], hs[2]}
s.testHandleGetData(t, payload.TXType, hs, notFound, tx)
})
}
func initGetBlocksTest(t *testing.T) (*Server, func(), []*block.Block) {
s, shutdown := startTestServer(t)
var blocks []*block.Block
for i := uint32(12); i <= 15; i++ {
b := newDummyBlock(i, 3)
s.chain.(*testChain).putBlock(b)
blocks = append(blocks, b)
}
return s, shutdown, blocks
}
func TestGetBlocks(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
expected := make([]util.Uint256, len(blocks))
for i := range blocks {
expected[i] = blocks[i].Hash()
}
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
actual = msg.Payload.(*payload.Inventory).Hashes
}
}
t.Run("2", func(t *testing.T) {
s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: 2})
require.Equal(t, expected[1:3], actual)
})
t.Run("-1", func(t *testing.T) {
s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: -1})
require.Equal(t, expected[1:], actual)
})
t.Run("invalid start", func(t *testing.T) {
msg := NewMessage(CMDGetBlocks, &payload.GetBlocks{HashStart: util.Uint256{}, Count: -1})
msg.Network = netmode.UnitTestNet
require.Error(t, s.handleMessage(p, msg))
})
}
func TestGetBlockByIndex(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
var expected []*block.Block
var actual []*block.Block
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDBlock {
actual = append(actual, msg.Payload.(*block.Block))
if len(actual) == len(expected) {
require.Equal(t, expected, actual)
}
}
}
t.Run("2", func(t *testing.T) {
actual = nil
expected = blocks[:2]
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2})
})
t.Run("-1", func(t *testing.T) {
actual = nil
expected = blocks
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
})
t.Run("-1, last header", func(t *testing.T) {
s.chain.(*testChain).putHeader(newDummyBlock(16, 2))
actual = nil
expected = blocks
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
})
}
func TestGetHeaders(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
expected := make([]*block.Header, len(blocks))
for i := range blocks {
expected[i] = blocks[i].Header()
}
var actual *payload.Headers
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDHeaders {
actual = msg.Payload.(*payload.Headers)
}
}
t.Run("2", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2})
require.Equal(t, expected[:2], actual.Hdrs)
})
t.Run("more, than we have", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 10})
require.Equal(t, expected, actual.Hdrs)
})
t.Run("-1", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
require.Equal(t, expected, actual.Hdrs)
})
t.Run("no headers", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: 123, Count: -1})
require.Nil(t, actual)
})
}
func TestInv(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetData {
actual = msg.Payload.(*payload.Inventory).Hashes
}
}
t.Run("blocks", func(t *testing.T) {
b := newDummyBlock(10, 3)
s.chain.(*testChain).putBlock(b)
hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()}
s.testHandleMessage(t, p, CMDInv, &payload.Inventory{
Type: payload.BlockType,
Hashes: hs,
})
require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual)
})
t.Run("transaction", func(t *testing.T) {
tx := newDummyTx()
s.chain.(*testChain).putTx(tx)
hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()}
s.testHandleMessage(t, p, CMDInv, &payload.Inventory{
Type: payload.TXType,
Hashes: hs,
})
require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual)
})
}
func TestRequestTx(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetData {
actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...)
}
}
s.register <- p
s.register <- p // ensure previous send was handled
t.Run("no hashes, no message", func(t *testing.T) {
actual = nil
s.requestTx()
require.Nil(t, actual)
})
t.Run("good, small", func(t *testing.T) {
actual = nil
expected := []util.Uint256{random.Uint256(), random.Uint256()}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
t.Run("good, exactly one chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxHashesCount)
for i := range expected {
expected[i] = random.Uint256()
}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
t.Run("good, multiple chunks", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxHashesCount*2+payload.MaxHashesCount/2)
for i := range expected {
expected[i] = random.Uint256()
}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
}
func TestAddrs(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
ips := make([][16]byte, 4)
copy(ips[0][:], net.IPv4(1, 2, 3, 4))
copy(ips[1][:], net.IPv4(7, 8, 9, 0))
for i := range ips[2] {
ips[2][i] = byte(i)
}
p := newLocalPeer(t, s)
p.handshaked = true
p.getAddrSent = 1
pl := &payload.AddressList{
Addrs: []*payload.AddressAndTime{
{
IP: ips[0],
Capabilities: capability.Capabilities{{
Type: capability.TCPServer,
Data: &capability.Server{Port: 12},
}},
},
{
IP: ips[1],
Capabilities: capability.Capabilities{},
},
{
IP: ips[2],
Capabilities: capability.Capabilities{{
Type: capability.TCPServer,
Data: &capability.Server{Port: 42},
}},
},
},
}
s.testHandleMessage(t, p, CMDAddr, pl)
addrs := s.discovery.(*testDiscovery).backfill
require.Equal(t, 2, len(addrs))
require.Equal(t, "1.2.3.4:12", addrs[0])
require.Equal(t, net.JoinHostPort(net.IP(ips[2][:]).String(), "42"), addrs[1])
t.Run("CMDAddr not requested", func(t *testing.T) {
msg := NewMessage(CMDAddr, pl)
msg.Network = netmode.UnitTestNet
require.Error(t, s.handleMessage(p, msg))
})
}
type feerStub struct {
blockHeight uint32
}
func (f feerStub) FeePerByte() int64 { return 1 }
func (f feerStub) GetUtilityTokenBalance(util.Uint160) *big.Int { return big.NewInt(100000000) }
func (f feerStub) BlockHeight() uint32 { return f.blockHeight }
func (f feerStub) P2PSigExtensionsEnabled() bool { return false }
func TestMemPool(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...)
}
}
bc := s.chain.(*testChain)
expected := make([]util.Uint256, 4)
for i := range expected {
tx := newDummyTx()
require.NoError(t, bc.Pool.Add(tx, &feerStub{blockHeight: 10}))
expected[i] = tx.Hash()
}
s.testHandleMessage(t, p, CMDMempool, payload.NullPayload{})
require.ElementsMatch(t, expected, actual)
}