parent
bd98940a54
commit
0cbd7823ab
4 changed files with 250 additions and 38 deletions
|
@ -42,3 +42,37 @@ func DecodeBinary(data []byte, a io.Serializable) error {
|
|||
a.DecodeBinary(r)
|
||||
return r.Err
|
||||
}
|
||||
|
||||
type encodable interface {
|
||||
Encode(*io.BinWriter) error
|
||||
Decode(*io.BinReader) error
|
||||
}
|
||||
|
||||
// EncodeDecode checks if expected stays the same after
|
||||
// serializing/deserializing via encodable methods.
|
||||
func EncodeDecode(t *testing.T, expected, actual encodable) {
|
||||
data, err := Encode(expected)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, Decode(data, actual))
|
||||
require.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
// Encode serializes a to a byte slice.
|
||||
func Encode(a encodable) ([]byte, error) {
|
||||
w := io.NewBufBinWriter()
|
||||
err := a.Encode(w.BinWriter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return w.Bytes(), nil
|
||||
}
|
||||
|
||||
// Decode deserializes a from a byte slice.
|
||||
func Decode(data []byte, a encodable) error {
|
||||
r := io.NewBinReaderFromBuf(data)
|
||||
err := a.Decode(r)
|
||||
if r.Err != nil {
|
||||
return r.Err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
33
pkg/network/compress.go
Normal file
33
pkg/network/compress.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/pierrec/lz4"
|
||||
)
|
||||
|
||||
// compress compresses bytes using lz4.
|
||||
func compress(source []byte) ([]byte, error) {
|
||||
dest := new(bytes.Buffer)
|
||||
w := lz4.NewWriter(dest)
|
||||
_, err := io.Copy(w, bytes.NewReader(source))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if w.Close() != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dest.Bytes(), nil
|
||||
}
|
||||
|
||||
// decompress decompresses bytes using lz4.
|
||||
func decompress(source []byte) ([]byte, error) {
|
||||
dest := new(bytes.Buffer)
|
||||
r := lz4.NewReader(bytes.NewReader(source))
|
||||
_, err := io.Copy(dest, r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dest.Bytes(), nil
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/consensus"
|
||||
|
@ -12,18 +13,37 @@ import (
|
|||
|
||||
//go:generate stringer -type=CommandType
|
||||
|
||||
const (
|
||||
// PayloadMaxSize is maximum payload size in decompressed form.
|
||||
PayloadMaxSize = 0x02000000
|
||||
// CompressionMinSize is the lower bound to apply compression.
|
||||
CompressionMinSize = 1024
|
||||
)
|
||||
|
||||
// Message is the complete message send between nodes.
|
||||
type Message struct {
|
||||
// Flags that represents whether a message is compressed.
|
||||
// 0 for None, 1 for Compressed.
|
||||
Flags MessageFlag
|
||||
// Command is byte command code.
|
||||
Command CommandType
|
||||
|
||||
// Length of the payload.
|
||||
Length uint32
|
||||
|
||||
// Payload send with the message.
|
||||
Payload payload.Payload
|
||||
|
||||
// Compressed message payload.
|
||||
compressedPayload []byte
|
||||
}
|
||||
|
||||
// MessageFlag represents compression level of message payload
|
||||
type MessageFlag byte
|
||||
|
||||
// Possible message flags
|
||||
const (
|
||||
None MessageFlag = 0
|
||||
Compressed MessageFlag = 1 << iota
|
||||
)
|
||||
|
||||
// CommandType represents the type of a message command.
|
||||
type CommandType byte
|
||||
|
||||
|
@ -65,47 +85,45 @@ const (
|
|||
|
||||
// NewMessage returns a new message with the given payload.
|
||||
func NewMessage(cmd CommandType, p payload.Payload) *Message {
|
||||
var (
|
||||
size uint32
|
||||
)
|
||||
|
||||
if p != nil {
|
||||
buf := io.NewBufBinWriter()
|
||||
p.EncodeBinary(buf.BinWriter)
|
||||
if buf.Err != nil {
|
||||
panic(buf.Err)
|
||||
}
|
||||
b := buf.Bytes()
|
||||
size = uint32(len(b))
|
||||
}
|
||||
|
||||
return &Message{
|
||||
Command: cmd,
|
||||
Length: size,
|
||||
Payload: p,
|
||||
Flags: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes a Message from the given reader.
|
||||
func (m *Message) Decode(br *io.BinReader) error {
|
||||
m.Flags = MessageFlag(br.ReadB())
|
||||
m.Command = CommandType(br.ReadB())
|
||||
m.Length = br.ReadU32LE()
|
||||
if br.Err != nil {
|
||||
return br.Err
|
||||
}
|
||||
// return if their is no payload.
|
||||
if m.Length == 0 {
|
||||
l := br.ReadVarUint()
|
||||
// check the length first in order not to allocate memory
|
||||
// for an empty compressed payload
|
||||
if l == 0 {
|
||||
m.Payload = payload.NewNullPayload()
|
||||
return nil
|
||||
}
|
||||
return m.decodePayload(br)
|
||||
}
|
||||
|
||||
func (m *Message) decodePayload(br *io.BinReader) error {
|
||||
buf := make([]byte, m.Length)
|
||||
br.ReadBytes(buf)
|
||||
m.compressedPayload = make([]byte, l)
|
||||
br.ReadBytes(m.compressedPayload)
|
||||
if br.Err != nil {
|
||||
return br.Err
|
||||
}
|
||||
if len(m.compressedPayload) > PayloadMaxSize {
|
||||
return errors.New("invalid payload size")
|
||||
}
|
||||
return m.decodePayload()
|
||||
}
|
||||
|
||||
func (m *Message) decodePayload() error {
|
||||
buf := m.compressedPayload
|
||||
// try decompression
|
||||
if m.Flags&Compressed != 0 {
|
||||
d, err := decompress(m.compressedPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf = d
|
||||
}
|
||||
|
||||
r := io.NewBinReaderFromBuf(buf)
|
||||
var p payload.Payload
|
||||
|
@ -147,17 +165,18 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
|||
|
||||
// Encode encodes a Message to any given BinWriter.
|
||||
func (m *Message) Encode(br *io.BinWriter) error {
|
||||
if err := m.tryCompressPayload(); err != nil {
|
||||
return err
|
||||
}
|
||||
br.WriteB(byte(m.Flags))
|
||||
br.WriteB(byte(m.Command))
|
||||
br.WriteU32LE(m.Length)
|
||||
if m.Payload != nil {
|
||||
m.Payload.EncodeBinary(br)
|
||||
|
||||
if m.compressedPayload != nil {
|
||||
br.WriteVarBytes(m.compressedPayload)
|
||||
} else {
|
||||
br.WriteB(0)
|
||||
}
|
||||
if br.Err != nil {
|
||||
return br.Err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes serializes a Message into the new allocated buffer and returns it.
|
||||
func (m *Message) Bytes() ([]byte, error) {
|
||||
|
@ -170,3 +189,37 @@ func (m *Message) Bytes() ([]byte, error) {
|
|||
}
|
||||
return w.Bytes(), nil
|
||||
}
|
||||
|
||||
// tryCompressPayload sets message's compressed payload to serialized
|
||||
// payload and compresses it in case if its size exceeds CompressionMinSize
|
||||
func (m *Message) tryCompressPayload() error {
|
||||
if m.Payload == nil {
|
||||
return nil
|
||||
}
|
||||
buf := io.NewBufBinWriter()
|
||||
m.Payload.EncodeBinary(buf.BinWriter)
|
||||
if buf.Err != nil {
|
||||
return buf.Err
|
||||
}
|
||||
compressedPayload := buf.Bytes()
|
||||
if m.Flags&Compressed == 0 {
|
||||
switch m.Payload.(type) {
|
||||
case *payload.Headers, *payload.MerkleBlock, *payload.NullPayload:
|
||||
break
|
||||
default:
|
||||
size := len(compressedPayload)
|
||||
// try compression
|
||||
if size > CompressionMinSize {
|
||||
c, err := compress(compressedPayload)
|
||||
if err == nil {
|
||||
compressedPayload = c
|
||||
m.Flags |= Compressed
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
m.compressedPayload = compressedPayload
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1 +1,93 @@
|
|||
package network
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEncodeDecodeVersion(t *testing.T) {
|
||||
// message with tiny payload, shouldn't be compressed
|
||||
expected := NewMessage(CMDVersion, &payload.Version{
|
||||
Magic: 1,
|
||||
Version: 2,
|
||||
Services: 1,
|
||||
Timestamp: uint32(time.Now().UnixNano()),
|
||||
Port: 1234,
|
||||
Nonce: 987,
|
||||
UserAgent: []byte{1, 2, 3},
|
||||
StartHeight: 123,
|
||||
Relay: true,
|
||||
})
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
|
||||
|
||||
// large payload should be compressed
|
||||
largeArray := make([]byte, CompressionMinSize)
|
||||
for i := range largeArray {
|
||||
largeArray[i] = byte(i)
|
||||
}
|
||||
expected.Payload.(*payload.Version).UserAgent = largeArray
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
uncompressed, err = testserdes.EncodeBinary(expected.Payload)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, len(expected.compressedPayload), len(uncompressed))
|
||||
}
|
||||
|
||||
func TestEncodeDecodeHeaders(t *testing.T) {
|
||||
// shouldn't try to compress headers payload
|
||||
headers := &payload.Headers{Hdrs: make([]*block.Header, CompressionMinSize)}
|
||||
for i := range headers.Hdrs {
|
||||
h := &block.Header{
|
||||
Base: block.Base{
|
||||
Index: uint32(i + 1),
|
||||
Script: transaction.Witness{
|
||||
InvocationScript: []byte{0x0},
|
||||
VerificationScript: []byte{0x1},
|
||||
},
|
||||
},
|
||||
}
|
||||
h.Hash()
|
||||
headers.Hdrs[i] = h
|
||||
}
|
||||
expected := NewMessage(CMDHeaders, headers)
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
|
||||
}
|
||||
|
||||
func TestEncodeDecodeGetAddr(t *testing.T) {
|
||||
// NullPayload should be handled properly
|
||||
expected := NewMessage(CMDGetAddr, payload.NewNullPayload())
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeNil(t *testing.T) {
|
||||
// nil payload should be decoded into NullPayload
|
||||
expected := NewMessage(CMDGetAddr, nil)
|
||||
encoded, err := testserdes.Encode(expected)
|
||||
require.NoError(t, err)
|
||||
decoded := &Message{}
|
||||
err = testserdes.Decode(encoded, decoded)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, NewMessage(CMDGetAddr, payload.NewNullPayload()), decoded)
|
||||
}
|
||||
|
||||
func TestEncodeDecodePing(t *testing.T) {
|
||||
expected := NewMessage(CMDPing, payload.NewPing(123, 456))
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
}
|
||||
|
||||
func TestEncodeDecodeInventory(t *testing.T) {
|
||||
expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}}))
|
||||
testserdes.EncodeDecode(t, expected, &Message{})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue