consensus: remove unused dBFT payload methods

Signed-off-by: Anna Shaleva <shaleva.ann@nspcc.ru>
This commit is contained in:
Anna Shaleva 2024-03-21 23:38:24 +03:00
parent 3e6dfff503
commit 708b439f4a
18 changed files with 101 additions and 192 deletions

View file

@ -56,9 +56,6 @@ func (n *neoBlock) SetTransactions(txes []dbft.Transaction[util.Uint256]) {
} }
} }
// Version implements the block.Block interface.
func (n *neoBlock) Version() uint32 { return n.Block.Version }
// PrevHash implements the block.Block interface. // PrevHash implements the block.Block interface.
func (n *neoBlock) PrevHash() util.Uint256 { return n.Block.PrevHash } func (n *neoBlock) PrevHash() util.Uint256 { return n.Block.PrevHash }
@ -71,11 +68,5 @@ func (n *neoBlock) Timestamp() uint64 { return n.Block.Timestamp * nsInMs }
// Index implements the block.Block interface. // Index implements the block.Block interface.
func (n *neoBlock) Index() uint32 { return n.Block.Index } func (n *neoBlock) Index() uint32 { return n.Block.Index }
// ConsensusData implements the block.Block interface.
func (n *neoBlock) ConsensusData() uint64 { return n.Block.Nonce }
// NextConsensus implements the block.Block interface.
func (n *neoBlock) NextConsensus() util.Uint160 { return n.Block.NextConsensus }
// Signature implements the block.Block interface. // Signature implements the block.Block interface.
func (n *neoBlock) Signature() []byte { return n.signature } func (n *neoBlock) Signature() []byte { return n.signature }

View file

