diff --git a/pkg/internal/testserdes/testing.go b/pkg/internal/testserdes/testing.go index 3883f461c..0be71a844 100644 --- a/pkg/internal/testserdes/testing.go +++ b/pkg/internal/testserdes/testing.go @@ -42,3 +42,37 @@ func DecodeBinary(data []byte, a io.Serializable) error { a.DecodeBinary(r) return r.Err } + +type encodable interface { + Encode(*io.BinWriter) error + Decode(*io.BinReader) error +} + +// EncodeDecode checks if expected stays the same after +// serializing/deserializing via encodable methods. +func EncodeDecode(t *testing.T, expected, actual encodable) { + data, err := Encode(expected) + require.NoError(t, err) + require.NoError(t, Decode(data, actual)) + require.Equal(t, expected, actual) +} + +// Encode serializes a to a byte slice. +func Encode(a encodable) ([]byte, error) { + w := io.NewBufBinWriter() + err := a.Encode(w.BinWriter) + if err != nil { + return nil, err + } + return w.Bytes(), nil +} + +// Decode deserializes a from a byte slice. +func Decode(data []byte, a encodable) error { + r := io.NewBinReaderFromBuf(data) + err := a.Decode(r) + if r.Err != nil { + return r.Err + } + return err +} diff --git a/pkg/network/compress.go b/pkg/network/compress.go new file mode 100644 index 000000000..7b6d6c3d0 --- /dev/null +++ b/pkg/network/compress.go @@ -0,0 +1,33 @@ +package network + +import ( + "bytes" + "io" + + "github.com/pierrec/lz4" +) + +// compress compresses bytes using lz4. +func compress(source []byte) ([]byte, error) { + dest := new(bytes.Buffer) + w := lz4.NewWriter(dest) + _, err := io.Copy(w, bytes.NewReader(source)) + if err != nil { + return nil, err + } + if w.Close() != nil { + return nil, err + } + return dest.Bytes(), nil +} + +// decompress decompresses bytes using lz4. +func decompress(source []byte) ([]byte, error) { + dest := new(bytes.Buffer) + r := lz4.NewReader(bytes.NewReader(source)) + _, err := io.Copy(dest, r) + if err != nil { + return nil, err + } + return dest.Bytes(), nil +} diff --git a/pkg/network/message.go b/pkg/network/message.go index 35402b844..b38c4f5eb 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -1,6 +1,7 @@ package network import ( + "errors" "fmt" "github.com/nspcc-dev/neo-go/pkg/consensus" @@ -12,18 +13,37 @@ import ( //go:generate stringer -type=CommandType +const ( + // PayloadMaxSize is maximum payload size in decompressed form. + PayloadMaxSize = 0x02000000 + // CompressionMinSize is the lower bound to apply compression. + CompressionMinSize = 1024 +) + // Message is the complete message send between nodes. type Message struct { + // Flags that represents whether a message is compressed. + // 0 for None, 1 for Compressed. + Flags MessageFlag // Command is byte command code. Command CommandType - // Length of the payload. - Length uint32 - // Payload send with the message. Payload payload.Payload + + // Compressed message payload. + compressedPayload []byte } +// MessageFlag represents compression level of message payload +type MessageFlag byte + +// Possible message flags +const ( + None MessageFlag = 0 + Compressed MessageFlag = 1 << iota +) + // CommandType represents the type of a message command. type CommandType byte @@ -65,47 +85,45 @@ const ( // NewMessage returns a new message with the given payload. func NewMessage(cmd CommandType, p payload.Payload) *Message { - var ( - size uint32 - ) - - if p != nil { - buf := io.NewBufBinWriter() - p.EncodeBinary(buf.BinWriter) - if buf.Err != nil { - panic(buf.Err) - } - b := buf.Bytes() - size = uint32(len(b)) - } - return &Message{ Command: cmd, - Length: size, Payload: p, + Flags: None, } } // Decode decodes a Message from the given reader. func (m *Message) Decode(br *io.BinReader) error { + m.Flags = MessageFlag(br.ReadB()) m.Command = CommandType(br.ReadB()) - m.Length = br.ReadU32LE() - if br.Err != nil { - return br.Err - } - // return if their is no payload. - if m.Length == 0 { + l := br.ReadVarUint() + // check the length first in order not to allocate memory + // for an empty compressed payload + if l == 0 { + m.Payload = payload.NewNullPayload() return nil } - return m.decodePayload(br) -} - -func (m *Message) decodePayload(br *io.BinReader) error { - buf := make([]byte, m.Length) - br.ReadBytes(buf) + m.compressedPayload = make([]byte, l) + br.ReadBytes(m.compressedPayload) if br.Err != nil { return br.Err } + if len(m.compressedPayload) > PayloadMaxSize { + return errors.New("invalid payload size") + } + return m.decodePayload() +} + +func (m *Message) decodePayload() error { + buf := m.compressedPayload + // try decompression + if m.Flags&Compressed != 0 { + d, err := decompress(m.compressedPayload) + if err != nil { + return err + } + buf = d + } r := io.NewBinReaderFromBuf(buf) var p payload.Payload @@ -147,16 +165,17 @@ func (m *Message) decodePayload(br *io.BinReader) error { // Encode encodes a Message to any given BinWriter. func (m *Message) Encode(br *io.BinWriter) error { + if err := m.tryCompressPayload(); err != nil { + return err + } + br.WriteB(byte(m.Flags)) br.WriteB(byte(m.Command)) - br.WriteU32LE(m.Length) - if m.Payload != nil { - m.Payload.EncodeBinary(br) - + if m.compressedPayload != nil { + br.WriteVarBytes(m.compressedPayload) + } else { + br.WriteB(0) } - if br.Err != nil { - return br.Err - } - return nil + return br.Err } // Bytes serializes a Message into the new allocated buffer and returns it. @@ -170,3 +189,37 @@ func (m *Message) Bytes() ([]byte, error) { } return w.Bytes(), nil } + +// tryCompressPayload sets message's compressed payload to serialized +// payload and compresses it in case if its size exceeds CompressionMinSize +func (m *Message) tryCompressPayload() error { + if m.Payload == nil { + return nil + } + buf := io.NewBufBinWriter() + m.Payload.EncodeBinary(buf.BinWriter) + if buf.Err != nil { + return buf.Err + } + compressedPayload := buf.Bytes() + if m.Flags&Compressed == 0 { + switch m.Payload.(type) { + case *payload.Headers, *payload.MerkleBlock, *payload.NullPayload: + break + default: + size := len(compressedPayload) + // try compression + if size > CompressionMinSize { + c, err := compress(compressedPayload) + if err == nil { + compressedPayload = c + m.Flags |= Compressed + } else { + return err + } + } + } + } + m.compressedPayload = compressedPayload + return nil +} diff --git a/pkg/network/message_test.go b/pkg/network/message_test.go index 1ae2e9d50..151bf8178 100644 --- a/pkg/network/message_test.go +++ b/pkg/network/message_test.go @@ -1 +1,93 @@ package network + +import ( + "testing" + "time" + + "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/network/payload" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestEncodeDecodeVersion(t *testing.T) { + // message with tiny payload, shouldn't be compressed + expected := NewMessage(CMDVersion, &payload.Version{ + Magic: 1, + Version: 2, + Services: 1, + Timestamp: uint32(time.Now().UnixNano()), + Port: 1234, + Nonce: 987, + UserAgent: []byte{1, 2, 3}, + StartHeight: 123, + Relay: true, + }) + testserdes.EncodeDecode(t, expected, &Message{}) + uncompressed, err := testserdes.EncodeBinary(expected.Payload) + require.NoError(t, err) + require.Equal(t, len(expected.compressedPayload), len(uncompressed)) + + // large payload should be compressed + largeArray := make([]byte, CompressionMinSize) + for i := range largeArray { + largeArray[i] = byte(i) + } + expected.Payload.(*payload.Version).UserAgent = largeArray + testserdes.EncodeDecode(t, expected, &Message{}) + uncompressed, err = testserdes.EncodeBinary(expected.Payload) + require.NoError(t, err) + require.NotEqual(t, len(expected.compressedPayload), len(uncompressed)) +} + +func TestEncodeDecodeHeaders(t *testing.T) { + // shouldn't try to compress headers payload + headers := &payload.Headers{Hdrs: make([]*block.Header, CompressionMinSize)} + for i := range headers.Hdrs { + h := &block.Header{ + Base: block.Base{ + Index: uint32(i + 1), + Script: transaction.Witness{ + InvocationScript: []byte{0x0}, + VerificationScript: []byte{0x1}, + }, + }, + } + h.Hash() + headers.Hdrs[i] = h + } + expected := NewMessage(CMDHeaders, headers) + testserdes.EncodeDecode(t, expected, &Message{}) + uncompressed, err := testserdes.EncodeBinary(expected.Payload) + require.NoError(t, err) + require.Equal(t, len(expected.compressedPayload), len(uncompressed)) +} + +func TestEncodeDecodeGetAddr(t *testing.T) { + // NullPayload should be handled properly + expected := NewMessage(CMDGetAddr, payload.NewNullPayload()) + testserdes.EncodeDecode(t, expected, &Message{}) +} + +func TestEncodeDecodeNil(t *testing.T) { + // nil payload should be decoded into NullPayload + expected := NewMessage(CMDGetAddr, nil) + encoded, err := testserdes.Encode(expected) + require.NoError(t, err) + decoded := &Message{} + err = testserdes.Decode(encoded, decoded) + require.NoError(t, err) + require.Equal(t, NewMessage(CMDGetAddr, payload.NewNullPayload()), decoded) +} + +func TestEncodeDecodePing(t *testing.T) { + expected := NewMessage(CMDPing, payload.NewPing(123, 456)) + testserdes.EncodeDecode(t, expected, &Message{}) +} + +func TestEncodeDecodeInventory(t *testing.T) { + expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) + testserdes.EncodeDecode(t, expected, &Message{}) +}