diff --git a/config/protocol.testnet.yml b/config/protocol.testnet.yml index 711485d26..8ccb17005 100644 --- a/config/protocol.testnet.yml +++ b/config/protocol.testnet.yml @@ -2,6 +2,8 @@ ProtocolConfiguration: Magic: 1953787457 AddressVersion: 23 SecondsPerBlock: 15 + EnableStateRoot: true + StateRootEnableIndex: 4380100 LowPriorityThreshold: 0.000 MemPoolSize: 50000 StandbyValidators: diff --git a/config/protocol.unit_testnet.yml b/config/protocol.unit_testnet.yml index c21e1c3f0..b9ca4a761 100644 --- a/config/protocol.unit_testnet.yml +++ b/config/protocol.unit_testnet.yml @@ -2,6 +2,7 @@ ProtocolConfiguration: Magic: 56753 AddressVersion: 23 SecondsPerBlock: 15 + EnableStateRoot: true LowPriorityThreshold: 0.000 MemPoolSize: 50000 StandbyValidators: diff --git a/pkg/config/protocol_config.go b/pkg/config/protocol_config.go index c84de79a2..dff592038 100644 --- a/pkg/config/protocol_config.go +++ b/pkg/config/protocol_config.go @@ -20,6 +20,8 @@ const ( type ( ProtocolConfiguration struct { AddressVersion byte `yaml:"AddressVersion"` + // EnableStateRoot specifies if exchange of state roots should be enabled. + EnableStateRoot bool `yaml:"EnableStateRoot"` // FeePerExtraByte sets the expected per-byte fee for // transactions exceeding the MaxFreeTransactionSize. FeePerExtraByte float64 `yaml:"FeePerExtraByte"` @@ -34,11 +36,13 @@ type ( MaxFreeTransactionsPerBlock int `yaml:"MaxFreeTransactionsPerBlock"` MemPoolSize int `yaml:"MemPoolSize"` // SaveStorageBatch enables storage batch saving before every persist. - SaveStorageBatch bool `yaml:"SaveStorageBatch"` - SecondsPerBlock int `yaml:"SecondsPerBlock"` - SeedList []string `yaml:"SeedList"` - StandbyValidators []string `yaml:"StandbyValidators"` - SystemFee SystemFee `yaml:"SystemFee"` + SaveStorageBatch bool `yaml:"SaveStorageBatch"` + SecondsPerBlock int `yaml:"SecondsPerBlock"` + SeedList []string `yaml:"SeedList"` + StandbyValidators []string `yaml:"StandbyValidators"` + // StateRootEnableIndex specifies starting height for state root calculations and exchange. + StateRootEnableIndex uint32 `yaml:"StateRootEnableIndex"` + SystemFee SystemFee `yaml:"SystemFee"` // Whether to verify received blocks. VerifyBlocks bool `yaml:"VerifyBlocks"` // Whether to verify transactions in received blocks. diff --git a/pkg/consensus/commit.go b/pkg/consensus/commit.go index 372448576..a000b7abf 100644 --- a/pkg/consensus/commit.go +++ b/pkg/consensus/commit.go @@ -9,6 +9,8 @@ import ( type commit struct { signature [signatureSize]byte stateSig [signatureSize]byte + + stateRootEnabled bool } // signatureSize is an rfc6989 signature size in bytes @@ -20,13 +22,17 @@ var _ payload.Commit = (*commit)(nil) // EncodeBinary implements io.Serializable interface. func (c *commit) EncodeBinary(w *io.BinWriter) { w.WriteBytes(c.signature[:]) - w.WriteBytes(c.stateSig[:]) + if c.stateRootEnabled { + w.WriteBytes(c.stateSig[:]) + } } // DecodeBinary implements io.Serializable interface. func (c *commit) DecodeBinary(r *io.BinReader) { r.ReadBytes(c.signature[:]) - r.ReadBytes(c.stateSig[:]) + if c.stateRootEnabled { + r.ReadBytes(c.stateSig[:]) + } } // Signature implements payload.Commit interface. diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index a53e0257d..d3558ebe4 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -138,7 +138,7 @@ func NewService(cfg Config) (Service, error) { dbft.WithGetValidators(srv.getValidators), dbft.WithGetConsensusAddress(srv.getConsensusAddress), - dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }), + dbft.WithNewConsensusPayload(srv.newPayload), dbft.WithNewPrepareRequest(srv.newPrepareRequest), dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }), dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }), @@ -214,31 +214,51 @@ func (s *service) eventLoop() { } } -func (s *service) newPrepareRequest() payload.PrepareRequest { - sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight()) - if err != nil { - return new(prepareRequest) - } - return &prepareRequest{ - proposalStateRoot: sr.MPTRootBase, +func (s *service) newPayload() payload.ConsensusPayload { + return &Payload{ + message: &message{ + stateRootEnabled: s.stateRootEnabled(), + }, } } +// stateRootEnabled checks if state root feature is enabled on current height. +// It should be called only from dbft callbacks and is not protected by any mutex. +func (s *service) stateRootEnabled() bool { + return s.Chain.GetConfig().EnableStateRoot +} + +func (s *service) newPrepareRequest() payload.PrepareRequest { + if !s.stateRootEnabled() { + return new(prepareRequest) + } + sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight()) + if err == nil { + return &prepareRequest{ + stateRootEnabled: true, + proposalStateRoot: sr.MPTRootBase, + } + } + return &prepareRequest{stateRootEnabled: true} +} + func (s *service) newCommit() payload.Commit { + if !s.stateRootEnabled() { + return new(commit) + } + c := &commit{stateRootEnabled: true} for _, p := range s.dbft.Context.PreparationPayloads { if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType { pr := p.GetPrepareRequest().(*prepareRequest) data := pr.proposalStateRoot.GetSignedPart() sign, err := s.dbft.Priv.Sign(data) if err == nil { - var c commit copy(c.stateSig[:], sign) - return &c } break } } - return new(commit) + return c } func (s *service) validatePayload(p *Payload) bool { @@ -293,8 +313,8 @@ func (s *service) OnPayload(cp *Payload) { // decode payload data into message if cp.message == nil { - if err := cp.decodeData(); err != nil { - log.Debug("can't decode payload data") + if err := cp.decodeData(s.stateRootEnabled()); err != nil { + log.Debug("can't decode payload data", zap.Error(err)) return } } @@ -372,6 +392,9 @@ func (s *service) verifyBlock(b block.Block) bool { } func (s *service) verifyRequest(p payload.ConsensusPayload) error { + if !s.stateRootEnabled() { + return nil + } r, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) if err != nil { return fmt.Errorf("can't get local state root: %v", err) diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 925125fc2..c74ac9515 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -22,6 +22,8 @@ type ( Type messageType ViewNumber byte + stateRootEnabled bool + payload io.Serializable } @@ -283,15 +285,21 @@ func (m *message) DecodeBinary(r *io.BinReader) { cv.newViewNumber = m.ViewNumber + 1 m.payload = cv case prepareRequestType: - m.payload = new(prepareRequest) + m.payload = &prepareRequest{ + stateRootEnabled: m.stateRootEnabled, + } case prepareResponseType: m.payload = new(prepareResponse) case commitType: - m.payload = new(commit) + m.payload = &commit{ + stateRootEnabled: m.stateRootEnabled, + } case recoveryRequestType: m.payload = new(recoveryRequest) case recoveryMessageType: - m.payload = new(recoveryMessage) + m.payload = &recoveryMessage{ + stateRootEnabled: m.stateRootEnabled, + } default: r.Err = errors.Errorf("invalid type: 0x%02x", byte(m.Type)) return @@ -319,9 +327,9 @@ func (t messageType) String() string { } } -// decode data of payload into it's message -func (p *Payload) decodeData() error { - m := new(message) +// decodeData decodes data of payload into it's message. +func (p *Payload) decodeData(stateRootEnabled bool) error { + m := &message{stateRootEnabled: stateRootEnabled} br := io.NewBinReaderFromBuf(p.data) m.DecodeBinary(br) if br.Err != nil { diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index 423b6eaf6..f060ede5d 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -94,13 +94,13 @@ func TestConsensusPayload_Serializable(t *testing.T) { // message is nil after decoding as we didn't yet call decodeData require.Nil(t, actual.message) // message should now be decoded from actual.data byte array - assert.NoError(t, actual.decodeData()) + assert.NoError(t, actual.decodeData(false)) require.Equal(t, p, actual) data = p.MarshalUnsigned() pu := new(Payload) require.NoError(t, pu.UnmarshalUnsigned(data)) - assert.NoError(t, pu.decodeData()) + assert.NoError(t, pu.decodeData(false)) p.Witness = transaction.Witness{} require.Equal(t, p, pu) @@ -144,14 +144,14 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { p := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, p)) // decode `data` into `message` - assert.NoError(t, p.decodeData()) + assert.NoError(t, p.decodeData(false)) require.Equal(t, expected, p) // invalid type buf[typeIndex] = 0xFF actual := new(Payload) require.NoError(t, testserdes.DecodeBinary(buf, actual)) - require.Error(t, actual.decodeData()) + require.Error(t, actual.decodeData(false)) // invalid format buf[delimeterIndex] = 0 @@ -165,9 +165,16 @@ func TestConsensusPayload_DecodeBinaryInvalid(t *testing.T) { require.Error(t, testserdes.DecodeBinary(buf, new(Payload))) } +func testEncodeDecode(srEnabled bool, mt messageType, actual io.Serializable) func(t *testing.T) { + return func(t *testing.T) { + expected := randomMessage(t, mt, srEnabled) + testserdes.EncodeDecodeBinary(t, expected, actual) + } +} + func TestCommit_Serializable(t *testing.T) { - c := randomMessage(t, commitType) - testserdes.EncodeDecodeBinary(t, c, new(commit)) + t.Run("WithStateRoot", testEncodeDecode(true, commitType, &commit{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, commitType, &commit{stateRootEnabled: false})) } func TestPrepareResponse_Serializable(t *testing.T) { @@ -176,8 +183,8 @@ func TestPrepareResponse_Serializable(t *testing.T) { } func TestPrepareRequest_Serializable(t *testing.T) { - req := randomMessage(t, prepareRequestType) - testserdes.EncodeDecodeBinary(t, req, new(prepareRequest)) + t.Run("WithStateRoot", testEncodeDecode(true, prepareRequestType, &prepareRequest{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, prepareRequestType, &prepareRequest{stateRootEnabled: false})) } func TestRecoveryRequest_Serializable(t *testing.T) { @@ -186,8 +193,8 @@ func TestRecoveryRequest_Serializable(t *testing.T) { } func TestRecoveryMessage_Serializable(t *testing.T) { - msg := randomMessage(t, recoveryMessageType) - testserdes.EncodeDecodeBinary(t, msg, new(recoveryMessage)) + t.Run("WithStateRoot", testEncodeDecode(true, recoveryMessageType, &recoveryMessage{stateRootEnabled: true})) + t.Run("NoStateRoot", testEncodeDecode(false, recoveryMessageType, &recoveryMessage{stateRootEnabled: false})) } func randomPayload(t *testing.T, mt messageType) *Payload { @@ -215,32 +222,35 @@ func randomPayload(t *testing.T, mt messageType) *Payload { return p } -func randomMessage(t *testing.T, mt messageType) io.Serializable { +func randomMessage(t *testing.T, mt messageType, srEnabled ...bool) io.Serializable { switch mt { case changeViewType: return &changeView{ timestamp: rand.Uint32(), } case prepareRequestType: - return randomPrepareRequest(t) + return randomPrepareRequest(t, srEnabled...) case prepareResponseType: return &prepareResponse{preparationHash: random.Uint256()} case commitType: var c commit random.Fill(c.signature[:]) - random.Fill(c.stateSig[:]) + if len(srEnabled) > 0 && srEnabled[0] { + c.stateRootEnabled = true + random.Fill(c.stateSig[:]) + } return &c case recoveryRequestType: return &recoveryRequest{timestamp: rand.Uint32()} case recoveryMessageType: - return randomRecoveryMessage(t) + return randomRecoveryMessage(t, srEnabled...) default: require.Fail(t, "invalid type") return nil } } -func randomPrepareRequest(t *testing.T) *prepareRequest { +func randomPrepareRequest(t *testing.T, srEnabled ...bool) *prepareRequest { const txCount = 3 req := &prepareRequest{ @@ -256,15 +266,22 @@ func randomPrepareRequest(t *testing.T) *prepareRequest { } req.nextConsensus = random.Uint160() + if len(srEnabled) > 0 && srEnabled[0] { + req.stateRootEnabled = true + req.proposalStateRoot.Index = rand.Uint32() + req.proposalStateRoot.PrevHash = random.Uint256() + req.proposalStateRoot.Root = random.Uint256() + } + return req } -func randomRecoveryMessage(t *testing.T) *recoveryMessage { - result := randomMessage(t, prepareRequestType) +func randomRecoveryMessage(t *testing.T, srEnabled ...bool) *recoveryMessage { + result := randomMessage(t, prepareRequestType, srEnabled...) require.IsType(t, (*prepareRequest)(nil), result) prepReq := result.(*prepareRequest) - return &recoveryMessage{ + rec := &recoveryMessage{ preparationPayloads: []*preparationCompact{ { ValidatorIndex: 1, @@ -276,14 +293,12 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage { ViewNumber: 0, ValidatorIndex: 1, Signature: [64]byte{1, 2, 3}, - StateSignature: [64]byte{4, 5, 6}, InvocationScript: random.Bytes(20), }, { ViewNumber: 0, ValidatorIndex: 2, Signature: [64]byte{11, 3, 4, 98}, - StateSignature: [64]byte{4, 8, 15, 16, 23, 42}, InvocationScript: random.Bytes(10), }, }, @@ -300,6 +315,15 @@ func randomRecoveryMessage(t *testing.T) *recoveryMessage { payload: prepReq, }, } + if len(srEnabled) > 0 && srEnabled[0] { + rec.stateRootEnabled = true + rec.prepareRequest.stateRootEnabled = true + for _, c := range rec.commitPayloads { + c.stateRootEnabled = true + random.Fill(c.StateSignature[:]) + } + } + return rec } func TestPayload_Sign(t *testing.T) { diff --git a/pkg/consensus/prepare_request.go b/pkg/consensus/prepare_request.go index ff94ba213..8a28f3a5c 100644 --- a/pkg/consensus/prepare_request.go +++ b/pkg/consensus/prepare_request.go @@ -16,6 +16,8 @@ type prepareRequest struct { minerTx transaction.Transaction nextConsensus util.Uint160 proposalStateRoot state.MPTRootBase + + stateRootEnabled bool } var _ payload.PrepareRequest = (*prepareRequest)(nil) @@ -27,7 +29,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { w.WriteBytes(p.nextConsensus[:]) w.WriteArray(p.transactionHashes) p.minerTx.EncodeBinary(w) - p.proposalStateRoot.EncodeBinary(w) + if p.stateRootEnabled { + p.proposalStateRoot.EncodeBinary(w) + } } // DecodeBinary implements io.Serializable interface. @@ -37,7 +41,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) { r.ReadBytes(p.nextConsensus[:]) r.ReadArray(&p.transactionHashes) p.minerTx.DecodeBinary(r) - p.proposalStateRoot.DecodeBinary(r) + if p.stateRootEnabled { + p.proposalStateRoot.DecodeBinary(r) + } } // Timestamp implements payload.PrepareRequest interface. diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index aa409b19a..17c7601f8 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -3,6 +3,7 @@ package consensus import ( "github.com/nspcc-dev/dbft/crypto" "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/pkg/errors" @@ -16,6 +17,8 @@ type ( commitPayloads []*commitCompact changeViewPayloads []*changeViewCompact prepareRequest *message + + stateRootEnabled bool } changeViewCompact struct { @@ -31,6 +34,8 @@ type ( Signature [signatureSize]byte StateSignature [signatureSize]byte InvocationScript []byte + + stateRootEnabled bool } preparationCompact struct { @@ -47,7 +52,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { var hasReq = r.ReadBool() if hasReq { - m.prepareRequest = new(message) + m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled} m.prepareRequest.DecodeBinary(r) if r.Err == nil && m.prepareRequest.Type != prepareRequestType { r.Err = errors.New("recovery message PrepareRequest has wrong type") @@ -68,7 +73,16 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { } r.ReadArray(&m.preparationPayloads) - r.ReadArray(&m.commitPayloads) + lu := r.ReadVarUint() + if lu > state.MaxValidatorsVoted { + r.Err = errors.New("too many preparation payloads") + return + } + m.commitPayloads = make([]*commitCompact, lu) + for i := uint64(0); i < lu; i++ { + m.commitPayloads[i] = &commitCompact{stateRootEnabled: m.stateRootEnabled} + m.commitPayloads[i].DecodeBinary(r) + } } // EncodeBinary implements io.Serializable interface. @@ -113,7 +127,9 @@ func (p *commitCompact) DecodeBinary(r *io.BinReader) { p.ViewNumber = r.ReadB() p.ValidatorIndex = r.ReadU16LE() r.ReadBytes(p.Signature[:]) - r.ReadBytes(p.StateSignature[:]) + if p.stateRootEnabled { + r.ReadBytes(p.StateSignature[:]) + } p.InvocationScript = r.ReadVarBytes(1024) } @@ -122,7 +138,9 @@ func (p *commitCompact) EncodeBinary(w *io.BinWriter) { w.WriteB(p.ViewNumber) w.WriteU16LE(p.ValidatorIndex) w.WriteBytes(p.Signature[:]) - w.WriteBytes(p.StateSignature[:]) + if p.stateRootEnabled { + w.WriteBytes(p.StateSignature[:]) + } w.WriteVarBytes(p.InvocationScript) } @@ -146,6 +164,8 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) { Type: prepareRequestType, ViewNumber: p.ViewNumber(), payload: p.GetPrepareRequest().(*prepareRequest), + + stateRootEnabled: m.stateRootEnabled, } h := p.Hash() m.preparationHash = &h @@ -172,6 +192,7 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) { }) case payload.CommitType: m.commitPayloads = append(m.commitPayloads, &commitCompact{ + stateRootEnabled: m.stateRootEnabled, ValidatorIndex: p.ValidatorIndex(), ViewNumber: p.ViewNumber(), Signature: p.GetCommit().(*commit).signature, @@ -254,7 +275,12 @@ func (m *recoveryMessage) GetCommits(p payload.ConsensusPayload, validators []cr ps := make([]payload.ConsensusPayload, len(m.commitPayloads)) for i, c := range m.commitPayloads { - cc := fromPayload(commitType, p.(*Payload), &commit{signature: c.Signature, stateSig: c.StateSignature}) + cc := fromPayload(commitType, p.(*Payload), &commit{ + signature: c.Signature, + stateSig: c.StateSignature, + + stateRootEnabled: m.stateRootEnabled, + }) cc.SetValidatorIndex(c.ValidatorIndex) cc.Witness.InvocationScript = c.InvocationScript cc.Witness.VerificationScript = getVerificationScript(c.ValidatorIndex, validators) @@ -294,6 +320,8 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { Type: t, ViewNumber: recovery.message.ViewNumber, payload: p, + + stateRootEnabled: recovery.stateRootEnabled, }, version: recovery.Version(), prevHash: recovery.PrevHash(), diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 528002435..d7cda96bb 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -231,8 +231,10 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight - if err = bc.dao.InitMPT(bHeight); err != nil { - return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + if bc.config.EnableStateRoot { + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } } hashes, err := bc.dao.GetHeaderHashes() @@ -558,12 +560,18 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 { // GetStateProof returns proof of having key in the MPT with the specified root. func (bc *Blockchain) GetStateProof(root util.Uint256, key []byte) ([][]byte, error) { + if !bc.config.EnableStateRoot { + return nil, errors.New("state root feature is not enabled") + } tr := mpt.NewTrie(mpt.NewHashNode(root), storage.NewMemCachedStore(bc.dao.Store)) return tr.GetProof(key) } // GetStateRoot returns state root for a given height. func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + if !bc.config.EnableStateRoot { + return nil, errors.New("state root feature is not enabled") + } return bc.dao.GetStateRoot(height) } @@ -835,24 +843,26 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } - root := bc.dao.MPT.StateRoot() - var prevHash util.Uint256 - if block.Index > 0 { - prev, err := bc.dao.GetStateRoot(block.Index - 1) - if err != nil { - return errors.WithMessagef(err, "can't get previous state root") + if bc.config.EnableStateRoot { + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = hash.DoubleSha256(prev.GetSignedPart()) + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err } - prevHash = hash.DoubleSha256(prev.GetSignedPart()) - } - err := bc.AddStateRoot(&state.MPTRoot{ - MPTRootBase: state.MPTRootBase{ - Index: block.Index, - PrevHash: prevHash, - Root: root, - }, - }) - if err != nil { - return err } if bc.config.SaveStorageBatch { @@ -860,12 +870,14 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } bc.lock.Lock() - _, err = cache.Persist() + _, err := cache.Persist() if err != nil { bc.lock.Unlock() return err } - bc.dao.MPT.Flush() + if bc.config.EnableStateRoot { + bc.dao.MPT.Flush() + } // Every persist cycle we also compact our in-memory MPT. persistedHeight := atomic.LoadUint32(&bc.persistedHeight) if persistedHeight == block.Index-1 { @@ -1784,6 +1796,10 @@ func (bc *Blockchain) StateHeight() uint32 { // AddStateRoot add new (possibly unverified) state root to the blockchain. func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + if !bc.config.EnableStateRoot { + bc.log.Warn("state root is being added but not enabled in config") + return nil + } our, err := bc.GetStateRoot(r.Index) if err == nil { if our.Flag == state.Verified { diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 551c4c8cf..940d5928b 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -52,6 +52,9 @@ func (ic *interopContext) SpawnVM() *vm.VM { }) vm.RegisterInteropGetter(ic.getSystemInterop) vm.RegisterInteropGetter(ic.getNeoInterop) + if ic.bc != nil && ic.bc.GetConfig().EnableStateRoot { + vm.RegisterInteropGetter(ic.getNeoxInterop) + } return vm } @@ -77,6 +80,12 @@ func (ic *interopContext) getNeoInterop(id uint32) *vm.InteropFuncPrice { return ic.getInteropFromSlice(id, neoInterops) } +// getNeoxInterop returns matching interop function from the NeoX extension +// for a given id in the current context. +func (ic *interopContext) getNeoxInterop(id uint32) *vm.InteropFuncPrice { + return ic.getInteropFromSlice(id, neoxInterops) +} + // getInteropFromSlice returns matching interop function from the given slice of // interop functions in the current context. func (ic *interopContext) getInteropFromSlice(id uint32, slice []interopedFunction) *vm.InteropFuncPrice { @@ -166,8 +175,6 @@ var neoInterops = []interopedFunction{ {Name: "Neo.Contract.GetStorageContext", Func: (*interopContext).contractGetStorageContext, Price: 1}, {Name: "Neo.Contract.IsPayable", Func: (*interopContext).contractIsPayable, Price: 1}, {Name: "Neo.Contract.Migrate", Func: (*interopContext).contractMigrate, Price: 0}, - {Name: "Neo.Cryptography.Secp256k1Recover", Func: (*interopContext).secp256k1Recover, Price: 100}, - {Name: "Neo.Cryptography.Secp256r1Recover", Func: (*interopContext).secp256r1Recover, Price: 100}, {Name: "Neo.Enumerator.Concat", Func: (*interopContext).enumeratorConcat, Price: 1}, {Name: "Neo.Enumerator.Create", Func: (*interopContext).enumeratorCreate, Price: 1}, {Name: "Neo.Enumerator.Next", Func: (*interopContext).enumeratorNext, Price: 1}, @@ -278,6 +285,11 @@ var neoInterops = []interopedFunction{ {Name: "AntShares.Transaction.GetType", Func: (*interopContext).txGetType, Price: 1}, } +var neoxInterops = []interopedFunction{ + {Name: "Neo.Cryptography.Secp256k1Recover", Func: (*interopContext).secp256k1Recover, Price: 100}, + {Name: "Neo.Cryptography.Secp256r1Recover", Func: (*interopContext).secp256r1Recover, Price: 100}, +} + // initIDinInteropsSlice initializes IDs from names in one given // interopedFunction slice and then sorts it. func initIDinInteropsSlice(iops []interopedFunction) { @@ -293,4 +305,5 @@ func initIDinInteropsSlice(iops []interopedFunction) { func init() { initIDinInteropsSlice(systemInterops) initIDinInteropsSlice(neoInterops) + initIDinInteropsSlice(neoxInterops) } diff --git a/pkg/io/binaryReader.go b/pkg/io/binaryReader.go index bd62b53ee..b8c935c80 100644 --- a/pkg/io/binaryReader.go +++ b/pkg/io/binaryReader.go @@ -8,9 +8,9 @@ import ( "reflect" ) -// maxArraySize is a maximums size of an array which can be decoded. +// MaxArraySize is the maximum size of an array which can be decoded. // It is taken from https://github.com/neo-project/neo/blob/master/neo/IO/Helper.cs#L130 -const maxArraySize = 0x1000000 +const MaxArraySize = 0x1000000 // BinReader is a convenient wrapper around a io.Reader and err object. // Used to simplify error handling when reading into a struct with many fields. @@ -110,7 +110,7 @@ func (r *BinReader) ReadArray(t interface{}, maxSize ...int) { elemType := sliceType.Elem() isPtr := elemType.Kind() == reflect.Ptr - ms := maxArraySize + ms := MaxArraySize if len(maxSize) != 0 { ms = maxSize[0] } @@ -170,7 +170,7 @@ func (r *BinReader) ReadVarUint() uint64 { // ReadVarUInt() is used to determine how large that slice is func (r *BinReader) ReadVarBytes(maxSize ...int) []byte { n := r.ReadVarUint() - ms := maxArraySize + ms := MaxArraySize if len(maxSize) != 0 { ms = maxSize[0] } diff --git a/pkg/network/server.go b/pkg/network/server.go index 5ce2118aa..9364da476 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -602,6 +602,10 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { // handleGetRootsCmd processees `getroots` request. func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { + cfg := s.chain.GetConfig() + if !cfg.EnableStateRoot || gr.Start < cfg.StateRootEnableIndex { + return nil + } count := gr.Count if count > payload.MaxStateRootsAllowed { count = payload.MaxStateRootsAllowed @@ -621,7 +625,13 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { // handleStateRootsCmd processees `roots` request. func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { + if !s.chain.GetConfig().EnableStateRoot { + return nil + } h := s.chain.StateHeight() + if h < s.chain.GetConfig().StateRootEnableIndex { + h = s.chain.GetConfig().StateRootEnableIndex + } for i := range rs.Roots { if rs.Roots[i].Index <= h { continue @@ -636,6 +646,13 @@ func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { func (s *Server) requestStateRoot(p Peer) error { stateHeight := s.chain.StateHeight() hdrHeight := s.chain.BlockHeight() + enableIndex := s.chain.GetConfig().StateRootEnableIndex + if hdrHeight < enableIndex { + return nil + } + if stateHeight < enableIndex { + stateHeight = enableIndex - 1 + } count := uint32(payload.MaxStateRootsAllowed) if diff := hdrHeight - stateHeight; diff < count { count = diff @@ -652,6 +669,9 @@ func (s *Server) requestStateRoot(p Peer) error { // handleStateRootCmd processees `stateroot` request. func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { + if !s.chain.GetConfig().EnableStateRoot { + return nil + } // we ignore error, because there is nothing wrong if we already have this state root err := s.chain.AddStateRoot(r) if err == nil && !s.stateCache.Has(r.Hash()) { diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 058faddf1..5643c335d 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -251,7 +251,7 @@ func (p *TCPPeer) StartProtocol() { if p.LastBlockIndex() > p.server.chain.BlockHeight() { err = p.server.requestBlocks(p) } - if err == nil { + if err == nil && p.server.chain.GetConfig().EnableStateRoot { err = p.server.requestStateRoot(p) } if err == nil {