From 8c5c248e79463698485625dbd91e595367951f2a Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 22 May 2020 12:59:18 +0300 Subject: [PATCH] protocol: add capabilities to address payload Part of #871 --- pkg/network/discovery.go | 31 ++++++++++++++-------- pkg/network/discovery_test.go | 20 +++++++++++++-- pkg/network/helper_test.go | 23 +++++++++-------- pkg/network/payload/address.go | 40 +++++++++++++++++------------ pkg/network/payload/address_test.go | 21 ++++++++++++--- pkg/network/server.go | 9 ++++--- pkg/network/tcp_peer.go | 2 +- 7 files changed, 100 insertions(+), 46 deletions(-) diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index 65b0b8f36..cf21618c1 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -3,6 +3,8 @@ package network import ( "sync" "time" + + "github.com/nspcc-dev/neo-go/pkg/network/capability" ) const ( @@ -18,12 +20,18 @@ type Discoverer interface { PoolCount() int RequestRemote(int) RegisterBadAddr(string) - RegisterGoodAddr(string) + RegisterGoodAddr(string, capability.Capabilities) RegisterConnectedAddr(string) UnregisterConnectedAddr(string) UnconnectedPeers() []string BadPeers() []string - GoodPeers() []string + GoodPeers() []AddressWithCapabilities +} + +// AddressWithCapabilities represents node address with its capabilities +type AddressWithCapabilities struct { + Address string + Capabilities capability.Capabilities } // DefaultDiscovery default implementation of the Discoverer interface. @@ -34,7 +42,7 @@ type DefaultDiscovery struct { dialTimeout time.Duration badAddrs map[string]bool connectedAddrs map[string]bool - goodAddrs map[string]bool + goodAddrs map[string]capability.Capabilities unconnectedAddrs map[string]int isDead bool requestCh chan int @@ -48,7 +56,7 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { dialTimeout: dt, badAddrs: make(map[string]bool), connectedAddrs: make(map[string]bool), - goodAddrs: make(map[string]bool), + goodAddrs: make(map[string]capability.Capabilities), unconnectedAddrs: make(map[string]int), requestCh: make(chan int), pool: make(chan string, maxPoolSize), @@ -135,11 +143,14 @@ func (d *DefaultDiscovery) BadPeers() []string { // GoodPeers returns all addresses of known good peers (that at least once // succeeded handshaking with us). -func (d *DefaultDiscovery) GoodPeers() []string { +func (d *DefaultDiscovery) GoodPeers() []AddressWithCapabilities { d.lock.RLock() - addrs := make([]string, 0, len(d.goodAddrs)) - for addr := range d.goodAddrs { - addrs = append(addrs, addr) + addrs := make([]AddressWithCapabilities, 0, len(d.goodAddrs)) + for addr, cap := range d.goodAddrs { + addrs = append(addrs, AddressWithCapabilities{ + Address: addr, + Capabilities: cap, + }) } d.lock.RUnlock() return addrs @@ -147,9 +158,9 @@ func (d *DefaultDiscovery) GoodPeers() []string { // RegisterGoodAddr registers good known connected address that passed // handshake successfully. -func (d *DefaultDiscovery) RegisterGoodAddr(s string) { +func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities) { d.lock.Lock() - d.goodAddrs[s] = true + d.goodAddrs[s] = c d.lock.Unlock() } diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 81469def2..94af43d83 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -79,9 +80,24 @@ func TestDefaultDiscoverer(t *testing.T) { // Registered good addresses should end up in appropriate set. for _, addr := range set1 { - d.RegisterGoodAddr(addr) + d.RegisterGoodAddr(addr, capability.Capabilities{ + { + Type: capability.FullNode, + Data: &capability.Node{StartHeight: 123}, + }, + }) + } + gAddrWithCap := d.GoodPeers() + gAddrs := make([]string, len(gAddrWithCap)) + for i, addr := range gAddrWithCap { + require.Equal(t, capability.Capabilities{ + { + Type: capability.FullNode, + Data: &capability.Node{StartHeight: 123}, + }, + }, addr.Capabilities) + gAddrs[i] = addr.Address } - gAddrs := d.GoodPeers() sort.Strings(gAddrs) assert.Equal(t, 0, d.PoolCount()) assert.Equal(t, 0, len(d.UnconnectedPeers())) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 9bb54cb6b..1ab6acf6c 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -151,17 +152,17 @@ func (chain testChain) VerifyTx(*transaction.Transaction, *block.Block) error { type testDiscovery struct{} -func (d testDiscovery) BackFill(addrs ...string) {} -func (d testDiscovery) Close() {} -func (d testDiscovery) PoolCount() int { return 0 } -func (d testDiscovery) RegisterBadAddr(string) {} -func (d testDiscovery) RegisterGoodAddr(string) {} -func (d testDiscovery) RegisterConnectedAddr(string) {} -func (d testDiscovery) UnregisterConnectedAddr(string) {} -func (d testDiscovery) UnconnectedPeers() []string { return []string{} } -func (d testDiscovery) RequestRemote(n int) {} -func (d testDiscovery) BadPeers() []string { return []string{} } -func (d testDiscovery) GoodPeers() []string { return []string{} } +func (d testDiscovery) BackFill(addrs ...string) {} +func (d testDiscovery) Close() {} +func (d testDiscovery) PoolCount() int { return 0 } +func (d testDiscovery) RegisterBadAddr(string) {} +func (d testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} +func (d testDiscovery) RegisterConnectedAddr(string) {} +func (d testDiscovery) UnregisterConnectedAddr(string) {} +func (d testDiscovery) UnconnectedPeers() []string { return []string{} } +func (d testDiscovery) RequestRemote(n int) {} +func (d testDiscovery) BadPeers() []string { return []string{} } +func (d testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} } var defaultMessageHandler = func(t *testing.T, msg *Message) {} diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index e3d34ae92..6608b2cf6 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -1,27 +1,27 @@ package payload import ( + "errors" "net" "strconv" "time" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/network/capability" ) // AddressAndTime payload. type AddressAndTime struct { - Timestamp uint32 - Services uint64 - IP [16]byte - Port uint16 + Timestamp uint32 + IP [16]byte + Capabilities capability.Capabilities } // NewAddressAndTime creates a new AddressAndTime object. -func NewAddressAndTime(e *net.TCPAddr, t time.Time) *AddressAndTime { +func NewAddressAndTime(e *net.TCPAddr, t time.Time, c capability.Capabilities) *AddressAndTime { aat := AddressAndTime{ - Timestamp: uint32(t.UTC().Unix()), - Services: 1, - Port: uint16(e.Port), + Timestamp: uint32(t.UTC().Unix()), + Capabilities: c, } copy(aat.IP[:], e.IP) return &aat @@ -30,26 +30,34 @@ func NewAddressAndTime(e *net.TCPAddr, t time.Time) *AddressAndTime { // DecodeBinary implements Serializable interface. func (p *AddressAndTime) DecodeBinary(br *io.BinReader) { p.Timestamp = br.ReadU32LE() - p.Services = br.ReadU64LE() br.ReadBytes(p.IP[:]) - p.Port = br.ReadU16BE() + p.Capabilities.DecodeBinary(br) } // EncodeBinary implements Serializable interface. func (p *AddressAndTime) EncodeBinary(bw *io.BinWriter) { bw.WriteU32LE(p.Timestamp) - bw.WriteU64LE(p.Services) bw.WriteBytes(p.IP[:]) - bw.WriteU16BE(p.Port) + p.Capabilities.EncodeBinary(bw) } -// IPPortString makes a string from IP and port specified. -func (p *AddressAndTime) IPPortString() string { +// GetTCPAddress makes a string from IP and port specified in TCPCapability. +// It returns an error if there's no such capability. +func (p *AddressAndTime) GetTCPAddress() (string, error) { var netip = make(net.IP, 16) copy(netip, p.IP[:]) - port := strconv.Itoa(int(p.Port)) - return netip.String() + ":" + port + port := -1 + for _, cap := range p.Capabilities { + if cap.Type == capability.TCPServer { + port = int(cap.Data.(*capability.Server).Port) + break + } + } + if port == -1 { + return "", errors.New("no TCP capability found") + } + return net.JoinHostPort(netip.String(), strconv.Itoa(port)), nil } // AddressList is a list with AddrAndTime. diff --git a/pkg/network/payload/address_test.go b/pkg/network/payload/address_test.go index 30a79d02f..6aa29414c 100644 --- a/pkg/network/payload/address_test.go +++ b/pkg/network/payload/address_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" ) @@ -14,14 +15,23 @@ func TestEncodeDecodeAddress(t *testing.T) { var ( e, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:2000") ts = time.Now() - addr = NewAddressAndTime(e, ts) + addr = NewAddressAndTime(e, ts, capability.Capabilities{ + { + Type: capability.TCPServer, + Data: &capability.Server{Port: uint16(e.Port)}, + }, + }) ) assert.Equal(t, ts.UTC().Unix(), int64(addr.Timestamp)) aatip := make(net.IP, 16) copy(aatip, addr.IP[:]) assert.Equal(t, e.IP, aatip) - assert.Equal(t, e.Port, int(addr.Port)) + assert.Equal(t, 1, len(addr.Capabilities)) + assert.Equal(t, capability.Capability{ + Type: capability.TCPServer, + Data: &capability.Server{Port: uint16(e.Port)}, + }, addr.Capabilities[0]) testserdes.EncodeDecodeBinary(t, addr, new(AddressAndTime)) } @@ -31,7 +41,12 @@ func TestEncodeDecodeAddressList(t *testing.T) { addrList := NewAddressList(int(lenList)) for i := 0; i < int(lenList); i++ { e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i)) - addrList.Addrs[i] = NewAddressAndTime(e, time.Now()) + addrList.Addrs[i] = NewAddressAndTime(e, time.Now(), capability.Capabilities{ + { + Type: capability.TCPServer, + Data: &capability.Server{Port: 123}, + }, + }) } testserdes.EncodeDecodeBinary(t, addrList, new(AddressList)) diff --git a/pkg/network/server.go b/pkg/network/server.go index 8606a7506..6830c4612 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -642,7 +642,10 @@ func (s *Server) handleTxCmd(tx *transaction.Transaction) error { // handleAddrCmd will process received addresses. func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error { for _, a := range addrs.Addrs { - s.discovery.BackFill(a.IPPortString()) + addr, err := a.GetTCPAddress() + if err != nil { + s.discovery.BackFill(addr) + } } return nil } @@ -657,8 +660,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error { ts := time.Now() for i, addr := range addrs { // we know it's a good address, so it can't fail - netaddr, _ := net.ResolveTCPAddr("tcp", addr) - alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) + netaddr, _ := net.ResolveTCPAddr("tcp", addr.Address) + alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts, addr.Capabilities) } return p.EnqueueP2PMessage(NewMessage(CMDAddr, alist)) } diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 44fa9f13b..a6bd9bb41 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -234,7 +234,7 @@ func (p *TCPPeer) StartProtocol() { zap.Uint32("startHeight", p.lastBlockIndex), zap.Uint32("id", p.Version().Nonce)) - p.server.discovery.RegisterGoodAddr(p.PeerAddr().String()) + p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) if p.server.chain.HeaderHeight() < p.LastBlockIndex() { err = p.server.requestHeaders(p) if err != nil {