diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index c9fe95e5e..6ac1425e6 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -782,6 +782,15 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error // because changes applied are the ones from HALTed transactions. return fmt.Errorf("error while trying to apply MPT changes: %w", err) } + if bc.config.StateRootInHeader && bc.HeaderHeight() > sr.Index { + h, err := bc.GetHeader(bc.GetHeaderHash(int(sr.Index) + 1)) + if err != nil { + return fmt.Errorf("failed to get next header: %w", err) + } + if h.PrevStateRoot != sr.Root { + return fmt.Errorf("local stateroot and next header's PrevStateRoot mismatch: %s vs %s", sr.Root.StringBE(), h.PrevStateRoot.StringBE()) + } + } if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() @@ -1430,9 +1439,11 @@ var ( func (bc *Blockchain) verifyHeader(currHeader, prevHeader *block.Header) error { if bc.config.StateRootInHeader { - if sr := bc.stateRoot.CurrentLocalStateRoot(); currHeader.PrevStateRoot != sr { - return fmt.Errorf("%w: %s != %s", - ErrHdrInvalidStateRoot, currHeader.PrevStateRoot.StringLE(), sr.StringLE()) + if bc.stateRoot.CurrentLocalHeight() == prevHeader.Index { + if sr := bc.stateRoot.CurrentLocalStateRoot(); currHeader.PrevStateRoot != sr { + return fmt.Errorf("%w: %s != %s", + ErrHdrInvalidStateRoot, currHeader.PrevStateRoot.StringLE(), sr.StringLE()) + } } } if prevHeader.Hash() != currHeader.PrevHash { diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 6393016d1..302b9c447 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -158,6 +158,28 @@ func TestAddBlockStateRoot(t *testing.T) { require.NoError(t, bc.AddBlock(b)) } +func TestAddHeadersStateRoot(t *testing.T) { + bc := newTestChainWithCustomCfg(t, func(c *config.Config) { + c.ProtocolConfiguration.StateRootInHeader = true + }) + + r := bc.stateRoot.CurrentLocalStateRoot() + h1 := bc.newBlock().Header + + // invalid stateroot + h1.PrevStateRoot[0] ^= 0xFF + require.True(t, errors.Is(bc.AddHeaders(&h1), ErrHdrInvalidStateRoot)) + + // valid stateroot + h1.PrevStateRoot = r + require.NoError(t, bc.AddHeaders(&h1)) + + // unable to verify stateroot (stateroot is computed for block #0 only => can + // verify stateroot of header #1 only) => just store the header + h2 := newBlockWithState(bc.config, 2, h1.Hash(), nil).Header + require.NoError(t, bc.AddHeaders(&h2)) +} + func TestAddBadBlock(t *testing.T) { bc := newTestChain(t) // It has ValidUntilBlock == 0, which is wrong diff --git a/pkg/core/stateroot/module.go b/pkg/core/stateroot/module.go index ec0aa6fb6..668cef8af 100644 --- a/pkg/core/stateroot/module.go +++ b/pkg/core/stateroot/module.go @@ -70,6 +70,11 @@ func (s *Module) CurrentLocalStateRoot() util.Uint256 { return s.currentLocal.Load().(util.Uint256) } +// CurrentLocalHeight returns height of the local state root. +func (s *Module) CurrentLocalHeight() uint32 { + return s.localHeight.Load() +} + // CurrentValidatedHeight returns current state root validated height. func (s *Module) CurrentValidatedHeight() uint32 { return s.validatedHeight.Load()