Merge pull request #1591 from nspcc-dev/tests/network

network: add tests for `Message` serialization
This commit is contained in:
Roman Khimov 2020-12-04 22:40:15 +03:00 committed by GitHub
commit c6dbdddba9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 282 additions and 11 deletions

View file

@ -159,7 +159,7 @@ func (m *Message) decodePayload() error {
case CMDTX: case CMDTX:
p = &transaction.Transaction{Network: m.Network} p = &transaction.Transaction{Network: m.Network}
case CMDMerkleBlock: case CMDMerkleBlock:
p = &payload.MerkleBlock{} p = &payload.MerkleBlock{Network: m.Network}
case CMDPing, CMDPong: case CMDPing, CMDPong:
p = &payload.Ping{} p = &payload.Ping{}
case CMDNotFound: case CMDNotFound:
@ -196,9 +196,6 @@ func (m *Message) Bytes() ([]byte, error) {
if err := m.Encode(w.BinWriter); err != nil { if err := m.Encode(w.BinWriter); err != nil {
return nil, err return nil, err
} }
if w.Err != nil {
return nil, w.Err
}
return w.Bytes(), nil return w.Bytes(), nil
} }

View file

@ -1,12 +1,17 @@
package network package network
import ( import (
"errors"
"math/rand"
"testing" "testing"
"time" "time"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/network/capability" "github.com/nspcc-dev/neo-go/pkg/network/capability"
"github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -72,8 +77,7 @@ func TestEncodeDecodeHeaders(t *testing.T) {
func TestEncodeDecodeGetAddr(t *testing.T) { func TestEncodeDecodeGetAddr(t *testing.T) {
// NullPayload should be handled properly // NullPayload should be handled properly
expected := NewMessage(CMDGetAddr, payload.NewNullPayload()) testEncodeDecode(t, CMDGetAddr, payload.NewNullPayload())
testserdes.EncodeDecode(t, expected, &Message{})
} }
func TestEncodeDecodeNil(t *testing.T) { func TestEncodeDecodeNil(t *testing.T) {
@ -88,11 +92,239 @@ func TestEncodeDecodeNil(t *testing.T) {
} }
func TestEncodeDecodePing(t *testing.T) { func TestEncodeDecodePing(t *testing.T) {
expected := NewMessage(CMDPing, payload.NewPing(123, 456)) testEncodeDecode(t, CMDPing, payload.NewPing(123, 456))
testserdes.EncodeDecode(t, expected, &Message{})
} }
func TestEncodeDecodeInventory(t *testing.T) { func TestEncodeDecodeInventory(t *testing.T) {
expected := NewMessage(CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}})) testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ConsensusType, []util.Uint256{{1, 2, 3}}))
testserdes.EncodeDecode(t, expected, &Message{}) }
func TestEncodeDecodeAddr(t *testing.T) {
const count = 3
p := payload.NewAddressList(count)
p.Addrs[0] = &payload.AddressAndTime{
Timestamp: rand.Uint32(),
Capabilities: capability.Capabilities{{
Type: capability.FullNode,
Data: &capability.Node{StartHeight: rand.Uint32()},
}},
}
p.Addrs[1] = &payload.AddressAndTime{
Timestamp: rand.Uint32(),
Capabilities: capability.Capabilities{{
Type: capability.TCPServer,
Data: &capability.Server{Port: uint16(rand.Uint32())},
}},
}
p.Addrs[2] = &payload.AddressAndTime{
Timestamp: rand.Uint32(),
Capabilities: capability.Capabilities{{
Type: capability.WSServer,
Data: &capability.Server{Port: uint16(rand.Uint32())},
}},
}
testEncodeDecode(t, CMDAddr, p)
}
func TestEncodeDecodeBlock(t *testing.T) {
t.Run("good", func(t *testing.T) {
testEncodeDecode(t, CMDBlock, newDummyBlock(1))
})
t.Run("invalid state root enabled setting", func(t *testing.T) {
expected := NewMessage(CMDBlock, newDummyBlock(1))
expected.Network = netmode.UnitTestNet
data, err := testserdes.Encode(expected)
require.NoError(t, err)
require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet, StateRootInHeader: true}))
})
}
func TestEncodeDecodeGetBlock(t *testing.T) {
t.Run("good, Count>0", func(t *testing.T) {
testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{
HashStart: random.Uint256(),
Count: int16(rand.Uint32() >> 17),
})
})
t.Run("good, Count=-1", func(t *testing.T) {
testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{
HashStart: random.Uint256(),
Count: -1,
})
})
t.Run("bad, Count=-2", func(t *testing.T) {
testEncodeDecodeFail(t, CMDGetBlocks, &payload.GetBlocks{
HashStart: random.Uint256(),
Count: -2,
})
})
}
func TestEnodeDecodeGetHeaders(t *testing.T) {
testEncodeDecode(t, CMDGetHeaders, &payload.GetBlockByIndex{
IndexStart: rand.Uint32(),
Count: payload.MaxHeadersAllowed,
})
}
func TestEncodeDecodeGetBlockByIndex(t *testing.T) {
t.Run("good, Count>0", func(t *testing.T) {
testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
IndexStart: rand.Uint32(),
Count: payload.MaxHeadersAllowed,
})
})
t.Run("bad, Count too big", func(t *testing.T) {
testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
IndexStart: rand.Uint32(),
Count: payload.MaxHeadersAllowed + 1,
})
})
t.Run("good, Count=-1", func(t *testing.T) {
testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
IndexStart: rand.Uint32(),
Count: -1,
})
})
t.Run("bad, Count=-2", func(t *testing.T) {
testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
IndexStart: rand.Uint32(),
Count: -2,
})
})
}
func TestEncodeDecodeTransaction(t *testing.T) {
testEncodeDecode(t, CMDTX, newDummyTx())
}
func TestEncodeDecodeMerkleBlock(t *testing.T) {
base := &block.Base{
PrevHash: random.Uint256(),
Timestamp: rand.Uint64(),
Script: transaction.Witness{
InvocationScript: random.Bytes(10),
VerificationScript: random.Bytes(11),
},
Network: netmode.UnitTestNet,
}
base.Hash()
t.Run("good", func(t *testing.T) {
testEncodeDecode(t, CMDMerkleBlock, &payload.MerkleBlock{
Network: netmode.UnitTestNet,
Base: base,
TxCount: 1,
Hashes: []util.Uint256{random.Uint256()},
Flags: []byte{0},
})
})
t.Run("bad, invalid TxCount", func(t *testing.T) {
testEncodeDecodeFail(t, CMDMerkleBlock, &payload.MerkleBlock{
Base: base,
TxCount: 2,
Hashes: []util.Uint256{random.Uint256()},
Flags: []byte{0},
})
})
}
func TestEncodeDecodeNotFound(t *testing.T) {
testEncodeDecode(t, CMDNotFound, &payload.Inventory{
Type: payload.TXType,
Hashes: []util.Uint256{random.Uint256()},
})
}
func TestInvalidMessages(t *testing.T) {
t.Run("CMDBlock, empty payload", func(t *testing.T) {
testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
})
t.Run("send decompressed with flag", func(t *testing.T) {
m := NewMessage(CMDTX, newDummyTx())
data, err := testserdes.Encode(m)
require.NoError(t, err)
require.True(t, m.Flags&Compressed == 0)
data[0] |= byte(Compressed)
require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet}))
})
t.Run("invalid command", func(t *testing.T) {
testEncodeDecodeFail(t, CommandType(0xFF), &payload.Version{Magic: netmode.UnitTestNet})
})
t.Run("very big payload size", func(t *testing.T) {
m := NewMessage(CMDBlock, nil)
w := io.NewBufBinWriter()
w.WriteB(byte(m.Flags))
w.WriteB(byte(m.Command))
w.WriteVarBytes(make([]byte, payload.MaxSize+1))
require.NoError(t, w.Err)
require.Error(t, testserdes.Decode(w.Bytes(), &Message{Network: netmode.UnitTestNet}))
})
t.Run("fail to encode message if payload can't be serialized", func(t *testing.T) {
m := NewMessage(CMDBlock, failSer(true))
_, err := m.Bytes()
require.Error(t, err)
// good otherwise
m = NewMessage(CMDBlock, failSer(false))
_, err = m.Bytes()
require.NoError(t, err)
})
t.Run("trimmed payload", func(t *testing.T) {
m := NewMessage(CMDBlock, newDummyBlock(0))
data, err := testserdes.Encode(m)
require.NoError(t, err)
data = data[:len(data)-1]
require.Error(t, testserdes.Decode(data, &Message{Network: netmode.UnitTestNet}))
})
}
type failSer bool
func (f failSer) EncodeBinary(r *io.BinWriter) {
if f {
r.Err = errors.New("unserializable payload")
}
}
func (failSer) DecodeBinary(w *io.BinReader) {}
func newDummyBlock(txCount int) *block.Block {
b := block.New(netmode.UnitTestNet, false)
b.PrevHash = random.Uint256()
b.Timestamp = rand.Uint64()
b.Script.InvocationScript = random.Bytes(2)
b.Script.VerificationScript = random.Bytes(3)
b.Transactions = make([]*transaction.Transaction, txCount)
for i := range b.Transactions {
b.Transactions[i] = newDummyTx()
}
b.Hash()
return b
}
func newDummyTx() *transaction.Transaction {
tx := transaction.New(netmode.UnitTestNet, random.Bytes(100), int64(rand.Uint64()>>1))
tx.Signers = []transaction.Signer{{Account: random.Uint160()}}
tx.Size()
tx.Hash()
return tx
}
func testEncodeDecode(t *testing.T, cmd CommandType, p payload.Payload) *Message {
expected := NewMessage(cmd, p)
expected.Network = netmode.UnitTestNet
actual := &Message{Network: netmode.UnitTestNet}
testserdes.EncodeDecode(t, expected, actual)
return actual
}
func testEncodeDecodeFail(t *testing.T, cmd CommandType, p payload.Payload) *Message {
expected := NewMessage(cmd, p)
expected.Network = netmode.UnitTestNet
data, err := testserdes.Encode(expected)
require.NoError(t, err)
actual := &Message{Network: netmode.UnitTestNet}
require.Error(t, testserdes.Decode(data, actual))
return actual
} }

