From c590cc02f45ebaa5fcf7f79982a7e1eaa4154587 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 22 May 2020 12:17:17 +0300 Subject: [PATCH 1/3] protocol: add capabilities to version payload closes #871 --- pkg/network/capability/capability.go | 111 +++++++++++++++++++++++++++ pkg/network/capability/type.go | 13 ++++ pkg/network/discovery_test.go | 3 + pkg/network/helper_test.go | 31 ++++---- pkg/network/payload/version.go | 64 +++++---------- pkg/network/payload/version_test.go | 31 ++++++-- pkg/network/peer.go | 1 + pkg/network/server.go | 41 ++++++++-- pkg/network/server_test.go | 67 ++++++++++++---- pkg/network/tcp_peer.go | 49 ++++++++++-- pkg/network/tcp_peer_test.go | 4 +- pkg/network/tcp_transport.go | 14 ++++ pkg/network/transport.go | 1 + pkg/rpc/client/rpc_test.go | 2 +- pkg/rpc/response/result/version.go | 2 +- 15 files changed, 333 insertions(+), 101 deletions(-) create mode 100644 pkg/network/capability/capability.go create mode 100644 pkg/network/capability/type.go diff --git a/pkg/network/capability/capability.go b/pkg/network/capability/capability.go new file mode 100644 index 000000000..ff5380aa7 --- /dev/null +++ b/pkg/network/capability/capability.go @@ -0,0 +1,111 @@ +package capability + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MaxCapabilities is the maximum number of capabilities per payload +const MaxCapabilities = 32 + +// Capabilities is a list of Capability +type Capabilities []Capability + +// DecodeBinary implements Serializable interface. +func (cs *Capabilities) DecodeBinary(br *io.BinReader) { + br.ReadArray(cs, MaxCapabilities) + br.Err = cs.checkUniqueCapabilities() +} + +// EncodeBinary implements Serializable interface. +func (cs *Capabilities) EncodeBinary(br *io.BinWriter) { + br.WriteArray(*cs) +} + +// checkUniqueCapabilities checks whether payload capabilities have unique type. +func (cs Capabilities) checkUniqueCapabilities() error { + err := errors.New("capabilities with the same type are not allowed") + var isFullNode, isTCP, isWS bool + for _, cap := range cs { + switch cap.Type { + case FullNode: + if isFullNode { + return err + } + isFullNode = true + case TCPServer: + if isTCP { + return err + } + isTCP = true + case WSServer: + if isWS { + return err + } + isWS = true + } + } + return nil +} + +// Capability describes network service available for node +type Capability struct { + Type Type + Data io.Serializable +} + +// DecodeBinary implements Serializable interface. +func (c *Capability) DecodeBinary(br *io.BinReader) { + c.Type = Type(br.ReadB()) + switch c.Type { + case FullNode: + c.Data = &Node{} + case TCPServer, WSServer: + c.Data = &Server{} + default: + br.Err = errors.New("unknown node capability type") + } + c.Data.DecodeBinary(br) +} + +// EncodeBinary implements Serializable interface. +func (c *Capability) EncodeBinary(bw *io.BinWriter) { + if c.Data == nil { + bw.Err = errors.New("capability has no data") + return + } + bw.WriteB(byte(c.Type)) + c.Data.EncodeBinary(bw) +} + +// Node represents full node capability with start height +type Node struct { + StartHeight uint32 +} + +// DecodeBinary implements Serializable interface. +func (n *Node) DecodeBinary(br *io.BinReader) { + n.StartHeight = br.ReadU32LE() +} + +// EncodeBinary implements Serializable interface. +func (n *Node) EncodeBinary(bw *io.BinWriter) { + bw.WriteU32LE(n.StartHeight) +} + +// Server represents TCP or WS server capability with port +type Server struct { + // Port is the port this server is listening on + Port uint16 +} + +// DecodeBinary implements Serializable interface. +func (s *Server) DecodeBinary(br *io.BinReader) { + s.Port = br.ReadU16LE() +} + +// EncodeBinary implements Serializable interface. +func (s *Server) EncodeBinary(bw *io.BinWriter) { + bw.WriteU16LE(s.Port) +} diff --git a/pkg/network/capability/type.go b/pkg/network/capability/type.go new file mode 100644 index 000000000..b25b15397 --- /dev/null +++ b/pkg/network/capability/type.go @@ -0,0 +1,13 @@ +package capability + +// Type represents node capability type +type Type byte + +const ( + // TCPServer represents TCP node capability type + TCPServer Type = 0x01 + // WSServer represents WebSocket node capability type + WSServer Type = 0x02 + // FullNode represents full node capability type + FullNode Type = 0x10 +) diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 9105cdc0b..81469def2 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -29,6 +29,9 @@ func (ft *fakeTransp) Accept() { func (ft *fakeTransp) Proto() string { return "" } +func (ft *fakeTransp) Address() string { + return "" +} func (ft *fakeTransp) Close() { } func TestDefaultDiscoverer(t *testing.T) { diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 411afc88d..9bb54cb6b 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -3,9 +3,9 @@ package network import ( "math/rand" "net" + "strconv" "sync/atomic" "testing" - "time" "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/core/block" @@ -163,15 +163,6 @@ func (d testDiscovery) RequestRemote(n int) {} func (d testDiscovery) BadPeers() []string { return []string{} } func (d testDiscovery) GoodPeers() []string { return []string{} } -type localTransport struct{} - -func (t localTransport) Dial(addr string, timeout time.Duration) error { - return nil -} -func (t localTransport) Accept() {} -func (t localTransport) Proto() string { return "local" } -func (t localTransport) Close() {} - var defaultMessageHandler = func(t *testing.T, msg *Message) {} type localPeer struct { @@ -180,6 +171,7 @@ type localPeer struct { version *payload.Version lastBlockIndex uint32 handshaked bool + isFullNode bool t *testing.T messageHandler func(t *testing.T, msg *Message) pingSent int @@ -240,7 +232,10 @@ func (p *localPeer) HandleVersion(v *payload.Version) error { return nil } func (p *localPeer) SendVersion() error { - m := p.server.getVersionMsg() + m, err := p.server.getVersionMsg() + if err != nil { + return err + } _ = p.EnqueueMessage(m) return nil } @@ -267,11 +262,14 @@ func (p *localPeer) Handshaked() bool { return p.handshaked } -func newTestServer(t *testing.T) *Server { - return &Server{ - ServerConfig: ServerConfig{}, +func (p *localPeer) IsFullNode() bool { + return p.isFullNode +} + +func newTestServer(t *testing.T, serverConfig ServerConfig) *Server { + s := &Server{ + ServerConfig: serverConfig, chain: &testChain{}, - transport: localTransport{}, discovery: testDiscovery{}, id: rand.Uint32(), quit: make(chan struct{}), @@ -280,5 +278,6 @@ func newTestServer(t *testing.T) *Server { peers: make(map[Peer]bool), log: zaptest.NewLogger(t), } - + s.transport = NewTCPTransport(s, net.JoinHostPort(s.ServerConfig.Address, strconv.Itoa(int(s.ServerConfig.Port))), s.log) + return s } diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 86be2b80c..b0747284c 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -5,19 +5,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/io" -) - -// Size of the payload not counting UserAgent encoding (which is at least 1 byte -// for zero-length string). -const minVersionSize = 27 - -// List of Services offered by the node. -const ( - nodePeerService uint64 = 1 - // BloomFilerService uint64 = 2 // Not implemented - // PrunedNode uint64 = 3 // Not implemented - // LightNode uint64 = 4 // Not implemented - + "github.com/nspcc-dev/neo-go/pkg/network/capability" ) // Version payload. @@ -26,34 +14,25 @@ type Version struct { Magic config.NetMode // currently the version of the protocol is 0 Version uint32 - // currently 1 - Services uint64 // timestamp Timestamp uint32 - // port this server is listening on - Port uint16 // it's used to distinguish the node from public IP Nonce uint32 // client id UserAgent []byte - // Height of the block chain - StartHeight uint32 - // Whether to receive and forward - Relay bool + // List of available network services + Capabilities capability.Capabilities } // NewVersion returns a pointer to a Version payload. -func NewVersion(magic config.NetMode, id uint32, p uint16, ua string, h uint32, r bool) *Version { +func NewVersion(magic config.NetMode, id uint32, ua string, c []capability.Capability) *Version { return &Version{ - Magic: magic, - Version: 0, - Services: nodePeerService, - Timestamp: uint32(time.Now().UTC().Unix()), - Port: p, - Nonce: id, - UserAgent: []byte(ua), - StartHeight: h, - Relay: r, + Magic: magic, + Version: 0, + Timestamp: uint32(time.Now().UTC().Unix()), + Nonce: id, + UserAgent: []byte(ua), + Capabilities: c, } } @@ -61,25 +40,18 @@ func NewVersion(magic config.NetMode, id uint32, p uint16, ua string, h uint32, func (p *Version) DecodeBinary(br *io.BinReader) { p.Magic = config.NetMode(br.ReadU32LE()) p.Version = br.ReadU32LE() - p.Services = br.ReadU64LE() p.Timestamp = br.ReadU32LE() - p.Port = br.ReadU16LE() p.Nonce = br.ReadU32LE() p.UserAgent = br.ReadVarBytes() - p.StartHeight = br.ReadU32LE() - p.Relay = br.ReadBool() + p.Capabilities.DecodeBinary(br) } // EncodeBinary implements Serializable interface. -func (p *Version) EncodeBinary(br *io.BinWriter) { - br.WriteU32LE(uint32(p.Magic)) - br.WriteU32LE(p.Version) - br.WriteU64LE(p.Services) - br.WriteU32LE(p.Timestamp) - br.WriteU16LE(p.Port) - br.WriteU32LE(p.Nonce) - - br.WriteVarBytes(p.UserAgent) - br.WriteU32LE(p.StartHeight) - br.WriteBool(p.Relay) +func (p *Version) EncodeBinary(bw *io.BinWriter) { + bw.WriteU32LE(uint32(p.Magic)) + bw.WriteU32LE(p.Version) + bw.WriteU32LE(p.Timestamp) + bw.WriteU32LE(p.Nonce) + bw.WriteVarBytes(p.UserAgent) + p.Capabilities.EncodeBinary(bw) } diff --git a/pkg/network/payload/version_test.go b/pkg/network/payload/version_test.go index c8c8ccad6..09cd9430b 100644 --- a/pkg/network/payload/version_test.go +++ b/pkg/network/payload/version_test.go @@ -5,25 +5,44 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" ) func TestVersionEncodeDecode(t *testing.T) { var magic config.NetMode = 56753 - var port uint16 = 3000 + var tcpPort uint16 = 3000 + var wsPort uint16 = 3001 var id uint32 = 13337 useragent := "/NEO:0.0.1/" var height uint32 = 100500 - var relay = true + var capabilities = []capability.Capability{ + { + Type: capability.TCPServer, + Data: &capability.Server{ + Port: tcpPort, + }, + }, + { + Type: capability.WSServer, + Data: &capability.Server{ + Port: wsPort, + }, + }, + { + Type: capability.FullNode, + Data: &capability.Node{ + StartHeight: height, + }, + }, + } - version := NewVersion(magic, id, port, useragent, height, relay) + version := NewVersion(magic, id, useragent, capabilities) versionDecoded := &Version{} testserdes.EncodeDecodeBinary(t, version, versionDecoded) assert.Equal(t, versionDecoded.Nonce, id) - assert.Equal(t, versionDecoded.Port, port) + assert.ElementsMatch(t, capabilities, versionDecoded.Capabilities) assert.Equal(t, versionDecoded.UserAgent, []byte(useragent)) - assert.Equal(t, versionDecoded.StartHeight, height) - assert.Equal(t, versionDecoded.Relay, relay) assert.Equal(t, version, versionDecoded) } diff --git a/pkg/network/peer.go b/pkg/network/peer.go index b40751251..873347b4a 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -51,6 +51,7 @@ type Peer interface { Version() *payload.Version LastBlockIndex() uint32 Handshaked() bool + IsFullNode() bool // SendPing enqueues a ping message to be sent to the peer and does // appropriate protocol handling like timeouts and outstanding pings diff --git a/pkg/network/server.go b/pkg/network/server.go index ff99d3d9a..8606a7506 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -15,6 +15,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/blockchainer" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" "go.uber.org/atomic" @@ -346,16 +347,42 @@ func (s *Server) HandshakedPeersCount() int { } // getVersionMsg returns current version message. -func (s *Server) getVersionMsg() *Message { +func (s *Server) getVersionMsg() (*Message, error) { + var port uint16 + _, portStr, err := net.SplitHostPort(s.transport.Address()) + if err != nil { + port = s.Port + } else { + p, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + port = uint16(p) + } + + capabilities := []capability.Capability{ + { + Type: capability.TCPServer, + Data: &capability.Server{ + Port: port, + }, + }, + } + if s.Relay { + capabilities = append(capabilities, capability.Capability{ + Type: capability.FullNode, + Data: &capability.Node{ + StartHeight: s.chain.BlockHeight(), + }, + }) + } payload := payload.NewVersion( s.Net, s.id, - s.Port, s.UserAgent, - s.chain.BlockHeight(), - s.Relay, + capabilities, ) - return NewMessage(CMDVersion, payload) + return NewMessage(CMDVersion, payload), nil } // IsInSync answers the question of whether the server is in sync with the @@ -835,9 +862,7 @@ func (s *Server) broadcastTxHashes(hs []util.Uint256) { // We need to filter out non-relaying nodes, so plain broadcast // functions don't fit here. - s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, func(p Peer) bool { - return p.Handshaked() && p.Version().Relay - }) + s.iteratePeersWithSendMsg(msg, Peer.EnqueuePacket, Peer.IsFullNode) } // broadcastTxLoop is a loop for batching and sending diff --git a/pkg/network/server_test.go b/pkg/network/server_test.go index 19ed92b1d..6c783f53f 100644 --- a/pkg/network/server_test.go +++ b/pkg/network/server_test.go @@ -2,8 +2,11 @@ package network import ( "net" + "strconv" "testing" + "time" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -11,22 +14,33 @@ import ( func TestSendVersion(t *testing.T) { var ( - s = newTestServer(t) + s = newTestServer(t, ServerConfig{Port: 0, UserAgent: "/test/"}) p = newLocalPeer(t, s) ) - s.Port = 3000 - s.UserAgent = "/test/" - + // we need to set listener at least to handle dynamic port correctly + go s.transport.Accept() + require.Eventually(t, func() bool { return s.transport.Address() != "" }, time.Second, 10*time.Millisecond) p.messageHandler = func(t *testing.T, msg *Message) { + // listener is already set, so Address() gives us proper address with port + _, p, err := net.SplitHostPort(s.transport.Address()) + assert.NoError(t, err) + port, err := strconv.ParseUint(p, 10, 16) + assert.NoError(t, err) assert.Equal(t, CMDVersion, msg.Command) assert.IsType(t, msg.Payload, &payload.Version{}) version := msg.Payload.(*payload.Version) assert.NotZero(t, version.Nonce) - assert.Equal(t, uint16(3000), version.Port) - assert.Equal(t, uint64(1), version.Services) + assert.Equal(t, 1, len(version.Capabilities)) + assert.ElementsMatch(t, []capability.Capability{ + { + Type: capability.TCPServer, + Data: &capability.Server{ + Port: uint16(port), + }, + }, + }, version.Capabilities) assert.Equal(t, uint32(0), version.Version) assert.Equal(t, []byte("/test/"), version.UserAgent) - assert.Equal(t, uint32(0), version.StartHeight) } require.NoError(t, p.SendVersion()) @@ -35,7 +49,7 @@ func TestSendVersion(t *testing.T) { // Server should reply with a verack after receiving a valid version. func TestVerackAfterHandleVersionCmd(t *testing.T) { var ( - s = newTestServer(t) + s = newTestServer(t, ServerConfig{}) p = newLocalPeer(t, s) ) na, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:3000") @@ -45,7 +59,21 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { p.messageHandler = func(t *testing.T, msg *Message) { assert.Equal(t, CMDVerack, msg.Command) } - version := payload.NewVersion(0, 1337, 3000, "/NEO-GO/", 0, true) + capabilities := []capability.Capability{ + { + Type: capability.FullNode, + Data: &capability.Node{ + StartHeight: 0, + }, + }, + { + Type: capability.TCPServer, + Data: &capability.Server{ + Port: 3000, + }, + }, + } + version := payload.NewVersion(0, 1337, "/NEO-GO/", capabilities) require.NoError(t, s.handleVersionCmd(p, version)) } @@ -54,12 +82,11 @@ func TestVerackAfterHandleVersionCmd(t *testing.T) { // invalid version and disconnects the peer. func TestServerNotSendsVerack(t *testing.T) { var ( - s = newTestServer(t) + s = newTestServer(t, ServerConfig{Net: 56753}) p = newLocalPeer(t, s) p2 = newLocalPeer(t, s) ) s.id = 1 - s.Net = 56753 finished := make(chan struct{}) go func() { s.run() @@ -76,8 +103,22 @@ func TestServerNotSendsVerack(t *testing.T) { p2.netaddr = *na s.register <- p + capabilities := []capability.Capability{ + { + Type: capability.FullNode, + Data: &capability.Node{ + StartHeight: 0, + }, + }, + { + Type: capability.TCPServer, + Data: &capability.Server{ + Port: 3000, + }, + }, + } // identical id's - version := payload.NewVersion(56753, 1, 3000, "/NEO-GO/", 0, true) + version := payload.NewVersion(56753, 1, "/NEO-GO/", capabilities) err := s.handleVersionCmd(p, version) assert.NotNil(t, err) assert.Equal(t, errIdenticalID, err) @@ -104,7 +145,7 @@ func TestServerNotSendsVerack(t *testing.T) { func TestRequestHeaders(t *testing.T) { var ( - s = newTestServer(t) + s = newTestServer(t, ServerConfig{}) p = newLocalPeer(t, s) ) p.messageHandler = func(t *testing.T, msg *Message) { diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index c817eb338..44fa9f13b 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -9,6 +9,7 @@ import ( "time" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "go.uber.org/zap" ) @@ -45,9 +46,10 @@ type TCPPeer struct { // Index of the last block. lastBlockIndex uint32 - lock sync.RWMutex - finale sync.Once - handShake handShakeStage + lock sync.RWMutex + finale sync.Once + handShake handShakeStage + isFullNode bool done chan struct{} sendQ chan []byte @@ -229,7 +231,7 @@ func (p *TCPPeer) StartProtocol() { p.server.log.Info("started protocol", zap.Stringer("addr", p.RemoteAddr()), zap.ByteString("userAgent", p.Version().UserAgent), - zap.Uint32("startHeight", p.Version().StartHeight), + zap.Uint32("startHeight", p.lastBlockIndex), zap.Uint32("id", p.Version().Nonce)) p.server.discovery.RegisterGoodAddr(p.PeerAddr().String()) @@ -267,18 +269,33 @@ func (p *TCPPeer) StartProtocol() { func (p *TCPPeer) Handshaked() bool { p.lock.RLock() defer p.lock.RUnlock() + return p.handshaked() +} + +// handshaked is internal unlocked version of Handshaked(). +func (p *TCPPeer) handshaked() bool { return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent) } +// IsFullNode returns whether the node has full capability or TCP/WS only. +func (p *TCPPeer) IsFullNode() bool { + p.lock.RLock() + defer p.lock.RUnlock() + return p.handshaked() && p.isFullNode +} + // SendVersion checks for the handshake state and sends a message to the peer. func (p *TCPPeer) SendVersion() error { - msg := p.server.getVersionMsg() + msg, err := p.server.getVersionMsg() + if err != nil { + return err + } p.lock.Lock() defer p.lock.Unlock() if p.handShake&versionSent != 0 { return errors.New("invalid handshake: already sent Version") } - err := p.writeMsg(msg) + err = p.writeMsg(msg) if err == nil { p.handShake |= versionSent } @@ -293,7 +310,14 @@ func (p *TCPPeer) HandleVersion(version *payload.Version) error { return errors.New("invalid handshake: already received Version") } p.version = version - p.lastBlockIndex = version.StartHeight + for _, cap := range version.Capabilities { + if cap.Type == capability.FullNode { + p.isFullNode = true + p.lastBlockIndex = cap.Data.(*capability.Node).StartHeight + break + } + } + p.handShake |= versionReceived return nil } @@ -352,7 +376,16 @@ func (p *TCPPeer) PeerAddr() net.Addr { if err != nil { return p.RemoteAddr() } - addrString := net.JoinHostPort(host, strconv.Itoa(int(p.version.Port))) + var port uint16 + for _, cap := range p.version.Capabilities { + if cap.Type == capability.TCPServer { + port = cap.Data.(*capability.Server).Port + } + } + if port == 0 { + return p.RemoteAddr() + } + addrString := net.JoinHostPort(host, strconv.Itoa(int(port))) tcpAddr, err := net.ResolveTCPAddr("tcp", addrString) if err != nil { return p.RemoteAddr() diff --git a/pkg/network/tcp_peer_test.go b/pkg/network/tcp_peer_test.go index 210a2a331..b1cdfa985 100644 --- a/pkg/network/tcp_peer_test.go +++ b/pkg/network/tcp_peer_test.go @@ -18,8 +18,8 @@ func connReadStub(conn net.Conn) { func TestPeerHandshake(t *testing.T) { server, client := net.Pipe() - tcpS := NewTCPPeer(server, newTestServer(t)) - tcpC := NewTCPPeer(client, newTestServer(t)) + tcpS := NewTCPPeer(server, newTestServer(t, ServerConfig{})) + tcpC := NewTCPPeer(client, newTestServer(t, ServerConfig{})) // Something should read things written into the pipe. go connReadStub(tcpS.conn) diff --git a/pkg/network/tcp_transport.go b/pkg/network/tcp_transport.go index 8195ca039..7bf0e39eb 100644 --- a/pkg/network/tcp_transport.go +++ b/pkg/network/tcp_transport.go @@ -3,6 +3,7 @@ package network import ( "net" "regexp" + "sync" "time" "go.uber.org/zap" @@ -14,6 +15,7 @@ type TCPTransport struct { server *Server listener net.Listener bindAddr string + lock sync.RWMutex } var reClosedNetwork = regexp.MustCompile(".* use of closed network connection") @@ -47,7 +49,9 @@ func (t *TCPTransport) Accept() { return } + t.lock.Lock() t.listener = l + t.lock.Unlock() for { conn, err := l.Accept() @@ -84,3 +88,13 @@ func (t *TCPTransport) Close() { func (t *TCPTransport) Proto() string { return "tcp" } + +// Address implements the Transporter interface. +func (t *TCPTransport) Address() string { + t.lock.RLock() + defer t.lock.RUnlock() + if t.listener != nil { + return t.listener.Addr().String() + } + return "" +} diff --git a/pkg/network/transport.go b/pkg/network/transport.go index 684f86717..0f4d9e821 100644 --- a/pkg/network/transport.go +++ b/pkg/network/transport.go @@ -8,5 +8,6 @@ type Transporter interface { Dial(addr string, timeout time.Duration) error Accept() Proto() string + Address() string Close() } diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 3bcf13f7a..f4f60a5e7 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -894,7 +894,7 @@ var rpcClientTestCases = map[string][]rpcClientTestCase{ invoke: func(c *Client) (interface{}, error) { return c.GetVersion() }, - serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"port":20332,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, + serverResponse: `{"id":1,"jsonrpc":"2.0","result":{"tcp_port":20332,"nonce":2153672787,"useragent":"/NEO-GO:0.73.1-pre-273-ge381358/"}}`, result: func(c *Client) interface{} { return &result.Version{ Port: uint16(20332), diff --git a/pkg/rpc/response/result/version.go b/pkg/rpc/response/result/version.go index 145d197b8..5c80fed5e 100644 --- a/pkg/rpc/response/result/version.go +++ b/pkg/rpc/response/result/version.go @@ -4,7 +4,7 @@ type ( // Version model used for reporting server version // info. Version struct { - Port uint16 `json:"port"` + Port uint16 `json:"tcp_port"` Nonce uint32 `json:"nonce"` UserAgent string `json:"useragent"` } From 8c5c248e79463698485625dbd91e595367951f2a Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Fri, 22 May 2020 12:59:18 +0300 Subject: [PATCH 2/3] protocol: add capabilities to address payload Part of #871 --- pkg/network/discovery.go | 31 ++++++++++++++-------- pkg/network/discovery_test.go | 20 +++++++++++++-- pkg/network/helper_test.go | 23 +++++++++-------- pkg/network/payload/address.go | 40 +++++++++++++++++------------ pkg/network/payload/address_test.go | 21 ++++++++++++--- pkg/network/server.go | 9 ++++--- pkg/network/tcp_peer.go | 2 +- 7 files changed, 100 insertions(+), 46 deletions(-) diff --git a/pkg/network/discovery.go b/pkg/network/discovery.go index 65b0b8f36..cf21618c1 100644 --- a/pkg/network/discovery.go +++ b/pkg/network/discovery.go @@ -3,6 +3,8 @@ package network import ( "sync" "time" + + "github.com/nspcc-dev/neo-go/pkg/network/capability" ) const ( @@ -18,12 +20,18 @@ type Discoverer interface { PoolCount() int RequestRemote(int) RegisterBadAddr(string) - RegisterGoodAddr(string) + RegisterGoodAddr(string, capability.Capabilities) RegisterConnectedAddr(string) UnregisterConnectedAddr(string) UnconnectedPeers() []string BadPeers() []string - GoodPeers() []string + GoodPeers() []AddressWithCapabilities +} + +// AddressWithCapabilities represents node address with its capabilities +type AddressWithCapabilities struct { + Address string + Capabilities capability.Capabilities } // DefaultDiscovery default implementation of the Discoverer interface. @@ -34,7 +42,7 @@ type DefaultDiscovery struct { dialTimeout time.Duration badAddrs map[string]bool connectedAddrs map[string]bool - goodAddrs map[string]bool + goodAddrs map[string]capability.Capabilities unconnectedAddrs map[string]int isDead bool requestCh chan int @@ -48,7 +56,7 @@ func NewDefaultDiscovery(dt time.Duration, ts Transporter) *DefaultDiscovery { dialTimeout: dt, badAddrs: make(map[string]bool), connectedAddrs: make(map[string]bool), - goodAddrs: make(map[string]bool), + goodAddrs: make(map[string]capability.Capabilities), unconnectedAddrs: make(map[string]int), requestCh: make(chan int), pool: make(chan string, maxPoolSize), @@ -135,11 +143,14 @@ func (d *DefaultDiscovery) BadPeers() []string { // GoodPeers returns all addresses of known good peers (that at least once // succeeded handshaking with us). -func (d *DefaultDiscovery) GoodPeers() []string { +func (d *DefaultDiscovery) GoodPeers() []AddressWithCapabilities { d.lock.RLock() - addrs := make([]string, 0, len(d.goodAddrs)) - for addr := range d.goodAddrs { - addrs = append(addrs, addr) + addrs := make([]AddressWithCapabilities, 0, len(d.goodAddrs)) + for addr, cap := range d.goodAddrs { + addrs = append(addrs, AddressWithCapabilities{ + Address: addr, + Capabilities: cap, + }) } d.lock.RUnlock() return addrs @@ -147,9 +158,9 @@ func (d *DefaultDiscovery) GoodPeers() []string { // RegisterGoodAddr registers good known connected address that passed // handshake successfully. -func (d *DefaultDiscovery) RegisterGoodAddr(s string) { +func (d *DefaultDiscovery) RegisterGoodAddr(s string, c capability.Capabilities) { d.lock.Lock() - d.goodAddrs[s] = true + d.goodAddrs[s] = c d.lock.Unlock() } diff --git a/pkg/network/discovery_test.go b/pkg/network/discovery_test.go index 81469def2..94af43d83 100644 --- a/pkg/network/discovery_test.go +++ b/pkg/network/discovery_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -79,9 +80,24 @@ func TestDefaultDiscoverer(t *testing.T) { // Registered good addresses should end up in appropriate set. for _, addr := range set1 { - d.RegisterGoodAddr(addr) + d.RegisterGoodAddr(addr, capability.Capabilities{ + { + Type: capability.FullNode, + Data: &capability.Node{StartHeight: 123}, + }, + }) + } + gAddrWithCap := d.GoodPeers() + gAddrs := make([]string, len(gAddrWithCap)) + for i, addr := range gAddrWithCap { + require.Equal(t, capability.Capabilities{ + { + Type: capability.FullNode, + Data: &capability.Node{StartHeight: 123}, + }, + }, addr.Capabilities) + gAddrs[i] = addr.Address } - gAddrs := d.GoodPeers() sort.Strings(gAddrs) assert.Equal(t, 0, d.PoolCount()) assert.Equal(t, 0, len(d.UnconnectedPeers())) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 9bb54cb6b..1ab6acf6c 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" @@ -151,17 +152,17 @@ func (chain testChain) VerifyTx(*transaction.Transaction, *block.Block) error { type testDiscovery struct{} -func (d testDiscovery) BackFill(addrs ...string) {} -func (d testDiscovery) Close() {} -func (d testDiscovery) PoolCount() int { return 0 } -func (d testDiscovery) RegisterBadAddr(string) {} -func (d testDiscovery) RegisterGoodAddr(string) {} -func (d testDiscovery) RegisterConnectedAddr(string) {} -func (d testDiscovery) UnregisterConnectedAddr(string) {} -func (d testDiscovery) UnconnectedPeers() []string { return []string{} } -func (d testDiscovery) RequestRemote(n int) {} -func (d testDiscovery) BadPeers() []string { return []string{} } -func (d testDiscovery) GoodPeers() []string { return []string{} } +func (d testDiscovery) BackFill(addrs ...string) {} +func (d testDiscovery) Close() {} +func (d testDiscovery) PoolCount() int { return 0 } +func (d testDiscovery) RegisterBadAddr(string) {} +func (d testDiscovery) RegisterGoodAddr(string, capability.Capabilities) {} +func (d testDiscovery) RegisterConnectedAddr(string) {} +func (d testDiscovery) UnregisterConnectedAddr(string) {} +func (d testDiscovery) UnconnectedPeers() []string { return []string{} } +func (d testDiscovery) RequestRemote(n int) {} +func (d testDiscovery) BadPeers() []string { return []string{} } +func (d testDiscovery) GoodPeers() []AddressWithCapabilities { return []AddressWithCapabilities{} } var defaultMessageHandler = func(t *testing.T, msg *Message) {} diff --git a/pkg/network/payload/address.go b/pkg/network/payload/address.go index e3d34ae92..6608b2cf6 100644 --- a/pkg/network/payload/address.go +++ b/pkg/network/payload/address.go @@ -1,27 +1,27 @@ package payload import ( + "errors" "net" "strconv" "time" "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/network/capability" ) // AddressAndTime payload. type AddressAndTime struct { - Timestamp uint32 - Services uint64 - IP [16]byte - Port uint16 + Timestamp uint32 + IP [16]byte + Capabilities capability.Capabilities } // NewAddressAndTime creates a new AddressAndTime object. -func NewAddressAndTime(e *net.TCPAddr, t time.Time) *AddressAndTime { +func NewAddressAndTime(e *net.TCPAddr, t time.Time, c capability.Capabilities) *AddressAndTime { aat := AddressAndTime{ - Timestamp: uint32(t.UTC().Unix()), - Services: 1, - Port: uint16(e.Port), + Timestamp: uint32(t.UTC().Unix()), + Capabilities: c, } copy(aat.IP[:], e.IP) return &aat @@ -30,26 +30,34 @@ func NewAddressAndTime(e *net.TCPAddr, t time.Time) *AddressAndTime { // DecodeBinary implements Serializable interface. func (p *AddressAndTime) DecodeBinary(br *io.BinReader) { p.Timestamp = br.ReadU32LE() - p.Services = br.ReadU64LE() br.ReadBytes(p.IP[:]) - p.Port = br.ReadU16BE() + p.Capabilities.DecodeBinary(br) } // EncodeBinary implements Serializable interface. func (p *AddressAndTime) EncodeBinary(bw *io.BinWriter) { bw.WriteU32LE(p.Timestamp) - bw.WriteU64LE(p.Services) bw.WriteBytes(p.IP[:]) - bw.WriteU16BE(p.Port) + p.Capabilities.EncodeBinary(bw) } -// IPPortString makes a string from IP and port specified. -func (p *AddressAndTime) IPPortString() string { +// GetTCPAddress makes a string from IP and port specified in TCPCapability. +// It returns an error if there's no such capability. +func (p *AddressAndTime) GetTCPAddress() (string, error) { var netip = make(net.IP, 16) copy(netip, p.IP[:]) - port := strconv.Itoa(int(p.Port)) - return netip.String() + ":" + port + port := -1 + for _, cap := range p.Capabilities { + if cap.Type == capability.TCPServer { + port = int(cap.Data.(*capability.Server).Port) + break + } + } + if port == -1 { + return "", errors.New("no TCP capability found") + } + return net.JoinHostPort(netip.String(), strconv.Itoa(port)), nil } // AddressList is a list with AddrAndTime. diff --git a/pkg/network/payload/address_test.go b/pkg/network/payload/address_test.go index 30a79d02f..6aa29414c 100644 --- a/pkg/network/payload/address_test.go +++ b/pkg/network/payload/address_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/stretchr/testify/assert" ) @@ -14,14 +15,23 @@ func TestEncodeDecodeAddress(t *testing.T) { var ( e, _ = net.ResolveTCPAddr("tcp", "127.0.0.1:2000") ts = time.Now() - addr = NewAddressAndTime(e, ts) + addr = NewAddressAndTime(e, ts, capability.Capabilities{ + { + Type: capability.TCPServer, + Data: &capability.Server{Port: uint16(e.Port)}, + }, + }) ) assert.Equal(t, ts.UTC().Unix(), int64(addr.Timestamp)) aatip := make(net.IP, 16) copy(aatip, addr.IP[:]) assert.Equal(t, e.IP, aatip) - assert.Equal(t, e.Port, int(addr.Port)) + assert.Equal(t, 1, len(addr.Capabilities)) + assert.Equal(t, capability.Capability{ + Type: capability.TCPServer, + Data: &capability.Server{Port: uint16(e.Port)}, + }, addr.Capabilities[0]) testserdes.EncodeDecodeBinary(t, addr, new(AddressAndTime)) } @@ -31,7 +41,12 @@ func TestEncodeDecodeAddressList(t *testing.T) { addrList := NewAddressList(int(lenList)) for i := 0; i < int(lenList); i++ { e, _ := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:200%d", i)) - addrList.Addrs[i] = NewAddressAndTime(e, time.Now()) + addrList.Addrs[i] = NewAddressAndTime(e, time.Now(), capability.Capabilities{ + { + Type: capability.TCPServer, + Data: &capability.Server{Port: 123}, + }, + }) } testserdes.EncodeDecodeBinary(t, addrList, new(AddressList)) diff --git a/pkg/network/server.go b/pkg/network/server.go index 8606a7506..6830c4612 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -642,7 +642,10 @@ func (s *Server) handleTxCmd(tx *transaction.Transaction) error { // handleAddrCmd will process received addresses. func (s *Server) handleAddrCmd(p Peer, addrs *payload.AddressList) error { for _, a := range addrs.Addrs { - s.discovery.BackFill(a.IPPortString()) + addr, err := a.GetTCPAddress() + if err != nil { + s.discovery.BackFill(addr) + } } return nil } @@ -657,8 +660,8 @@ func (s *Server) handleGetAddrCmd(p Peer) error { ts := time.Now() for i, addr := range addrs { // we know it's a good address, so it can't fail - netaddr, _ := net.ResolveTCPAddr("tcp", addr) - alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts) + netaddr, _ := net.ResolveTCPAddr("tcp", addr.Address) + alist.Addrs[i] = payload.NewAddressAndTime(netaddr, ts, addr.Capabilities) } return p.EnqueueP2PMessage(NewMessage(CMDAddr, alist)) } diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 44fa9f13b..a6bd9bb41 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -234,7 +234,7 @@ func (p *TCPPeer) StartProtocol() { zap.Uint32("startHeight", p.lastBlockIndex), zap.Uint32("id", p.Version().Nonce)) - p.server.discovery.RegisterGoodAddr(p.PeerAddr().String()) + p.server.discovery.RegisterGoodAddr(p.PeerAddr().String(), p.version.Capabilities) if p.server.chain.HeaderHeight() < p.LastBlockIndex() { err = p.server.requestHeaders(p) if err != nil { From 7547a3efe1b34eb8ac83879140243346e96ba997 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Tue, 26 May 2020 19:49:25 +0300 Subject: [PATCH 3/3] network: minor code refactoring We need to handle IPv6 addresses correctly and net.JoinHostPort takes it into account. --- pkg/network/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/network/server.go b/pkg/network/server.go index 6830c4612..31f591a95 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -149,7 +149,7 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo s.AttemptConnPeers = defaultAttemptConnPeers } - s.transport = NewTCPTransport(s, fmt.Sprintf("%s:%d", config.Address, config.Port), s.log) + s.transport = NewTCPTransport(s, net.JoinHostPort(config.Address, strconv.Itoa(int(config.Port))), s.log) s.discovery = NewDefaultDiscovery( s.DialTimeout, s.transport,