consensus: refactor payloads structure

1. `Version` and `PrevHash` are now in `PrepareRequest`.
2. Serialization is done via `Extensible` payload.
3. Update dbft version.
This commit is contained in:
Evgeniy Stratonikov 2021-01-14 14:17:00 +03:00
parent 59a193c7c7
commit b918ec3abc
10 changed files with 287 additions and 309 deletions

2
go.mod
View file

@ -11,7 +11,7 @@ require (
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/mr-tron/base58 v1.1.2 github.com/mr-tron/base58 v1.1.2
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2 github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d
github.com/nspcc-dev/rfc6979 v0.2.0 github.com/nspcc-dev/rfc6979 v0.2.0
github.com/pierrec/lz4 v2.5.2+incompatible github.com/pierrec/lz4 v2.5.2+incompatible
github.com/prometheus/client_golang v1.2.1 github.com/prometheus/client_golang v1.2.1

4
go.sum
View file

@ -166,8 +166,8 @@ github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a h1:ajvxgEe9qY4vvoSm
github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk= github.com/nspcc-dev/dbft v0.0.0-20200117124306-478e5cfbf03a/go.mod h1:/YFK+XOxxg0Bfm6P92lY5eDSLYfp06XOdL8KAVgXjVk=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI= github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1 h1:yEx9WznS+rjE0jl0dLujCxuZSIb+UTjF+005TJu/nNI=
github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ= github.com/nspcc-dev/dbft v0.0.0-20200219114139-199d286ed6c1/go.mod h1:O0qtn62prQSqizzoagHmuuKoz8QMkU3SzBoKdEvm3aQ=
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2 h1:vbPjd6xbX8w61abcNfzUvSI7WT0QeS9fHWp1Mocv9N0= github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d h1:uUaRysqa/9VtHETVARUlteqfbXAgwxR2nvUc4DzK4pI=
github.com/nspcc-dev/dbft v0.0.0-20201221101812-e13a1a1c3cb2/go.mod h1:I5D0W3tu3epdt2RMCTxS//HDr4S+OHRqajouQTOAHI8= github.com/nspcc-dev/dbft v0.0.0-20210122071512-d9a728094f0d/go.mod h1:I5D0W3tu3epdt2RMCTxS//HDr4S+OHRqajouQTOAHI8=
github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg= github.com/nspcc-dev/neo-go v0.73.1-pre.0.20200303142215-f5a1b928ce09/go.mod h1:pPYwPZ2ks+uMnlRLUyXOpLieaDQSEaf4NM3zHVbRjmg=
github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck= github.com/nspcc-dev/neofs-crypto v0.2.0 h1:ftN+59WqxSWz/RCgXYOfhmltOOqU+udsNQSvN6wkFck=
github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA= github.com/nspcc-dev/neofs-crypto v0.2.0/go.mod h1:F/96fUzPM3wR+UGsPi3faVNmFlA9KAEAUQR7dMxZmNA=

View file

@ -52,12 +52,12 @@ func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) {
var sign [signatureSize]byte var sign [signatureSize]byte
random.Fill(sign[:]) random.Fill(sign[:])
payloads[i].message = &message{}
payloads[i].SetValidatorIndex(uint16(i)) payloads[i].SetValidatorIndex(uint16(i))
payloads[i].SetType(payload.MessageType(commitType)) payloads[i].SetType(payload.MessageType(commitType))
payloads[i].payload = &commit{ payloads[i].payload = &commit{
signature: sign, signature: sign,
} }
payloads[i].encodeData()
} }
return return

View file

