From 8bbe1435fbf01fc09e42589cf772a0ad0614a37d Mon Sep 17 00:00:00 2001 From: anthdm Date: Sun, 28 Jan 2018 16:06:41 +0100 Subject: [PATCH] huge message and payload refactor. --- pkg/network/message.go | 120 ++++++++++++---------------- pkg/network/payload/payload.go | 10 +++ pkg/network/payload/payloader.go | 13 --- pkg/network/payload/version.go | 78 ++++++++---------- pkg/network/payload/version_test.go | 20 +++-- pkg/network/server.go | 4 +- 6 files changed, 111 insertions(+), 134 deletions(-) create mode 100644 pkg/network/payload/payload.go delete mode 100644 pkg/network/payload/payloader.go diff --git a/pkg/network/message.go b/pkg/network/message.go index e8cefc255..f1c27e709 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -14,6 +14,7 @@ import ( const ( // The minimum size of a valid message. minMessageSize = 24 + cmdSize = 12 ) // NetMode type that is compatible with netModes below. @@ -53,14 +54,14 @@ type Message struct { Magic NetMode // Command is utf8 code, of which the length is 12 bytes, // the extra part is filled with 0. - Command []byte + Command [cmdSize]byte // Length of the payload Length uint32 // Checksum is the first 4 bytes of the value that two times SHA256 // hash of the payload Checksum uint32 // Payload send with the message. - Payload payload.Payloader + Payload payload.Payload } type commandType string @@ -80,7 +81,7 @@ const ( cmdTX = "tx" ) -func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { +func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message { var ( size uint32 checksum []byte @@ -88,15 +89,18 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { if p != nil { size = p.Size() - b, _ := p.MarshalBinary() - checksum = sumSHA256(sumSHA256(b)) + buf := new(bytes.Buffer) + if err := p.EncodeBinary(buf); err != nil { + panic(err) + } + checksum = sumSHA256(sumSHA256(buf.Bytes())) } else { checksum = sumSHA256(sumSHA256([]byte{})) } return &Message{ Magic: magic, - Command: cmdToByteSlice(cmd), + Command: cmdToByteArray(cmd), Length: size, Payload: p, Checksum: binary.LittleEndian.Uint32(checksum[:4]), @@ -105,7 +109,7 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { // Converts the 12 byte command slice to a commandType. func (m *Message) commandType() commandType { - cmd := string(bytes.TrimRight(m.Command, "\x00")) + cmd := cmdByteArrayToString(m.Command) switch cmd { case "version": return cmdVersion @@ -136,31 +140,20 @@ func (m *Message) commandType() commandType { // decode a Message from the given reader. func (m *Message) decode(r io.Reader) error { - // 24 bytes for the fixed sized fields. - buf := make([]byte, minMessageSize) - n, err := r.Read(buf) - if err != nil { - return err - } - - if n != minMessageSize { - return fmt.Errorf("Expected to read exactly %d bytes got %d", minMessageSize, n) - } - - m.Magic = NetMode(binary.LittleEndian.Uint32(buf[0:4])) - m.Command = buf[4:16] - m.Length = binary.LittleEndian.Uint32(buf[16:20]) - m.Checksum = binary.LittleEndian.Uint32(buf[20:24]) + binary.Read(r, binary.LittleEndian, &m.Magic) + binary.Read(r, binary.LittleEndian, &m.Command) + binary.Read(r, binary.LittleEndian, &m.Length) + binary.Read(r, binary.LittleEndian, &m.Checksum) // return if their is no payload. if m.Length == 0 { return nil } - return m.unmarshalPayload(r) + return m.decodePayload(r) } -func (m *Message) unmarshalPayload(r io.Reader) error { +func (m *Message) decodePayload(r io.Reader) error { pbuf := make([]byte, m.Length) n, err := r.Read(pbuf) if err != nil { @@ -176,23 +169,24 @@ func (m *Message) unmarshalPayload(r io.Reader) error { return errors.New("checksum mismatch error") } - var p payload.Payloader + rr := bytes.NewReader(pbuf) + var p payload.Payload switch m.commandType() { case cmdVersion: p = &payload.Version{} - if err := p.UnmarshalBinary(pbuf); err != nil { - return err - } - case cmdInv: - p = &payload.Inventory{} - if err := p.UnmarshalBinary(pbuf); err != nil { - return err - } - case cmdAddr: - p = &payload.AddressList{} - if err := p.UnmarshalBinary(pbuf); err != nil { + if err := p.DecodeBinary(rr); err != nil { return err } + // case cmdInv: + // p = &payload.Inventory{} + // if err := p.UnmarshalBinary(pbuf); err != nil { + // return err + // } + // case cmdAddr: + // p = &payload.AddressList{} + // if err := p.UnmarshalBinary(pbuf); err != nil { + // return err + // } } m.Payload = p @@ -202,37 +196,13 @@ func (m *Message) unmarshalPayload(r io.Reader) error { // encode a Message to any given io.Writer. func (m *Message) encode(w io.Writer) error { - buf := make([]byte, minMessageSize+m.Length) - - binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) - copy(buf[4:16], m.Command) - binary.LittleEndian.PutUint32(buf[16:20], m.Length) - binary.LittleEndian.PutUint32(buf[20:24], m.Checksum) + binary.Write(w, binary.LittleEndian, m.Magic) + binary.Write(w, binary.LittleEndian, m.Command) + binary.Write(w, binary.LittleEndian, m.Length) + binary.Write(w, binary.LittleEndian, m.Checksum) if m.Payload != nil { - payload, err := m.Payload.MarshalBinary() - if err != nil { - return err - } - copy(buf[minMessageSize:minMessageSize+m.Length], payload) - } - - n, err := w.Write(buf) - if err != nil { - return err - } - - // safety check to if we have written enough bytes. - if m.Length > 0 { - expectWritten := minMessageSize + m.Length - if uint32(n) != expectWritten { - return fmt.Errorf("expected to written exactly %d did %d", expectWritten, n) - } - } else { - expectWritten := minMessageSize - if n != expectWritten { - return fmt.Errorf("expected to written exactly %d did %d", expectWritten, n) - } + return m.Payload.EncodeBinary(w) } return nil @@ -240,21 +210,31 @@ func (m *Message) encode(w io.Writer) error { // convert a command (string) to a byte slice filled with 0 bytes till // size 12. -func cmdToByteSlice(cmd commandType) []byte { +func cmdToByteArray(cmd commandType) [cmdSize]byte { cmdLen := len(cmd) - if cmdLen > 12 { + if cmdLen > cmdSize { panic("exceeded command max length of size 12") } // The command can have max 12 bytes, rest is filled with 0. - b := []byte(cmd) - for i := 0; i < 12-cmdLen; i++ { - b = append(b, byte('\x00')) + b := [cmdSize]byte{} + for i := 0; i < cmdLen; i++ { + b[i] = cmd[i] } return b } +func cmdByteArrayToString(cmd [cmdSize]byte) string { + buf := []byte{} + for i := 0; i < cmdSize; i++ { + if cmd[i] != 0 { + buf = append(buf, cmd[i]) + } + } + return string(buf) +} + func sumSHA256(b []byte) []byte { h := sha256.New() h.Write(b) diff --git a/pkg/network/payload/payload.go b/pkg/network/payload/payload.go new file mode 100644 index 000000000..104a3b9f5 --- /dev/null +++ b/pkg/network/payload/payload.go @@ -0,0 +1,10 @@ +package payload + +import "io" + +// Payload is anything that can be binary encoded/decoded. +type Payload interface { + EncodeBinary(io.Writer) error + DecodeBinary(io.Reader) error + Size() uint32 +} diff --git a/pkg/network/payload/payloader.go b/pkg/network/payload/payloader.go deleted file mode 100644 index 902753524..000000000 --- a/pkg/network/payload/payloader.go +++ /dev/null @@ -1,13 +0,0 @@ -package payload - -import ( - "encoding" -) - -// Payloader is anything that can be binary marshaled and unmarshaled. -// Every payload embbedded in messages need to satisfy the Payloader interface. -type Payloader interface { - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler - Size() uint32 -} diff --git a/pkg/network/payload/version.go b/pkg/network/payload/version.go index 34828a9d9..329d920b7 100644 --- a/pkg/network/payload/version.go +++ b/pkg/network/payload/version.go @@ -2,9 +2,11 @@ package payload import ( "encoding/binary" + "io" ) const ( + lenUA = 12 minVersionSize = 27 ) @@ -21,7 +23,7 @@ type Version struct { // it's used to distinguish the node from public IP Nonce uint32 // client id - UserAgent []byte + UserAgent [lenUA]byte // Height of the block chain StartHeight uint32 // Whether to receive and forward @@ -36,63 +38,53 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version { Timestamp: 12345, Port: p, Nonce: id, - UserAgent: []byte(ua), + UserAgent: uaToByteArray(ua), StartHeight: 0, Relay: r, } } -// UnmarshalBinary implements the Payloader interface. -func (p *Version) UnmarshalBinary(b []byte) error { - // Length of the user agent should be calculated dynamicaly. +// DecodeBinary implements the Payload interface. +func (p *Version) DecodeBinary(r io.Reader) error { + // TODO: Length of the user agent should be calculated dynamicaly. // There is no information about the size or format of this. // the only thing we know is by looking at the #c source code. // /NEO:{0}/ => /NEO:2.6.0/ - lenUA := len(b) - minVersionSize + err := binary.Read(r, binary.LittleEndian, &p.Version) + err = binary.Read(r, binary.LittleEndian, &p.Services) + err = binary.Read(r, binary.LittleEndian, &p.Timestamp) + err = binary.Read(r, binary.LittleEndian, &p.Port) + err = binary.Read(r, binary.LittleEndian, &p.Nonce) + err = binary.Read(r, binary.LittleEndian, &p.UserAgent) + err = binary.Read(r, binary.LittleEndian, &p.StartHeight) + err = binary.Read(r, binary.LittleEndian, &p.Relay) - p.Version = binary.LittleEndian.Uint32(b[0:4]) - p.Services = binary.LittleEndian.Uint64(b[4:12]) - p.Timestamp = binary.LittleEndian.Uint32(b[12:16]) - // FIXME: port's byteorder should be big endian according to the docs. - // but when connecting to the privnet docker image it's little endian. - p.Port = binary.LittleEndian.Uint16(b[16:18]) - p.Nonce = binary.LittleEndian.Uint32(b[18:22]) - p.UserAgent = b[22 : 22+lenUA] - curlen := 22 + lenUA - p.StartHeight = binary.LittleEndian.Uint32(b[curlen : curlen+4]) - p.Relay = b[len(b)-1 : len(b)][0] == 1 - - return nil + return err } -// MarshalBinary implements the Payloader interface. -func (p *Version) MarshalBinary() ([]byte, error) { - b := make([]byte, p.Size()) +// EncodeBinary implements the Payload interface. +func (p *Version) EncodeBinary(w io.Writer) error { + err := binary.Write(w, binary.LittleEndian, p.Version) + err = binary.Write(w, binary.LittleEndian, p.Services) + err = binary.Write(w, binary.LittleEndian, p.Timestamp) + err = binary.Write(w, binary.LittleEndian, p.Port) + err = binary.Write(w, binary.LittleEndian, p.Nonce) + err = binary.Write(w, binary.LittleEndian, p.UserAgent) + err = binary.Write(w, binary.LittleEndian, p.StartHeight) + err = binary.Write(w, binary.LittleEndian, p.Relay) - binary.LittleEndian.PutUint32(b[0:4], p.Version) - binary.LittleEndian.PutUint64(b[4:12], p.Services) - binary.LittleEndian.PutUint32(b[12:16], p.Timestamp) - // FIXME: byte order (little / big)? - binary.LittleEndian.PutUint16(b[16:18], p.Port) - binary.LittleEndian.PutUint32(b[18:22], p.Nonce) - copy(b[22:22+len(p.UserAgent)], p.UserAgent) // - curLen := 22 + len(p.UserAgent) - binary.LittleEndian.PutUint32(b[curLen:curLen+4], p.StartHeight) - - // yikes - var bln []byte - if p.Relay { - bln = []byte{1} - } else { - bln = []byte{0} - } - - copy(b[curLen+4:len(b)], bln) - - return b, nil + return err } // Size implements the payloader interface. func (p *Version) Size() uint32 { return uint32(minVersionSize + len(p.UserAgent)) } + +func uaToByteArray(ua string) [lenUA]byte { + buf := [lenUA]byte{} + for i := 0; i < lenUA; i++ { + buf[i] = ua[i] + } + return buf +} diff --git a/pkg/network/payload/version_test.go b/pkg/network/payload/version_test.go index 66d562156..6eac5b645 100644 --- a/pkg/network/payload/version_test.go +++ b/pkg/network/payload/version_test.go @@ -7,15 +7,23 @@ import ( ) func TestVersionEncodeDecode(t *testing.T) { - p := NewVersion(3000, "/NEO/", 0, true) + version := NewVersion(13337, 3000, "./NEO:0.0.1/", 0, true) buf := new(bytes.Buffer) - p.Encode(buf) + if err := version.EncodeBinary(buf); err != nil { + t.Fatal(err) + } - pd := &Version{} - pd.Decode(buf) + versionDecoded := &Version{} + if err := versionDecoded.DecodeBinary(buf); err != nil { + t.Fatal(err) + } - if !reflect.DeepEqual(p, pd) { - t.Fatalf("expect %v to be equal to %v", p, pd) + if !reflect.DeepEqual(version, versionDecoded) { + t.Fatalf("expected both version payload to be equal: %+v and %+v", version, versionDecoded) + } + + if version.Size() != uint32(minVersionSize+len(version.UserAgent)) { + t.Fatalf("Expected version size of %d", minVersionSize+lenUA) } } diff --git a/pkg/network/server.go b/pkg/network/server.go index 9f5d615c7..8795c8ea1 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -181,9 +181,9 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error { return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) case cmdVerack: case cmdGetAddr: - return s.handleGetAddrCmd(msg, peer) + // return s.handleGetAddrCmd(msg, peer) case cmdAddr: - return s.handleAddrCmd(msg.Payload.(*payload.AddressList), peer) + // return s.handleAddrCmd(msg.Payload.(*payload.AddressList), peer) case cmdGetHeaders: case cmdHeaders: case cmdGetBlocks: