diff --git a/pkg/network/payload/getblocks.go b/pkg/network/payload/getblocks.go new file mode 100644 index 000000000..81371cb25 --- /dev/null +++ b/pkg/network/payload/getblocks.go @@ -0,0 +1,54 @@ +package payload + +import ( + "encoding/binary" + "io" + + . "github.com/anthdm/neo-go/pkg/util" +) + +// HashStartStop contains fields and methods to be shared with the +// "GetBlocks" and "GetHeaders" payload. +type HashStartStop struct { + // hash of latest block that node requests + HashStart []Uint256 + // hash of last block that node requests + HashStop Uint256 +} + +// DecodeBinary implements the payload interface. +func (p *HashStartStop) DecodeBinary(r io.Reader) error { + var lenStart uint8 + + err := binary.Read(r, binary.LittleEndian, &lenStart) + p.HashStart = make([]Uint256, lenStart) + err = binary.Read(r, binary.LittleEndian, &p.HashStart) + err = binary.Read(r, binary.LittleEndian, &p.HashStop) + + return err +} + +// EncodeBinary implements the payload interface. +func (p *HashStartStop) EncodeBinary(w io.Writer) error { + err := binary.Write(w, binary.LittleEndian, uint8(len(p.HashStart))) + err = binary.Write(w, binary.LittleEndian, p.HashStart) + err = binary.Write(w, binary.LittleEndian, p.HashStop) + + return err +} + +// Size implements the payload interface. +func (p *HashStartStop) Size() uint32 { return 0 } + +// GetBlocks payload +type GetBlocks struct { + HashStartStop +} + +// NewGetBlocks return a pointer to a GetBlocks object. +func NewGetBlocks(start []Uint256, stop Uint256) *GetBlocks { + p := &GetBlocks{} + p.HashStart = start + p.HashStop = stop + return p +} diff --git a/pkg/network/payload/getblocks_test.go b/pkg/network/payload/getblocks_test.go new file mode 100644 index 000000000..c73772386 --- /dev/null +++ b/pkg/network/payload/getblocks_test.go @@ -0,0 +1,37 @@ +package payload + +import ( + "bytes" + "crypto/sha256" + "reflect" + "testing" + + . "github.com/anthdm/neo-go/pkg/util" +) + +func TestGetBlocksEncodeDecode(t *testing.T) { + start := []Uint256{ + sha256.Sum256([]byte("a")), + sha256.Sum256([]byte("b")), + } + stop := sha256.Sum256([]byte("c")) + + p := NewGetBlocks(start, stop) + buf := new(bytes.Buffer) + if err := p.EncodeBinary(buf); err != nil { + t.Fatal(err) + } + + if have, want := buf.Len(), 1+64+32; have != want { + t.Fatalf("expecting a length of %d got %d", want, have) + } + + pDecode := &GetBlocks{} + if err := pDecode.DecodeBinary(buf); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(p, pDecode) { + t.Fatalf("expecting both getblocks payloads to be equal %v and %v", p, pDecode) + } +} diff --git a/pkg/network/payload/getheaders.go b/pkg/network/payload/getheaders.go new file mode 100644 index 000000000..9ade274f7 --- /dev/null +++ b/pkg/network/payload/getheaders.go @@ -0,0 +1,17 @@ +package payload + +import "github.com/anthdm/neo-go/pkg/util" + +// GetHeaders payload is the same as the "GetBlocks" payload. +type GetHeaders struct { + HashStartStop +} + +// NewGetHeaders return a pointer to a GetHeaders object. +func NewGetHeaders(start []util.Uint256, stop util.Uint256) *GetHeaders { + p := &GetHeaders{} + p.HashStart = start + p.HashStop = stop + + return p +} diff --git a/pkg/network/payload/getheaders_test.go b/pkg/network/payload/getheaders_test.go new file mode 100644 index 000000000..32cbb1b86 --- /dev/null +++ b/pkg/network/payload/getheaders_test.go @@ -0,0 +1,37 @@ +package payload + +import ( + "bytes" + "crypto/sha256" + "reflect" + "testing" + + "github.com/anthdm/neo-go/pkg/util" +) + +func TestGetHeadersEncodeDecode(t *testing.T) { + start := []util.Uint256{ + sha256.Sum256([]byte("a")), + sha256.Sum256([]byte("b")), + } + stop := sha256.Sum256([]byte("c")) + + p := NewGetHeaders(start, stop) + buf := new(bytes.Buffer) + if err := p.EncodeBinary(buf); err != nil { + t.Fatal(err) + } + + if have, want := buf.Len(), 1+64+32; have != want { + t.Fatalf("expecting a length of %d got %d", want, have) + } + + pDecode := &GetHeaders{} + if err := pDecode.DecodeBinary(buf); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(p, pDecode) { + t.Fatalf("expecting both getheaders payloads to be equal %v and %v", p, pDecode) + } +} diff --git a/pkg/network/peer.go b/pkg/network/peer.go index 025f4b7d8..36a4f87f2 100644 --- a/pkg/network/peer.go +++ b/pkg/network/peer.go @@ -9,7 +9,6 @@ import ( type Peer interface { id() uint32 addr() util.Endpoint - verack() bool disconnect() callVersion(*Message) callGetaddr(*Message) @@ -20,7 +19,6 @@ type Peer interface { type LocalPeer struct { s *Server nonce uint32 - isVerack bool endpoint util.Endpoint } @@ -39,6 +37,5 @@ func (p *LocalPeer) callGetaddr(msg *Message) { } func (p *LocalPeer) id() uint32 { return p.nonce } -func (p *LocalPeer) verack() bool { return p.isVerack } func (p *LocalPeer) addr() util.Endpoint { return p.endpoint } func (p *LocalPeer) disconnect() {} diff --git a/pkg/network/tcp.go b/pkg/network/tcp.go index a1280a6b0..648fe6274 100644 --- a/pkg/network/tcp.go +++ b/pkg/network/tcp.go @@ -22,6 +22,7 @@ func listenTCP(s *Server, port int) error { if err != nil { return err } + go handleConnection(s, conn) } } @@ -54,8 +55,10 @@ func handleConnection(s *Server, conn net.Conn) { s.unregister <- peer }() - // Start a goroutine that will handle all writes to the registered peer. + // Start a goroutine that will handle all outgoing messages. go peer.writeLoop() + // Start a goroutine that will handle all incomming messages. + go handleMessage(s, peer) // Read from the connection and decode it into a Message ready for processing. buf := make([]byte, 1024) @@ -71,39 +74,55 @@ func handleConnection(s *Server, conn net.Conn) { s.logger.Printf("decode error %s", err) break } - handleMessage(msg, s, peer) + + peer.receive <- msg } } // handleMessage hands the message received from a TCP connection over to the server. -func handleMessage(msg *Message, s *Server, p *TCPPeer) { - command := msg.commandType() +func handleMessage(s *Server, p *TCPPeer) { + // Disconnect the peer when we break out of the loop. + defer func() { + p.disconnect() + }() - s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg) + for { + msg := <-p.receive + command := msg.commandType() - switch command { - case cmdVersion: - resp := s.handleVersionCmd(msg, p) - p.isVerack = true - p.nonce = msg.Payload.(*payload.Version).Nonce - p.send <- resp - case cmdAddr: - s.handleAddrCmd(msg, p) - case cmdGetAddr: - s.handleGetaddrCmd(msg, p) - case cmdInv: - resp := s.handleInvCmd(msg, p) - p.send <- resp - case cmdBlock: - case cmdConsensus: - case cmdTX: - case cmdVerack: - go s.sendLoop(p) - case cmdGetHeaders: - case cmdGetBlocks: - case cmdGetData: - case cmdHeaders: - default: + s.logger.Printf("IN :: %d :: %s :: %v", p.id(), command, msg) + + switch command { + case cmdVersion: + resp := s.handleVersionCmd(msg, p) + p.nonce = msg.Payload.(*payload.Version).Nonce + p.send <- resp + + // after sending our version we want a "verack" and nothing else. + msg := <-p.receive + if msg.commandType() != cmdVerack { + break + } + // we can start the protocol now. + go s.sendLoop(p) + case cmdAddr: + s.handleAddrCmd(msg, p) + case cmdGetAddr: + s.handleGetaddrCmd(msg, p) + case cmdInv: + resp := s.handleInvCmd(msg, p) + p.send <- resp + case cmdBlock: + case cmdConsensus: + case cmdTX: + case cmdVerack: + // disconnect the peer, verack should already be handled. + break + case cmdGetHeaders: + case cmdGetBlocks: + case cmdGetData: + case cmdHeaders: + } } } @@ -118,8 +137,8 @@ type TCPPeer struct { endpoint util.Endpoint // channel to coordinate messages writen back to the connection. send chan *Message - // whether this peers version was acknowledged. - isVerack bool + // channel to receive from underlying connection. + receive chan *Message } // NewTCPPeer returns a pointer to a TCP Peer. @@ -129,6 +148,7 @@ func NewTCPPeer(conn net.Conn, s *Server) *TCPPeer { return &TCPPeer{ conn: conn, send: make(chan *Message), + receive: make(chan *Message), endpoint: e, s: s, } @@ -148,11 +168,6 @@ func (p *TCPPeer) addr() util.Endpoint { return p.endpoint } -// verack implements the peer interface -func (p *TCPPeer) verack() bool { - return p.isVerack -} - // callGetaddr will send the "getaddr" command to the remote. func (p *TCPPeer) callGetaddr(msg *Message) { p.send <- msg