*: add config flag for enabling state root feature

This commit is contained in:
Evgenii Stratonikov 2020-06-23 09:41:53 +03:00
parent c06b3b669d
commit d128b55dbf
11 changed files with 183 additions and 66 deletions

View file

@ -2,6 +2,7 @@ ProtocolConfiguration:
Magic: 56753
AddressVersion: 23
SecondsPerBlock: 15
EnableStateRoot: true
LowPriorityThreshold: 0.000
MemPoolSize: 50000
StandbyValidators:

View file

@ -20,6 +20,8 @@ const (
type (
ProtocolConfiguration struct {
AddressVersion byte `yaml:"AddressVersion"`
// EnableStateRoot specifies if exchange of state roots should be enabled.
EnableStateRoot bool `yaml:"EnableStateRoot"`
// FeePerExtraByte sets the expected per-byte fee for
// transactions exceeding the MaxFreeTransactionSize.
FeePerExtraByte float64 `yaml:"FeePerExtraByte"`

View file

@ -9,6 +9,8 @@ import (
type commit struct {
signature [signatureSize]byte
stateSig [signatureSize]byte
stateRootEnabled bool
}
// signatureSize is an rfc6989 signature size in bytes
@ -20,13 +22,17 @@ var _ payload.Commit = (*commit)(nil)
// EncodeBinary implements io.Serializable interface.
func (c *commit) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(c.signature[:])
w.WriteBytes(c.stateSig[:])
if c.stateRootEnabled {
w.WriteBytes(c.stateSig[:])
}
}
// DecodeBinary implements io.Serializable interface.
func (c *commit) DecodeBinary(r *io.BinReader) {
r.ReadBytes(c.signature[:])
r.ReadBytes(c.stateSig[:])
if c.stateRootEnabled {
r.ReadBytes(c.stateSig[:])
}
}
// Signature implements payload.Commit interface.

View file

@ -216,35 +216,49 @@ func (s *service) eventLoop() {
func (s *service) newPayload() payload.ConsensusPayload {
return &Payload{
message: new(message),
message: &message{
stateRootEnabled: s.stateRootEnabled(),
},
}
}
// stateRootEnabled checks if state root feature is enabled on current height.
// It should be called only from dbft callbacks and is not protected by any mutex.
func (s *service) stateRootEnabled() bool {
return s.Chain.GetConfig().EnableStateRoot
}
func (s *service) newPrepareRequest() payload.PrepareRequest {
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight())
if err != nil {
if !s.stateRootEnabled() {
return new(prepareRequest)
}
return &prepareRequest{
proposalStateRoot: sr.MPTRootBase,
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight())
if err == nil {
return &prepareRequest{
stateRootEnabled: true,
proposalStateRoot: sr.MPTRootBase,
}
}
return &prepareRequest{stateRootEnabled: true}
}
func (s *service) newCommit() payload.Commit {
if !s.stateRootEnabled() {
return new(commit)
}
c := &commit{stateRootEnabled: true}
for _, p := range s.dbft.Context.PreparationPayloads {
if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType {
pr := p.GetPrepareRequest().(*prepareRequest)
data := pr.proposalStateRoot.GetSignedPart()
sign, err := s.dbft.Priv.Sign(data)
if err == nil {
var c commit
copy(c.stateSig[:], sign)
return &c
}
break
}
}
return new(commit)
return c
}
func (s *service) validatePayload(p *Payload) bool {
@ -299,8 +313,8 @@ func (s *service) OnPayload(cp *Payload) {
// decode payload data into message
if cp.message == nil {
if err := cp.decodeData(); err != nil {
log.Debug("can't decode payload data")
if err := cp.decodeData(s.stateRootEnabled()); err != nil {
log.Debug("can't decode payload data", zap.Error(err))
return
}
}
@ -378,6 +392,9 @@ func (s *service) verifyBlock(b block.Block) bool {
}
func (s *service) verifyRequest(p payload.ConsensusPayload) error {
if !s.stateRootEnabled() {
return nil
}
r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
if err != nil {
return fmt.Errorf("can't get local state root: %v", err)

View file

@ -22,6 +22,8 @@ type (
Type messageType
ViewNumber byte
stateRootEnabled bool
payload io.Serializable
}
@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) {
cv.newViewNumber = m.ViewNumber + 1
m.payload = cv
case prepareRequestType:
m.payload = new(prepareRequest)
m.payload = &prepareRequest{
stateRootEnabled: m.stateRootEnabled,
}
case prepareResponseType:
m.payload = new(prepareResponse)
case commitType:
m.payload = new(commit)
m.payload = &commit{
stateRootEnabled: m.stateRootEnabled,
}
case recoveryRequestType:
m.payload = new(recoveryRequest)
case recoveryMessageType:
m.payload = new(recoveryMessage)
m.payload = &recoveryMessage{
stateRootEnabled: m.stateRootEnabled,
}
default:
r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type))
return
@ -320,8 +328,8 @@ func (t messageType) String() string {
}
// decodeData decodes data of payload into it's message.
func (p *Payload) decodeData() error {
m := new(message)
func (p *Payload) decodeData(stateRootEnabled bool) error {
m := &message{stateRootEnabled: stateRootEnabled}
br := io.NewBinReaderFromBuf(p.data)
m.DecodeBinary(br)
if br.Err != nil {

View file

@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) {
// message is nil after decoding as we didn't yet call decodeData
require.Nil(t, actual.message)
// message should now be decoded from actual.data byte array
assert.NoError(t, actual.decodeData())
assert.NoError(t, actual.decodeData(false))
require.Equal(t, p, actual)
data = p.MarshalUnsigned()
pu := new(Payload)
require.NoError(t, pu.UnmarshalUnsigned(data))
assert.NoError(t, pu.decodeData())
assert.NoError(t, pu.decodeData(false))
p.Witness = transaction.Witness{}
require.Equal(t, p, pu)
@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
p := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, p))
// decode `data` into `message`
assert.NoError(t, p.decodeData())
assert.NoError(t, p.decodeData(false))
require.Equal(t, expected, p)
// invalid type
buf[typeIndex] = 0xFF
actual := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, actual))
require.Error(t, actual.decodeData())
require.Error(t, actual.decodeData(false))
// invalid format
buf[delimeterIndex] = 0
@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
}
func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) {
return func(t *testing.T) {
expected := randomMessage(t, mt, srEnabled)
testserdes.EncodeDecodeBinary(t, expected, actual)
}
}
func TestCommit_Serializable(t *testing.T) {
c := randomMessage(t, commitType)
testserdes.EncodeDecodeBinary(t, c, new(commit))
t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false}))
}
func TestPrepareResponse_Serializable(t *testing.T) {
@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) {
}
func TestPrepareRequest_Serializable(t *testing.T) {
req := randomMessage(t, prepareRequestType)
testserdes.EncodeDecodeBinary(t, req, new(prepareRequest))
t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false}))
}
func TestRecoveryRequest_Serializable(t *testing.T) {
@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) {
}
func TestRecoveryMessage_Serializable(t *testing.T) {
msg := randomMessage(t, recoveryMessageType)
testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage))
t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true}))
t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false}))
}
func randomPayload(t *testing.T, mt messageType) *Payload {
@ -215,32 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload {
return p
}
func randomMessage(t *testing.T, mt messageType) io.Serializable {
func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable {
switch mt {
case changeViewType:
return &changeView{
timestamp: rand.Uint32(),
}
case prepareRequestType:
return randomPrepareRequest(t)
return randomPrepareRequest(t, srEnabled...)
case prepareResponseType:
return &prepareResponse{preparationHash: random.Uint256()}
case commitType:
var c commit
random.Fill(c.signature[:])
random.Fill(c.stateSig[:])
if len(srEnabled) > 0 && srEnabled[0] {
c.stateRootEnabled = true
random.Fill(c.stateSig[:])
}
return &c
case recoveryRequestType:
return &recoveryRequest{timestamp: rand.Uint32()}
case recoveryMessageType:
return randomRecoveryMessage(t)
return randomRecoveryMessage(t, srEnabled...)
default:
require.Fail(t, "invalid type")
return nil
}
}
func randomPrepareRequest(t *testing.T) *prepareRequest {
func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest {
const txCount = 3
req := &prepareRequest{
@ -256,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest {
}
req.nextConsensus = random.Uint160()
if len(srEnabled) > 0 && srEnabled[0] {
req.stateRootEnabled = true
req.proposalStateRoot.Index = rand.Uint32()
req.proposalStateRoot.PrevHash = random.Uint256()
req.proposalStateRoot.Root = random.Uint256()
}
return req
}
func randomRecoveryMessage(t *testing.T) *recoveryMessage {
result := randomMessage(t, prepareRequestType)
func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage {
result := randomMessage(t, prepareRequestType, srEnabled...)
require.IsType(t, (*prepareRequest)(nil), result)
prepReq := result.(*prepareRequest)
return &recoveryMessage{
rec := &recoveryMessage{
preparationPayloads: []*preparationCompact{
{
ValidatorIndex: 1,
@ -276,14 +293,12 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
ViewNumber: 0,
ValidatorIndex: 1,
Signature: [64]byte{1, 2, 3},
StateSignature: [64]byte{4, 5, 6},
InvocationScript: random.Bytes(20),
},
{
ViewNumber: 0,
ValidatorIndex: 2,
Signature: [64]byte{11, 3, 4, 98},
StateSignature: [64]byte{4, 8, 15, 16, 23, 42},
InvocationScript: random.Bytes(10),
},
},
@ -300,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
payload: prepReq,
},
}
if len(srEnabled) > 0 && srEnabled[0] {
rec.stateRootEnabled = true
rec.prepareRequest.stateRootEnabled = true
for _, c := range rec.commitPayloads {
c.stateRootEnabled = true
random.Fill(c.StateSignature[:])
}
}
return rec
}
func TestPayload_Sign(t *testing.T) {

View file

@ -16,6 +16,8 @@ type prepareRequest struct {
minerTx transaction.Transaction
nextConsensus util.Uint160
proposalStateRoot state.MPTRootBase
stateRootEnabled bool
}
var _ payload.PrepareRequest = (*prepareRequest)(nil)
@ -27,7 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(p.nextConsensus[:])
w.WriteArray(p.transactionHashes)
p.minerTx.EncodeBinary(w)
p.proposalStateRoot.EncodeBinary(w)
if p.stateRootEnabled {
p.proposalStateRoot.EncodeBinary(w)
}
}
// DecodeBinary implements io.Serializable interface.
@ -37,7 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
r.ReadBytes(p.nextConsensus[:])
r.ReadArray(&p.transactionHashes)
p.minerTx.DecodeBinary(r)
p.proposalStateRoot.DecodeBinary(r)
if p.stateRootEnabled {
p.proposalStateRoot.DecodeBinary(r)
}
}
// Timestamp implements payload.PrepareRequest interface.

View file

@ -3,6 +3,7 @@ package consensus
import (
"github.com/nspcc-dev/dbft/crypto"
"github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/pkg/errors"
@ -16,6 +17,8 @@ type (
commitPayloads []*commitCompact
changeViewPayloads []*changeViewCompact
prepareRequest *message
stateRootEnabled bool
}
changeViewCompact struct {
@ -31,6 +34,8 @@ type (
Signature [signatureSize]byte
StateSignature [signatureSize]byte
InvocationScript []byte
stateRootEnabled bool
}
preparationCompact struct {
@ -47,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
var hasReq = r.ReadBool()
if hasReq {
m.prepareRequest = new(message)
m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled}
m.prepareRequest.DecodeBinary(r)
if r.Err == nil && m.prepareRequest.Type != prepareRequestType {
r.Err = errors.New("recovery message PrepareRequest has wrong type")
@ -68,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
}
r.ReadArray(&m.preparationPayloads)
r.ReadArray(&m.commitPayloads)
lu := r.ReadVarUint()
if lu > state.MaxValidatorsVoted {
r.Err = errors.New("too many preparation payloads")
return
}
m.commitPayloads = make([]*commitCompact, lu)
for i := uint64(0); i < lu; i++ {
m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled}
m.commitPayloads[i].DecodeBinary(r)
}
}
// EncodeBinary implements io.Serializable interface.
@ -113,7 +127,9 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) {
p.ViewNumber = r.ReadB()
p.ValidatorIndex = r.ReadU16LE()
r.ReadBytes(p.Signature[:])
r.ReadBytes(p.StateSignature[:])
if p.stateRootEnabled {
r.ReadBytes(p.StateSignature[:])
}
p.InvocationScript = r.ReadVarBytes(1024)
}
@ -122,7 +138,9 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) {
w.WriteB(p.ViewNumber)
w.WriteU16LE(p.ValidatorIndex)
w.WriteBytes(p.Signature[:])
w.WriteBytes(p.StateSignature[:])
if p.stateRootEnabled {
w.WriteBytes(p.StateSignature[:])
}
w.WriteVarBytes(p.InvocationScript)
}
@ -146,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
Type: prepareRequestType,
ViewNumber: p.ViewNumber(),
payload: p.GetPrepareRequest().(*prepareRequest),
stateRootEnabled: m.stateRootEnabled,
}
h := p.Hash()
m.preparationHash = &h
@ -172,6 +192,7 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
})
case payload.CommitType:
m.commitPayloads = append(m.commitPayloads, &commitCompact{
stateRootEnabled: m.stateRootEnabled,
ValidatorIndex: p.ValidatorIndex(),
ViewNumber: p.ViewNumber(),
Signature: p.GetCommit().(*commit).signature,
@ -254,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr
ps := make([]payload.ConsensusPayload, len(m.commitPayloads))
for i, c := range m.commitPayloads {
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature, stateSig: c.StateSignature})
cc := fromPayload(commitType, p.(*Payload), &commit{
signature: c.Signature,
stateSig: c.StateSignature,
stateRootEnabled: m.stateRootEnabled,
})
cc.SetValidatorIndex(c.ValidatorIndex)
cc.Witness.InvocationScript = c.InvocationScript
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
@ -294,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
Type: t,
ViewNumber: recovery.message.ViewNumber,
payload: p,
stateRootEnabled: recovery.stateRootEnabled,
},
version: recovery.Version(),
prevHash: recovery.PrevHash(),

