*: add config flag for enabling state root feature

This commit is contained in:
Evgenii Stratonikov 2020-06-23 09:41:53 +03:00
parent c06b3b669d
commit d128b55dbf
11 changed files with 183 additions and 66 deletions

View file

@ -2,6 +2,7 @@ ProtocolConfiguration:
Magic: 56753 Magic: 56753
AddressVersion: 23 AddressVersion: 23
SecondsPerBlock: 15 SecondsPerBlock: 15
EnableStateRoot: true
LowPriorityThreshold: 0.000 LowPriorityThreshold: 0.000
MemPoolSize: 50000 MemPoolSize: 50000
StandbyValidators: StandbyValidators:

View file

@ -20,6 +20,8 @@ const (
type ( type (
ProtocolConfiguration struct { ProtocolConfiguration struct {
AddressVersion byte `yaml:"AddressVersion"` AddressVersion byte `yaml:"AddressVersion"`
// EnableStateRoot specifies if exchange of state roots should be enabled.
EnableStateRoot bool `yaml:"EnableStateRoot"`
// FeePerExtraByte sets the expected per-byte fee for // FeePerExtraByte sets the expected per-byte fee for
// transactions exceeding the MaxFreeTransactionSize. // transactions exceeding the MaxFreeTransactionSize.
FeePerExtraByte float64 `yaml:"FeePerExtraByte"` FeePerExtraByte float64 `yaml:"FeePerExtraByte"`

View file

@ -9,6 +9,8 @@ import (
type commit struct { type commit struct {
signature [signatureSize]byte signature [signatureSize]byte
stateSig [signatureSize]byte stateSig [signatureSize]byte
stateRootEnabled bool
} }
// signatureSize is an rfc6989 signature size in bytes // signatureSize is an rfc6989 signature size in bytes
@ -20,13 +22,17 @@ var _ payload.Commit = (*commit)(nil)
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
func (c *commit) EncodeBinary(w *io.BinWriter) { func (c *commit) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(c.signature[:]) w.WriteBytes(c.signature[:])
if c.stateRootEnabled {
w.WriteBytes(c.stateSig[:]) w.WriteBytes(c.stateSig[:])
}
} }
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.
func (c *commit) DecodeBinary(r *io.BinReader) { func (c *commit) DecodeBinary(r *io.BinReader) {
r.ReadBytes(c.signature[:]) r.ReadBytes(c.signature[:])
if c.stateRootEnabled {
r.ReadBytes(c.stateSig[:]) r.ReadBytes(c.stateSig[:])
}
} }
// Signature implements payload.Commit interface. // Signature implements payload.Commit interface.

View file

@ -216,35 +216,49 @@ func (s *service) eventLoop() {
func (s *service) newPayload() payload.ConsensusPayload { func (s *service) newPayload() payload.ConsensusPayload {
return &Payload{ return &Payload{
message: new(message), message: &message{
stateRootEnabled: s.stateRootEnabled(),
},
} }
} }
// stateRootEnabled checks if state root feature is enabled on current height.
// It should be called only from dbft callbacks and is not protected by any mutex.
func (s *service) stateRootEnabled() bool {
return s.Chain.GetConfig().EnableStateRoot
}
func (s *service) newPrepareRequest() payload.PrepareRequest { func (s *service) newPrepareRequest() payload.PrepareRequest {
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight()) if !s.stateRootEnabled() {
if err != nil {
return new(prepareRequest) return new(prepareRequest)
} }
sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight())
if err == nil {
return &prepareRequest{ return &prepareRequest{
stateRootEnabled: true,
proposalStateRoot: sr.MPTRootBase, proposalStateRoot: sr.MPTRootBase,
} }
}
return &prepareRequest{stateRootEnabled: true}
} }
func (s *service) newCommit() payload.Commit { func (s *service) newCommit() payload.Commit {
if !s.stateRootEnabled() {
return new(commit)
}
c := &commit{stateRootEnabled: true}
for _, p := range s.dbft.Context.PreparationPayloads { for _, p := range s.dbft.Context.PreparationPayloads {
if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType { if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType {
pr := p.GetPrepareRequest().(*prepareRequest) pr := p.GetPrepareRequest().(*prepareRequest)
data := pr.proposalStateRoot.GetSignedPart() data := pr.proposalStateRoot.GetSignedPart()
sign, err := s.dbft.Priv.Sign(data) sign, err := s.dbft.Priv.Sign(data)
if err == nil { if err == nil {
var c commit
copy(c.stateSig[:], sign) copy(c.stateSig[:], sign)
return &c
} }
break break
} }
} }
return new(commit) return c
} }
func (s *service) validatePayload(p *Payload) bool { func (s *service) validatePayload(p *Payload) bool {
@ -299,8 +313,8 @@ func (s *service) OnPayload(cp *Payload) {
// decode payload data into message // decode payload data into message
if cp.message == nil { if cp.message == nil {
if err := cp.decodeData(); err != nil { if err := cp.decodeData(s.stateRootEnabled()); err != nil {
log.Debug("can't decode payload data") log.Debug("can't decode payload data", zap.Error(err))
return return
} }
} }
@ -378,6 +392,9 @@ func (s *service) verifyBlock(b block.Block) bool {
} }
func (s *service) verifyRequest(p payload.ConsensusPayload) error { func (s *service) verifyRequest(p payload.ConsensusPayload) error {
if !s.stateRootEnabled() {
return nil
}
r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1)
if err != nil { if err != nil {
return fmt.Errorf("can't get local state root: %v", err) return fmt.Errorf("can't get local state root: %v", err)

View file

@ -22,6 +22,8 @@ type (
Type messageType Type messageType
ViewNumber byte ViewNumber byte
stateRootEnabled bool
payload io.Serializable payload io.Serializable
} }
@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) {
cv.newViewNumber = m.ViewNumber + 1 cv.newViewNumber = m.ViewNumber + 1
m.payload = cv m.payload = cv
case prepareRequestType: case prepareRequestType:
m.payload = new(prepareRequest) m.payload = &prepareRequest{
stateRootEnabled: m.stateRootEnabled,
}
case prepareResponseType: case prepareResponseType:
m.payload = new(prepareResponse) m.payload = new(prepareResponse)
case commitType: case commitType:
m.payload = new(commit) m.payload = &commit{
stateRootEnabled: m.stateRootEnabled,
}
case recoveryRequestType: case recoveryRequestType:
m.payload = new(recoveryRequest) m.payload = new(recoveryRequest)
case recoveryMessageType: case recoveryMessageType:
m.payload = new(recoveryMessage) m.payload = &recoveryMessage{
stateRootEnabled: m.stateRootEnabled,
}
default: default:
r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type)) r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type))
return return
@ -320,8 +328,8 @@ func (t messageType) String() string {
} }
// decodeData decodes data of payload into it's message. // decodeData decodes data of payload into it's message.
func (p *Payload) decodeData() error { func (p *Payload) decodeData(stateRootEnabled bool) error {
m := new(message) m := &message{stateRootEnabled: stateRootEnabled}
br := io.NewBinReaderFromBuf(p.data) br := io.NewBinReaderFromBuf(p.data)
m.DecodeBinary(br) m.DecodeBinary(br)
if br.Err != nil { if br.Err != nil {

View file

@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) {
// message is nil after decoding as we didn't yet call decodeData // message is nil after decoding as we didn't yet call decodeData
require.Nil(t, actual.message) require.Nil(t, actual.message)
// message should now be decoded from actual.data byte array // message should now be decoded from actual.data byte array
assert.NoError(t, actual.decodeData()) assert.NoError(t, actual.decodeData(false))
require.Equal(t, p, actual) require.Equal(t, p, actual)
data = p.MarshalUnsigned() data = p.MarshalUnsigned()
pu := new(Payload) pu := new(Payload)
require.NoError(t, pu.UnmarshalUnsigned(data)) require.NoError(t, pu.UnmarshalUnsigned(data))
assert.NoError(t, pu.decodeData()) assert.NoError(t, pu.decodeData(false))
p.Witness = transaction.Witness{} p.Witness = transaction.Witness{}
require.Equal(t, p, pu) require.Equal(t, p, pu)
@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
p := new(Payload) p := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, p)) require.NoError(t, testserdes.DecodeBinary(buf, p))
// decode `data` into `message` // decode `data` into `message`
assert.NoError(t, p.decodeData()) assert.NoError(t, p.decodeData(false))
require.Equal(t, expected, p) require.Equal(t, expected, p)
// invalid type // invalid type
buf[typeIndex] = 0xFF buf[typeIndex] = 0xFF
actual := new(Payload) actual := new(Payload)
require.NoError(t, testserdes.DecodeBinary(buf, actual)) require.NoError(t, testserdes.DecodeBinary(buf, actual))
require.Error(t, actual.decodeData()) require.Error(t, actual.decodeData(false))
// invalid format // invalid format
buf[delimeterIndex] = 0 buf[delimeterIndex] = 0
@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) {
require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) require.Error(t, testserdes.DecodeBinary(buf, new(Payload)))
} }
func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) {
return func(t *testing.T) {
expected := randomMessage(t, mt, srEnabled)
testserdes.EncodeDecodeBinary(t, expected, actual)
}
}
func TestCommit_Serializable(t *testing.T) { func TestCommit_Serializable(t *testing.T) {
c := randomMessage(t, commitType) t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true}))
testserdes.EncodeDecodeBinary(t, c, new(commit)) t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false}))
} }
func TestPrepareResponse_Serializable(t *testing.T) { func TestPrepareResponse_Serializable(t *testing.T) {
@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) {
} }
func TestPrepareRequest_Serializable(t *testing.T) { func TestPrepareRequest_Serializable(t *testing.T) {
req := randomMessage(t, prepareRequestType) t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true}))
testserdes.EncodeDecodeBinary(t, req, new(prepareRequest)) t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false}))
} }
func TestRecoveryRequest_Serializable(t *testing.T) { func TestRecoveryRequest_Serializable(t *testing.T) {
@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) {
} }
func TestRecoveryMessage_Serializable(t *testing.T) { func TestRecoveryMessage_Serializable(t *testing.T) {
msg := randomMessage(t, recoveryMessageType) t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true}))
testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage)) t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false}))
} }
func randomPayload(t *testing.T, mt messageType) *Payload { func randomPayload(t *testing.T, mt messageType) *Payload {
@ -215,32 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload {
return p return p
} }
func randomMessage(t *testing.T, mt messageType) io.Serializable { func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable {
switch mt { switch mt {
case changeViewType: case changeViewType:
return &changeView{ return &changeView{
timestamp: rand.Uint32(), timestamp: rand.Uint32(),
} }
case prepareRequestType: case prepareRequestType:
return randomPrepareRequest(t) return randomPrepareRequest(t, srEnabled...)
case prepareResponseType: case prepareResponseType:
return &prepareResponse{preparationHash: random.Uint256()} return &prepareResponse{preparationHash: random.Uint256()}
case commitType: case commitType:
var c commit var c commit
random.Fill(c.signature[:]) random.Fill(c.signature[:])
if len(srEnabled) > 0 && srEnabled[0] {
c.stateRootEnabled = true
random.Fill(c.stateSig[:]) random.Fill(c.stateSig[:])
}
return &c return &c
case recoveryRequestType: case recoveryRequestType:
return &recoveryRequest{timestamp: rand.Uint32()} return &recoveryRequest{timestamp: rand.Uint32()}
case recoveryMessageType: case recoveryMessageType:
return randomRecoveryMessage(t) return randomRecoveryMessage(t, srEnabled...)
default: default:
require.Fail(t, "invalid type") require.Fail(t, "invalid type")
return nil return nil
} }
} }
func randomPrepareRequest(t *testing.T) *prepareRequest { func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest {
const txCount = 3 const txCount = 3
req := &prepareRequest{ req := &prepareRequest{
@ -256,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest {
} }
req.nextConsensus = random.Uint160() req.nextConsensus = random.Uint160()
if len(srEnabled) > 0 && srEnabled[0] {
req.stateRootEnabled = true
req.proposalStateRoot.Index = rand.Uint32()
req.proposalStateRoot.PrevHash = random.Uint256()
req.proposalStateRoot.Root = random.Uint256()
}
return req return req
} }
func randomRecoveryMessage(t *testing.T) *recoveryMessage { func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage {
result := randomMessage(t, prepareRequestType) result := randomMessage(t, prepareRequestType, srEnabled...)
require.IsType(t, (*prepareRequest)(nil), result) require.IsType(t, (*prepareRequest)(nil), result)
prepReq := result.(*prepareRequest) prepReq := result.(*prepareRequest)
return &recoveryMessage{ rec := &recoveryMessage{
preparationPayloads: []*preparationCompact{ preparationPayloads: []*preparationCompact{
{ {
ValidatorIndex: 1, ValidatorIndex: 1,
@ -276,14 +293,12 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
ViewNumber: 0, ViewNumber: 0,
ValidatorIndex: 1, ValidatorIndex: 1,
Signature: [64]byte{1, 2, 3}, Signature: [64]byte{1, 2, 3},
StateSignature: [64]byte{4, 5, 6},
InvocationScript: random.Bytes(20), InvocationScript: random.Bytes(20),
}, },
{ {
ViewNumber: 0, ViewNumber: 0,
ValidatorIndex: 2, ValidatorIndex: 2,
Signature: [64]byte{11, 3, 4, 98}, Signature: [64]byte{11, 3, 4, 98},
StateSignature: [64]byte{4, 8, 15, 16, 23, 42},
InvocationScript: random.Bytes(10), InvocationScript: random.Bytes(10),
}, },
}, },
@ -300,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage {
payload: prepReq, payload: prepReq,
}, },
} }
if len(srEnabled) > 0 && srEnabled[0] {
rec.stateRootEnabled = true
rec.prepareRequest.stateRootEnabled = true
for _, c := range rec.commitPayloads {
c.stateRootEnabled = true
random.Fill(c.StateSignature[:])
}
}
return rec
} }
func TestPayload_Sign(t *testing.T) { func TestPayload_Sign(t *testing.T) {

View file

@ -16,6 +16,8 @@ type prepareRequest struct {
minerTx transaction.Transaction minerTx transaction.Transaction
nextConsensus util.Uint160 nextConsensus util.Uint160
proposalStateRoot state.MPTRootBase proposalStateRoot state.MPTRootBase
stateRootEnabled bool
} }
var _ payload.PrepareRequest = (*prepareRequest)(nil) var _ payload.PrepareRequest = (*prepareRequest)(nil)
@ -27,7 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) {
w.WriteBytes(p.nextConsensus[:]) w.WriteBytes(p.nextConsensus[:])
w.WriteArray(p.transactionHashes) w.WriteArray(p.transactionHashes)
p.minerTx.EncodeBinary(w) p.minerTx.EncodeBinary(w)
if p.stateRootEnabled {
p.proposalStateRoot.EncodeBinary(w) p.proposalStateRoot.EncodeBinary(w)
}
} }
// DecodeBinary implements io.Serializable interface. // DecodeBinary implements io.Serializable interface.
@ -37,7 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) {
r.ReadBytes(p.nextConsensus[:]) r.ReadBytes(p.nextConsensus[:])
r.ReadArray(&p.transactionHashes) r.ReadArray(&p.transactionHashes)
p.minerTx.DecodeBinary(r) p.minerTx.DecodeBinary(r)
if p.stateRootEnabled {
p.proposalStateRoot.DecodeBinary(r) p.proposalStateRoot.DecodeBinary(r)
}
} }
// Timestamp implements payload.PrepareRequest interface. // Timestamp implements payload.PrepareRequest interface.

