forked from TrueCloudLab/neoneo-go
parent
bd98940a54
commit
0cbd7823ab
4 changed files with 250 additions and 38 deletions
|
@ -42,3 +42,37 @@ func DecodeBinary(data []byte, a io.Serializable) error {
|
||||||
a.DecodeBinary(r)
|
a.DecodeBinary(r)
|
||||||
return r.Err
|
return r.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type encodable interface {
|
||||||
|
Encode(*io.BinWriter) error
|
||||||
|
Decode(*io.BinReader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeDecode checks if expected stays the same after
|
||||||
|
// serializing/deserializing via encodable methods.
|
||||||
|
func EncodeDecode(t *testing.T, expected, actual encodable) {
|
||||||
|
data, err := Encode(expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, Decode(data, actual))
|
||||||
|
require.Equal(t, expected, actual)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode serializes a to a byte slice.
|
||||||
|
func Encode(a encodable) ([]byte, error) {
|
||||||
|
w := io.NewBufBinWriter()
|
||||||
|
err := a.Encode(w.BinWriter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return w.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode deserializes a from a byte slice.
|
||||||
|
func Decode(data []byte, a encodable) error {
|
||||||
|
r := io.NewBinReaderFromBuf(data)
|
||||||
|
err := a.Decode(r)
|
||||||
|
if r.Err != nil {
|
||||||
|
return r.Err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
33
pkg/network/compress.go
Normal file
33
pkg/network/compress.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/pierrec/lz4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// compress compresses bytes using lz4.
|
||||||
|
func compress(source []byte) ([]byte, error) {
|
||||||
|
dest := new(bytes.Buffer)
|
||||||
|
w := lz4.NewWriter(dest)
|
||||||
|
_, err := io.Copy(w, bytes.NewReader(source))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if w.Close() != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dest.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompress decompresses bytes using lz4.
|
||||||
|
func decompress(source []byte) ([]byte, error) {
|
||||||
|
dest := new(bytes.Buffer)
|
||||||
|
r := lz4.NewReader(bytes.NewReader(source))
|
||||||
|
_, err := io.Copy(dest, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dest.Bytes(), nil
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/consensus"
|
"github.com/nspcc-dev/neo-go/pkg/consensus"
|
||||||
|
@ -12,18 +13,37 @@ import (
|
||||||
|
|
||||||
//go:generate stringer -type=CommandType
|
//go:generate stringer -type=CommandType
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PayloadMaxSize is maximum payload size in decompressed form.
|
||||||
|
PayloadMaxSize = 0x02000000
|
||||||
|
// CompressionMinSize is the lower bound to apply compression.
|
||||||
|
CompressionMinSize = 1024
|
||||||
|
)
|
||||||
|
|
||||||
// Message is the complete message send between nodes.
|
// Message is the complete message send between nodes.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
|
// Flags that represents whether a message is compressed.
|
||||||
|
// 0 for None, 1 for Compressed.
|
||||||
|
Flags MessageFlag
|
||||||
// Command is byte command code.
|
// Command is byte command code.
|
||||||
Command CommandType
|
Command CommandType
|
||||||
|
|
||||||
// Length of the payload.
|
|
||||||
Length uint32
|
|
||||||
|
|
||||||
// Payload send with the message.
|
// Payload send with the message.
|
||||||
Payload payload.Payload
|
Payload payload.Payload
|
||||||
|
|
||||||
|
// Compressed message payload.
|
||||||
|
compressedPayload []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MessageFlag represents compression level of message payload
|
||||||
|
type MessageFlag byte
|
||||||
|
|
||||||
|
// Possible message flags
|
||||||
|
const (
|
||||||
|
None MessageFlag = 0
|
||||||
|
Compressed MessageFlag = 1 << iota
|
||||||
|
)
|
||||||
|
|
||||||
// CommandType represents the type of a message command.
|
// CommandType represents the type of a message command.
|
||||||
type CommandType byte
|
type CommandType byte
|
||||||
|
|
||||||
|
@ -65,47 +85,45 @@ const (
|
||||||
|
|
||||||
// NewMessage returns a new message with the given payload.
|
// NewMessage returns a new message with the given payload.
|
||||||
func NewMessage(cmd CommandType, p payload.Payload) *Message {
|
func NewMessage(cmd CommandType, p payload.Payload) *Message {
|
||||||
var (
|
|
||||||
size uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
if p != nil {
|
|
||||||
buf := io.NewBufBinWriter()
|
|
||||||
p.EncodeBinary(buf.BinWriter)
|
|
||||||
if buf.Err != nil {
|
|
||||||
panic(buf.Err)
|
|
||||||
}
|
|
||||||
b := buf.Bytes()
|
|
||||||
size = uint32(len(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Message{
|
return &Message{
|
||||||
Command: cmd,
|
Command: cmd,
|
||||||
Length: size,
|
|
||||||
Payload: p,
|
Payload: p,
|
||||||
|
Flags: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode decodes a Message from the given reader.
|
// Decode decodes a Message from the given reader.
|
||||||
func (m *Message) Decode(br *io.BinReader) error {
|
func (m *Message) Decode(br *io.BinReader) error {
|
||||||
|
m.Flags = MessageFlag(br.ReadB())
|
||||||
m.Command = CommandType(br.ReadB())
|
m.Command = CommandType(br.ReadB())
|
||||||
m.Length = br.ReadU32LE()
|
l := br.ReadVarUint()
|
||||||
if br.Err != nil {
|
// check the length first in order not to allocate memory
|
||||||
return br.Err
|
// for an empty compressed payload
|
||||||
}
|
if l == 0 {
|
||||||
// return if their is no payload.
|
m.Payload = payload.NewNullPayload()
|
||||||
if m.Length == 0 {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.decodePayload(br)
|
m.compressedPayload = make([]byte, l)
|
||||||
}
|
br.ReadBytes(m.compressedPayload)
|
||||||
|
|
||||||
func (m *Message) decodePayload(br *io.BinReader) error {
|
|
||||||
buf := make([]byte, m.Length)
|
|
||||||
br.ReadBytes(buf)
|
|
||||||
if br.Err != nil {
|
if br.Err != nil {
|
||||||
return br.Err
|
return br.Err
|
||||||
}
|
}
|
||||||
|
if len(m.compressedPayload) > PayloadMaxSize {
|
||||||
|
return errors.New("invalid payload size")
|
||||||
|
}
|
||||||
|
return m.decodePayload()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Message) decodePayload() error {
|
||||||
|
buf := m.compressedPayload
|
||||||
|
// try decompression
|
||||||
|
if m.Flags&Compressed != 0 {
|
||||||
|
d, err := decompress(m.compressedPayload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buf = d
|
||||||
|
}
|
||||||
|
|
||||||
r := io.NewBinReaderFromBuf(buf)
|
r := io.NewBinReaderFromBuf(buf)
|
||||||
var p payload.Payload
|
var p payload.Payload
|
||||||
|
@ -147,16 +165,17 @@ func (m *Message) decodePayload(br *io.BinReader) error {
|
||||||
|
|
||||||
// Encode encodes a Message to any given BinWriter.
|
// Encode encodes a Message to any given BinWriter.
|
||||||
func (m *Message) Encode(br *io.BinWriter) error {
|
func (m *Message) Encode(br *io.BinWriter) error {
|
||||||
|
if err := m.tryCompressPayload(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
br.WriteB(byte(m.Flags))
|
||||||
br.WriteB(byte(m.Command))
|
br.WriteB(byte(m.Command))
|
||||||
br.WriteU32LE(m.Length)
|
if m.compressedPayload != nil {
|
||||||
if m.Payload != nil {
|
br.WriteVarBytes(m.compressedPayload)
|
||||||
m.Payload.EncodeBinary(br)
|
} else {
|
||||||
|
br.WriteB(0)
|
||||||
}
|
}
|
||||||
if br.Err != nil {
|
return br.Err
|
||||||
return br.Err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bytes serializes a Message into the new allocated buffer and returns it.
|
// Bytes serializes a Message into the new allocated buffer and returns it.
|
||||||
|
@ -170,3 +189,37 @@ func (m *Message) Bytes() ([]byte, error) {
|
||||||
}
|
}
|
||||||
return w.Bytes(), nil
|
return w.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tryCompressPayload sets message's compressed payload to serialized
|
||||||
|
// payload and compresses it in case if its size exceeds CompressionMinSize
|
||||||
|
func (m *Message) tryCompressPayload() error {
|
||||||
|
if m.Payload == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
buf := io.NewBufBinWriter()
|
||||||
|
m.Payload.EncodeBinary(buf.BinWriter)
|
||||||
|
if buf.Err != nil {
|
||||||
|
return buf.Err
|
||||||
|
}
|
||||||
|
compressedPayload := buf.Bytes()
|
||||||
|
if m.Flags&Compressed == 0 {
|
||||||
|
switch m.Payload.(type) {
|
||||||
|
case *payload.Headers, *payload.MerkleBlock, *payload.NullPayload:
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
size := len(compressedPayload)
|
||||||
|
// try compression
|
||||||
|
if size > CompressionMinSize {
|
||||||
|
c, err := compress(compressedPayload)
|
||||||
|
if err == nil {
|
||||||
|
compressedPayload = c
|
||||||
|
m.Flags |= Compressed
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.compressedPayload = compressedPayload
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1 +1,93 @@
|
||||||
package network
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/block"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/internal/testserdes"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/network/payload"
|
||||||
|
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEncodeDecodeVersion(t *testing.T) {
|
||||||
|
// message with tiny payload, shouldn't be compressed
|
||||||
|
expected := NewMessage(CMDVersion, &payload.Version{
|
||||||
|
Magic: 1,
|
||||||
|
Version: 2,
|
||||||
|
Services: 1,
|
||||||
|
Timestamp: uint32(time.Now().UnixNano()),
|
||||||
|
Port: 1234,
|
||||||
|
Nonce: 987,
|
||||||
|
UserAgent: []byte{1, 2, 3},
|
||||||
|
StartHeight: 123,
|
||||||
|
Relay: true,
|
||||||
|
})
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
|
||||||
|
|
||||||
|
// large payload should be compressed
|
||||||
|
largeArray := make([]byte, CompressionMinSize)
|
||||||
|
for i := range largeArray {
|
||||||
|
largeArray[i] = byte(i)
|
||||||
|
}
|
||||||
|
expected.Payload.(*payload.Version).UserAgent = largeArray
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
uncompressed, err = testserdes.EncodeBinary(expected.Payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEqual(t, len(expected.compressedPayload), len(uncompressed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeHeaders(t *testing.T) {
|
||||||
|
// shouldn't try to compress headers payload
|
||||||
|
headers := &payload.Headers{Hdrs: make([]*block.Header, CompressionMinSize)}
|
||||||
|
for i := range headers.Hdrs {
|
||||||
|
h := &block.Header{
|
||||||
|
Base: block.Base{
|
||||||
|
Index: uint32(i + 1),
|
||||||
|
Script: transaction.Witness{
|
||||||
|
InvocationScript: []byte{0x0},
|
||||||
|
VerificationScript: []byte{0x1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h.Hash()
|
||||||
|
headers.Hdrs[i] = h
|
||||||
|
}
|
||||||
|
expected := NewMessage(CMDHeaders, headers)
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
uncompressed, err := testserdes.EncodeBinary(expected.Payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, len(expected.compressedPayload), len(uncompressed))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeGetAddr(t *testing.T) {
|
||||||
|
// NullPayload should be handled properly
|
||||||
|
expected := NewMessage(CMDGetAddr, payload.NewNullPayload())
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeNil(t *testing.T) {
|
||||||
|
// nil payload should be decoded into NullPayload
|
||||||
|
expected := NewMessage(CMDGetAddr, nil)
|
||||||
|
encoded, err := testserdes.Encode(expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
decoded := &Message{}
|
||||||
|
err = testserdes.Decode(encoded, decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, NewMessage(CMDGetAddr, payload.NewNullPayload()), decoded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodePing(t *testing.T) {
|
||||||
|
expected := NewMessage(CMDPing, payload.NewPing(123, 456))
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeInventory(t *testing.T) {
|
||||||
|
expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}}))
|
||||||
|
testserdes.EncodeDecode(t, expected, &Message{})
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue