diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index e8ae39980..41c2758e2 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -731,7 +731,8 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error d := cache.DAO.(*dao.Simple) b := d.GetMPTBatch() - if err := bc.stateRoot.AddMPTBatch(block.Index, b); err != nil { + mpt, sr, err := bc.stateRoot.AddMPTBatch(block.Index, b, d.Store) + if err != nil { // Here MPT can be left in a half-applied state. // However if this error occurs, this is a bug somewhere in code // because changes applied are the ones from HALTed transactions. @@ -761,6 +762,8 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error return err } + mpt.Store = bc.dao.Store + bc.stateRoot.UpdateCurrentLocal(mpt, sr) bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(func(tx *transaction.Transaction) bool { return bc.IsTxStillRelevant(tx, txpool, false) }, bc) diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index 790deb469..b0f441e68 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -557,7 +557,8 @@ func TestContractDestroy(t *testing.T) { err = bc.dao.PutStorageItem(cs1.ID, []byte{1, 2, 3}, state.StorageItem{3, 2, 1}) require.NoError(t, err) b := bc.dao.GetMPTBatch() - require.NoError(t, bc.GetStateModule().(*stateroot.Module).AddMPTBatch(bc.BlockHeight(), b)) + _, _, err = bc.GetStateModule().(*stateroot.Module).AddMPTBatch(bc.BlockHeight(), b, bc.dao.Store) + require.NoError(t, err) t.Run("no contract", func(t *testing.T) { res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "destroy") diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index fa6f5127d..ec0aa6fb6 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -110,20 +110,33 @@ func (s *Module) Init(height uint32, enableRefCount bool) error { } // AddMPTBatch updates using provided batch. -func (s *Module) AddMPTBatch(index uint32, b mpt.Batch) error { - if _, err := s.mpt.PutBatch(b); err != nil { - return err +func (s *Module) AddMPTBatch(index uint32, b mpt.Batch, cache *storage.MemCachedStore) (*mpt.Trie, *state.MPTRoot, error) { + mpt := *s.mpt + mpt.Store = cache + if _, err := mpt.PutBatch(b); err != nil { + return nil, nil, err } - s.mpt.Flush() - err := s.addLocalStateRoot(&state.MPTRoot{ + mpt.Flush() + sr := &state.MPTRoot{ Index: index, - Root: s.mpt.StateRoot(), - }) - if err != nil { - return err + Root: mpt.StateRoot(), + } + err := s.addLocalStateRoot(cache, sr) + if err != nil { + return nil, nil, err + } + return &mpt, sr, err +} + +// UpdateCurrentLocal updates local caches using provided state root. +func (s *Module) UpdateCurrentLocal(mpt *mpt.Trie, sr *state.MPTRoot) { + s.mpt = mpt + s.currentLocal.Store(sr.Root) + s.localHeight.Store(sr.Index) + if s.bc.GetConfig().StateRootInHeader { + s.validatedHeight.Store(sr.Index) + updateStateHeightMetric(sr.Index) } - _, err = s.Store.Persist() - return err } // VerifyStateRoot checks if state root is valid. diff --git a/pkg/core/stateroot/store.go b/pkg/core/stateroot/store.go index 865073159..ff3b1a6cb 100644 --- a/pkg/core/stateroot/store.go +++ b/pkg/core/stateroot/store.go @@ -14,30 +14,21 @@ const ( prefixValidated = 0x03 ) -func (s *Module) addLocalStateRoot(sr *state.MPTRoot) error { +func (s *Module) addLocalStateRoot(store *storage.MemCachedStore, sr *state.MPTRoot) error { key := makeStateRootKey(sr.Index) - if err := s.putStateRoot(key, sr); err != nil { + if err := putStateRoot(store, key, sr); err != nil { return err } data := make([]byte, 4) binary.LittleEndian.PutUint32(data, sr.Index) - if err := s.Store.Put([]byte{byte(storage.DataMPT), prefixLocal}, data); err != nil { - return err - } - s.currentLocal.Store(sr.Root) - s.localHeight.Store(sr.Index) - if s.bc.GetConfig().StateRootInHeader { - s.validatedHeight.Store(sr.Index) - updateStateHeightMetric(sr.Index) - } - return nil + return store.Put([]byte{byte(storage.DataMPT), prefixLocal}, data) } -func (s *Module) putStateRoot(key []byte, sr *state.MPTRoot) error { +func putStateRoot(store *storage.MemCachedStore, key []byte, sr *state.MPTRoot) error { w := io.NewBufBinWriter() sr.EncodeBinary(w.BinWriter) - return s.Store.Put(key, w.Bytes()) + return store.Put(key, w.Bytes()) } func (s *Module) getStateRoot(key []byte) (*state.MPTRoot, error) { @@ -72,7 +63,7 @@ func (s *Module) AddStateRoot(sr *state.MPTRoot) error { if len(local.Witness) != 0 { return nil } - if err := s.putStateRoot(key, sr); err != nil { + if err := putStateRoot(s.Store, key, sr); err != nil { return err }