Merge pull request #987 from nspcc-dev/neo3/protocol/compression

protocol: add payload compression
This commit is contained in:
Roman Khimov 2020-05-26 23:31:17 +03:00 committed by GitHub
commit 5733573d2e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 250 additions and 38 deletions

View file

@ -42,3 +42,37 @@ func DecodeBinary(data []byte, a io.Serializable) error {
a.DecodeBinary(r) a.DecodeBinary(r)
return r.Err return r.Err
} }
type encodable interface {
Encode(*io.BinWriter) error
Decode(*io.BinReader) error
}
// EncodeDecode checks if expected stays the same after
// serializing/deserializing via encodable methods.
func EncodeDecode(t *testing.T, expected, actual encodable) {
data, err := Encode(expected)
require.NoError(t, err)
require.NoError(t, Decode(data, actual))
require.Equal(t, expected, actual)
}
// Encode serializes a to a byte slice.
func Encode(a encodable) ([]byte, error) {
w := io.NewBufBinWriter()
err := a.Encode(w.BinWriter)
if err != nil {
return nil, err
}
return w.Bytes(), nil
}
// Decode deserializes a from a byte slice.
func Decode(data []byte, a encodable) error {
r := io.NewBinReaderFromBuf(data)
err := a.Decode(r)
if r.Err != nil {
return r.Err
}
return err
}

33
pkg/network/compress.go Normal file
View file

@ -0,0 +1,33 @@
package network
import (
"bytes"
"io"
"github.com/pierrec/lz4"
)
// compress compresses bytes using lz4.
func compress(source []byte) ([]byte, error) {
dest := new(bytes.Buffer)
w := lz4.NewWriter(dest)
_, err := io.Copy(w, bytes.NewReader(source))
if err != nil {
return nil, err
}
if w.Close() != nil {
return nil, err
}
return dest.Bytes(), nil
}
// decompress decompresses bytes using lz4.
func decompress(source []byte) ([]byte, error) {
dest := new(bytes.Buffer)
r := lz4.NewReader(bytes.NewReader(source))
_, err := io.Copy(dest, r)
if err != nil {
return nil, err
}
return dest.Bytes(), nil
}

View file

