mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2025-01-10 15:54:05 +00:00
1b83dc2476
Mostly it's about Go 1.22+ syntax with ranging over integers, but it also prefers ranging over slices where possible (it makes code a little better to read). Notice that we have a number of dangerous loops where slices are mutated during loop execution, many of these can't be converted since we need proper length evalutation at every iteration. Signed-off-by: Roman Khimov <roman@nspcc.ru>
234 lines
6 KiB
Go
234 lines
6 KiB
Go
package network
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"slices"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/nspcc-dev/neo-go/pkg/network/capability"
|
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type fakeTransp struct {
|
|
retFalse atomic.Int32
|
|
started atomic.Bool
|
|
closed atomic.Bool
|
|
dialCh chan string
|
|
host string
|
|
port string
|
|
}
|
|
|
|
type fakeAPeer struct {
|
|
addr string
|
|
peer string
|
|
version *payload.Version
|
|
}
|
|
|
|
func (f *fakeAPeer) ConnectionAddr() string {
|
|
return f.addr
|
|
}
|
|
|
|
func (f *fakeAPeer) PeerAddr() net.Addr {
|
|
tcpAddr, err := net.ResolveTCPAddr("tcp", f.peer)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return tcpAddr
|
|
}
|
|
|
|
func (f *fakeAPeer) Version() *payload.Version {
|
|
return f.version
|
|
}
|
|
|
|
func newFakeTransp(s *Server, addr string) Transporter {
|
|
tr := &fakeTransp{}
|
|
h, p, err := net.SplitHostPort(addr)
|
|
if err == nil {
|
|
tr.host = h
|
|
tr.port = p
|
|
}
|
|
return tr
|
|
}
|
|
|
|
func (ft *fakeTransp) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
|
|
var ret error
|
|
if ft.retFalse.Load() > 0 {
|
|
ret = errors.New("smth bad happened")
|
|
}
|
|
ft.dialCh <- addr
|
|
|
|
return &fakeAPeer{addr: addr, peer: addr}, ret
|
|
}
|
|
func (ft *fakeTransp) Accept() {
|
|
if ft.started.Load() {
|
|
panic("started twice")
|
|
}
|
|
ft.host = "0.0.0.0"
|
|
ft.port = "42"
|
|
ft.started.Store(true)
|
|
}
|
|
func (ft *fakeTransp) Proto() string {
|
|
return ""
|
|
}
|
|
func (ft *fakeTransp) HostPort() (string, string) {
|
|
return ft.host, ft.port
|
|
}
|
|
func (ft *fakeTransp) Close() {
|
|
if ft.closed.Load() {
|
|
panic("closed twice")
|
|
}
|
|
ft.closed.Store(true)
|
|
}
|
|
func TestDefaultDiscoverer(t *testing.T) {
|
|
ts := &fakeTransp{}
|
|
ts.dialCh = make(chan string)
|
|
d := NewDefaultDiscovery(nil, time.Second/16, ts)
|
|
|
|
tryMaxWait = 1 // Don't waste time.
|
|
var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
|
|
slices.Sort(set1)
|
|
|
|
// Added addresses should end up in the pool and in the unconnected set.
|
|
// Done twice to check re-adding unconnected addresses, which should be
|
|
// a no-op.
|
|
for range 2 {
|
|
d.BackFill(set1...)
|
|
assert.Equal(t, len(set1), d.PoolCount())
|
|
set1D := d.UnconnectedPeers()
|
|
slices.Sort(set1D)
|
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
require.Equal(t, set1, set1D)
|
|
}
|
|
require.Equal(t, 2, d.GetFanOut())
|
|
|
|
// Request should make goroutines dial our addresses draining the pool.
|
|
d.RequestRemote(len(set1))
|
|
dialled := make([]string, 0)
|
|
for range set1 {
|
|
select {
|
|
case a := <-ts.dialCh:
|
|
dialled = append(dialled, a)
|
|
d.RegisterConnected(&fakeAPeer{addr: a, peer: a})
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout expecting for transport dial")
|
|
}
|
|
}
|
|
require.Eventually(t, func() bool { return len(d.UnconnectedPeers()) == 0 }, 2*time.Second, 50*time.Millisecond)
|
|
slices.Sort(dialled)
|
|
assert.Equal(t, 0, d.PoolCount())
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
|
require.Equal(t, set1, dialled)
|
|
|
|
// Registered good addresses should end up in appropriate set.
|
|
for _, addr := range set1 {
|
|
d.RegisterGood(&fakeAPeer{
|
|
addr: addr,
|
|
peer: addr,
|
|
version: &payload.Version{
|
|
Capabilities: capability.Capabilities{{
|
|
Type: capability.FullNode,
|
|
Data: &capability.Node{StartHeight: 123},
|
|
}},
|
|
},
|
|
})
|
|
}
|
|
gAddrWithCap := d.GoodPeers()
|
|
gAddrs := make([]string, len(gAddrWithCap))
|
|
for i, addr := range gAddrWithCap {
|
|
require.Equal(t, capability.Capabilities{
|
|
{
|
|
Type: capability.FullNode,
|
|
Data: &capability.Node{StartHeight: 123},
|
|
},
|
|
}, addr.Capabilities)
|
|
gAddrs[i] = addr.Address
|
|
}
|
|
slices.Sort(gAddrs)
|
|
assert.Equal(t, 0, d.PoolCount())
|
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
require.Equal(t, set1, gAddrs)
|
|
|
|
// Re-adding connected addresses should be no-op.
|
|
d.BackFill(set1...)
|
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
|
require.Equal(t, 0, d.PoolCount())
|
|
|
|
// Unregistering connected should work.
|
|
for _, addr := range set1 {
|
|
d.UnregisterConnected(&fakeAPeer{addr: addr, peer: addr}, false)
|
|
}
|
|
assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically.
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
assert.Equal(t, len(set1), len(d.GoodPeers()))
|
|
require.Equal(t, 2, d.PoolCount())
|
|
|
|
// Now make Dial() fail and wait to see addresses in the bad list.
|
|
ts.retFalse.Store(1)
|
|
assert.Equal(t, len(set1), d.PoolCount())
|
|
set1D := d.UnconnectedPeers()
|
|
slices.Sort(set1D)
|
|
assert.Equal(t, 0, len(d.BadPeers()))
|
|
require.Equal(t, set1, set1D)
|
|
|
|
dialledBad := make([]string, 0)
|
|
d.RequestRemote(len(set1))
|
|
for i := range connRetries {
|
|
for j := range set1 {
|
|
select {
|
|
case a := <-ts.dialCh:
|
|
dialledBad = append(dialledBad, a)
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j)
|
|
}
|
|
}
|
|
}
|
|
require.Eventually(t, func() bool { return d.PoolCount() == 0 }, 2*time.Second, 50*time.Millisecond)
|
|
slices.Sort(dialledBad)
|
|
for i := range set1 {
|
|
for j := range connRetries {
|
|
assert.Equal(t, set1[i], dialledBad[i*connRetries+j])
|
|
}
|
|
}
|
|
require.Eventually(t, func() bool { return len(d.BadPeers()) == len(set1) }, 2*time.Second, 50*time.Millisecond)
|
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
|
|
|
// Re-adding bad addresses is a no-op.
|
|
d.BackFill(set1...)
|
|
assert.Equal(t, 0, len(d.UnconnectedPeers()))
|
|
assert.Equal(t, len(set1), len(d.BadPeers()))
|
|
assert.Equal(t, 0, len(d.GoodPeers()))
|
|
require.Equal(t, 0, d.PoolCount())
|
|
}
|
|
|
|
func TestSeedDiscovery(t *testing.T) {
|
|
var seeds = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
|
|
ts := &fakeTransp{}
|
|
ts.dialCh = make(chan string)
|
|
ts.retFalse.Store(1) // Fail all dial requests.
|
|
slices.Sort(seeds)
|
|
|
|
d := NewDefaultDiscovery(seeds, time.Second/10, ts)
|
|
tryMaxWait = 1 // Don't waste time.
|
|
|
|
d.RequestRemote(len(seeds))
|
|
for range connRetries * 2 {
|
|
for range seeds {
|
|
select {
|
|
case <-ts.dialCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatalf("timeout expecting for transport dial")
|
|
}
|
|
}
|
|
}
|
|
}
|