@ -22,9 +22,6 @@ func TestNeoBlock_Sign(t *testing.T) {
func TestNeoBlock_Setters(t *testing.T) { func TestNeoBlock_Setters(t *testing.T) {
b := new(neoBlock) b := new(neoBlock)
b.Block.Version = 1
require.EqualValues(t, 1, b.Version())
b.Block.Index = 12 b.Block.Index = 12
require.EqualValues(t, 12, b.Index()) require.EqualValues(t, 12, b.Index())
@ -35,9 +32,6 @@ func TestNeoBlock_Setters(t *testing.T) {
b.Block.MerkleRoot = util.Uint256{1, 2, 3, 4} b.Block.MerkleRoot = util.Uint256{1, 2, 3, 4}
require.Equal(t, util.Uint256{1, 2, 3, 4}, b.MerkleRoot()) require.Equal(t, util.Uint256{1, 2, 3, 4}, b.MerkleRoot())
b.Block.NextConsensus = util.Uint160{9, 2}
require.Equal(t, util.Uint160{9, 2}, b.NextConsensus())
b.Block.PrevHash = util.Uint256{9, 8, 7} b.Block.PrevHash = util.Uint256{9, 8, 7}
require.Equal(t, util.Uint256{9, 8, 7}, b.PrevHash()) require.Equal(t, util.Uint256{9, 8, 7}, b.PrevHash())

View file

@ -3,7 +3,6 @@ package consensus
import ( import (
"testing" "testing"
"github.com/nspcc-dev/dbft"
"github.com/nspcc-dev/neo-go/internal/random" "github.com/nspcc-dev/neo-go/internal/random"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -52,8 +51,8 @@ func getDifferentPayloads(t *testing.T, n int) (payloads []Payload) {
var sign [signatureSize]byte var sign [signatureSize]byte
random.Fill(sign[:]) random.Fill(sign[:])
payloads[i].SetValidatorIndex(uint16(i)) payloads[i].message.ValidatorIndex = byte(i)
payloads[i].SetType(dbft.MessageType(commitType)) payloads[i].message.Type = commitType
payloads[i].payload = &commit{ payloads[i].payload = &commit{
signature: sign, signature: sign,
} }

View file

@ -29,17 +29,5 @@ func (c *changeView) DecodeBinary(r *io.BinReader) {
// NewViewNumber implements the payload.ChangeView interface. // NewViewNumber implements the payload.ChangeView interface.
func (c changeView) NewViewNumber() byte { return c.newViewNumber } func (c changeView) NewViewNumber() byte { return c.newViewNumber }
// SetNewViewNumber implements the payload.ChangeView interface.
func (c *changeView) SetNewViewNumber(view byte) { c.newViewNumber = view }
// Timestamp implements the payload.ChangeView interface.
func (c changeView) Timestamp() uint64 { return c.timestamp * nsInMs }
// SetTimestamp implements the payload.ChangeView interface.
func (c *changeView) SetTimestamp(ts uint64) { c.timestamp = ts / nsInMs }
// Reason implements the payload.ChangeView interface. // Reason implements the payload.ChangeView interface.
func (c changeView) Reason() dbft.ChangeViewReason { return c.reason } func (c changeView) Reason() dbft.ChangeViewReason { return c.reason }
// SetReason implements the payload.ChangeView interface.
func (c *changeView) SetReason(reason dbft.ChangeViewReason) { c.reason = reason }

View file

@ -3,15 +3,16 @@ package consensus
import ( import (
"testing" "testing"
"github.com/nspcc-dev/dbft"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestChangeView_Setters(t *testing.T) { func TestChangeView_Getters(t *testing.T) {
var c changeView var c = &changeView{
newViewNumber: 2,
reason: dbft.CVTimeout,
}
c.SetTimestamp(123 * nsInMs)
require.EqualValues(t, 123*nsInMs, c.Timestamp())
c.SetNewViewNumber(2)
require.EqualValues(t, 2, c.NewViewNumber()) require.EqualValues(t, 2, c.NewViewNumber())
require.EqualValues(t, dbft.CVTimeout, c.Reason())
} }

View file

@ -28,8 +28,3 @@ func (c *commit) DecodeBinary(r *io.BinReader) {
// Signature implements the payload.Commit interface. // Signature implements the payload.Commit interface.
func (c commit) Signature() []byte { return c.signature[:] } func (c commit) Signature() []byte { return c.signature[:] }
// SetSignature implements the payload.Commit interface.
func (c *commit) SetSignature(signature []byte) {
copy(c.signature[:], signature)
}

View file

@ -7,11 +7,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestCommit_Setters(t *testing.T) { func TestCommit_Getters(t *testing.T) {
var sign [signatureSize]byte var sign [signatureSize]byte
random.Fill(sign[:]) random.Fill(sign[:])
var c commit var c = &commit{
c.SetSignature(sign[:]) signature: sign,
}
require.Equal(t, sign[:], c.Signature()) require.Equal(t, sign[:], c.Signature())
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/nspcc-dev/dbft" "github.com/nspcc-dev/dbft"
"github.com/nspcc-dev/dbft/timer"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/config/netmode"
coreb "github.com/nspcc-dev/neo-go/pkg/core/block" coreb "github.com/nspcc-dev/neo-go/pkg/core/block"
@ -179,6 +180,7 @@ func NewService(cfg Config) (Service, error) {
} }
srv.dbft = dbft.New[util.Uint256]( srv.dbft = dbft.New[util.Uint256](
dbft.WithTimer[util.Uint256](timer.New()),
dbft.WithLogger[util.Uint256](srv.log), dbft.WithLogger[util.Uint256](srv.log),
dbft.WithSecondsPerBlock[util.Uint256](cfg.TimePerBlock), dbft.WithSecondsPerBlock[util.Uint256](cfg.TimePerBlock),
dbft.WithGetKeyPair[util.Uint256](srv.getKeyPair), dbft.WithGetKeyPair[util.Uint256](srv.getKeyPair),
@ -234,15 +236,15 @@ func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload {
func (s *service) newPayload(c *dbft.Context[util.Uint256], t dbft.MessageType, msg any) dbft.ConsensusPayload[util.Uint256] { func (s *service) newPayload(c *dbft.Context[util.Uint256], t dbft.MessageType, msg any) dbft.ConsensusPayload[util.Uint256] {
cp := NewPayload(s.ProtocolConfiguration.Magic, s.ProtocolConfiguration.StateRootInHeader) cp := NewPayload(s.ProtocolConfiguration.Magic, s.ProtocolConfiguration.StateRootInHeader)
cp.SetHeight(c.BlockIndex) cp.BlockIndex = c.BlockIndex
cp.SetValidatorIndex(uint16(c.MyIndex)) cp.message.ValidatorIndex = byte(c.MyIndex)
cp.SetViewNumber(c.ViewNumber) cp.message.ViewNumber = c.ViewNumber
cp.SetType(t) cp.message.Type = messageType(t)
if pr, ok := msg.(*prepareRequest); ok { if pr, ok := msg.(*prepareRequest); ok {
pr.SetPrevHash(s.dbft.PrevHash) pr.prevHash = s.dbft.PrevHash
pr.SetVersion(coreb.VersionInitial) pr.version = coreb.VersionInitial
} }
cp.SetPayload(msg) cp.payload = msg.(io.Serializable)
cp.Extensible.ValidBlockStart = 0 cp.Extensible.ValidBlockStart = 0
cp.Extensible.ValidBlockEnd = c.BlockIndex cp.Extensible.ValidBlockEnd = c.BlockIndex

View file

@ -212,14 +212,14 @@ func TestService_GetVerified(t *testing.T) {
p := new(Payload) p := new(Payload)
// One PrepareRequest and three ChangeViews. // One PrepareRequest and three ChangeViews.
if i == 1 { if i == 1 {
p.SetType(dbft.PrepareRequestType) p.message.Type = messageType(dbft.PrepareRequestType)
p.SetPayload(&prepareRequest{prevHash: srv.Chain.CurrentBlockHash(), transactionHashes: hashes}) p.payload = &prepareRequest{prevHash: srv.Chain.CurrentBlockHash(), transactionHashes: hashes}
} else { } else {
p.SetType(dbft.ChangeViewType) p.message.Type = messageType(dbft.ChangeViewType)
p.SetPayload(&changeView{newViewNumber: 1, timestamp: uint64(time.Now().UnixNano() / nsInMs)}) p.payload = &changeView{newViewNumber: 1, timestamp: uint64(time.Now().UnixNano() / nsInMs)}
} }
p.SetHeight(1) p.BlockIndex = 1
p.SetValidatorIndex(uint16(i)) p.message.ValidatorIndex = byte(i)
priv, _ := getTestValidator(i) priv, _ := getTestValidator(i)
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
@ -253,10 +253,10 @@ func TestService_ValidatePayload(t *testing.T) {
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.Sender = priv.GetScriptHash() p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{}) p.payload = &prepareRequest{}
t.Run("invalid validator index", func(t *testing.T) { t.Run("invalid validator index", func(t *testing.T) {
p.SetValidatorIndex(11) p.message.ValidatorIndex = 11
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
var ok bool var ok bool
@ -265,20 +265,20 @@ func TestService_ValidatePayload(t *testing.T) {
}) })
t.Run("wrong validator index", func(t *testing.T) { t.Run("wrong validator index", func(t *testing.T) {
p.SetValidatorIndex(2) p.message.ValidatorIndex = 2
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.False(t, srv.validatePayload(p)) require.False(t, srv.validatePayload(p))
}) })
t.Run("invalid sender", func(t *testing.T) { t.Run("invalid sender", func(t *testing.T) {
p.SetValidatorIndex(1) p.message.ValidatorIndex = 1
p.Sender = util.Uint160{} p.Sender = util.Uint160{}
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.False(t, srv.validatePayload(p)) require.False(t, srv.validatePayload(p))
}) })
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
p.SetValidatorIndex(1) p.message.ValidatorIndex = 1
p.Sender = priv.GetScriptHash() p.Sender = priv.GetScriptHash()
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.True(t, srv.validatePayload(p)) require.True(t, srv.validatePayload(p))
@ -328,12 +328,12 @@ func TestService_PrepareRequest(t *testing.T) {
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.SetValidatorIndex(1) p.message.ValidatorIndex = 1
prevHash := srv.Chain.CurrentBlockHash() prevHash := srv.Chain.CurrentBlockHash()
checkRequest := func(t *testing.T, expectedErr error, req *prepareRequest) { checkRequest := func(t *testing.T, expectedErr error, req *prepareRequest) {
p.SetPayload(req) p.payload = req
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
err := srv.verifyRequest(p) err := srv.verifyRequest(p)
if expectedErr == nil { if expectedErr == nil {
@ -375,8 +375,8 @@ func TestService_OnPayload(t *testing.T) {
priv, _ := getTestValidator(1) priv, _ := getTestValidator(1)
p := new(Payload) p := new(Payload)
p.SetValidatorIndex(1) p.message.ValidatorIndex = 1
p.SetPayload(&prepareRequest{}) p.payload = &prepareRequest{}
p.encodeData() p.encodeData()
// sender is invalid // sender is invalid
@ -384,9 +384,9 @@ func TestService_OnPayload(t *testing.T) {
shouldNotReceive(t, srv.messages) shouldNotReceive(t, srv.messages)
p = new(Payload) p = new(Payload)
p.SetValidatorIndex(1) p.message.ValidatorIndex = 1
p.Sender = priv.GetScriptHash() p.Sender = priv.GetScriptHash()
p.SetPayload(&prepareRequest{}) p.payload = &prepareRequest{}
require.NoError(t, p.Sign(priv)) require.NoError(t, p.Sign(priv))
require.NoError(t, srv.OnPayload(&p.Extensible)) require.NoError(t, srv.OnPayload(&p.Extensible))
shouldReceive(t, srv.messages) shouldReceive(t, srv.messages)

View file

@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"errors" "errors"
"github.com/nspcc-dev/dbft"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
) )
@ -13,6 +14,8 @@ type privateKey struct {
*keys.PrivateKey *keys.PrivateKey
} }
var _ dbft.PrivateKey = &privateKey{}
// Sign implements the dbft's crypto.PrivateKey interface. // Sign implements the dbft's crypto.PrivateKey interface.
func (p *privateKey) Sign(data []byte) ([]byte, error) { func (p *privateKey) Sign(data []byte) ([]byte, error) {
return p.PrivateKey.Sign(data), nil return p.PrivateKey.Sign(data), nil
@ -24,6 +27,8 @@ type publicKey struct {
*keys.PublicKey *keys.PublicKey
} }
var _ dbft.PublicKey = &publicKey{}
// MarshalBinary implements the encoding.BinaryMarshaler interface. // MarshalBinary implements the encoding.BinaryMarshaler interface.
func (p publicKey) MarshalBinary() (data []byte, err error) { func (p publicKey) MarshalBinary() (data []byte, err error) {
return p.PublicKey.Bytes(), nil return p.PublicKey.Bytes(), nil

View file

@ -44,6 +44,8 @@ const (
payloadGasLimit = 2000000 // 0.02 GAS payloadGasLimit = 2000000 // 0.02 GAS
) )
var _ dbft.ConsensusPayload[util.Uint256] = &Payload{}
// ViewNumber implements the payload.ConsensusPayload interface. // ViewNumber implements the payload.ConsensusPayload interface.
func (p Payload) ViewNumber() byte { func (p Payload) ViewNumber() byte {
return p.message.ViewNumber return p.message.ViewNumber
@ -59,21 +61,11 @@ func (p Payload) Type() dbft.MessageType {
return dbft.MessageType(p.message.Type) return dbft.MessageType(p.message.Type)
} }
// SetType implements the payload.ConsensusPayload interface.
func (p *Payload) SetType(t dbft.MessageType) {
p.message.Type = messageType(t)
}
// Payload implements the payload.ConsensusPayload interface. // Payload implements the payload.ConsensusPayload interface.
func (p Payload) Payload() any { func (p Payload) Payload() any {
return p.payload return p.payload
} }
// SetPayload implements the payload.ConsensusPayload interface.
func (p *Payload) SetPayload(pl any) {
p.payload = pl.(io.Serializable)
}
// GetChangeView implements the payload.ConsensusPayload interface. // GetChangeView implements the payload.ConsensusPayload interface.
func (p Payload) GetChangeView() dbft.ChangeView { return p.payload.(dbft.ChangeView) } func (p Payload) GetChangeView() dbft.ChangeView { return p.payload.(dbft.ChangeView) }
@ -115,11 +107,6 @@ func (p Payload) Height() uint32 {
return p.message.BlockIndex return p.message.BlockIndex
} }
// SetHeight implements the payload.ConsensusPayload interface.
func (p *Payload) SetHeight(h uint32) {
p.message.BlockIndex = h
}
// EncodeBinary implements the io.Serializable interface. // EncodeBinary implements the io.Serializable interface.
func (p *Payload) EncodeBinary(w *io.BinWriter) { func (p *Payload) EncodeBinary(w *io.BinWriter) {
p.encodeData() p.encodeData()

View file

@ -29,50 +29,45 @@ var messageTypes = []messageType{
recoveryMessageType, recoveryMessageType,
} }
func TestConsensusPayload_Setters(t *testing.T) { func TestConsensusPayload_Getters(t *testing.T) {
var p Payload var p = &Payload{
Extensible: npayload.Extensible{},
message: message{
Type: prepareRequestType,
BlockIndex: 11,
ValidatorIndex: 4,
ViewNumber: 2,
},
}
//p.SetVersion(1)
//assert.EqualValues(t, 1, p.Version())
//p.SetPrevHash(util.Uint256{1, 2, 3})
//assert.Equal(t, util.Uint256{1, 2, 3}, p.PrevHash())
p.SetValidatorIndex(4)
assert.EqualValues(t, 4, p.ValidatorIndex()) assert.EqualValues(t, 4, p.ValidatorIndex())
p.SetHeight(11)
assert.EqualValues(t, 11, p.Height()) assert.EqualValues(t, 11, p.Height())
p.SetViewNumber(2)
assert.EqualValues(t, 2, p.ViewNumber()) assert.EqualValues(t, 2, p.ViewNumber())
p.SetType(dbft.PrepareRequestType)
assert.Equal(t, dbft.PrepareRequestType, p.Type()) assert.Equal(t, dbft.PrepareRequestType, p.Type())
pl := randomMessage(t, prepareRequestType) pl := randomMessage(t, prepareRequestType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.Payload()) require.Equal(t, pl, p.Payload())
require.Equal(t, pl, p.GetPrepareRequest()) require.Equal(t, pl, p.GetPrepareRequest())
pl = randomMessage(t, prepareResponseType) pl = randomMessage(t, prepareResponseType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.GetPrepareResponse()) require.Equal(t, pl, p.GetPrepareResponse())
pl = randomMessage(t, commitType) pl = randomMessage(t, commitType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.GetCommit()) require.Equal(t, pl, p.GetCommit())
pl = randomMessage(t, changeViewType) pl = randomMessage(t, changeViewType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.GetChangeView()) require.Equal(t, pl, p.GetChangeView())
pl = randomMessage(t, recoveryRequestType) pl = randomMessage(t, recoveryRequestType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.GetRecoveryRequest()) require.Equal(t, pl, p.GetRecoveryRequest())
pl = randomMessage(t, recoveryMessageType) pl = randomMessage(t, recoveryMessageType)
p.SetPayload(pl) p.payload = pl
require.Equal(t, pl, p.GetRecoveryMessage()) require.Equal(t, pl, p.GetRecoveryMessage())
} }

View file

@ -47,46 +47,11 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
} }
} }
// Version implements the payload.PrepareRequest interface.
func (p prepareRequest) Version() uint32 {
return p.version
}
// SetVersion implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetVersion(v uint32) {
p.version = v
}
// PrevHash implements the payload.PrepareRequest interface.
func (p prepareRequest) PrevHash() util.Uint256 {
return p.prevHash
}
// SetPrevHash implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetPrevHash(h util.Uint256) {
p.prevHash = h
}
// Timestamp implements the payload.PrepareRequest interface. // Timestamp implements the payload.PrepareRequest interface.
func (p *prepareRequest) Timestamp() uint64 { return p.timestamp * nsInMs } func (p *prepareRequest) Timestamp() uint64 { return p.timestamp * nsInMs }
// SetTimestamp implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetTimestamp(ts uint64) { p.timestamp = ts / nsInMs }
// Nonce implements the payload.PrepareRequest interface. // Nonce implements the payload.PrepareRequest interface.
func (p *prepareRequest) Nonce() uint64 { return p.nonce } func (p *prepareRequest) Nonce() uint64 { return p.nonce }
// SetNonce implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetNonce(nonce uint64) { p.nonce = nonce }
// TransactionHashes implements the payload.PrepareRequest interface. // TransactionHashes implements the payload.PrepareRequest interface.
func (p *prepareRequest) TransactionHashes() []util.Uint256 { return p.transactionHashes } func (p *prepareRequest) TransactionHashes() []util.Uint256 { return p.transactionHashes }
// SetTransactionHashes implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetTransactionHashes(hs []util.Uint256) { p.transactionHashes = hs }
// NextConsensus implements the payload.PrepareRequest interface.
func (p *prepareRequest) NextConsensus() util.Uint160 { return util.Uint160{} }
// SetNextConsensus implements the payload.PrepareRequest interface.
func (p *prepareRequest) SetNextConsensus(_ util.Uint160) {}

View file

@ -10,24 +10,17 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestPrepareRequest_Setters(t *testing.T) { func TestPrepareRequest_Getters(t *testing.T) {
var p prepareRequest hashes := []util.Uint256{random.Uint256(), random.Uint256()}
var p = &prepareRequest{
version: 123,
prevHash: util.Uint256{1, 2, 3},
timestamp: 123,
transactionHashes: hashes,
}
p.SetTimestamp(123) require.EqualValues(t, 123000000, p.Timestamp())
// 123ns -> 0ms -> 0ns require.Equal(t, hashes, p.TransactionHashes())
require.EqualValues(t, 0, p.Timestamp())
p.SetTimestamp(1230000)
// 1230000ns -> 1ms -> 1000000ns
require.EqualValues(t, 1000000, p.Timestamp())
p.SetNextConsensus(util.Uint160{5, 6, 7})
require.Equal(t, util.Uint160{}, p.NextConsensus())
hashes := [2]util.Uint256{random.Uint256(), random.Uint256()}
p.SetTransactionHashes(hashes[:])
require.Equal(t, hashes[:], p.TransactionHashes())
} }
func TestPrepareRequest_EncodeDecodeBinary(t *testing.T) { func TestPrepareRequest_EncodeDecodeBinary(t *testing.T) {

View file

@ -201,7 +201,7 @@ func (m *recoveryMessage) GetPrepareRequest(p dbft.ConsensusPayload[util.Uint256
} }
req := fromPayload(prepareRequestType, p.(*Payload), m.prepareRequest.payload) req := fromPayload(prepareRequestType, p.(*Payload), m.prepareRequest.payload)
req.SetValidatorIndex(primary) req.message.ValidatorIndex = byte(primary)
req.Sender = validators[primary].(*publicKey).GetScriptHash() req.Sender = validators[primary].(*publicKey).GetScriptHash()
req.Witness.InvocationScript = compact.InvocationScript req.Witness.InvocationScript = compact.InvocationScript
req.Witness.VerificationScript = getVerificationScript(uint8(primary), validators) req.Witness.VerificationScript = getVerificationScript(uint8(primary), validators)
@ -221,7 +221,7 @@ func (m *recoveryMessage) GetPrepareResponses(p dbft.ConsensusPayload[util.Uint2
r := fromPayload(prepareResponseType, p.(*Payload), &prepareResponse{ r := fromPayload(prepareResponseType, p.(*Payload), &prepareResponse{
preparationHash: *m.preparationHash, preparationHash: *m.preparationHash,
}) })
r.SetValidatorIndex(uint16(resp.ValidatorIndex)) r.message.ValidatorIndex = resp.ValidatorIndex
r.Sender = validators[resp.ValidatorIndex].(*publicKey).GetScriptHash() r.Sender = validators[resp.ValidatorIndex].(*publicKey).GetScriptHash()
r.Witness.InvocationScript = resp.InvocationScript r.Witness.InvocationScript = resp.InvocationScript
r.Witness.VerificationScript = getVerificationScript(resp.ValidatorIndex, validators) r.Witness.VerificationScript = getVerificationScript(resp.ValidatorIndex, validators)
@ -242,7 +242,7 @@ func (m *recoveryMessage) GetChangeViews(p dbft.ConsensusPayload[util.Uint256],
timestamp: cv.Timestamp, timestamp: cv.Timestamp,
}) })
c.message.ViewNumber = cv.OriginalViewNumber c.message.ViewNumber = cv.OriginalViewNumber
c.SetValidatorIndex(uint16(cv.ValidatorIndex)) c.message.ValidatorIndex = cv.ValidatorIndex
c.Sender = validators[cv.ValidatorIndex].(*publicKey).GetScriptHash() c.Sender = validators[cv.ValidatorIndex].(*publicKey).GetScriptHash()
c.Witness.InvocationScript = cv.InvocationScript c.Witness.InvocationScript = cv.InvocationScript
c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators) c.Witness.VerificationScript = getVerificationScript(cv.ValidatorIndex, validators)
@ -259,7 +259,7 @@ func (m *recoveryMessage) GetCommits(p dbft.ConsensusPayload[util.Uint256], vali
for i, c := range m.commitPayloads { for i, c := range m.commitPayloads {
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature}) cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature})
cc.SetValidatorIndex(uint16(c.ValidatorIndex)) cc.message.ValidatorIndex = c.ValidatorIndex
cc.Sender = validators[c.ValidatorIndex].(*publicKey).GetScriptHash() cc.Sender = validators[c.ValidatorIndex].(*publicKey).GetScriptHash()
cc.Witness.InvocationScript = c.InvocationScript cc.Witness.InvocationScript = c.InvocationScript
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators) cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
@ -275,11 +275,6 @@ func (m *recoveryMessage) PreparationHash() *util.Uint256 {
return m.preparationHash return m.preparationHash
} }
// SetPreparationHash implements the payload.RecoveryMessage interface.
func (m *recoveryMessage) SetPreparationHash(h *util.Uint256) {
m.preparationHash = h
}
func getVerificationScript(i uint8, validators []dbft.PublicKey) []byte { func getVerificationScript(i uint8, validators []dbft.PublicKey) []byte {
if int(i) >= len(validators) { if int(i) >= len(validators) {
return nil return nil

View file

@ -31,9 +31,9 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
r := &recoveryMessage{stateRootEnabled: enableStateRoot} r := &recoveryMessage{stateRootEnabled: enableStateRoot}
p := NewPayload(netmode.UnitTestNet, enableStateRoot) p := NewPayload(netmode.UnitTestNet, enableStateRoot)
p.SetType(dbft.RecoveryMessageType) p.message.Type = messageType(dbft.RecoveryMessageType)
p.SetHeight(msgHeight) p.BlockIndex = msgHeight
p.SetPayload(r) p.payload = r
// sign payload to have verification script // sign payload to have verification script
require.NoError(t, p.Sign(privs[0])) require.NoError(t, p.Sign(privs[0]))
@ -43,21 +43,21 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
stateRootEnabled: enableStateRoot, stateRootEnabled: enableStateRoot,
} }
p1 := NewPayload(netmode.UnitTestNet, enableStateRoot) p1 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p1.SetType(dbft.PrepareRequestType) p1.message.Type = messageType(dbft.PrepareRequestType)
p1.SetHeight(msgHeight) p1.BlockIndex = msgHeight
p1.SetPayload(req) p1.payload = req
p1.SetValidatorIndex(0) p1.message.ValidatorIndex = 0
p1.Sender = privs[0].GetScriptHash() p1.Sender = privs[0].GetScriptHash()
require.NoError(t, p1.Sign(privs[0])) require.NoError(t, p1.Sign(privs[0]))
t.Run("prepare response is added", func(t *testing.T) { t.Run("prepare response is added", func(t *testing.T) {
p2 := NewPayload(netmode.UnitTestNet, enableStateRoot) p2 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p2.SetType(dbft.PrepareResponseType) p2.message.Type = messageType(dbft.PrepareResponseType)
p2.SetHeight(msgHeight) p2.BlockIndex = msgHeight
p2.SetPayload(&prepareResponse{ p2.payload = &prepareResponse{
preparationHash: p1.Hash(), preparationHash: p1.Hash(),
}) }
p2.SetValidatorIndex(1) p2.message.ValidatorIndex = 1
p2.Sender = privs[1].GetScriptHash() p2.Sender = privs[1].GetScriptHash()
require.NoError(t, p2.Sign(privs[1])) require.NoError(t, p2.Sign(privs[1]))
@ -90,13 +90,13 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
t.Run("change view is added", func(t *testing.T) { t.Run("change view is added", func(t *testing.T) {
p3 := NewPayload(netmode.UnitTestNet, enableStateRoot) p3 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p3.SetType(dbft.ChangeViewType) p3.message.Type = messageType(dbft.ChangeViewType)
p3.SetHeight(msgHeight) p3.BlockIndex = msgHeight
p3.SetPayload(&changeView{ p3.payload = &changeView{
newViewNumber: 1, newViewNumber: 1,
timestamp: 12345, timestamp: 12345,
}) }
p3.SetValidatorIndex(3) p3.message.ValidatorIndex = 3
p3.Sender = privs[3].GetScriptHash() p3.Sender = privs[3].GetScriptHash()
require.NoError(t, p3.Sign(privs[3])) require.NoError(t, p3.Sign(privs[3]))
@ -114,10 +114,10 @@ func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) {
t.Run("commit is added", func(t *testing.T) { t.Run("commit is added", func(t *testing.T) {
p4 := NewPayload(netmode.UnitTestNet, enableStateRoot) p4 := NewPayload(netmode.UnitTestNet, enableStateRoot)
p4.SetType(dbft.CommitType) p4.message.Type = messageType(dbft.CommitType)
p4.SetHeight(msgHeight) p4.BlockIndex = msgHeight
p4.SetPayload(randomMessage(t, commitType)) p4.payload = randomMessage(t, commitType)
p4.SetValidatorIndex(3) p4.message.ValidatorIndex = 3
p4.Sender = privs[3].GetScriptHash() p4.Sender = privs[3].GetScriptHash()
require.NoError(t, p4.Sign(privs[3])) require.NoError(t, p4.Sign(privs[3]))

View file

@ -24,6 +24,3 @@ func (m *recoveryRequest) EncodeBinary(w *io.BinWriter) {
// Timestamp implements the payload.RecoveryRequest interface. // Timestamp implements the payload.RecoveryRequest interface.
func (m *recoveryRequest) Timestamp() uint64 { return m.timestamp * nsInMs } func (m *recoveryRequest) Timestamp() uint64 { return m.timestamp * nsInMs }
// SetTimestamp implements the payload.RecoveryRequest interface.
func (m *recoveryRequest) SetTimestamp(ts uint64) { m.timestamp = ts / nsInMs }

View file

@ -6,9 +6,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestRecoveryRequest_Setters(t *testing.T) { func TestRecoveryRequest_Getters(t *testing.T) {
var r recoveryRequest var r = &recoveryRequest{
timestamp: 123,
}
r.SetTimestamp(123 * nsInMs)
require.EqualValues(t, 123*nsInMs, r.Timestamp()) require.EqualValues(t, 123*nsInMs, r.Timestamp())
} }