View file

@ -72,3 +72,22 @@ func TestEncodeDecodeBadAddressList(t *testing.T) {
err = testserdes.DecodeBinary(bin, newAL) err = testserdes.DecodeBinary(bin, newAL)
require.Error(t, err) require.Error(t, err)
} }
func TestGetTCPAddress(t *testing.T) {
t.Run("bad, no capability", func(t *testing.T) {
p := &AddressAndTime{}
copy(p.IP[:], net.IPv4(1, 1, 1, 1))
p.Capabilities = append(p.Capabilities, capability.Capability{
Type: capability.TCPServer,
Data: &capability.Server{Port: 123},
})
s, err := p.GetTCPAddress()
require.NoError(t, err)
require.Equal(t, "1.1.1.1:123", s)
})
t.Run("bad, no capability", func(t *testing.T) {
p := &AddressAndTime{}
s, err := p.GetTCPAddress()
fmt.Println(s, err)
})
}

View file

@ -1,12 +1,14 @@
package payload package payload
import ( import (
"strings"
"testing" "testing"
"github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/hash"
. "github.com/nspcc-dev/neo-go/pkg/util" . "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestInventoryEncodeDecode(t *testing.T) { func TestInventoryEncodeDecode(t *testing.T) {
@ -27,3 +29,17 @@ func TestEmptyInv(t *testing.T) {
assert.Equal(t, []byte{byte(TXType), 0}, data) assert.Equal(t, []byte{byte(TXType), 0}, data)
assert.Equal(t, 0, len(msgInv.Hashes)) assert.Equal(t, 0, len(msgInv.Hashes))
} }
func TestValid(t *testing.T) {
require.True(t, TXType.Valid())
require.True(t, BlockType.Valid())
require.True(t, ConsensusType.Valid())
require.False(t, InventoryType(0xFF).Valid())
}
func TestString(t *testing.T) {
require.Equal(t, "TX", TXType.String())
require.Equal(t, "block", BlockType.String())
require.Equal(t, "consensus", ConsensusType.String())
require.True(t, strings.Contains(InventoryType(0xFF).String(), "unknown"))
}

View file

@ -1,6 +1,9 @@
package payload package payload
import ( import (
"errors"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -9,6 +12,7 @@ import (
// MerkleBlock represents a merkle block packet payload. // MerkleBlock represents a merkle block packet payload.
type MerkleBlock struct { type MerkleBlock struct {
*block.Base *block.Base
Network netmode.Magic
TxCount int TxCount int
Hashes []util.Uint256 Hashes []util.Uint256
Flags []byte Flags []byte
@ -16,7 +20,7 @@ type MerkleBlock struct {
// DecodeBinary implements Serializable interface. // DecodeBinary implements Serializable interface.
func (m *MerkleBlock) DecodeBinary(br *io.BinReader) { func (m *MerkleBlock) DecodeBinary(br *io.BinReader) {
m.Base = &block.Base{} m.Base = &block.Base{Network: m.Network}
m.Base.DecodeBinary(br) m.Base.DecodeBinary(br)
txCount := int(br.ReadVarUint()) txCount := int(br.ReadVarUint())
@ -26,6 +30,9 @@ func (m *MerkleBlock) DecodeBinary(br *io.BinReader) {
} }
m.TxCount = txCount m.TxCount = txCount
br.ReadArray(&m.Hashes, m.TxCount) br.ReadArray(&m.Hashes, m.TxCount)
if txCount != len(m.Hashes) {
br.Err = errors.New("invalid tx count")
}
m.Flags = br.ReadVarBytes((txCount + 7) / 8) m.Flags = br.ReadVarBytes((txCount + 7) / 8)
} }