huge message and payload refactor.

This commit is contained in:
anthdm 2018-01-28 16:06:41 +01:00
parent 1821ff1a0e
commit 8bbe1435fb
6 changed files with 111 additions and 134 deletions

View file

@ -14,6 +14,7 @@ import (
const ( const (
// The minimum size of a valid message. // The minimum size of a valid message.
minMessageSize = 24 minMessageSize = 24
cmdSize = 12
) )
// NetMode type that is compatible with netModes below. // NetMode type that is compatible with netModes below.
@ -53,14 +54,14 @@ type Message struct {
Magic NetMode Magic NetMode
// Command is utf8 code, of which the length is 12 bytes, // Command is utf8 code, of which the length is 12 bytes,
// the extra part is filled with 0. // the extra part is filled with 0.
Command []byte Command [cmdSize]byte
// Length of the payload // Length of the payload
Length uint32 Length uint32
// Checksum is the first 4 bytes of the value that two times SHA256 // Checksum is the first 4 bytes of the value that two times SHA256
// hash of the payload // hash of the payload
Checksum uint32 Checksum uint32
// Payload send with the message. // Payload send with the message.
Payload payload.Payloader Payload payload.Payload
} }
type commandType string type commandType string
@ -80,7 +81,7 @@ const (
cmdTX = "tx" cmdTX = "tx"
) )
func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message { func newMessage(magic NetMode, cmd commandType, p payload.Payload) *Message {
var ( var (
size uint32 size uint32
checksum []byte checksum []byte
@ -88,15 +89,18 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
if p != nil { if p != nil {
size = p.Size() size = p.Size()
b, _ := p.MarshalBinary() buf := new(bytes.Buffer)
checksum = sumSHA256(sumSHA256(b)) if err := p.EncodeBinary(buf); err != nil {
panic(err)
}
checksum = sumSHA256(sumSHA256(buf.Bytes()))
} else { } else {
checksum = sumSHA256(sumSHA256([]byte{})) checksum = sumSHA256(sumSHA256([]byte{}))
} }
return &Message{ return &Message{
Magic: magic, Magic: magic,
Command: cmdToByteSlice(cmd), Command: cmdToByteArray(cmd),
Length: size, Length: size,
Payload: p, Payload: p,
Checksum: binary.LittleEndian.Uint32(checksum[:4]), Checksum: binary.LittleEndian.Uint32(checksum[:4]),
@ -105,7 +109,7 @@ func newMessage(magic NetMode, cmd commandType, p payload.Payloader) *Message {
// Converts the 12 byte command slice to a commandType. // Converts the 12 byte command slice to a commandType.
func (m *Message) commandType() commandType { func (m *Message) commandType() commandType {
cmd := string(bytes.TrimRight(m.Command, "\x00")) cmd := cmdByteArrayToString(m.Command)
switch cmd { switch cmd {
case "version": case "version":
return cmdVersion return cmdVersion
@ -136,31 +140,20 @@ func (m *Message) commandType() commandType {
// decode a Message from the given reader. // decode a Message from the given reader.
func (m *Message) decode(r io.Reader) error { func (m *Message) decode(r io.Reader) error {
// 24 bytes for the fixed sized fields. binary.Read(r, binary.LittleEndian, &m.Magic)
buf := make([]byte, minMessageSize) binary.Read(r, binary.LittleEndian, &m.Command)
n, err := r.Read(buf) binary.Read(r, binary.LittleEndian, &m.Length)
if err != nil { binary.Read(r, binary.LittleEndian, &m.Checksum)
return err
}
if n != minMessageSize {
return fmt.Errorf("Expected to read exactly %d bytes got %d", minMessageSize, n)
}
m.Magic = NetMode(binary.LittleEndian.Uint32(buf[0:4]))
m.Command = buf[4:16]
m.Length = binary.LittleEndian.Uint32(buf[16:20])
m.Checksum = binary.LittleEndian.Uint32(buf[20:24])
// return if their is no payload. // return if their is no payload.
if m.Length == 0 { if m.Length == 0 {
return nil return nil
} }
return m.unmarshalPayload(r) return m.decodePayload(r)
} }
func (m *Message) unmarshalPayload(r io.Reader) error { func (m *Message) decodePayload(r io.Reader) error {
pbuf := make([]byte, m.Length) pbuf := make([]byte, m.Length)
n, err := r.Read(pbuf) n, err := r.Read(pbuf)
if err != nil { if err != nil {
@ -176,23 +169,24 @@ func (m *Message) unmarshalPayload(r io.Reader) error {
return errors.New("checksum mismatch error") return errors.New("checksum mismatch error")
} }
var p payload.Payloader rr := bytes.NewReader(pbuf)
var p payload.Payload
switch m.commandType() { switch m.commandType() {
case cmdVersion: case cmdVersion:
p = &payload.Version{} p = &payload.Version{}
if err := p.UnmarshalBinary(pbuf); err != nil { if err := p.DecodeBinary(rr); err != nil {
return err
}
case cmdInv:
p = &payload.Inventory{}
if err := p.UnmarshalBinary(pbuf); err != nil {
return err
}
case cmdAddr:
p = &payload.AddressList{}
if err := p.UnmarshalBinary(pbuf); err != nil {
return err return err
} }
// case cmdInv:
// p = &payload.Inventory{}
// if err := p.UnmarshalBinary(pbuf); err != nil {
// return err
// }
// case cmdAddr:
// p = &payload.AddressList{}
// if err := p.UnmarshalBinary(pbuf); err != nil {
// return err
// }
} }
m.Payload = p m.Payload = p
@ -202,37 +196,13 @@ func (m *Message) unmarshalPayload(r io.Reader) error {
// encode a Message to any given io.Writer. // encode a Message to any given io.Writer.
func (m *Message) encode(w io.Writer) error { func (m *Message) encode(w io.Writer) error {
buf := make([]byte, minMessageSize+m.Length) binary.Write(w, binary.LittleEndian, m.Magic)
binary.Write(w, binary.LittleEndian, m.Command)
binary.LittleEndian.PutUint32(buf[0:4], uint32(m.Magic)) binary.Write(w, binary.LittleEndian, m.Length)
copy(buf[4:16], m.Command) binary.Write(w, binary.LittleEndian, m.Checksum)
binary.LittleEndian.PutUint32(buf[16:20], m.Length)
binary.LittleEndian.PutUint32(buf[20:24], m.Checksum)
if m.Payload != nil { if m.Payload != nil {
payload, err := m.Payload.MarshalBinary() return m.Payload.EncodeBinary(w)
if err != nil {
return err
}
copy(buf[minMessageSize:minMessageSize+m.Length], payload)
}
n, err := w.Write(buf)
if err != nil {
return err
}
// safety check to if we have written enough bytes.
if m.Length > 0 {
expectWritten := minMessageSize + m.Length
if uint32(n) != expectWritten {
return fmt.Errorf("expected to written exactly %d did %d", expectWritten, n)
}
} else {
expectWritten := minMessageSize
if n != expectWritten {
return fmt.Errorf("expected to written exactly %d did %d", expectWritten, n)
}
} }
return nil return nil
@ -240,21 +210,31 @@ func (m *Message) encode(w io.Writer) error {
// convert a command (string) to a byte slice filled with 0 bytes till // convert a command (string) to a byte slice filled with 0 bytes till
// size 12. // size 12.
func cmdToByteSlice(cmd commandType) []byte { func cmdToByteArray(cmd commandType) [cmdSize]byte {
cmdLen := len(cmd) cmdLen := len(cmd)
if cmdLen > 12 { if cmdLen > cmdSize {
panic("exceeded command max length of size 12") panic("exceeded command max length of size 12")
} }
// The command can have max 12 bytes, rest is filled with 0. // The command can have max 12 bytes, rest is filled with 0.
b := []byte(cmd) b := [cmdSize]byte{}
for i := 0; i < 12-cmdLen; i++ { for i := 0; i < cmdLen; i++ {
b = append(b, byte('\x00')) b[i] = cmd[i]
} }
return b return b
} }
func cmdByteArrayToString(cmd [cmdSize]byte) string {
buf := []byte{}
for i := 0; i < cmdSize; i++ {
if cmd[i] != 0 {
buf = append(buf, cmd[i])
}
}
return string(buf)
}
func sumSHA256(b []byte) []byte { func sumSHA256(b []byte) []byte {
h := sha256.New() h := sha256.New()
h.Write(b) h.Write(b)

View file

@ -0,0 +1,10 @@
package payload
import "io"
// Payload is anything that can be binary encoded/decoded.
type Payload interface {
EncodeBinary(io.Writer) error
DecodeBinary(io.Reader) error
Size() uint32
}

View file

@ -1,13 +0,0 @@
package payload
import (
"encoding"
)
// Payloader is anything that can be binary marshaled and unmarshaled.
// Every payload embbedded in messages need to satisfy the Payloader interface.
type Payloader interface {
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
Size() uint32
}

View file

@ -2,9 +2,11 @@ package payload
import ( import (
"encoding/binary" "encoding/binary"
"io"
) )
const ( const (
lenUA = 12
minVersionSize = 27 minVersionSize = 27
) )
@ -21,7 +23,7 @@ type Version struct {
// it's used to distinguish the node from public IP // it's used to distinguish the node from public IP
Nonce uint32 Nonce uint32
// client id // client id
UserAgent []byte UserAgent [lenUA]byte
// Height of the block chain // Height of the block chain
StartHeight uint32 StartHeight uint32
// Whether to receive and forward // Whether to receive and forward
@ -36,63 +38,53 @@ func NewVersion(id uint32, p uint16, ua string, h uint32, r bool) *Version {
Timestamp: 12345, Timestamp: 12345,
Port: p, Port: p,
Nonce: id, Nonce: id,
UserAgent: []byte(ua), UserAgent: uaToByteArray(ua),
StartHeight: 0, StartHeight: 0,
Relay: r, Relay: r,
} }
} }
// UnmarshalBinary implements the Payloader interface. // DecodeBinary implements the Payload interface.
func (p *Version) UnmarshalBinary(b []byte) error { func (p *Version) DecodeBinary(r io.Reader) error {
// Length of the user agent should be calculated dynamicaly. // TODO: Length of the user agent should be calculated dynamicaly.
// There is no information about the size or format of this. // There is no information about the size or format of this.
// the only thing we know is by looking at the #c source code. // the only thing we know is by looking at the #c source code.
// /NEO:{0}/ => /NEO:2.6.0/ // /NEO:{0}/ => /NEO:2.6.0/
lenUA := len(b) - minVersionSize err := binary.Read(r, binary.LittleEndian, &p.Version)
err = binary.Read(r, binary.LittleEndian, &p.Services)
err = binary.Read(r, binary.LittleEndian, &p.Timestamp)
err = binary.Read(r, binary.LittleEndian, &p.Port)
err = binary.Read(r, binary.LittleEndian, &p.Nonce)
err = binary.Read(r, binary.LittleEndian, &p.UserAgent)
err = binary.Read(r, binary.LittleEndian, &p.StartHeight)
err = binary.Read(r, binary.LittleEndian, &p.Relay)
p.Version = binary.LittleEndian.Uint32(b[0:4]) return err
p.Services = binary.LittleEndian.Uint64(b[4:12])
p.Timestamp = binary.LittleEndian.Uint32(b[12:16])
// FIXME: port's byteorder should be big endian according to the docs.
// but when connecting to the privnet docker image it's little endian.
p.Port = binary.LittleEndian.Uint16(b[16:18])
p.Nonce = binary.LittleEndian.Uint32(b[18:22])
p.UserAgent = b[22 : 22+lenUA]
curlen := 22 + lenUA
p.StartHeight = binary.LittleEndian.Uint32(b[curlen : curlen+4])
p.Relay = b[len(b)-1 : len(b)][0] == 1
return nil
} }
// MarshalBinary implements the Payloader interface. // EncodeBinary implements the Payload interface.
func (p *Version) MarshalBinary() ([]byte, error) { func (p *Version) EncodeBinary(w io.Writer) error {
b := make([]byte, p.Size()) err := binary.Write(w, binary.LittleEndian, p.Version)
err = binary.Write(w, binary.LittleEndian, p.Services)
err = binary.Write(w, binary.LittleEndian, p.Timestamp)
err = binary.Write(w, binary.LittleEndian, p.Port)
err = binary.Write(w, binary.LittleEndian, p.Nonce)
err = binary.Write(w, binary.LittleEndian, p.UserAgent)
err = binary.Write(w, binary.LittleEndian, p.StartHeight)
err = binary.Write(w, binary.LittleEndian, p.Relay)
binary.LittleEndian.PutUint32(b[0:4], p.Version) return err
binary.LittleEndian.PutUint64(b[4:12], p.Services)
binary.LittleEndian.PutUint32(b[12:16], p.Timestamp)
// FIXME: byte order (little / big)?
binary.LittleEndian.PutUint16(b[16:18], p.Port)
binary.LittleEndian.PutUint32(b[18:22], p.Nonce)
copy(b[22:22+len(p.UserAgent)], p.UserAgent) //
curLen := 22 + len(p.UserAgent)
binary.LittleEndian.PutUint32(b[curLen:curLen+4], p.StartHeight)
// yikes
var bln []byte
if p.Relay {
bln = []byte{1}
} else {
bln = []byte{0}
}
copy(b[curLen+4:len(b)], bln)
return b, nil
} }
// Size implements the payloader interface. // Size implements the payloader interface.
func (p *Version) Size() uint32 { func (p *Version) Size() uint32 {
return uint32(minVersionSize + len(p.UserAgent)) return uint32(minVersionSize + len(p.UserAgent))
} }
func uaToByteArray(ua string) [lenUA]byte {
buf := [lenUA]byte{}
for i := 0; i < lenUA; i++ {
buf[i] = ua[i]
}
return buf
}

View file

@ -7,15 +7,23 @@ import (
) )
func TestVersionEncodeDecode(t *testing.T) { func TestVersionEncodeDecode(t *testing.T) {
p := NewVersion(3000, "/NEO/", 0, true) version := NewVersion(13337, 3000, "./NEO:0.0.1/", 0, true)
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
p.Encode(buf) if err := version.EncodeBinary(buf); err != nil {
t.Fatal(err)
}
pd := &Version{} versionDecoded := &Version{}
pd.Decode(buf) if err := versionDecoded.DecodeBinary(buf); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(p, pd) { if !reflect.DeepEqual(version, versionDecoded) {
t.Fatalf("expect %v to be equal to %v", p, pd) t.Fatalf("expected both version payload to be equal: %+v and %+v", version, versionDecoded)
}
if version.Size() != uint32(minVersionSize+len(version.UserAgent)) {
t.Fatalf("Expected version size of %d", minVersionSize+lenUA)
} }
} }

View file

@ -181,9 +181,9 @@ func (s *Server) processMessage(msg *Message, peer *Peer) error {
return s.handleVersionCmd(msg.Payload.(*payload.Version), peer) return s.handleVersionCmd(msg.Payload.(*payload.Version), peer)
case cmdVerack: case cmdVerack:
case cmdGetAddr: case cmdGetAddr:
return s.handleGetAddrCmd(msg, peer) // return s.handleGetAddrCmd(msg, peer)
case cmdAddr: case cmdAddr:
return s.handleAddrCmd(msg.Payload.(*payload.AddressList), peer) // return s.handleAddrCmd(msg.Payload.(*payload.AddressList), peer)
case cmdGetHeaders: case cmdGetHeaders:
case cmdHeaders: case cmdHeaders:
case cmdGetBlocks: case cmdGetBlocks: