network: fix MerkleBlock serialization

1. It contains `block.Base` thus needs network magic.
2. TxCount should match number of hashes.
This commit is contained in:
Evgenii Stratonikov 2020-12-04 14:59:10 +03:00
parent 2d7b823f25
commit 63aebfeae3
2 changed files with 9 additions and 2 deletions

View file

@ -159,7 +159,7 @@ func (m *Message) decodePayload() error {
case CMDTX: case CMDTX:
p = &transaction.Transaction{Network: m.Network} p = &transaction.Transaction{Network: m.Network}
case CMDMerkleBlock: case CMDMerkleBlock:
p = &payload.MerkleBlock{} p = &payload.MerkleBlock{Network: m.Network}
case CMDPing, CMDPong: case CMDPing, CMDPong:
p = &payload.Ping{} p = &payload.Ping{}
case CMDNotFound: case CMDNotFound:

View file

@ -1,6 +1,9 @@
package payload package payload
import ( import (
"errors"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/block"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
@ -9,6 +12,7 @@ import (
// MerkleBlock represents a merkle block packet payload. // MerkleBlock represents a merkle block packet payload.
type MerkleBlock struct { type MerkleBlock struct {
*block.Base *block.Base
Network netmode.Magic
TxCount int TxCount int
Hashes []util.Uint256 Hashes []util.Uint256
Flags []byte Flags []byte
@ -16,7 +20,7 @@ type MerkleBlock struct {
// DecodeBinary implements Serializable interface. // DecodeBinary implements Serializable interface.
func (m *MerkleBlock) DecodeBinary(br *io.BinReader) { func (m *MerkleBlock) DecodeBinary(br *io.BinReader) {
m.Base = &block.Base{} m.Base = &block.Base{Network: m.Network}
m.Base.DecodeBinary(br) m.Base.DecodeBinary(br)
txCount := int(br.ReadVarUint()) txCount := int(br.ReadVarUint())
@ -26,6 +30,9 @@ func (m *MerkleBlock) DecodeBinary(br *io.BinReader) {
} }
m.TxCount = txCount m.TxCount = txCount
br.ReadArray(&m.Hashes, m.TxCount) br.ReadArray(&m.Hashes, m.TxCount)
if txCount != len(m.Hashes) {
br.Err = errors.New("invalid tx count")
}
m.Flags = br.ReadVarBytes((txCount + 7) / 8) m.Flags = br.ReadVarBytes((txCount + 7) / 8)
} }