View file

@ -3,6 +3,7 @@ package consensus
import ( import (
"github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/crypto"
"github.com/nspcc-dev/dbft/payload" "github.com/nspcc-dev/dbft/payload"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/util"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -16,6 +17,8 @@ type (
commitPayloads []*commitCompact commitPayloads []*commitCompact
changeViewPayloads []*changeViewCompact changeViewPayloads []*changeViewCompact
prepareRequest *message prepareRequest *message
stateRootEnabled bool
} }
changeViewCompact struct { changeViewCompact struct {
@ -31,6 +34,8 @@ type (
Signature [signatureSize]byte Signature [signatureSize]byte
StateSignature [signatureSize]byte StateSignature [signatureSize]byte
InvocationScript []byte InvocationScript []byte
stateRootEnabled bool
} }
preparationCompact struct { preparationCompact struct {
@ -47,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
var hasReq = r.ReadBool() var hasReq = r.ReadBool()
if hasReq { if hasReq {
m.prepareRequest = new(message) m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled}
m.prepareRequest.DecodeBinary(r) m.prepareRequest.DecodeBinary(r)
if r.Err == nil && m.prepareRequest.Type != prepareRequestType { if r.Err == nil && m.prepareRequest.Type != prepareRequestType {
r.Err = errors.New("recovery message PrepareRequest has wrong type") r.Err = errors.New("recovery message PrepareRequest has wrong type")
@ -68,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) {
} }
r.ReadArray(&m.preparationPayloads) r.ReadArray(&m.preparationPayloads)
r.ReadArray(&m.commitPayloads) lu := r.ReadVarUint()
if lu > state.MaxValidatorsVoted {
r.Err = errors.New("too many preparation payloads")
return
}
m.commitPayloads = make([]*commitCompact, lu)
for i := uint64(0); i < lu; i++ {
m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled}
m.commitPayloads[i].DecodeBinary(r)
}
} }
// EncodeBinary implements io.Serializable interface. // EncodeBinary implements io.Serializable interface.
@ -113,7 +127,9 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) {
p.ViewNumber = r.ReadB() p.ViewNumber = r.ReadB()
p.ValidatorIndex = r.ReadU16LE() p.ValidatorIndex = r.ReadU16LE()
r.ReadBytes(p.Signature[:]) r.ReadBytes(p.Signature[:])
if p.stateRootEnabled {
r.ReadBytes(p.StateSignature[:]) r.ReadBytes(p.StateSignature[:])
}
p.InvocationScript = r.ReadVarBytes(1024) p.InvocationScript = r.ReadVarBytes(1024)
} }
@ -122,7 +138,9 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) {
w.WriteB(p.ViewNumber) w.WriteB(p.ViewNumber)
w.WriteU16LE(p.ValidatorIndex) w.WriteU16LE(p.ValidatorIndex)
w.WriteBytes(p.Signature[:]) w.WriteBytes(p.Signature[:])
if p.stateRootEnabled {
w.WriteBytes(p.StateSignature[:]) w.WriteBytes(p.StateSignature[:])
}
w.WriteVarBytes(p.InvocationScript) w.WriteVarBytes(p.InvocationScript)
} }
@ -146,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
Type: prepareRequestType, Type: prepareRequestType,
ViewNumber: p.ViewNumber(), ViewNumber: p.ViewNumber(),
payload: p.GetPrepareRequest().(*prepareRequest), payload: p.GetPrepareRequest().(*prepareRequest),
stateRootEnabled: m.stateRootEnabled,
} }
h := p.Hash() h := p.Hash()
m.preparationHash = &h m.preparationHash = &h
@ -172,6 +192,7 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) {
}) })
case payload.CommitType: case payload.CommitType:
m.commitPayloads = append(m.commitPayloads, &commitCompact{ m.commitPayloads = append(m.commitPayloads, &commitCompact{
stateRootEnabled: m.stateRootEnabled,
ValidatorIndex: p.ValidatorIndex(), ValidatorIndex: p.ValidatorIndex(),
ViewNumber: p.ViewNumber(), ViewNumber: p.ViewNumber(),
Signature: p.GetCommit().(*commit).signature, Signature: p.GetCommit().(*commit).signature,
@ -254,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr
ps := make([]payload.ConsensusPayload, len(m.commitPayloads)) ps := make([]payload.ConsensusPayload, len(m.commitPayloads))
for i, c := range m.commitPayloads { for i, c := range m.commitPayloads {
cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature, stateSig: c.StateSignature}) cc := fromPayload(commitType, p.(*Payload), &commit{
signature: c.Signature,
stateSig: c.StateSignature,
stateRootEnabled: m.stateRootEnabled,
})
cc.SetValidatorIndex(c.ValidatorIndex) cc.SetValidatorIndex(c.ValidatorIndex)
cc.Witness.InvocationScript = c.InvocationScript cc.Witness.InvocationScript = c.InvocationScript
cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators) cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators)
@ -294,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload {
Type: t, Type: t,
ViewNumber: recovery.message.ViewNumber, ViewNumber: recovery.message.ViewNumber,
payload: p, payload: p,
stateRootEnabled: recovery.stateRootEnabled,
}, },
version: recovery.Version(), version: recovery.Version(),
prevHash: recovery.PrevHash(), prevHash: recovery.PrevHash(),

View file

@ -231,9 +231,11 @@ func (bc *Blockchain) init() error {
} }
bc.blockHeight = bHeight bc.blockHeight = bHeight
bc.persistedHeight = bHeight bc.persistedHeight = bHeight
if bc.config.EnableStateRoot {
if err = bc.dao.InitMPT(bHeight); err != nil { if err = bc.dao.InitMPT(bHeight); err != nil {
return errors.Wrapf(err, "can't init MPT at height %d", bHeight) return errors.Wrapf(err, "can't init MPT at height %d", bHeight)
} }
}
hashes, err := bc.dao.GetHeaderHashes() hashes, err := bc.dao.GetHeaderHashes()
if err != nil { if err != nil {
@ -558,12 +560,18 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 {
// GetStateProof returns proof of having key in the MPT with the specified root. // GetStateProof returns proof of having key in the MPT with the specified root.
func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) { func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store)) tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store))
return tr.GetProof(key) return tr.GetProof(key)
} }
// GetStateRoot returns state root for a given height. // GetStateRoot returns state root for a given height.
func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) {
if !bc.config.EnableStateRoot {
return nil, errors.New("state root feature is not enabled")
}
return bc.dao.GetStateRoot(height) return bc.dao.GetStateRoot(height)
} }
@ -835,6 +843,7 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
} }
} }
if bc.config.EnableStateRoot {
root := bc.dao.MPT.StateRoot() root := bc.dao.MPT.StateRoot()
var prevHash util.Uint256 var prevHash util.Uint256
if block.Index > 0 { if block.Index > 0 {
@ -854,18 +863,21 @@ func (bc *Blockchain) storeBlock(block *block.Block) error {
if err != nil { if err != nil {
return err return err
} }
}
if bc.config.SaveStorageBatch { if bc.config.SaveStorageBatch {
bc.lastBatch = cache.DAO.GetBatch() bc.lastBatch = cache.DAO.GetBatch()
} }
bc.lock.Lock() bc.lock.Lock()
_, err = cache.Persist() _, err := cache.Persist()
if err != nil { if err != nil {
bc.lock.Unlock() bc.lock.Unlock()
return err return err
} }
if bc.config.EnableStateRoot {
bc.dao.MPT.Flush() bc.dao.MPT.Flush()
}
// Every persist cycle we also compact our in-memory MPT. // Every persist cycle we also compact our in-memory MPT.
persistedHeight := atomic.LoadUint32(&bc.persistedHeight) persistedHeight := atomic.LoadUint32(&bc.persistedHeight)
if persistedHeight == block.Index-1 { if persistedHeight == block.Index-1 {
@ -1784,6 +1796,10 @@ func (bc *Blockchain) StateHeight() uint32 {
// AddStateRoot add new (possibly unverified) state root to the blockchain. // AddStateRoot add new (possibly unverified) state root to the blockchain.
func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error {
if !bc.config.EnableStateRoot {
bc.log.Warn("state root is being added but not enabled in config")
return nil
}
our, err := bc.GetStateRoot(r.Index) our, err := bc.GetStateRoot(r.Index)
if err == nil { if err == nil {
if our.Flag == state.Verified { if our.Flag == state.Verified {

View file

@ -602,6 +602,9 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error {
// handleGetRootsCmd processees `getroots` request. // handleGetRootsCmd processees `getroots` request.
func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
count := gr.Count count := gr.Count
if count > payload.MaxStateRootsAllowed { if count > payload.MaxStateRootsAllowed {
count = payload.MaxStateRootsAllowed count = payload.MaxStateRootsAllowed
@ -621,6 +624,9 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error {
// handleStateRootsCmd processees `roots` request. // handleStateRootsCmd processees `roots` request.
func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
h := s.chain.StateHeight() h := s.chain.StateHeight()
for i := range rs.Roots { for i := range rs.Roots {
if rs.Roots[i].Index <= h { if rs.Roots[i].Index <= h {
@ -652,6 +658,9 @@ func (s *Server) requestStateRoot(p Peer) error {
// handleStateRootCmd processees `stateroot` request. // handleStateRootCmd processees `stateroot` request.
func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { func (s *Server) handleStateRootCmd(r *state.MPTRoot) error {
if !s.chain.GetConfig().EnableStateRoot {
return nil
}
// we ignore error, because there is nothing wrong if we already have this state root // we ignore error, because there is nothing wrong if we already have this state root
err := s.chain.AddStateRoot(r) err := s.chain.AddStateRoot(r)
if err == nil && !s.stateCache.Has(r.Hash()) { if err == nil && !s.stateCache.Has(r.Hash()) {

View file

@ -251,7 +251,7 @@ func (p *TCPPeer) StartProtocol() {
if p.LastBlockIndex() > p.server.chain.BlockHeight() { if p.LastBlockIndex() > p.server.chain.BlockHeight() {
err = p.server.requestBlocks(p) err = p.server.requestBlocks(p)
} }
if err == nil { if err == nil && p.server.chain.GetConfig().EnableStateRoot {
err = p.server.requestStateRoot(p) err = p.server.requestStateRoot(p)
} }
if err == nil { if err == nil {