View file

@ -231,8 +231,10 @@ func (bc *Blockchain) init() error {
}
bc.blockHeight = bHeight
bc.persistedHeight = bHeight
if err = bc.dao.InitMPT(bHeight); err != nil {
return errors.Wrapf(err, "can't init MPT at height %d", bHeight)
if bc.config.EnableStateRoot {
if err = bc.dao.InitMPT(bHeight); err != nil {
return errors.Wrapf(err, "can't init MPT at height %d", bHeight)
}
}
hashes, err := bc.dao.GetHeaderHashes()
@ -558,12 +560,18 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 {
// GetStateProof returns proof of having key in the MPT with the specified root.
func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store))
return tr.GetProof(key)
}
// GetStateRoot returns state root for a given height.
func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
return bc.dao.GetStateRoot(height)
}
@ -835,24 +843,26 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
}
}
root := bc.dao.MPT.StateRoot()
var prevHash util.Uint256
if block.Index > 0 {
prev, err := bc.dao.GetStateRoot(block.Index - 1)
if err != nil {
return errors.WithMessagef(err, "can't get previous state root")
if bc.config.EnableStateRoot {
root := bc.dao.MPT.StateRoot()
var prevHash util.Uint256
if block.Index > 0 {
prev, err := bc.dao.GetStateRoot(block.Index - 1)
if err != nil {
return errors.WithMessagef(err, "can't get previous state root")
}
prevHash = hash.DoubleSha256(prev.GetSignedPart())
}
err := bc.AddStateRoot(&state.MPTRoot{
MPTRootBase: state.MPTRootBase{
Index: block.Index,
PrevHash: prevHash,
Root: root,
},
})
if err != nil {
return err
}
prevHash = hash.DoubleSha256(prev.GetSignedPart())
}
err := bc.AddStateRoot(&state.MPTRoot{
MPTRootBase: state.MPTRootBase{
Index: block.Index,
PrevHash: prevHash,
Root: root,
},
})
if err != nil {
return err
}
if bc.config.SaveStorageBatch {
@ -860,12 +870,14 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
}
bc.lock.Lock()
_, err = cache.Persist()
_, err := cache.Persist()
if err != nil {
bc.lock.Unlock()
return err
}
bc.dao.MPT.Flush()
if bc.config.EnableStateRoot {
bc.dao.MPT.Flush()
}
// Every persist cycle we also compact our in-memory MPT.
persistedHeight := atomic.LoadUint32(&bc.persistedHeight)
if persistedHeight == block.Index-1 {
@ -1784,6 +1796,10 @@ func (bc *Blockchain) StateHeight() uint32 {
// AddStateRoot add new (possibly unverified) state root to the blockchain.
func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error {
if !bc.config.EnableStateRoot {
bc.log.Warn("state root is being added but not enabled in config")
return nil
}
our, err := bc.GetStateRoot(r.Index)
if err == nil {
if our.Flag == state.Verified {

View file

@ -602,6 +602,9 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error {
// handleGetRootsCmd processees `getroots` request.
func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
count := gr.Count
if count > payload.MaxStateRootsAllowed {
count = payload.MaxStateRootsAllowed
@ -621,6 +624,9 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
// handleStateRootsCmd processees `roots` request.
func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
h := s.chain.StateHeight()
for i := range rs.Roots {
if rs.Roots[i].Index <= h {
@ -652,6 +658,9 @@ func (s *Server) requestStateRoot(p Peer) error {
// handleStateRootCmd processees `stateroot` request.
func (s *Server) handleStateRootCmd(r *state.MPTRoot) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
// we ignore error, because there is nothing wrong if we already have this state root
err := s.chain.AddStateRoot(r)
if err == nil && !s.stateCache.Has(r.Hash()) {

View file

@ -251,7 +251,7 @@ func (p *TCPPeer) StartProtocol() {
if p.LastBlockIndex() > p.server.chain.BlockHeight() {
err = p.server.requestBlocks(p)
}
if err == nil {
if err == nil && p.server.chain.GetConfig().EnableStateRoot {
err = p.server.requestStateRoot(p)
}
if err == nil {