@ -1,6 +1,7 @@
package network package network
import ( import (
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/consensus"
@ -12,18 +13,37 @@ import (
//go:generate stringer -type=CommandType //go:generate stringer -type=CommandType
const (
// PayloadMaxSize is maximum payload size in decompressed form.
PayloadMaxSize = 0x02000000
// CompressionMinSize is the lower bound to apply compression.
CompressionMinSize = 1024
)
// Message is the complete message send between nodes. // Message is the complete message send between nodes.
type Message struct { type Message struct {
// Flags that represents whether a message is compressed.
// 0 for None, 1 for Compressed.
Flags MessageFlag
// Command is byte command code. // Command is byte command code.
Command CommandType Command CommandType
// Length of the payload.
Length uint32
// Payload send with the message. // Payload send with the message.
Payload payload.Payload Payload payload.Payload
// Compressed message payload.
compressedPayload []byte
} }
// MessageFlag represents compression level of message payload
type MessageFlag byte
// Possible message flags
const (
None MessageFlag = 0
Compressed MessageFlag = 1 << iota
)
// CommandType represents the type of a message command. // CommandType represents the type of a message command.
type CommandType byte type CommandType byte
@ -65,47 +85,45 @@ const (
// NewMessage returns a new message with the given payload. // NewMessage returns a new message with the given payload.
func NewMessage(cmd CommandType, p payload.Payload) *Message { func NewMessage(cmd CommandType, p payload.Payload) *Message {
var (
size uint32
)
if p != nil {
buf := io.NewBufBinWriter()
p.EncodeBinary(buf.BinWriter)
if buf.Err != nil {
panic(buf.Err)
}
b := buf.Bytes()
size = uint32(len(b))
}
return &Message{ return &Message{
Command: cmd, Command: cmd,
Length: size,
Payload: p, Payload: p,
Flags: None,
} }
} }
// Decode decodes a Message from the given reader. // Decode decodes a Message from the given reader.
func (m *Message) Decode(br *io.BinReader) error { func (m *Message) Decode(br *io.BinReader) error {
m.Flags = MessageFlag(br.ReadB())
m.Command = CommandType(br.ReadB()) m.Command = CommandType(br.ReadB())
m.Length = br.ReadU32LE() l := br.ReadVarUint()
if br.Err != nil { // check the length first in order not to allocate memory
return br.Err // for an empty compressed payload
} if l == 0 {
// return if their is no payload. m.Payload = payload.NewNullPayload()
if m.Length == 0 {
return nil return nil
} }
return m.decodePayload(br) m.compressedPayload = make([]byte, l)
} br.ReadBytes(m.compressedPayload)
func (m *Message) decodePayload(br *io.BinReader) error {
buf := make([]byte, m.Length)
br.ReadBytes(buf)
if br.Err != nil { if br.Err != nil {
return br.Err return br.Err
} }
if len(m.compressedPayload) > PayloadMaxSize {
return errors.New("invalid payload size")
}
return m.decodePayload()
}
func (m *Message) decodePayload() error {
buf := m.compressedPayload
// try decompression
if m.Flags&Compressed != 0 {
d, err := decompress(m.compressedPayload)
if err != nil {
return err
}
buf = d
}
r := io.NewBinReaderFromBuf(buf) r := io.NewBinReaderFromBuf(buf)
var p payload.Payload var p payload.Payload
@ -147,16 +165,17 @@ func (m *Message) decodePayload(br *io.BinReader) error {
// Encode encodes a Message to any given BinWriter. // Encode encodes a Message to any given BinWriter.
func (m *Message) Encode(br *io.BinWriter) error { func (m *Message) Encode(br *io.BinWriter) error {
if err := m.tryCompressPayload(); err != nil {
return err
}
br.WriteB(byte(m.Flags))
br.WriteB(byte(m.Command)) br.WriteB(byte(m.Command))
br.WriteU32LE(m.Length) if m.compressedPayload != nil {
if m.Payload != nil { br.WriteVarBytes(m.compressedPayload)
m.Payload.EncodeBinary(br) } else {
br.WriteB(0)
} }
if br.Err != nil { return br.Err
return br.Err
}
return nil
} }
// Bytes serializes a Message into the new allocated buffer and returns it. // Bytes serializes a Message into the new allocated buffer and returns it.
@ -170,3 +189,37 @@ func (m *Message) Bytes() ([]byte, error) {
} }
return w.Bytes(), nil return w.Bytes(), nil
} }
// tryCompressPayload sets message's compressed payload to serialized
// payload and compresses it in case if its size exceeds CompressionMinSize
func (m *Message) tryCompressPayload() error {
if m.Payload == nil {
return nil
}
buf := io.NewBufBinWriter()
m.Payload.EncodeBinary(buf.BinWriter)
if buf.Err != nil {
return buf.Err
}
compressedPayload := buf.Bytes()
if m.Flags&Compressed == 0 {
switch m.Payload.(type) {
case *payload.Headers, *payload.MerkleBlock, *payload.NullPayload:
break
default:
size := len(compressedPayload)
// try compression
if size > CompressionMinSize {
c, err := compress(compressedPayload)
if err == nil {
compressedPayload = c
m.Flags |= Compressed
} else {
return err
}
}
}
}
m.compressedPayload = compressedPayload
return nil
}

View file

@ -1 +1,93 @@
package network package network
import (
"testing"
"time"
"github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/require"
)
func TestEncodeDecodeVersion(t *testing.T) {
// message with tiny payload, shouldn't be compressed
expected := NewMessage(CMDVersion, &payload.Version{
Magic: 1,
Version: 2,
Services: 1,
Timestamp: uint32(time.Now().UnixNano()),
Port: 1234,
Nonce: 987,
UserAgent: []byte{1, 2, 3},
StartHeight: 123,
Relay: true,
})
testserdes.EncodeDecode(t, expected, &Message{})
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
require.NoError(t, err)
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
// large payload should be compressed
largeArray := make([]byte, CompressionMinSize)
for i := range largeArray {
largeArray[i] = byte(i)
}
expected.Payload.(*payload.Version).UserAgent = largeArray
testserdes.EncodeDecode(t, expected, &Message{})
uncompressed, err = testserdes.EncodeBinary(expected.Payload)
require.NoError(t, err)
require.NotEqual(t, len(expected.compressedPayload), len(uncompressed))
}
func TestEncodeDecodeHeaders(t *testing.T) {
// shouldn't try to compress headers payload
headers := &payload.Headers{Hdrs: make([]*block.Header, CompressionMinSize)}
for i := range headers.Hdrs {
h := &block.Header{
Base: block.Base{
Index: uint32(i + 1),
Script: transaction.Witness{
InvocationScript: []byte{0x0},
VerificationScript: []byte{0x1},
},
},
}
h.Hash()
headers.Hdrs[i] = h
}
expected := NewMessage(CMDHeaders, headers)
testserdes.EncodeDecode(t, expected, &Message{})
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
require.NoError(t, err)
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
}
func TestEncodeDecodeGetAddr(t *testing.T) {
// NullPayload should be handled properly
expected := NewMessage(CMDGetAddr, payload.NewNullPayload())
testserdes.EncodeDecode(t, expected, &Message{})
}
func TestEncodeDecodeNil(t *testing.T) {
// nil payload should be decoded into NullPayload
expected := NewMessage(CMDGetAddr, nil)
encoded, err := testserdes.Encode(expected)
require.NoError(t, err)
decoded := &Message{}
err = testserdes.Decode(encoded, decoded)
require.NoError(t, err)
require.Equal(t, NewMessage(CMDGetAddr, payload.NewNullPayload()), decoded)
}
func TestEncodeDecodePing(t *testing.T) {
expected := NewMessage(CMDPing, payload.NewPing(123, 456))
testserdes.EncodeDecode(t, expected, &Message{})
}
func TestEncodeDecodeInventory(t *testing.T) {
expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}}))
testserdes.EncodeDecode(t, expected, &Message{})
}