From f28d8f9ab66c4ce361c5e2ce1795adcb1e5fe27f Mon Sep 17 00:00:00 2001 From: anthdm Date: Sun, 28 Jan 2018 08:03:18 +0100 Subject: [PATCH] uint256 + inventoryType --- pkg/network/message.go | 110 +++++++++++++------------------ pkg/network/payload/payloader.go | 12 ++-- pkg/network/payload/version.go | 59 +++++++---------- pkg/network/server.go | 20 ++++-- 4 files changed, 89 insertions(+), 112 deletions(-) diff --git a/pkg/network/message.go b/pkg/network/message.go index 84099c555..1b214a5e9 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/binary" "errors" - "fmt" "io" "github.com/anthdm/neo-go/pkg/network/payload" @@ -81,15 +80,25 @@ const ( ) func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { - var size uint32 + var ( + size uint32 + checksum []byte + ) + if p != nil { size = p.Size() + b, _ := p.MarshalBinary() + checksum = sumSHA256(sumSHA256(b)) + } else { + checksum = sumSHA256(sumSHA256([]byte{})) } + return &Message{ - Magic: magic, - Command: cmdToByteSlice(cmd), - Length: size, - Payload: p, + Magic: magic, + Command: cmdToByteSlice(cmd), + Length: size, + Payload: p, + Checksum: binary.LittleEndian.Uint32(checksum[:4]), } } @@ -137,40 +146,39 @@ func (m *Message) decode(r io.Reader) error { m.Length = binary.LittleEndian.Uint32(buf[16:20]) m.Checksum = binary.LittleEndian.Uint32(buf[20:24]) - // if their is no payload. - if m.Length == 0 || !needPayloadDecode(m.commandType()) { + // return if their is no payload. + if m.Length == 0 { return nil } - return m.decodePayload(r) + return m.unmarshalPayload(r) } -func (m *Message) decodePayload(r io.Reader) error { - // write to a buffer what we read to calculate the checksum. - buffer := new(bytes.Buffer) - tr := io.TeeReader(r, buffer) - var p payload.Payloader - - switch m.commandType() { - case cmdVersion: - p = &payload.Version{} - if err := p.Decode(tr); err != nil { - return err - } - case cmdInv: - p = payload.Inventories{} - if err := p.Decode(tr); err != nil { - return err - } - default: - return fmt.Errorf("unknown command to decode: %s", m.commandType()) +func (m *Message) unmarshalPayload(r io.Reader) error { + pbuf := make([]byte, m.Length) + if _, err := r.Read(pbuf); err != nil { + return err } // Compare the checksum of the payload. - if !compareChecksum(m.Checksum, buffer.Bytes()) { + if !compareChecksum(m.Checksum, pbuf) { return errors.New("checksum mismatch error") } + var p payload.Payloader + switch m.commandType() { + case cmdVersion: + p = &payload.Version{} + if err := p.UnmarshalBinary(pbuf); err != nil { + return err + } + case cmdInv: + p = &payload.Inventory{} + if err := p.UnmarshalBinary(pbuf); err != nil { + return err + } + } + m.Payload = p return nil @@ -178,49 +186,23 @@ func (m *Message) decodePayload(r io.Reader) error { // encode a Message to any given io.Writer. func (m *Message) encode(w io.Writer) error { - buf := make([]byte, minMessageSize) - pbuf := new(bytes.Buffer) + buf := make([]byte, minMessageSize+m.Length) - // if there is a payload fill its allocated buffer. - var checksum []byte - if m.Payload != nil { - if err := m.Payload.Encode(pbuf); err != nil { - return err - } - checksum = sumSHA256(sumSHA256(pbuf.Bytes()))[:4] - } else { - checksum = sumSHA256(sumSHA256([]byte{}))[:4] - } - - m.Checksum = binary.LittleEndian.Uint32(checksum) - - // fill the message buffer binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) copy(buf[4:16], m.Command) binary.LittleEndian.PutUint32(buf[16:20], m.Length) binary.LittleEndian.PutUint32(buf[20:24], m.Checksum) - // write the message - n, err := w.Write(buf) - if err != nil { - return err - } - - // we need to have at least writen exactly minMessageSize bytes. - if n != minMessageSize { - return errors.New("long/short read error when encoding message") - } - - // write the payload if there was any - if pbuf.Len() > 0 { - n, err = w.Write(pbuf.Bytes()) + if m.Payload != nil { + payload, err := m.Payload.MarshalBinary() if err != nil { return err } + copy(buf[minMessageSize:minMessageSize+m.Length], payload) + } - if uint32(n) != m.Payload.Size() { - return errors.New("long/short read error when encoding payload") - } + if _, err := w.Write(buf); err != nil { + return err } return nil @@ -243,10 +225,6 @@ func cmdToByteSlice(cmd commandType) []byte { return b } -func needPayloadDecode(cmd commandType) bool { - return cmd != cmdVerack && cmd != cmdGetAddr -} - func sumSHA256(b []byte) []byte { h := sha256.New() h.Write(b) diff --git a/pkg/network/payload/payloader.go b/pkg/network/payload/payloader.go index 015b43066..902753524 100644 --- a/pkg/network/payload/payloader.go +++ b/pkg/network/payload/payloader.go @@ -1,11 +1,13 @@ package payload -import "io" +import ( + "encoding" +) -// Payloader is anything that can be binary encoded and decoded. -// Every payload used in messages need to satisfy the Payloader interface. +// Payloader is anything that can be binary marshaled and unmarshaled. +// Every payload embbedded in messages need to satisfy the Payloader interface. type Payloader interface { - Encode(io.Writer) error - Decode(io.Reader) error + encoding.BinaryMarshaler + encoding.BinaryUnmarshaler Size() uint32 } diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 526dde118..f0e7a79c0 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -2,7 +2,6 @@ package payload import ( "encoding/binary" - "io" ) const ( @@ -31,36 +30,21 @@ type Version struct { } // NewVersion returns a pointer to a Version payload. -func NewVersion(p uint16, ua string, h uint32, r bool) *Version { +func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { return &Version{ Version: 0, Services: 1, Timestamp: 12345, Port: p, - Nonce: 19110, + Nonce: id, UserAgent: []byte(ua), StartHeight: 0, Relay: r, } } -// Size implements the Payloader interface. -func (p *Version) Size() uint32 { - n := minVersionSize - return uint32(n) -} - -// Decode implements the Payloader interface. -func (p *Version) Decode(r io.Reader) error { - b := make([]byte, minVersionSize) - if _, err := r.Read(b); err != nil { - return err - } - - // 27 bytes for the fixed size fields + the length of the user agent - // which is kinda variable, according to the docs. - lenUA := len(b) - minVersionSize - +// UnmarshalBinary implements the Payloader interface. +func (p *Version) UnmarshalBinary(b []byte) error { p.Version = binary.LittleEndian.Uint32(b[0:4]) p.Services = binary.LittleEndian.Uint64(b[4:12]) p.Timestamp = binary.LittleEndian.Uint32(b[12:16]) @@ -76,30 +60,33 @@ func (p *Version) Decode(r io.Reader) error { return nil } -// Encode implements the Payloader interface. -func (p *Version) Encode(w io.Writer) error { - buf := make([]byte, p.Size()) +// MarshalBinary implements the Payloader interface. +func (p *Version) MarshalBinary() ([]byte, error) { + b := make([]byte, p.Size()) - binary.LittleEndian.PutUint32(buf[0:4], p.Version) - binary.LittleEndian.PutUint64(buf[4:12], p.Services) - binary.LittleEndian.PutUint32(buf[12:16], p.Timestamp) + binary.LittleEndian.PutUint32(b[0:4], p.Version) + binary.LittleEndian.PutUint64(b[4:12], p.Services) + binary.LittleEndian.PutUint32(b[12:16], p.Timestamp) // FIXME: byte order (little / big)? - binary.LittleEndian.PutUint16(buf[16:18], p.Port) - binary.LittleEndian.PutUint32(buf[18:22], p.Nonce) - copy(buf[22:22+len(p.UserAgent)], p.UserAgent) // + binary.LittleEndian.PutUint16(b[16:18], p.Port) + binary.LittleEndian.PutUint32(b[18:22], p.Nonce) + copy(b[22:22+len(p.UserAgent)], p.UserAgent) // curLen := 22 + len(p.UserAgent) - binary.LittleEndian.PutUint32(buf[curLen:curLen+4], p.StartHeight) + binary.LittleEndian.PutUint32(b[curLen:curLen+4], p.StartHeight) // yikes - var b []byte + var bln []byte if p.Relay { - b = []byte{1} + bln = []byte{1} } else { - b = []byte{0} + bln = []byte{0} } - copy(buf[curLen+4:len(buf)], b) + copy(b[curLen+4:len(b)], bln) - _, err := w.Write(buf) - return err + return b, nil +} + +func (p *Version) Size() uint32 { + return uint32(minVersionSize + len(p.UserAgent)) } diff --git a/pkg/network/server.go b/pkg/network/server.go index 99c411c5f..13bd593ce 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -9,6 +9,7 @@ import ( "strconv" "github.com/anthdm/neo-go/pkg/network/payload" + "github.com/anthdm/neo-go/pkg/util" ) const ( @@ -33,6 +34,9 @@ type messageTuple struct { type Server struct { logger *log.Logger + // id of the server + id uint32 + // the port the TCP listener is listening on. port uint16 @@ -72,6 +76,7 @@ func NewServer(net NetMode) *Server { s := &Server{ // It is important to have this user agent correct. Otherwise we will get // disconnected. + id: util.RandUint32(1111111, 9999999), userAgent: fmt.Sprintf("\v/NEO:%s/", version), logger: logger, peers: make(map[*Peer]bool), @@ -95,8 +100,10 @@ func (s *Server) Start(port string, seeds []string) { s.port = uint16(p) fmt.Println(logo()) - s.logger.Printf("running %s on %s - TCP %d - relay: %v", - s.userAgent, s.net, int(s.port), s.relay) + fmt.Println(string(s.userAgent)) + fmt.Println("") + s.logger.Printf("NET: %s - TCP: %d - RELAY: %v - ID: %d", + s.net, int(s.port), s.relay, s.id) go listenTCP(s, port) @@ -163,7 +170,10 @@ func (s *Server) loop() { // TODO: unregister peers on error. // processMessage processes the received message from a remote node. func (s *Server) processMessage(msg *Message, peer *Peer) error { - rpcLogger.Printf("IN :: %+v", string(msg.Command)) + rpcLogger.Printf("IN :: %s", msg.commandType()) + if msg.Length > 0 { + rpcLogger.Printf("IN :: %+v", msg.Payload) + } switch msg.commandType() { case cmdVersion: @@ -190,7 +200,7 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error { // No further communication should been made before both sides has received // the version of eachother. func (s *Server) handlePeerConnected() (*Message, error) { - payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay) + payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) msg := newMessage(s.net, cmdVersion, payload) return msg, nil } @@ -199,7 +209,7 @@ func (s *Server) handlePeerConnected() (*Message, error) { func (s *Server) handleVersionCmd(v *payload.Version, peer *Peer) error { // TODO: check version and verify to trust that node. - payload := payload.NewVersion(s.port, s.userAgent, 0, s.relay) + payload := payload.NewVersion(s.id, s.port, s.userAgent, 0, s.relay) // we respond with our version. versionMsg := newMessage(s.net, cmdVersion, payload) peer.send <- versionMsg