package network import ( "context" "errors" "fmt" "net" "sync" "sync/atomic" "testing" "time" "github.com/nspcc-dev/neo-go/internal/fakechain" "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" ) type testDiscovery struct { sync.Mutex bad []string connected []string unregistered []string backfill []string } func newTestDiscovery([]string, time.Duration, Transporter) Discoverer { return new(testDiscovery) } func (d *testDiscovery) BackFill(addrs ...string) { d.Lock() defer d.Unlock() d.backfill = append(d.backfill, addrs...) } func (d *testDiscovery) PoolCount() int { return 0 } func (d *testDiscovery) RegisterSelf(p AddressablePeer) { d.Lock() defer d.Unlock() d.bad = append(d.bad, p.ConnectionAddr()) } func (d *testDiscovery) GetFanOut() int { d.Lock() defer d.Unlock() return (len(d.connected) + len(d.backfill)) * 2 / 3 } func (d *testDiscovery) NetworkSize() int { d.Lock() defer d.Unlock() return len(d.connected) + len(d.backfill) } func (d *testDiscovery) RegisterGood(AddressablePeer) {} func (d *testDiscovery) RegisterConnected(p AddressablePeer) { d.Lock() defer d.Unlock() d.connected = append(d.connected, p.ConnectionAddr()) } func (d *testDiscovery) UnregisterConnected(p AddressablePeer, force bool) { d.Lock() defer d.Unlock() d.unregistered = append(d.unregistered, p.ConnectionAddr()) } func (d *testDiscovery) UnconnectedPeers() []string { d.Lock() defer d.Unlock() return d.unregistered } func (d *testDiscovery) RequestRemote(n int) {} func (d *testDiscovery) BadPeers() []string { d.Lock() defer d.Unlock() return d.bad } func (d *testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} } var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { netaddr net.TCPAddr server *Server version *payload.Version lastBlockIndex uint32 handshaked int32 // TODO: use atomic.Bool after #2626. isFullNode bool t *testing.T messageHandler func(t *testing.T, msg *Message) pingSent int getAddrSent int droppedWith atomic.Value } func newLocalPeer(t *testing.T, s *Server) *localPeer { naddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") return &localPeer{ t: t, server: s, netaddr: *naddr, messageHandler: defaultMessageHandler, } } func (p *localPeer) ConnectionAddr() string { return p.netaddr.String() } func (p *localPeer) RemoteAddr() net.Addr { return &p.netaddr } func (p *localPeer) PeerAddr() net.Addr { return &p.netaddr } func (p *localPeer) StartProtocol() {} func (p *localPeer) Disconnect(err error) { if p.droppedWith.Load() == nil { p.droppedWith.Store(err) } fmt.Println("peer dropped:", err) p.server.unregister <- peerDrop{p, err} } func (p *localPeer) BroadcastPacket(_ context.Context, m []byte) error { if len(m) == 0 { return errors.New("empty msg") } msg := &Message{} r := io.NewBinReaderFromBuf(m) for r.Len() > 0 { err := msg.Decode(r) if err == nil { p.messageHandler(p.t, msg) } } return nil } func (p *localPeer) EnqueueP2PMessage(msg *Message) error { return p.EnqueueHPMessage(msg) } func (p *localPeer) EnqueueP2PPacket(m []byte) error { return p.BroadcastPacket(context.TODO(), m) } func (p *localPeer) BroadcastHPPacket(ctx context.Context, m []byte) error { return p.BroadcastPacket(ctx, m) } func (p *localPeer) EnqueueHPMessage(msg *Message) error { p.messageHandler(p.t, msg) return nil } func (p *localPeer) EnqueueHPPacket(m []byte) error { return p.BroadcastPacket(context.TODO(), m) } func (p *localPeer) Version() *payload.Version { return p.version } func (p *localPeer) LastBlockIndex() uint32 { return p.lastBlockIndex } func (p *localPeer) HandleVersion(v *payload.Version) error { p.version = v return nil } func (p *localPeer) SendVersion() error { m, err := p.server.getVersionMsg(nil) if err != nil { return err } _ = p.EnqueueHPMessage(m) return nil } func (p *localPeer) SendVersionAck(m *Message) error { _ = p.EnqueueHPMessage(m) return nil } func (p *localPeer) HandleVersionAck() error { atomic.StoreInt32(&p.handshaked, 1) return nil } func (p *localPeer) SetPingTimer() { p.pingSent++ } func (p *localPeer) HandlePing(ping *payload.Ping) error { p.lastBlockIndex = ping.LastBlockIndex return nil } func (p *localPeer) HandlePong(pong *payload.Ping) error { p.lastBlockIndex = pong.LastBlockIndex p.pingSent-- return nil } func (p *localPeer) Handshaked() bool { return atomic.LoadInt32(&p.handshaked) != 0 } func (p *localPeer) IsFullNode() bool { return p.isFullNode } func (p *localPeer) AddGetAddrSent() { p.getAddrSent++ } func (p *localPeer) CanProcessAddr() bool { p.getAddrSent-- return p.getAddrSent >= 0 } func newTestServer(t *testing.T, serverConfig ServerConfig) *Server { return newTestServerWithCustomCfg(t, serverConfig, nil) } func newTestServerWithCustomCfg(t *testing.T, serverConfig ServerConfig, protocolCfg func(*config.ProtocolConfiguration)) *Server { if len(serverConfig.Addresses) == 0 { // Normally it will be done by ApplicationConfiguration.GetAddresses(). serverConfig.Addresses = []config.AnnounceableAddress{{Address: ":0"}} } s, err := newServerFromConstructors(serverConfig, fakechain.NewFakeChainWithCustomCfg(protocolCfg), new(fakechain.FakeStateSync), zaptest.NewLogger(t), newFakeTransp, newTestDiscovery) require.NoError(t, err) return s }