network/test: add tests for server commands

This commit is contained in:
Evgenii Stratonikov 2020-12-07 12:52:19 +03:00
parent bd81b19a7a
commit 27624946d9
7 changed files with 825 additions and 141 deletions

View file

@ -10,7 +10,7 @@ import (
) )
func TestBlockQueue(t *testing.T) { func TestBlockQueue(t *testing.T) {
chain := &testChain{} chain := newTestChain()
// notice, it's not yet running // notice, it's not yet running
bq := newBlockQueue(0, chain, zaptest.NewLogger(t), nil) bq := newBlockQueue(0, chain, zaptest.NewLogger(t), nil)
blocks := make([]*block.Block, 11) blocks := make([]*block.Block, 11)

View file

@ -67,6 +67,10 @@ func NewDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) *Defa
return d return d
} }
func newDefaultDiscovery(addrs []string, dt time.Duration, ts Transporter) Discoverer {
return NewDefaultDiscovery(addrs, dt, ts)
}
// BackFill implements the Discoverer interface and will backfill the // BackFill implements the Discoverer interface and will backfill the
// the pool with the given addresses. // the pool with the given addresses.
func (d *DefaultDiscovery) BackFill(addrs ...string) { func (d *DefaultDiscovery) BackFill(addrs ...string) {

View file

@ -2,6 +2,7 @@ package network
import ( import (
"errors" "errors"
"net"
"sort" "sort"
"sync/atomic" "sync/atomic"
"testing" "testing"
@ -10,11 +11,19 @@ import (
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
atomic2 "go.uber.org/atomic"
) )
type fakeTransp struct { type fakeTransp struct {
retFalse int32 retFalse int32
started atomic2.Bool
closed atomic2.Bool
dialCh chan string dialCh chan string
addr string
}
func newFakeTransp(s *Server) Transporter {
return &fakeTransp{}
} }
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error { func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
@ -26,14 +35,23 @@ func (ft *fakeTransp) Dial(addr string, timeout time.Duration) error {
return nil return nil
} }
func (ft *fakeTransp) Accept() { func (ft *fakeTransp) Accept() {
if ft.started.Load() {
panic("started twice")
}
ft.addr = net.JoinHostPort("0.0.0.0", "42")
ft.started.Store(true)
} }
func (ft *fakeTransp) Proto() string { func (ft *fakeTransp) Proto() string {
return "" return ""
} }
func (ft *fakeTransp) Address() string { func (ft *fakeTransp) Address() string {
return "" return ft.addr
} }
func (ft *fakeTransp) Close() { func (ft *fakeTransp) Close() {
if ft.closed.Load() {
panic("closed twice")
}
ft.closed.Store(true)
} }
func TestDefaultDiscoverer(t *testing.T) { func TestDefaultDiscoverer(t *testing.T) {
ts := &fakeTransp{} ts := &fakeTransp{}

View file

@ -1,14 +1,17 @@
package network package network
import ( import (
"errors"
"fmt"
"math/big" "math/big"
"math/rand"
"net" "net"
"strconv" "sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/mempool" "github.com/nspcc-dev/neo-go/pkg/core/mempool"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
@ -21,45 +24,76 @@ import (
"github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger" "github.com/nspcc-dev/neo-go/pkg/smartcontract/trigger"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
) )
type testChain struct { type testChain struct {
config.ProtocolConfiguration
*mempool.Pool
blocksCh []chan<- *block.Block
blockheight uint32 blockheight uint32
poolTx func(*transaction.Transaction) error
blocks map[util.Uint256]*block.Block
hdrHashes map[uint32]util.Uint256
txs map[util.Uint256]*transaction.Transaction
} }
func (chain testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction { func newTestChain() *testChain {
return &testChain{
Pool: mempool.New(10),
poolTx: func(*transaction.Transaction) error { return nil },
blocks: make(map[util.Uint256]*block.Block),
hdrHashes: make(map[uint32]util.Uint256),
txs: make(map[util.Uint256]*transaction.Transaction),
}
}
func (chain *testChain) putBlock(b *block.Block) {
chain.blocks[b.Hash()] = b
chain.hdrHashes[b.Index] = b.Hash()
atomic.StoreUint32(&chain.blockheight, b.Index)
}
func (chain *testChain) putHeader(b *block.Block) {
chain.hdrHashes[b.Index] = b.Hash()
}
func (chain *testChain) putTx(tx *transaction.Transaction) {
chain.txs[tx.Hash()] = tx
}
func (chain *testChain) ApplyPolicyToTxSet([]*transaction.Transaction) []*transaction.Transaction {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetConfig() config.ProtocolConfiguration { func (chain *testChain) GetConfig() config.ProtocolConfiguration {
panic("TODO") return chain.ProtocolConfiguration
} }
func (chain testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) { func (chain *testChain) CalculateClaimable(util.Uint160, uint32) (*big.Int, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) FeePerByte() int64 { func (chain *testChain) FeePerByte() int64 {
panic("TODO") panic("TODO")
} }
func (chain testChain) P2PSigExtensionsEnabled() bool { func (chain *testChain) P2PSigExtensionsEnabled() bool {
return false return false
} }
func (chain testChain) GetMaxBlockSystemFee() int64 { func (chain *testChain) GetMaxBlockSystemFee() int64 {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetMaxBlockSize() uint32 { func (chain *testChain) GetMaxBlockSize() uint32 {
panic("TODO") panic("TODO")
} }
func (chain testChain) AddHeaders(...*block.Header) error { func (chain *testChain) AddHeaders(...*block.Header) error {
panic("TODO") panic("TODO")
} }
func (chain *testChain) AddBlock(block *block.Block) error { func (chain *testChain) AddBlock(block *block.Block) error {
if block.Index == chain.blockheight+1 { if block.Index == atomic.LoadUint32(&chain.blockheight)+1 {
atomic.StoreUint32(&chain.blockheight, block.Index) chain.putBlock(block)
} }
return nil return nil
} }
@ -72,148 +106,200 @@ func (chain *testChain) BlockHeight() uint32 {
func (chain *testChain) Close() { func (chain *testChain) Close() {
panic("TODO") panic("TODO")
} }
func (chain testChain) HeaderHeight() uint32 { func (chain *testChain) HeaderHeight() uint32 {
return 0 return atomic.LoadUint32(&chain.blockheight)
} }
func (chain testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) { func (chain *testChain) GetAppExecResults(hash util.Uint256, trig trigger.Type) ([]state.AppExecResult, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetBlock(hash util.Uint256) (*block.Block, error) { func (chain *testChain) GetBlock(hash util.Uint256) (*block.Block, error) {
if b, ok := chain.blocks[hash]; ok {
return b, nil
}
return nil, errors.New("not found")
}
func (chain *testChain) GetCommittee() (keys.PublicKeys, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetCommittee() (keys.PublicKeys, error) { func (chain *testChain) GetContractState(hash util.Uint160) *state.Contract {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetContractState(hash util.Uint160) *state.Contract { func (chain *testChain) GetContractScriptHash(id int32) (util.Uint160, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetContractScriptHash(id int32) (util.Uint160, error) { func (chain *testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetNativeContractScriptHash(name string) (util.Uint160, error) { func (chain *testChain) GetHeaderHash(n int) util.Uint256 {
return chain.hdrHashes[uint32(n)]
}
func (chain *testChain) GetHeader(hash util.Uint256) (*block.Header, error) {
b, err := chain.GetBlock(hash)
if err != nil {
return nil, err
}
return b.Header(), nil
}
func (chain *testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetHeaderHash(int) util.Uint256 { func (chain *testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error {
panic("TODO")
}
func (chain *testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances {
panic("TODO")
}
func (chain *testChain) GetValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain *testChain) GetStandByCommittee() keys.PublicKeys {
panic("TODO")
}
func (chain *testChain) GetStandByValidators() keys.PublicKeys {
panic("TODO")
}
func (chain *testChain) GetEnrollments() ([]state.Validator, error) {
panic("TODO")
}
func (chain *testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
panic("TODO")
}
func (chain *testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
panic("TODO")
}
func (chain *testChain) GetStorageItem(id int32, key []byte) *state.StorageItem {
panic("TODO")
}
func (chain *testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM {
panic("TODO")
}
func (chain *testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) {
panic("TODO")
}
func (chain *testChain) CurrentHeaderHash() util.Uint256 {
return util.Uint256{} return util.Uint256{}
} }
func (chain testChain) GetHeader(hash util.Uint256) (*block.Header, error) { func (chain *testChain) CurrentBlockHash() util.Uint256 {
panic("TODO")
}
func (chain testChain) GetNextBlockValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain testChain) ForEachNEP17Transfer(util.Uint160, func(*state.NEP17Transfer) (bool, error)) error {
panic("TODO")
}
func (chain testChain) GetNEP17Balances(util.Uint160) *state.NEP17Balances {
panic("TODO")
}
func (chain testChain) GetValidators() ([]*keys.PublicKey, error) {
panic("TODO")
}
func (chain testChain) GetStandByCommittee() keys.PublicKeys {
panic("TODO")
}
func (chain testChain) GetStandByValidators() keys.PublicKeys {
panic("TODO")
}
func (chain testChain) GetEnrollments() ([]state.Validator, error) {
panic("TODO")
}
func (chain testChain) GetStateProof(util.Uint256, []byte) ([][]byte, error) {
panic("TODO")
}
func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
panic("TODO")
}
func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem {
panic("TODO")
}
func (chain testChain) GetTestVM(tx *transaction.Transaction, b *block.Block) *vm.VM {
panic("TODO")
}
func (chain testChain) GetStorageItems(id int32) (map[string]*state.StorageItem, error) {
panic("TODO")
}
func (chain testChain) CurrentHeaderHash() util.Uint256 {
return util.Uint256{} return util.Uint256{}
} }
func (chain testChain) CurrentBlockHash() util.Uint256 { func (chain *testChain) HasBlock(h util.Uint256) bool {
return util.Uint256{} _, ok := chain.blocks[h]
return ok
} }
func (chain testChain) HasBlock(util.Uint256) bool { func (chain *testChain) HasTransaction(h util.Uint256) bool {
return false _, ok := chain.txs[h]
return ok
} }
func (chain testChain) HasTransaction(util.Uint256) bool { func (chain *testChain) GetTransaction(h util.Uint256) (*transaction.Transaction, uint32, error) {
return false if tx, ok := chain.txs[h]; ok {
return tx, 1, nil
}
return nil, 0, errors.New("not found")
} }
func (chain testChain) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) {
func (chain *testChain) GetMemPool() *mempool.Pool {
return chain.Pool
}
func (chain *testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetMemPool() *mempool.Pool { func (chain *testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetGoverningTokenBalance(acc util.Uint160) (*big.Int, uint32) { func (chain *testChain) PoolTx(tx *transaction.Transaction, _ ...*mempool.Pool) error {
return chain.poolTx(tx)
}
func (chain *testChain) SubscribeForBlocks(ch chan<- *block.Block) {
chain.blocksCh = append(chain.blocksCh, ch)
}
func (chain *testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
}
func (chain *testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
}
func (chain *testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) {
panic("TODO") panic("TODO")
} }
func (chain testChain) GetUtilityTokenBalance(uint160 util.Uint160) *big.Int { func (chain *testChain) VerifyTx(*transaction.Transaction) error {
panic("TODO")
}
func (*testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error {
panic("TODO") panic("TODO")
} }
func (chain testChain) PoolTx(*transaction.Transaction, ...*mempool.Pool) error { func (chain *testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) {
for i, c := range chain.blocksCh {
if c == ch {
if i < len(chain.blocksCh) {
copy(chain.blocksCh[i:], chain.blocksCh[i+1:])
}
chain.blocksCh = chain.blocksCh[:len(chain.blocksCh)]
}
}
}
func (chain *testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) {
panic("TODO")
}
func (chain *testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) {
panic("TODO")
}
func (chain *testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) {
panic("TODO") panic("TODO")
} }
func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) { type testDiscovery struct {
panic("TODO") sync.Mutex
} bad []string
func (chain testChain) SubscribeForExecutions(ch chan<- *state.AppExecResult) { good []string
panic("TODO") connected []string
} unregistered []string
func (chain testChain) SubscribeForNotifications(ch chan<- *state.NotificationEvent) { backfill []string
panic("TODO")
}
func (chain testChain) SubscribeForTransactions(ch chan<- *transaction.Transaction) {
panic("TODO")
} }
func (chain testChain) VerifyTx(*transaction.Transaction) error { func newTestDiscovery([]string, time.Duration, Transporter) Discoverer { return new(testDiscovery) }
panic("TODO")
}
func (testChain) VerifyWitness(util.Uint160, crypto.Verifiable, *transaction.Witness, int64) error {
panic("TODO")
}
func (chain testChain) UnsubscribeFromBlocks(ch chan<- *block.Block) { func (d *testDiscovery) BackFill(addrs ...string) {
panic("TODO") d.Lock()
defer d.Unlock()
d.backfill = append(d.backfill, addrs...)
} }
func (chain testChain) UnsubscribeFromExecutions(ch chan<- *state.AppExecResult) { func (d *testDiscovery) Close() {}
panic("TODO") func (d *testDiscovery) PoolCount() int { return 0 }
func (d *testDiscovery) RegisterBadAddr(addr string) {
d.Lock()
defer d.Unlock()
d.bad = append(d.bad, addr)
} }
func (chain testChain) UnsubscribeFromNotifications(ch chan<- *state.NotificationEvent) { func (d *testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {}
panic("TODO") func (d *testDiscovery) RegisterConnectedAddr(addr string) {
d.Lock()
defer d.Unlock()
d.connected = append(d.connected, addr)
} }
func (chain testChain) UnsubscribeFromTransactions(ch chan<- *transaction.Transaction) { func (d *testDiscovery) UnregisterConnectedAddr(addr string) {
panic("TODO") d.Lock()
defer d.Unlock()
d.unregistered = append(d.unregistered, addr)
} }
func (d *testDiscovery) UnconnectedPeers() []string {
type testDiscovery struct{} d.Lock()
defer d.Unlock()
func (d testDiscovery) BackFill(addrs ...string) {} return d.unregistered
func (d testDiscovery) Close() {} }
func (d testDiscovery) PoolCount() int { return 0 } func (d *testDiscovery) RequestRemote(n int) {}
func (d testDiscovery) RegisterBadAddr(string) {} func (d *testDiscovery) BadPeers() []string {
func (d testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} d.Lock()
func (d testDiscovery) RegisterConnectedAddr(string) {} defer d.Unlock()
func (d testDiscovery) UnregisterConnectedAddr(string) {} return d.bad
func (d testDiscovery) UnconnectedPeers() []string { return []string{} } }
func (d testDiscovery) RequestRemote(n int) {} func (d *testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} }
func (d testDiscovery) BadPeers() []string { return []string{} }
func (d testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} }
var defaultMessageHandler = func(t *testing.T, msg *Message) {} var defaultMessageHandler = func(t *testing.T, msg *Message) {}
@ -228,6 +314,7 @@ type localPeer struct {
messageHandler func(t *testing.T, msg *Message) messageHandler func(t *testing.T, msg *Message)
pingSent int pingSent int
getAddrSent int getAddrSent int
droppedWith atomic.Value
} }
func newLocalPeer(t *testing.T, s *Server) *localPeer { func newLocalPeer(t *testing.T, s *Server) *localPeer {
@ -247,7 +334,13 @@ func (p *localPeer) PeerAddr() net.Addr {
return &p.netaddr return &p.netaddr
} }
func (p *localPeer) StartProtocol() {} func (p *localPeer) StartProtocol() {}
func (p *localPeer) Disconnect(err error) {} func (p *localPeer) Disconnect(err error) {
if p.droppedWith.Load() == nil {
p.droppedWith.Store(err)
}
fmt.Println("peer dropped:", err)
p.server.unregister <- peerDrop{p, err}
}
func (p *localPeer) EnqueueMessage(msg *Message) error { func (p *localPeer) EnqueueMessage(msg *Message) error {
b, err := msg.Bytes() b, err := msg.Bytes()
@ -266,7 +359,7 @@ func (p *localPeer) EnqueueP2PPacket(m []byte) error {
return p.EnqueueHPPacket(m) return p.EnqueueHPPacket(m)
} }
func (p *localPeer) EnqueueHPPacket(m []byte) error { func (p *localPeer) EnqueueHPPacket(m []byte) error {
msg := &Message{} msg := &Message{Network: netmode.UnitTestNet}
r := io.NewBinReaderFromBuf(m) r := io.NewBinReaderFromBuf(m)
err := msg.Decode(r) err := msg.Decode(r)
if err == nil { if err == nil {
@ -333,17 +426,8 @@ func (p *localPeer) CanProcessAddr() bool {
} }
func newTestServer(t *testing.T, serverConfig ServerConfig) *Server { func newTestServer(t *testing.T, serverConfig ServerConfig) *Server {
s := &Server{ s, err := newServerFromConstructors(serverConfig, newTestChain(), zaptest.NewLogger(t),
ServerConfig: serverConfig, newFakeTransp, newFakeConsensus, newTestDiscovery)
chain: &testChain{}, require.NoError(t, err)
discovery: testDiscovery{},
id: rand.Uint32(),
quit: make(chan struct{}),
register: make(chan Peer),
unregister: make(chan peerDrop),
peers: make(map[Peer]bool),
log: zaptest.NewLogger(t),
}
s.transport = NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log)
return s return s
} }

View file

@ -128,10 +128,10 @@ func TestEncodeDecodeAddr(t *testing.T) {
func TestEncodeDecodeBlock(t *testing.T) { func TestEncodeDecodeBlock(t *testing.T) {
t.Run("good", func(t *testing.T) { t.Run("good", func(t *testing.T) {
testEncodeDecode(t, CMDBlock, newDummyBlock(1)) testEncodeDecode(t, CMDBlock, newDummyBlock(12, 1))
}) })
t.Run("invalid state root enabled setting", func(t *testing.T) { t.Run("invalid state root enabled setting", func(t *testing.T) {
expected := NewMessage(CMDBlock, newDummyBlock(1)) expected := NewMessage(CMDBlock, newDummyBlock(31, 1))
expected.Network = netmode.UnitTestNet expected.Network = netmode.UnitTestNet
data, err := testserdes.Encode(expected) data, err := testserdes.Encode(expected)
require.NoError(t, err) require.NoError(t, err)
@ -270,7 +270,7 @@ func TestInvalidMessages(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("trimmed payload", func(t *testing.T) { t.Run("trimmed payload", func(t *testing.T) {
m := NewMessage(CMDBlock, newDummyBlock(0)) m := NewMessage(CMDBlock, newDummyBlock(1, 0))
data, err := testserdes.Encode(m) data, err := testserdes.Encode(m)
require.NoError(t, err) require.NoError(t, err)
data = data[:len(data)-1] data = data[:len(data)-1]
@ -288,8 +288,9 @@ func (f failSer) EncodeBinary(r *io.BinWriter) {
func (failSer) DecodeBinary(w *io.BinReader) {} func (failSer) DecodeBinary(w *io.BinReader) {}
func newDummyBlock(txCount int) *block.Block { func newDummyBlock(height uint32, txCount int) *block.Block {
b := block.New(netmode.UnitTestNet, false) b := block.New(netmode.UnitTestNet, false)
b.Index = height
b.PrevHash = random.Uint256() b.PrevHash = random.Uint256()
b.Timestamp = rand.Uint64() b.Timestamp = rand.Uint64()
b.Script.InvocationScript = random.Bytes(2) b.Script.InvocationScript = random.Bytes(2)
@ -303,7 +304,7 @@ func newDummyBlock(txCount int) *block.Block {
} }
func newDummyTx() *transaction.Transaction { func newDummyTx() *transaction.Transaction {
tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), int64(rand.Uint64()>>1)) tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), 123)
tx.Signers = []transaction.Signer{{Account: random.Uint160()}} tx.Signers = []transaction.Signer{{Account: random.Uint160()}}
tx.Size() tx.Size()
tx.Hash() tx.Hash()

View file

@ -96,6 +96,16 @@ func randomID() uint32 {
// NewServer returns a new Server, initialized with the given configuration. // NewServer returns a new Server, initialized with the given configuration.
func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger) (*Server, error) { func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger) (*Server, error) {
return newServerFromConstructors(config, chain, log, func(s *Server) Transporter {
return NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log)
}, consensus.NewService, newDefaultDiscovery)
}
func newServerFromConstructors(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Logger,
newTransport func(*Server) Transporter,
newConsensus func(consensus.Config) (consensus.Service, error),
newDiscovery func([]string, time.Duration, Transporter) Discoverer,
) (*Server, error) {
if log == nil { if log == nil {
return nil, errors.New("logger is a required parameter") return nil, errors.New("logger is a required parameter")
} }
@ -120,7 +130,7 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
} }
}) })
srv, err := consensus.NewService(consensus.Config{ srv, err := newConsensus(consensus.Config{
Logger: log, Logger: log,
Broadcast: s.handleNewPayload, Broadcast: s.handleNewPayload,
Chain: chain, Chain: chain,
@ -156,8 +166,8 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo
s.AttemptConnPeers = defaultAttemptConnPeers s.AttemptConnPeers = defaultAttemptConnPeers
} }
s.transport = NewTCPTransport(s, net.JoinHostPort(config.Address, strconv.Itoa(int(config.Port))), s.log) s.transport = newTransport(s)
s.discovery = NewDefaultDiscovery( s.discovery = newDiscovery(
s.Seeds, s.Seeds,
s.DialTimeout, s.DialTimeout,
s.transport, s.transport,

View file

@ -1,17 +1,186 @@
package network package network
import ( import (
"errors"
"math/big"
"net" "net"
"strconv" "strconv"
atomic2 "sync/atomic"
"testing" "testing"
"time" "time"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/consensus"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/zap/zaptest"
) )
type fakeConsensus struct {
started atomic.Bool
stopped atomic.Bool
payloads []*consensus.Payload
txs []*transaction.Transaction
}
var _ consensus.Service = (*fakeConsensus)(nil)
func newFakeConsensus(c consensus.Config) (consensus.Service, error) {
return new(fakeConsensus), nil
}
func (f *fakeConsensus) Start() { f.started.Store(true) }
func (f *fakeConsensus) Shutdown() { f.stopped.Store(true) }
func (f *fakeConsensus) OnPayload(p *consensus.Payload) { f.payloads = append(f.payloads, p) }
func (f *fakeConsensus) OnTransaction(tx *transaction.Transaction) { f.txs = append(f.txs, tx) }
func (f *fakeConsensus) GetPayload(h util.Uint256) *consensus.Payload { panic("implement me") }
func TestNewServer(t *testing.T) {
bc := &testChain{}
s, err := newServerFromConstructors(ServerConfig{}, bc, nil, newFakeTransp, newFakeConsensus, newTestDiscovery)
require.Error(t, err)
t.Run("set defaults", func(t *testing.T) {
s = newTestServer(t, ServerConfig{MinPeers: -1})
defer s.discovery.Close()
require.True(t, s.ID() != 0)
require.Equal(t, defaultMinPeers, s.ServerConfig.MinPeers)
require.Equal(t, defaultMaxPeers, s.ServerConfig.MaxPeers)
require.Equal(t, defaultAttemptConnPeers, s.ServerConfig.AttemptConnPeers)
})
t.Run("don't defaults", func(t *testing.T) {
cfg := ServerConfig{
MinPeers: 1,
MaxPeers: 2,
AttemptConnPeers: 3,
}
s = newTestServer(t, cfg)
defer s.discovery.Close()
require.True(t, s.ID() != 0)
require.Equal(t, 1, s.ServerConfig.MinPeers)
require.Equal(t, 2, s.ServerConfig.MaxPeers)
require.Equal(t, 3, s.ServerConfig.AttemptConnPeers)
})
t.Run("consensus error is not dropped", func(t *testing.T) {
errConsensus := errors.New("can't create consensus")
_, err = newServerFromConstructors(ServerConfig{MinPeers: -1}, bc, zaptest.NewLogger(t), newFakeTransp,
func(consensus.Config) (consensus.Service, error) { return nil, errConsensus },
newTestDiscovery)
require.True(t, errors.Is(err, errConsensus), "got: %#v", err)
})
}
func startWithChannel(s *Server) chan error {
ch := make(chan error)
go func() {
s.Start(ch)
close(ch)
}()
return ch
}
func TestServerStartAndShutdown(t *testing.T) {
t.Run("no consensus", func(t *testing.T) {
s := newTestServer(t, ServerConfig{})
ch := startWithChannel(s)
p := newLocalPeer(t, s)
s.register <- p
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
assert.True(t, s.transport.(*fakeTransp).started.Load())
assert.False(t, s.consensus.(*fakeConsensus).started.Load())
s.Shutdown()
<-ch
require.True(t, s.transport.(*fakeTransp).closed.Load())
require.False(t, s.consensus.(*fakeConsensus).stopped.Load())
err, ok := p.droppedWith.Load().(error)
require.True(t, ok)
require.True(t, errors.Is(err, errServerShutdown))
})
t.Run("with consensus", func(t *testing.T) {
s := newTestServer(t, ServerConfig{Wallet: new(config.Wallet)})
ch := startWithChannel(s)
p := newLocalPeer(t, s)
s.register <- p
assert.True(t, s.consensus.(*fakeConsensus).started.Load())
s.Shutdown()
<-ch
require.True(t, s.consensus.(*fakeConsensus).stopped.Load())
})
}
func TestServerRegisterPeer(t *testing.T) {
const peerCount = 3
s := newTestServer(t, ServerConfig{MaxPeers: 2})
ps := make([]*localPeer, peerCount)
for i := range ps {
ps[i] = newLocalPeer(t, s)
ps[i].netaddr.Port = i + 1
}
ch := startWithChannel(s)
defer func() {
s.Shutdown()
<-ch
}()
s.register <- ps[0]
require.Eventually(t, func() bool { return 1 == s.PeerCount() }, time.Second, time.Millisecond*10)
s.register <- ps[1]
require.Eventually(t, func() bool { return 2 == s.PeerCount() }, time.Second, time.Millisecond*10)
require.Equal(t, 0, len(s.discovery.UnconnectedPeers()))
s.register <- ps[2]
require.Eventually(t, func() bool { return len(s.discovery.UnconnectedPeers()) > 0 }, time.Second, time.Millisecond*100)
index := -1
addrs := s.discovery.UnconnectedPeers()
for _, addr := range addrs {
for j := range ps {
if ps[j].PeerAddr().String() == addr {
index = j
break
}
}
}
require.True(t, index >= 0)
err, ok := ps[index].droppedWith.Load().(error)
require.True(t, ok)
require.True(t, errors.Is(err, errMaxPeers))
index = (index + 1) % peerCount
s.unregister <- peerDrop{ps[index], errIdenticalID}
require.Eventually(t, func() bool {
bad := s.BadPeers()
for i := range bad {
if bad[i] == ps[index].PeerAddr().String() {
return true
}
}
return false
}, time.Second, time.Millisecond*50)
}
func TestGetBlocksByIndex(t *testing.T) { func TestGetBlocksByIndex(t *testing.T) {
s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"})
ps := make([]*localPeer, 10) ps := make([]*localPeer, 10)
@ -72,8 +241,7 @@ func TestSendVersion(t *testing.T) {
p = newLocalPeer(t, s) p = newLocalPeer(t, s)
) )
// we need to set listener at least to handle dynamic port correctly // we need to set listener at least to handle dynamic port correctly
go s.transport.Accept() s.transport.Accept()
require.Eventually(t, func() bool { return s.transport.Address() != "" }, time.Second, 10*time.Millisecond)
p.messageHandler = func(t *testing.T, msg *Message) { p.messageHandler = func(t *testing.T, msg *Message) {
// listener is already set, so Address() gives us proper address with port // listener is already set, so Address() gives us proper address with port
_, p, err := net.SplitHostPort(s.transport.Address()) _, p, err := net.SplitHostPort(s.transport.Address())
@ -136,7 +304,7 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) {
// invalid version and disconnects the peer. // invalid version and disconnects the peer.
func TestServerNotSendsVerack(t *testing.T) { func TestServerNotSendsVerack(t *testing.T) {
var ( var (
s = newTestServer(t, ServerConfig{Net: 56753}) s = newTestServer(t, ServerConfig{MaxPeers: 10, Net: 56753})
p = newLocalPeer(t, s) p = newLocalPeer(t, s)
p2 = newLocalPeer(t, s) p2 = newLocalPeer(t, s)
) )
@ -196,3 +364,402 @@ func TestServerNotSendsVerack(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
require.Equal(t, errAlreadyConnected, err) require.Equal(t, errAlreadyConnected, err)
} }
func (s *Server) testHandleMessage(t *testing.T, p Peer, cmd CommandType, pl payload.Payload) *Server {
if p == nil {
p = newLocalPeer(t, s)
p.(*localPeer).handshaked = true
}
msg := NewMessage(cmd, pl)
msg.Network = netmode.UnitTestNet
require.NoError(t, s.handleMessage(p, msg))
return s
}
func startTestServer(t *testing.T) (*Server, func()) {
s := newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"})
ch := startWithChannel(s)
return s, func() {
s.Shutdown()
<-ch
}
}
func TestBlock(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
atomic2.StoreUint32(&s.chain.(*testChain).blockheight, 12344)
require.Equal(t, uint32(12344), s.chain.BlockHeight())
b := block.New(netmode.UnitTestNet, false)
b.Index = 12345
s.testHandleMessage(t, nil, CMDBlock, b)
require.Eventually(t, func() bool { return s.chain.BlockHeight() == 12345 }, time.Second, time.Millisecond*500)
}
func TestConsensus(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
pl := consensus.NewPayload(netmode.UnitTestNet, false)
s.testHandleMessage(t, nil, CMDConsensus, pl)
require.Contains(t, s.consensus.(*fakeConsensus).payloads, pl)
}
func TestTransaction(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
t.Run("good", func(t *testing.T) {
tx := newDummyTx()
p := newLocalPeer(t, s)
p.isFullNode = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
inv := msg.Payload.(*payload.Inventory)
require.Equal(t, payload.TXType, inv.Type)
require.Equal(t, []util.Uint256{tx.Hash()}, inv.Hashes)
}
}
s.register <- p
s.testHandleMessage(t, nil, CMDTX, tx)
require.Contains(t, s.consensus.(*fakeConsensus).txs, tx)
})
t.Run("bad", func(t *testing.T) {
tx := newDummyTx()
s.chain.(*testChain).poolTx = func(*transaction.Transaction) error { return core.ErrInsufficientFunds }
s.testHandleMessage(t, nil, CMDTX, tx)
for _, ftx := range s.consensus.(*fakeConsensus).txs {
require.NotEqual(t, ftx, tx)
}
})
}
func (s *Server) testHandleGetData(t *testing.T, invType payload.InventoryType, hs, notFound []util.Uint256, found payload.Payload) {
var recvResponse atomic.Bool
var recvNotFound atomic.Bool
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
switch msg.Command {
case CMDTX, CMDBlock, CMDConsensus:
require.Equal(t, found, msg.Payload)
recvResponse.Store(true)
case CMDNotFound:
require.Equal(t, notFound, msg.Payload.(*payload.Inventory).Hashes)
recvNotFound.Store(true)
}
}
s.testHandleMessage(t, p, CMDGetData, payload.NewInventory(invType, hs))
require.Eventually(t, func() bool { return recvResponse.Load() }, time.Second, time.Millisecond)
require.Eventually(t, func() bool { return recvNotFound.Load() }, time.Second, time.Millisecond)
}
func TestGetData(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
t.Run("block", func(t *testing.T) {
b := newDummyBlock(2, 0)
hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()}
s.chain.(*testChain).putBlock(b)
notFound := []util.Uint256{hs[0], hs[2]}
s.testHandleGetData(t, payload.BlockType, hs, notFound, b)
})
t.Run("transaction", func(t *testing.T) {
tx := newDummyTx()
hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()}
s.chain.(*testChain).putTx(tx)
notFound := []util.Uint256{hs[0], hs[2]}
s.testHandleGetData(t, payload.TXType, hs, notFound, tx)
})
}
func initGetBlocksTest(t *testing.T) (*Server, func(), []*block.Block) {
s, shutdown := startTestServer(t)
var blocks []*block.Block
for i := uint32(12); i <= 15; i++ {
b := newDummyBlock(i, 3)
s.chain.(*testChain).putBlock(b)
blocks = append(blocks, b)
}
return s, shutdown, blocks
}
func TestGetBlocks(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
expected := make([]util.Uint256, len(blocks))
for i := range blocks {
expected[i] = blocks[i].Hash()
}
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
actual = msg.Payload.(*payload.Inventory).Hashes
}
}
t.Run("2", func(t *testing.T) {
s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: 2})
require.Equal(t, expected[1:3], actual)
})
t.Run("-1", func(t *testing.T) {
s.testHandleMessage(t, p, CMDGetBlocks, &payload.GetBlocks{HashStart: expected[0], Count: -1})
require.Equal(t, expected[1:], actual)
})
t.Run("invalid start", func(t *testing.T) {
msg := NewMessage(CMDGetBlocks, &payload.GetBlocks{HashStart: util.Uint256{}, Count: -1})
msg.Network = netmode.UnitTestNet
require.Error(t, s.handleMessage(p, msg))
})
}
func TestGetBlockByIndex(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
var expected []*block.Block
var actual []*block.Block
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDBlock {
actual = append(actual, msg.Payload.(*block.Block))
if len(actual) == len(expected) {
require.Equal(t, expected, actual)
}
}
}
t.Run("2", func(t *testing.T) {
actual = nil
expected = blocks[:2]
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2})
})
t.Run("-1", func(t *testing.T) {
actual = nil
expected = blocks
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
})
t.Run("-1, last header", func(t *testing.T) {
s.chain.(*testChain).putHeader(newDummyBlock(16, 2))
actual = nil
expected = blocks
s.testHandleMessage(t, p, CMDGetBlockByIndex, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
})
}
func TestGetHeaders(t *testing.T) {
s, shutdown, blocks := initGetBlocksTest(t)
defer shutdown()
expected := make([]*block.Header, len(blocks))
for i := range blocks {
expected[i] = blocks[i].Header()
}
var actual *payload.Headers
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDHeaders {
actual = msg.Payload.(*payload.Headers)
}
}
t.Run("2", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 2})
require.Equal(t, expected[:2], actual.Hdrs)
})
t.Run("more, than we have", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: 10})
require.Equal(t, expected, actual.Hdrs)
})
t.Run("-1", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: blocks[0].Index, Count: -1})
require.Equal(t, expected, actual.Hdrs)
})
t.Run("no headers", func(t *testing.T) {
actual = nil
s.testHandleMessage(t, p, CMDGetHeaders, &payload.GetBlockByIndex{IndexStart: 123, Count: -1})
require.Nil(t, actual)
})
}
func TestInv(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetData {
actual = msg.Payload.(*payload.Inventory).Hashes
}
}
t.Run("blocks", func(t *testing.T) {
b := newDummyBlock(10, 3)
s.chain.(*testChain).putBlock(b)
hs := []util.Uint256{random.Uint256(), b.Hash(), random.Uint256()}
s.testHandleMessage(t, p, CMDInv, &payload.Inventory{
Type: payload.BlockType,
Hashes: hs,
})
require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual)
})
t.Run("transaction", func(t *testing.T) {
tx := newDummyTx()
s.chain.(*testChain).putTx(tx)
hs := []util.Uint256{random.Uint256(), tx.Hash(), random.Uint256()}
s.testHandleMessage(t, p, CMDInv, &payload.Inventory{
Type: payload.TXType,
Hashes: hs,
})
require.Equal(t, []util.Uint256{hs[0], hs[2]}, actual)
})
}
func TestRequestTx(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDGetData {
actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...)
}
}
s.register <- p
s.register <- p // ensure previous send was handled
t.Run("no hashes, no message", func(t *testing.T) {
actual = nil
s.requestTx()
require.Nil(t, actual)
})
t.Run("good, small", func(t *testing.T) {
actual = nil
expected := []util.Uint256{random.Uint256(), random.Uint256()}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
t.Run("good, exactly one chunk", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxHashesCount)
for i := range expected {
expected[i] = random.Uint256()
}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
t.Run("good, multiple chunks", func(t *testing.T) {
actual = nil
expected := make([]util.Uint256, payload.MaxHashesCount*2+payload.MaxHashesCount/2)
for i := range expected {
expected[i] = random.Uint256()
}
s.requestTx(expected...)
require.Equal(t, expected, actual)
})
}
func TestAddrs(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
ips := make([][16]byte, 4)
copy(ips[0][:], net.IPv4(1, 2, 3, 4))
copy(ips[1][:], net.IPv4(7, 8, 9, 0))
for i := range ips[2] {
ips[2][i] = byte(i)
}
p := newLocalPeer(t, s)
p.handshaked = true
p.getAddrSent = 1
pl := &payload.AddressList{
Addrs: []*payload.AddressAndTime{
{
IP: ips[0],
Capabilities: capability.Capabilities{{
Type: capability.TCPServer,
Data: &capability.Server{Port: 12},
}},
},
{
IP: ips[1],
Capabilities: capability.Capabilities{},
},
{
IP: ips[2],
Capabilities: capability.Capabilities{{
Type: capability.TCPServer,
Data: &capability.Server{Port: 42},
}},
},
},
}
s.testHandleMessage(t, p, CMDAddr, pl)
addrs := s.discovery.(*testDiscovery).backfill
require.Equal(t, 2, len(addrs))
require.Equal(t, "1.2.3.4:12", addrs[0])
require.Equal(t, net.JoinHostPort(net.IP(ips[2][:]).String(), "42"), addrs[1])
t.Run("CMDAddr not requested", func(t *testing.T) {
msg := NewMessage(CMDAddr, pl)
msg.Network = netmode.UnitTestNet
require.Error(t, s.handleMessage(p, msg))
})
}
type feerStub struct {
blockHeight uint32
}
func (f feerStub) FeePerByte() int64 { return 1 }
func (f feerStub) GetUtilityTokenBalance(util.Uint160) *big.Int { return big.NewInt(100000000) }
func (f feerStub) BlockHeight() uint32 { return f.blockHeight }
func (f feerStub) P2PSigExtensionsEnabled() bool { return false }
func TestMemPool(t *testing.T) {
s, shutdown := startTestServer(t)
defer shutdown()
var actual []util.Uint256
p := newLocalPeer(t, s)
p.handshaked = true
p.messageHandler = func(t *testing.T, msg *Message) {
if msg.Command == CMDInv {
actual = append(actual, msg.Payload.(*payload.Inventory).Hashes...)
}
}
bc := s.chain.(*testChain)
expected := make([]util.Uint256, 4)
for i := range expected {
tx := newDummyTx()
require.NoError(t, bc.Pool.Add(tx, &feerStub{blockHeight: 10}))
expected[i] = tx.Hash()
}
s.testHandleMessage(t, p, CMDMempool, payload.NullPayload{})
require.ElementsMatch(t, expected, actual)
}