@ -21,6 +21,7 @@ import (
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/encoding/address" "github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
@ -39,6 +40,9 @@ const defaultTimePerBlock = 15 * time.Second
// Number of nanoseconds in millisecond. // Number of nanoseconds in millisecond.
const nsInMs = 1000000 const nsInMs = 1000000
// Category is message category for extensible payloads.
const Category = "Consensus"
// Service represents consensus instance. // Service represents consensus instance.
type Service interface { type Service interface {
// Start initializes dBFT and starts event loop for consensus service. // Start initializes dBFT and starts event loop for consensus service.
@ -204,15 +208,33 @@ var (
// NewPayload creates new consensus payload for the provided network. // NewPayload creates new consensus payload for the provided network.
func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload { func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload {
return &Payload{ return &Payload{
network: m, Extensible: npayload.Extensible{
message: &message{ Network: m,
Category: Category,
},
message: message{
stateRootEnabled: stateRootEnabled, stateRootEnabled: stateRootEnabled,
}, },
} }
} }
func (s *service) newPayload() payload.ConsensusPayload { func (s *service) newPayload(c *dbft.Context, t payload.MessageType, msg interface{}) payload.ConsensusPayload {
return NewPayload(s.network, s.stateRootEnabled) cp := NewPayload(s.network, s.stateRootEnabled)
cp.SetHeight(c.BlockIndex)
cp.SetValidatorIndex(uint16(c.MyIndex))
cp.SetViewNumber(c.ViewNumber)
cp.SetType(t)
if pr, ok := msg.(*prepareRequest); ok {
pr.SetPrevHash(s.dbft.PrevHash)
pr.SetVersion(s.dbft.Version)
}
cp.SetPayload(msg)
cp.Extensible.ValidBlockStart = 0
cp.Extensible.ValidBlockEnd = c.BlockIndex
cp.Extensible.Sender = c.Validators[c.MyIndex].(*publicKey).GetScriptHash()
return cp
} }
func (s *service) newPrepareRequest() payload.PrepareRequest { func (s *service) newPrepareRequest() payload.PrepareRequest {
@ -257,7 +279,7 @@ events:
s.dbft.OnTimeout(hv) s.dbft.OnTimeout(hv)
case msg := <-s.messages: case msg := <-s.messages:
fields := []zap.Field{ fields := []zap.Field{
zap.Uint8("from", msg.validatorIndex), zap.Uint8("from", msg.message.ValidatorIndex),
zap.Stringer("type", msg.Type()), zap.Stringer("type", msg.Type()),
} }
@ -312,14 +334,13 @@ func (s *service) handleChainBlock(b *coreb.Block) {
func (s *service) validatePayload(p *Payload) bool { func (s *service) validatePayload(p *Payload) bool {
validators := s.getValidators() validators := s.getValidators()
if int(p.validatorIndex) >= len(validators) { if int(p.message.ValidatorIndex) >= len(validators) {
return false return false
} }
pub := validators[p.validatorIndex] pub := validators[p.message.ValidatorIndex]
h := pub.(*publicKey).GetScriptHash() h := pub.(*publicKey).GetScriptHash()
return p.Sender == h
return s.Chain.VerifyWitness(h, p, &p.Witness, payloadGasLimit) == nil
} }
func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey) { func (s *service) getKeyPair(pubs []crypto.PublicKey) (int, crypto.PrivateKey, crypto.PublicKey) {
@ -353,7 +374,7 @@ func (s *service) OnPayload(cp *Payload) {
log.Debug("payload is already in cache") log.Debug("payload is already in cache")
return return
} else if !s.validatePayload(cp) { } else if !s.validatePayload(cp) {
log.Debug("can't validate payload") log.Info("can't validate payload")
return return
} }
@ -368,7 +389,7 @@ func (s *service) OnPayload(cp *Payload) {
// decode payload data into message // decode payload data into message
if cp.message.payload == nil { if cp.message.payload == nil {
if err := cp.decodeData(); err != nil { if err := cp.decodeData(); err != nil {
log.Debug("can't decode payload data") log.Info("can't decode payload data")
return return
} }
} }
@ -479,14 +500,26 @@ func (s *service) verifyBlock(b block.Block) bool {
return true return true
} }
var (
errInvalidPrevHash = errors.New("invalid PrevHash")
errInvalidVersion = errors.New("invalid Version")
errInvalidStateRoot = errors.New("state root mismatch")
)
func (s *service) verifyRequest(p payload.ConsensusPayload) error { func (s *service) verifyRequest(p payload.ConsensusPayload) error {
req := p.GetPrepareRequest().(*prepareRequest) req := p.GetPrepareRequest().(*prepareRequest)
if req.prevHash != s.dbft.PrevHash {
return errInvalidPrevHash
}
if req.version != s.dbft.Version {
return errInvalidVersion
}
if s.stateRootEnabled { if s.stateRootEnabled {
sr, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) sr, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
if err != nil { if err != nil {
return err return err
} else if sr.Root != req.stateRoot { } else if sr.Root != req.stateRoot {
return fmt.Errorf("state root mismatch: %s != %s", sr.Root, req.stateRoot) return fmt.Errorf("%w: %s != %s", errInvalidStateRoot, sr.Root, req.stateRoot)
} }
} }
// Save lastProposal for getVerified(). // Save lastProposal for getVerified().

View file

@ -1,12 +1,14 @@
package consensus package consensus
import ( import (
"errors"
"testing" "testing"
"time" "time"
"github.com/nspcc-dev/dbft/block" "github.com/nspcc-dev/dbft/block"
"github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/dbft/timer" "github.com/nspcc-dev/dbft/timer"
"github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/internal/testchain" "github.com/nspcc-dev/neo-go/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
@ -180,11 +182,10 @@ func TestService_GetVerified(t *testing.T) {
// Everyone sends a message. // Everyone sends a message.
for i := 0; i < 4; i++ { for i := 0; i < 4; i++ {
p := new(Payload) p := new(Payload)
p.message = &message{}
// One PrepareRequest and three ChangeViews. // One PrepareRequest and three ChangeViews.
if i == 1 { if i == 1 {
p.SetType(payload.PrepareRequestType) p.SetType(payload.PrepareRequestType)
p.SetPayload(&prepareRequest{transactionHashes: hashes}) p.SetPayload(&prepareRequest{prevHash: srv.Chain.CurrentBlockHash(), transactionHashes: hashes})
} else { } else {
p.SetType(payload.ChangeViewType) p.SetType(payload.ChangeViewType)
p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint64(time.Now().UnixNano() / nsInMs)}) p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint64(time.Now().UnixNano() / nsInMs)})
@ -224,8 +225,7 @@ func TestService_ValidatePayload(t *testing.T) {
srv := newTestService(t) srv := newTestService(t)
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.message = &message{} p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{}) p.SetPayload(&prepareRequest{})
t.Run("invalid validator index", func(t *testing.T) { t.Run("invalid validator index", func(t *testing.T) {
@ -243,8 +243,16 @@ func TestService_ValidatePayload(t *testing.T) {
require.False(t, srv.validatePayload(p)) require.False(t, srv.validatePayload(p))
}) })
t.Run("invalid sender", func(t *testing.T) {
p.SetValidatorIndex(1)
p.Sender = util.Uint160{}
require.NoError(t, p.Sign(priv))
require.False(t, srv.validatePayload(p))
})
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
p.SetValidatorIndex(1) p.SetValidatorIndex(1)
p.Sender = priv.GetScriptHash()
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.True(t, srv.validatePayload(p)) require.True(t, srv.validatePayload(p))
}) })
@ -295,22 +303,35 @@ func TestService_PrepareRequest(t *testing.T) {
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.message = &message{}
p.SetValidatorIndex(1) p.SetValidatorIndex(1)
p.SetPayload(&prepareRequest{}) prevHash := srv.Chain.CurrentBlockHash()
require.NoError(t, p.Sign(priv))
require.Error(t, srv.verifyRequest(p), "invalid stateroot setting")
p.SetPayload(&prepareRequest{stateRootEnabled: true}) checkRequest := func(t *testing.T, expectedErr error, req *prepareRequest) {
require.NoError(t, p.Sign(priv)) p.SetPayload(req)
require.Error(t, srv.verifyRequest(p), "invalid state root") require.NoError(t, p.Sign(priv))
err := srv.verifyRequest(p)
if expectedErr == nil {
require.NoError(t, err)
return
}
require.True(t, errors.Is(err, expectedErr), "got: %v", err)
}
checkRequest(t, errInvalidVersion, &prepareRequest{version: 0xFF, prevHash: prevHash})
checkRequest(t, errInvalidPrevHash, &prepareRequest{prevHash: random.Uint256()})
checkRequest(t, errInvalidStateRoot, &prepareRequest{
stateRootEnabled: true,
prevHash: prevHash,
})
sr, err := srv.Chain.GetStateRoot(srv.dbft.BlockIndex - 1) sr, err := srv.Chain.GetStateRoot(srv.dbft.BlockIndex - 1)
require.NoError(t, err) require.NoError(t, err)
p.SetPayload(&prepareRequest{stateRootEnabled: true, stateRoot: sr.Root}) checkRequest(t, nil, &prepareRequest{
require.NoError(t, p.Sign(priv)) stateRootEnabled: true,
require.NoError(t, srv.verifyRequest(p)) prevHash: prevHash,
stateRoot: sr.Root,
})
} }
func TestService_OnPayload(t *testing.T) { func TestService_OnPayload(t *testing.T) {
@ -322,15 +343,18 @@ func TestService_OnPayload(t *testing.T) {
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.message = &message{}
p.SetValidatorIndex(1) p.SetValidatorIndex(1)
p.SetPayload(&prepareRequest{}) p.SetPayload(&prepareRequest{})
// payload is not signed // sender is invalid
srv.OnPayload(p) srv.OnPayload(p)
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
require.Nil(t, srv.GetPayload(p.Hash())) require.Nil(t, srv.GetPayload(p.Hash()))
p = new(Payload)
p.SetValidatorIndex(1)
p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{})
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
srv.OnPayload(p) srv.OnPayload(p)
shouldReceive(t, srv.messages) shouldReceive(t, srv.messages)

View file

@ -1,14 +1,11 @@
package consensus package consensus
import ( import (
"errors"
"fmt" "fmt"
"github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/emit"
) )
@ -17,8 +14,10 @@ type (
messageType byte messageType byte
message struct { message struct {
Type messageType Type messageType
ViewNumber byte BlockIndex uint32
ValidatorIndex byte
ViewNumber byte
payload io.Serializable payload io.Serializable
// stateRootEnabled specifies if state root is exchanged during consensus. // stateRootEnabled specifies if state root is exchanged during consensus.
@ -27,20 +26,8 @@ type (
// Payload is a type for consensus-related messages. // Payload is a type for consensus-related messages.
Payload struct { Payload struct {
*message npayload.Extensible
message
network netmode.Magic
data []byte
version uint32
validatorIndex uint8
prevHash util.Uint256
height uint32
Witness transaction.Witness
hash util.Uint256
signedHash util.Uint256
signedpart []byte
} }
) )
@ -111,99 +98,36 @@ func (p Payload) GetRecoveryMessage() payload.RecoveryMessage {
return p.payload.(payload.RecoveryMessage) return p.payload.(payload.RecoveryMessage)
} }
// MarshalUnsigned implements payload.ConsensusPayload interface.
func (p *Payload) MarshalUnsigned() []byte {
if p.signedpart == nil {
w := io.NewBufBinWriter()
p.encodeHashData(w.BinWriter)
p.signedpart = w.Bytes()
}
return p.signedpart
}
// UnmarshalUnsigned implements payload.ConsensusPayload interface.
func (p *Payload) UnmarshalUnsigned(data []byte) error {
r := io.NewBinReaderFromBuf(data)
p.network = netmode.Magic(r.ReadU32LE())
p.DecodeBinaryUnsigned(r)
return r.Err
}
// Version implements payload.ConsensusPayload interface.
func (p Payload) Version() uint32 {
return p.version
}
// SetVersion implements payload.ConsensusPayload interface.
func (p *Payload) SetVersion(v uint32) {
p.version = v
}
// ValidatorIndex implements payload.ConsensusPayload interface. // ValidatorIndex implements payload.ConsensusPayload interface.
func (p Payload) ValidatorIndex() uint16 { func (p Payload) ValidatorIndex() uint16 {
return uint16(p.validatorIndex) return uint16(p.message.ValidatorIndex)
} }
// SetValidatorIndex implements payload.ConsensusPayload interface. // SetValidatorIndex implements payload.ConsensusPayload interface.
func (p *Payload) SetValidatorIndex(i uint16) { func (p *Payload) SetValidatorIndex(i uint16) {
p.validatorIndex = uint8(i) p.message.ValidatorIndex = byte(i)
}
// PrevHash implements payload.ConsensusPayload interface.
func (p Payload) PrevHash() util.Uint256 {
return p.prevHash
}
// SetPrevHash implements payload.ConsensusPayload interface.
func (p *Payload) SetPrevHash(h util.Uint256) {
p.prevHash = h
} }
// Height implements payload.ConsensusPayload interface. // Height implements payload.ConsensusPayload interface.
func (p Payload) Height() uint32 { func (p Payload) Height() uint32 {
return p.height return p.message.BlockIndex
} }
// SetHeight implements payload.ConsensusPayload interface. // SetHeight implements payload.ConsensusPayload interface.
func (p *Payload) SetHeight(h uint32) { func (p *Payload) SetHeight(h uint32) {
p.height = h p.message.BlockIndex = h
}
// EncodeBinaryUnsigned writes payload to w excluding signature.
func (p *Payload) EncodeBinaryUnsigned(w *io.BinWriter) {
w.WriteU32LE(p.version)
w.WriteBytes(p.prevHash[:])
w.WriteU32LE(p.height)
w.WriteB(p.validatorIndex)
if p.data == nil {
ww := io.NewBufBinWriter()
p.message.EncodeBinary(ww.BinWriter)
p.data = ww.Bytes()
}
w.WriteVarBytes(p.data)
} }
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
func (p *Payload) EncodeBinary(w *io.BinWriter) { func (p *Payload) EncodeBinary(w *io.BinWriter) {
if p.signedpart == nil { p.encodeData()
_ = p.MarshalUnsigned() p.Extensible.EncodeBinary(w)
}
w.WriteBytes(p.signedpart[4:])
w.WriteB(1)
p.Witness.EncodeBinary(w)
}
func (p *Payload) encodeHashData(w *io.BinWriter) {
w.WriteU32LE(uint32(p.network))
p.EncodeBinaryUnsigned(w)
} }
// Sign signs payload using the private key. // Sign signs payload using the private key.
// It also sets corresponding verification and invocation scripts. // It also sets corresponding verification and invocation scripts.
func (p *Payload) Sign(key *privateKey) error { func (p *Payload) Sign(key *privateKey) error {
p.encodeData()
sig := key.SignHash(p.GetSignedHash()) sig := key.SignHash(p.GetSignedHash())
buf := io.NewBufBinWriter() buf := io.NewBufBinWriter()
@ -216,78 +140,39 @@ func (p *Payload) Sign(key *privateKey) error {
// GetSignedPart implements crypto.Verifiable interface. // GetSignedPart implements crypto.Verifiable interface.
func (p *Payload) GetSignedPart() []byte { func (p *Payload) GetSignedPart() []byte {
return p.MarshalUnsigned() if p.Extensible.Data == nil {
} p.encodeData()
// DecodeBinaryUnsigned reads payload from w excluding signature.
func (p *Payload) DecodeBinaryUnsigned(r *io.BinReader) {
p.version = r.ReadU32LE()
r.ReadBytes(p.prevHash[:])
p.height = r.ReadU32LE()
p.validatorIndex = r.ReadB()
p.data = r.ReadVarBytes()
if r.Err != nil {
return
} }
return p.Extensible.GetSignedPart()
} }
// GetSignedHash returns a hash of the payload used to verify it. // GetSignedHash returns a hash of the payload used to verify it.
func (p *Payload) GetSignedHash() util.Uint256 { func (p *Payload) GetSignedHash() util.Uint256 {
if p.signedHash.Equals(util.Uint256{}) { if p.Extensible.Data == nil {
if p.createHash() != nil { p.encodeData()
panic("failed to compute hash!")
}
} }
return p.signedHash return p.Extensible.GetSignedHash()
} }
// Hash implements payload.ConsensusPayload interface. // Hash implements payload.ConsensusPayload interface.
func (p *Payload) Hash() util.Uint256 { func (p *Payload) Hash() util.Uint256 {
if p.hash.Equals(util.Uint256{}) { if p.Extensible.Data == nil {
if p.createHash() != nil { p.encodeData()
panic("failed to compute hash!")
}
} }
return p.hash return p.Extensible.Hash()
}
// createHash creates hashes of the payload.
func (p *Payload) createHash() error {
b := p.GetSignedPart()
if b == nil {
return errors.New("failed to serialize hashable data")
}
p.updateHashes(b)
return nil
}
// updateHashes updates Payload's hashes based on the given buffer which should
// be a signable data slice.
func (p *Payload) updateHashes(b []byte) {
p.signedHash = hash.Sha256(b)
p.hash = hash.Sha256(p.signedHash.BytesBE())
} }
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.
func (p *Payload) DecodeBinary(r *io.BinReader) { func (p *Payload) DecodeBinary(r *io.BinReader) {
p.DecodeBinaryUnsigned(r) p.Extensible.DecodeBinary(r)
if r.Err != nil { p.decodeData()
return
}
var b = r.ReadB()
if b != 1 {
r.Err = errors.New("invalid format")
return
}
p.Witness.DecodeBinary(r)
} }
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
func (m *message) EncodeBinary(w *io.BinWriter) { func (m *message) EncodeBinary(w *io.BinWriter) {
w.WriteBytes([]byte{byte(m.Type)}) w.WriteB(byte(m.Type))
w.WriteU32LE(m.BlockIndex)
w.WriteB(m.ValidatorIndex)
w.WriteB(m.ViewNumber) w.WriteB(m.ViewNumber)
m.payload.EncodeBinary(w) m.payload.EncodeBinary(w)
} }
@ -295,6 +180,8 @@ func (m *message) EncodeBinary(w *io.BinWriter) {
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.
func (m *message) DecodeBinary(r *io.BinReader) { func (m *message) DecodeBinary(r *io.BinReader) {
m.Type = messageType(r.ReadB()) m.Type = messageType(r.ReadB())
m.BlockIndex = r.ReadU32LE()
m.ValidatorIndex = r.ReadB()
m.ViewNumber = r.ReadB() m.ViewNumber = r.ReadB()
switch m.Type { switch m.Type {
@ -348,14 +235,22 @@ func (t messageType) String() string {
} }
} }
func (p *Payload) encodeData() {
if p.Extensible.Data == nil {
p.Extensible.ValidBlockStart = 0
p.Extensible.ValidBlockEnd = p.BlockIndex
bw := io.NewBufBinWriter()
p.message.EncodeBinary(bw.BinWriter)
p.Extensible.Data = bw.Bytes()
}
}
// decode data of payload into it's message // decode data of payload into it's message
func (p *Payload) decodeData() error { func (p *Payload) decodeData() error {
m := p.message br := io.NewBinReaderFromBuf(p.Extensible.Data)
br := io.NewBinReaderFromBuf(p.data) p.message.DecodeBinary(br)
m.DecodeBinary(br)
if br.Err != nil { if br.Err != nil {
return fmt.Errorf("can't decode message: %w", br.Err) return fmt.Errorf("can't decode message: %w", br.Err)
} }
p.message = m
return nil return nil
} }

View file

@ -1,18 +1,16 @@
package consensus package consensus
import ( import (
"encoding/hex"
gio "io"
"math/rand" "math/rand"
"testing" "testing"
"github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/random"
"github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/nspcc-dev/neo-go/pkg/config/netmode"
"github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -30,13 +28,12 @@ var messageTypes = []messageType{
func TestConsensusPayload_Setters(t *testing.T) { func TestConsensusPayload_Setters(t *testing.T) {
var p Payload var p Payload
p.message = &message{}
p.SetVersion(1) //p.SetVersion(1)
assert.EqualValues(t, 1, p.Version()) //assert.EqualValues(t, 1, p.Version())
p.SetPrevHash(util.Uint256{1, 2, 3}) //p.SetPrevHash(util.Uint256{1, 2, 3})
assert.Equal(t, util.Uint256{1, 2, 3}, p.PrevHash()) //assert.Equal(t, util.Uint256{1, 2, 3}, p.PrevHash())
p.SetValidatorIndex(4) p.SetValidatorIndex(4)
assert.EqualValues(t, 4, p.ValidatorIndex()) assert.EqualValues(t, 4, p.ValidatorIndex())
@ -76,22 +73,22 @@ func TestConsensusPayload_Setters(t *testing.T) {
require.Equal(t, pl, p.GetRecoveryMessage()) require.Equal(t, pl, p.GetRecoveryMessage())
} }
func TestConsensusPayload_Verify(t *testing.T) { //func TestConsensusPayload_Verify(t *testing.T) {
// signed payload from mixed privnet (Go + 3C# nodes) // // signed payload from mixed privnet (Go + 3C# nodes)
dataHex := "000000004c02d52305a6a8981bd1598f0c3076d6de15a44a60ca692e189cd8a7249f175c0800000003222100f24b9147a21e09562c68abdec56d3c5fc09936592933aea5692800b75edbab2301420c40b2b8080ab02b703bc4e64407a6f31bb7ae4c9b1b1c8477668afa752eba6148e03b3ffc7e06285c09bdce4582188466209f876c38f9921a88b545393543ab201a290c2103d90c07df63e690ce77912e10ab51acc944b66860237b608c4f8f8309e71ee6990b4195440d78" // dataHex := "000000004c02d52305a6a8981bd1598f0c3076d6de15a44a60ca692e189cd8a7249f175c0800000003222100f24b9147a21e09562c68abdec56d3c5fc09936592933aea5692800b75edbab2301420c40b2b8080ab02b703bc4e64407a6f31bb7ae4c9b1b1c8477668afa752eba6148e03b3ffc7e06285c09bdce4582188466209f876c38f9921a88b545393543ab201a290c2103d90c07df63e690ce77912e10ab51acc944b66860237b608c4f8f8309e71ee6990b4195440d78"
data, err := hex.DecodeString(dataHex) // data, err := hex.DecodeString(dataHex)
require.NoError(t, err) // require.NoError(t, err)
//
h, err := util.Uint160DecodeStringLE("a8826043c40abacfac1d9acc6b92a4458308ca18") // h, err := util.Uint160DecodeStringLE("a8826043c40abacfac1d9acc6b92a4458308ca18")
require.NoError(t, err) // require.NoError(t, err)
//
p := NewPayload(netmode.PrivNet, false) // p := NewPayload(netmode.PrivNet, false)
require.NoError(t, testserdes.DecodeBinary(data, p)) // require.NoError(t, testserdes.DecodeBinary(data, p))
require.NoError(t, p.decodeData()) // require.NoError(t, p.decodeData())
bc := newTestChain(t, false) // bc := newTestChain(t, false)
defer bc.Close() // defer bc.Close()
require.NoError(t, bc.VerifyWitness(h, p, &p.Witness, payloadGasLimit)) // require.NoError(t, bc.VerifyWitness(h, p, &p.Witness, payloadGasLimit))
} //}
func TestConsensusPayload_Serializable(t *testing.T) { func TestConsensusPayload_Serializable(t *testing.T) {
for _, mt := range messageTypes { for _, mt := range messageTypes {
@ -99,85 +96,70 @@ func TestConsensusPayload_Serializable(t *testing.T) {
actual := new(Payload) actual := new(Payload)
data, err := testserdes.EncodeBinary(p) data, err := testserdes.EncodeBinary(p)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, testserdes.DecodeBinary(data, actual)) require.NoError(t, testserdes.DecodeBinary(data, &actual.Extensible))
// message is nil after decoding as we didn't yet call decodeData
require.Nil(t, actual.message)
actual.message = new(message)
// message should now be decoded from actual.data byte array
actual.message = new(message)
assert.NoError(t, actual.decodeData()) assert.NoError(t, actual.decodeData())
assert.NotNil(t, actual.MarshalUnsigned())
require.Equal(t, p, actual) require.Equal(t, p, actual)
data = p.MarshalUnsigned()
pu := NewPayload(netmode.Magic(rand.Uint32()), false)
require.NoError(t, pu.UnmarshalUnsigned(data))
assert.NoError(t, pu.decodeData())
_ = pu.MarshalUnsigned()
p.Witness = transaction.Witness{}
require.Equal(t, p, pu)
} }
} }
func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { //func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
// PrepareResponse ConsensusPayload consists of: // // PrepareResponse ConsensusPayload consists of:
// 41-byte common prefix // // 41-byte common prefix
// 1-byte varint length of the payload (34), // // 1-byte varint length of the payload (34),
// - 1-byte view number // // - 1-byte view number
// - 1-byte message type (PrepareResponse) // // - 1-byte message type (PrepareResponse)
// - 32-byte preparation hash // // - 32-byte preparation hash
// 1-byte delimiter (1) // // 1-byte delimiter (1)
// 2-byte for empty invocation and verification scripts // // 2-byte for empty invocation and verification scripts
const ( // const (
lenIndex = 41 // lenIndex = 41
typeIndex = lenIndex + 1 // typeIndex = lenIndex + 1
delimeterIndex = typeIndex + 34 // delimeterIndex = typeIndex + 34
) // )
//
buf := make([]byte, delimeterIndex+1+2) // buf := make([]byte, delimeterIndex+1+2)
//
expected := &Payload{ // expected := &Payload{
message: &message{ // message: &message{
Type: prepareResponseType, // Type: prepareResponseType,
payload: &prepareResponse{}, // payload: &prepareResponse{},
}, // },
Witness: transaction.Witness{ // Extensible: transaction.Witness{
InvocationScript: []byte{}, // InvocationScript: []byte{},
VerificationScript: []byte{}, // VerificationScript: []byte{},
}, // },
} // }
// fill `data` for next check // // fill `data` for next check
_ = expected.Hash() // _ = expected.Hash()
//
// valid payload // // valid payload
buf[delimeterIndex] = 1 // buf[delimeterIndex] = 1
buf[lenIndex] = 34 // buf[lenIndex] = 34
buf[typeIndex] = byte(prepareResponseType) // buf[typeIndex] = byte(prepareResponseType)
p := &Payload{message: new(message)} // p := &Payload{message: new(message)}
require.NoError(t, testserdes.DecodeBinary(buf, p)) // require.NoError(t, testserdes.DecodeBinary(buf, p))
// decode `data` into `message` // // decode `data` into `message`
_ = p.Hash() // _ = p.Hash()
assert.NoError(t, p.decodeData()) // assert.NoError(t, p.decodeData())
require.Equal(t, expected, p) // require.Equal(t, expected, p)
//
// invalid type // // invalid type
buf[typeIndex] = 0xFF // buf[typeIndex] = 0xFF
actual := &Payload{message: new(message)} // actual := &Payload{message: new(message)}
require.NoError(t, testserdes.DecodeBinary(buf, actual)) // require.NoError(t, testserdes.DecodeBinary(buf, actual))
require.Error(t, actual.decodeData()) // require.Error(t, actual.decodeData())
//
// invalid format // // invalid format
buf[delimeterIndex] = 0 // buf[delimeterIndex] = 0
buf[typeIndex] = byte(prepareResponseType) // buf[typeIndex] = byte(prepareResponseType)
require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) // require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
//
// invalid message length // // invalid message length
buf[delimeterIndex] = 1 // buf[delimeterIndex] = 1
buf[lenIndex] = 0xFF // buf[lenIndex] = 0xFF
buf[typeIndex] = byte(prepareResponseType) // buf[typeIndex] = byte(prepareResponseType)
require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) // require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
} //}
func TestCommit_Serializable(t *testing.T) { func TestCommit_Serializable(t *testing.T) {
c := randomMessage(t, commitType) c := randomMessage(t, commitType)
@ -206,18 +188,18 @@ func TestRecoveryMessage_Serializable(t *testing.T) {
func randomPayload(t *testing.T, mt messageType) *Payload { func randomPayload(t *testing.T, mt messageType) *Payload {
p := &Payload{ p := &Payload{
message: &message{ message: message{
Type: mt, Type: mt,
ViewNumber: byte(rand.Uint32()), ValidatorIndex: byte(rand.Uint32()),
payload: randomMessage(t, mt), BlockIndex: rand.Uint32(),
ViewNumber: byte(rand.Uint32()),
payload: randomMessage(t, mt),
}, },
version: 1, Extensible: npayload.Extensible{
validatorIndex: 13, Witness: transaction.Witness{
height: rand.Uint32(), InvocationScript: random.Bytes(3),
prevHash: random.Uint256(), VerificationScript: []byte{byte(opcode.PUSH0)},
Witness: transaction.Witness{ },
InvocationScript: random.Bytes(3),
VerificationScript: []byte{byte(opcode.PUSH0)},
}, },
} }
@ -334,19 +316,19 @@ func TestMessageType_String(t *testing.T) {
require.Equal(t, "UNKNOWN(0xff)", messageType(0xff).String()) require.Equal(t, "UNKNOWN(0xff)", messageType(0xff).String())
} }
func TestPayload_DecodeFromPrivnet(t *testing.T) { //func TestPayload_DecodeFromPrivnet(t *testing.T) {
hexDump := "000000004c02d52305a6a8981bd1598f0c3076d6de15a44a60ca692e189cd8a7249f175c08000000004230000368c5c5401d40eef6b8a9899d2041d29fd2e6300980fdcaa6660c10b85965f57852193cdb6f0d1e9f91dc510dff6df3a004b569fe2ad456d07007f6ccd55b1d01420c40e760250b821a4dcfc4b8727ecc409a758ab4bd3b288557fd3c3d76e083fe7c625b4ed25e763ad96c4eb0abc322600d82651fd32f8866fca1403fa04d3acc4675290c2102103a7f7dd016558597f7960d27c516a4394fd968b9e65155eb4b013e4040406e0b4195440d78" // hexDump := "000000004c02d52305a6a8981bd1598f0c3076d6de15a44a60ca692e189cd8a7249f175c08000000004230000368c5c5401d40eef6b8a9899d2041d29fd2e6300980fdcaa6660c10b85965f57852193cdb6f0d1e9f91dc510dff6df3a004b569fe2ad456d07007f6ccd55b1d01420c40e760250b821a4dcfc4b8727ecc409a758ab4bd3b288557fd3c3d76e083fe7c625b4ed25e763ad96c4eb0abc322600d82651fd32f8866fca1403fa04d3acc4675290c2102103a7f7dd016558597f7960d27c516a4394fd968b9e65155eb4b013e4040406e0b4195440d78"
data, err := hex.DecodeString(hexDump) // data, err := hex.DecodeString(hexDump)
require.NoError(t, err) // require.NoError(t, err)
//
buf := io.NewBinReaderFromBuf(data) // buf := io.NewBinReaderFromBuf(data)
p := NewPayload(netmode.PrivNet, false) // p := NewPayload(netmode.PrivNet, false)
p.DecodeBinary(buf) // p.DecodeBinary(buf)
require.NoError(t, buf.Err) // require.NoError(t, buf.Err)
require.NoError(t, p.decodeData()) // require.NoError(t, p.decodeData())
require.Equal(t, payload.CommitType, p.Type()) // require.Equal(t, payload.CommitType, p.Type())
require.Equal(t, uint32(8), p.Height()) // require.Equal(t, uint32(8), p.Height())
//
buf.ReadB() // buf.ReadB()
require.Equal(t, gio.EOF, buf.Err) // require.Equal(t, gio.EOF, buf.Err)
} //}

View file

@ -9,6 +9,8 @@ import (
// prepareRequest represents dBFT prepareRequest message. // prepareRequest represents dBFT prepareRequest message.
type prepareRequest struct { type prepareRequest struct {
version uint32
prevHash util.Uint256
timestamp uint64 timestamp uint64
nonce uint64 nonce uint64
transactionHashes []util.Uint256 transactionHashes []util.Uint256
@ -20,6 +22,8 @@ var _ payload.PrepareRequest = (*prepareRequest)(nil)
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
w.WriteU32LE(p.version)
w.WriteBytes(p.prevHash[:])
w.WriteU64LE(p.timestamp) w.WriteU64LE(p.timestamp)
w.WriteU64LE(p.nonce) w.WriteU64LE(p.nonce)
w.WriteArray(p.transactionHashes) w.WriteArray(p.transactionHashes)
@ -30,6 +34,8 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.
func (p *prepareRequest) DecodeBinary(r *io.BinReader) { func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
p.version = r.ReadU32LE()
r.ReadBytes(p.prevHash[:])
p.timestamp = r.ReadU64LE() p.timestamp = r.ReadU64LE()
p.nonce = r.ReadU64LE() p.nonce = r.ReadU64LE()
r.ReadArray(&p.transactionHashes, block.MaxTransactionsPerBlock) r.ReadArray(&p.transactionHashes, block.MaxTransactionsPerBlock)
@ -38,6 +44,26 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
} }
} }
// Version implements payload.PrepareRequest interface.
func (p prepareRequest) Version() uint32 {
return p.version
}
// SetVersion implements payload.PrepareRequest interface.
func (p *prepareRequest) SetVersion(v uint32) {
p.version = v
}
// PrevHash implements payload.PrepareRequest interface.
func (p prepareRequest) PrevHash() util.Uint256 {
return p.prevHash
}
// SetPrevHash implements payload.PrepareRequest interface.
func (p *prepareRequest) SetPrevHash(h util.Uint256) {
p.prevHash = h
}
// Timestamp implements payload.PrepareRequest interface. // Timestamp implements payload.PrepareRequest interface.
func (p *prepareRequest) Timestamp() uint64 { return p.timestamp * nsInMs } func (p *prepareRequest) Timestamp() uint64 { return p.timestamp * nsInMs }

View file

@ -6,6 +6,7 @@ import (
"github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/crypto"
"github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
npayload "github.com/nspcc-dev/neo-go/pkg/network/payload"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
) )
@ -202,6 +203,7 @@ func (m *recoveryMessage) GetPrepareRequest(p payload.ConsensusPayload, validato
req := fromPayload(prepareRequestType, p.(*Payload), m.prepareRequest.payload) req := fromPayload(prepareRequestType, p.(*Payload), m.prepareRequest.payload)
req.SetValidatorIndex(primary) req.SetValidatorIndex(primary)
req.Sender = validators[primary].(*publicKey).GetScriptHash()
req.Witness.InvocationScript = compact.InvocationScript req.Witness.InvocationScript = compact.InvocationScript
req.Witness.VerificationScript = getVerificationScript(uint8(primary), validators) req.Witness.VerificationScript = getVerificationScript(uint8(primary), validators)
@ -221,6 +223,7 @@ func (m *recoveryMessage) GetPrepareResponses(p payload.ConsensusPayload, valida
preparationHash: *m.preparationHash, preparationHash: *m.preparationHash,
}) })
r.SetValidatorIndex(uint16(resp.ValidatorIndex)) r.SetValidatorIndex(uint16(resp.ValidatorIndex))
r.Sender = validators[resp.ValidatorIndex].(*publicKey).GetScriptHash()
r.Witness.InvocationScript = resp.InvocationScript r.Witness.InvocationScript = resp.InvocationScript
r.Witness.VerificationScript = getVerificationScript(resp.ValidatorIndex, validators) r.Witness.VerificationScript = getVerificationScript(resp.ValidatorIndex, validators)
@ -241,6 +244,7 @@ func (m *recoveryMessage) GetChangeViews(p payload.ConsensusPayload, validators
}) })
c.message.ViewNumber = cv.OriginalViewNumber c.message.ViewNumber = cv.OriginalViewNumber
c.SetValidatorIndex(uint16(cv.ValidatorIndex)) c.SetValidatorIndex(uint16(cv.ValidatorIndex))
c.Sender = validators[cv.ValidatorIndex].(*publicKey).GetScriptHash()
c.Witness.InvocationScript = cv.InvocationScript c.Witness.InvocationScript = cv.InvocationScript
c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators) c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators)
@ -257,6 +261,7 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr
for i, c := range m.commitPayloads { for i, c := range m.commitPayloads {
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature}) cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature})
cc.SetValidatorIndex(uint16(c.ValidatorIndex)) cc.SetValidatorIndex(uint16(c.ValidatorIndex))
cc.Sender = validators[c.ValidatorIndex].(*publicKey).GetScriptHash()
cc.Witness.InvocationScript = c.InvocationScript cc.Witness.InvocationScript = c.InvocationScript
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators) cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
@ -291,15 +296,17 @@ func getVerificationScript(i uint8, validators []crypto.PublicKey) []byte {
func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
return &Payload{ return &Payload{
network: recovery.network, Extensible: npayload.Extensible{
message: &message{ Category: Category,
Network: recovery.Network,
ValidBlockEnd: recovery.BlockIndex,
},
message: message{
Type: t, Type: t,
BlockIndex: recovery.BlockIndex,
ViewNumber: recovery.message.ViewNumber, ViewNumber: recovery.message.ViewNumber,
payload: p, payload: p,
stateRootEnabled: recovery.stateRootEnabled, stateRootEnabled: recovery.stateRootEnabled,
}, },
version: recovery.Version(),
prevHash: recovery.PrevHash(),
height: recovery.Height(),
} }
} }

View file

@ -30,9 +30,12 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
privs[i], pubs[i] = getTestValidator(i) privs[i], pubs[i] = getTestValidator(i)
} }
const msgHeight = 10
r := &recoveryMessage{stateRootEnabled: enableStateRoot} r := &recoveryMessage{stateRootEnabled: enableStateRoot}
p := NewPayload(netmode.UnitTestNet, enableStateRoot) p := NewPayload(netmode.UnitTestNet, enableStateRoot)
p.SetType(payload.RecoveryMessageType) p.SetType(payload.RecoveryMessageType)
p.SetHeight(msgHeight)
p.SetPayload(r) p.SetPayload(r)
// sign payload to have verification script // sign payload to have verification script
require.NoError(t, p.Sign(privs[0])) require.NoError(t, p.Sign(privs[0]))
@ -45,17 +48,21 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
} }
p1 := NewPayload(netmode.UnitTestNet, enableStateRoot) p1 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p1.SetType(payload.PrepareRequestType) p1.SetType(payload.PrepareRequestType)
p1.SetHeight(msgHeight)
p1.SetPayload(req) p1.SetPayload(req)
p1.SetValidatorIndex(0) p1.SetValidatorIndex(0)
p1.Sender = privs[0].GetScriptHash()
require.NoError(t, p1.Sign(privs[0])) require.NoError(t, p1.Sign(privs[0]))
t.Run("prepare response is added", func(t *testing.T) { t.Run("prepare response is added", func(t *testing.T) {
p2 := NewPayload(netmode.UnitTestNet, enableStateRoot) p2 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p2.SetType(payload.PrepareResponseType) p2.SetType(payload.PrepareResponseType)
p2.SetHeight(msgHeight)
p2.SetPayload(&prepareResponse{ p2.SetPayload(&prepareResponse{
preparationHash: p1.Hash(), preparationHash: p1.Hash(),
}) })
p2.SetValidatorIndex(1) p2.SetValidatorIndex(1)
p2.Sender = privs[1].GetScriptHash()
require.NoError(t, p2.Sign(privs[1])) require.NoError(t, p2.Sign(privs[1]))
r.AddPayload(p2) r.AddPayload(p2)
@ -88,11 +95,13 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
t.Run("change view is added", func(t *testing.T) { t.Run("change view is added", func(t *testing.T) {
p3 := NewPayload(netmode.UnitTestNet, enableStateRoot) p3 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p3.SetType(payload.ChangeViewType) p3.SetType(payload.ChangeViewType)
p3.SetHeight(msgHeight)
p3.SetPayload(&changeView{ p3.SetPayload(&changeView{
newViewNumber: 1, newViewNumber: 1,
timestamp: 12345, timestamp: 12345,
}) })
p3.SetValidatorIndex(3) p3.SetValidatorIndex(3)
p3.Sender = privs[3].GetScriptHash()
require.NoError(t, p3.Sign(privs[3])) require.NoError(t, p3.Sign(privs[3]))
r.AddPayload(p3) r.AddPayload(p3)
@ -110,8 +119,10 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
t.Run("commit is added", func(t *testing.T) { t.Run("commit is added", func(t *testing.T) {
p4 := NewPayload(netmode.UnitTestNet, enableStateRoot) p4 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p4.SetType(payload.CommitType) p4.SetType(payload.CommitType)
p4.SetHeight(msgHeight)
p4.SetPayload(randomMessage(t, commitType)) p4.SetPayload(randomMessage(t, commitType))
p4.SetValidatorIndex(3) p4.SetValidatorIndex(3)
p4.Sender = privs[3].GetScriptHash()
require.NoError(t, p4.Sign(privs[3])) require.NoError(t, p4.Sign(privs[3]))
r.AddPayload(p4) r.AddPayload(p4)