From 0e29382035514dd75ee8b29b831f2ff8c1fa6452 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Fri, 29 May 2020 17:20:00 +0300 Subject: [PATCH] core: update MPT during block processing Signed-off-by: Evgenii Stratonikov --- pkg/core/blockchain.go | 92 +++++++++++++++++++++- pkg/core/blockchainer/blockchainer.go | 2 + pkg/core/dao/dao.go | 65 +++++++++++++++- pkg/core/state/mpt_root.go | 105 ++++++++++++++++++++++++++ pkg/core/state/mpt_root_test.go | 61 +++++++++++++++ pkg/network/helper_test.go | 6 ++ 6 files changed, 326 insertions(+), 5 deletions(-) create mode 100644 pkg/core/state/mpt_root.go create mode 100644 pkg/core/state/mpt_root_test.go diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 7ac605bed..d592eb085 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -221,6 +221,9 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } hashes, err := bc.dao.GetHeaderHashes() if err != nil { @@ -550,6 +553,11 @@ func (bc *Blockchain) processHeader(h *block.Header, batch storage.Batch, header return nil } +// GetStateRoot returns state root for a given height. +func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + return bc.dao.GetStateRoot(height) +} + // storeBlock performs chain update using the block given, it executes all // transactions with all appropriate side-effects and updates Blockchain state. // This is the only way to change Blockchain state. @@ -633,17 +641,38 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = prev.Root + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err + } + if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } bc.lock.Lock() - _, err := cache.Persist() + _, err = cache.Persist() if err != nil { bc.lock.Unlock() return err } bc.contracts.Policy.OnPersistEnd(bc.dao) + bc.dao.MPT.Flush() bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant, bc) @@ -1194,6 +1223,67 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// AddStateRoot add new (possibly unverified) state root to the blockchain. +func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + our, err := bc.GetStateRoot(r.Index) + if err == nil { + if our.Flag == state.Verified { + return nil + } else if r.Witness == nil && our.Witness != nil { + r.Witness = our.Witness + } + } + if err := bc.verifyStateRoot(r); err != nil { + return errors.WithMessage(err, "invalid state root") + } + if r.Index > bc.BlockHeight() { // just put it into the store for future checks + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: state.Unverified, + }) + } + + flag := state.Unverified + if r.Witness != nil { + if err := bc.verifyStateRootWitness(r); err != nil { + return errors.WithMessage(err, "can't verify signature") + } + flag = state.Verified + } + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: flag, + }) +} + +// verifyStateRoot checks if state root is valid. +func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { + if r.Index == 0 { + return nil + } + prev, err := bc.GetStateRoot(r.Index - 1) + if err != nil { + return errors.New("can't get previous state root") + } else if !prev.Root.Equals(r.PrevHash) { + return errors.New("previous hash mismatch") + } else if prev.Version != r.Version { + return errors.New("version mismatch") + } + return nil +} + +// verifyStateRootWitness verifies that state root signature is correct. +func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { + b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index))) + if err != nil { + return err + } + interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil) + interopCtx.Container = r + return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, interopCtx, true, + bc.contracts.Policy.GetMaxVerificationGas(interopCtx.DAO)) +} + // VerifyTx verifies whether a transaction is bonafide or not. Block parameter // is used for easy interop access and can be omitted for transactions that are // not yet added into any block. diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index 9dcac9e33..1086c6bee 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -20,6 +20,7 @@ type Blockchainer interface { GetConfig() config.ProtocolConfiguration AddHeaders(...*block.Header) error AddBlock(*block.Block) error + AddStateRoot(r *state.MPTRoot) error BlockHeight() uint32 CalculateClaimable(value *big.Int, startHeight, endHeight uint32) *big.Int Close() @@ -42,6 +43,7 @@ type Blockchainer interface { GetValidators() ([]*keys.PublicKey, error) GetStandByValidators() keys.PublicKeys GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetTestVM(tx *transaction.Transaction) *vm.VM diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index f24c7267c..f55fd68e1 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config/netmode" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -33,6 +34,8 @@ type DAO interface { GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) GetAndUpdateNextContractID() (int32, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) + PutStateRoot(root *state.MPTRootState) error GetStorageItem(id int32, key []byte) *state.StorageItem GetStorageItems(id int32) (map[string]*state.StorageItem, error) GetStorageItemsWithPrefix(id int32, prefix []byte) (map[string]*state.StorageItem, error) @@ -58,13 +61,15 @@ type DAO interface { // Simple is memCached wrapper around DB, simple DAO implementation. type Simple struct { + MPT *mpt.Trie Store *storage.MemCachedStore network netmode.Magic } // NewSimple creates new simple dao using provided backend store. func NewSimple(backend storage.Store, network netmode.Magic) *Simple { - return &Simple{Store: storage.NewMemCachedStore(backend), network: network} + st := storage.NewMemCachedStore(backend) + return &Simple{Store: st, network: network, MPT: mpt.NewTrie(nil, st)} } // GetBatch returns currently accumulated DB changeset. @@ -75,7 +80,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - return NewSimple(dao.Store, dao.network) + d := NewSimple(dao.Store, dao.network) + d.MPT = dao.MPT + return d } // GetAndDecode performs get operation and decoding with serializable structures. @@ -288,6 +295,42 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error { // -- start storage item. +func makeStateRootKey(height uint32) []byte { + key := make([]byte, 5) + key[0] = byte(storage.DataMPT) + binary.LittleEndian.PutUint32(key[1:], height) + return key +} + +// InitMPT initializes MPT at the given height. +func (dao *Simple) InitMPT(height uint32) error { + if height == 0 { + dao.MPT = mpt.NewTrie(nil, dao.Store) + return nil + } + r, err := dao.GetStateRoot(height) + if err != nil { + return err + } + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + return nil +} + +// GetStateRoot returns state root of a given height. +func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { + r := new(state.MPTRootState) + err := dao.GetAndDecode(r, makeStateRootKey(height)) + if err != nil { + return nil, err + } + return r, nil +} + +// PutStateRoot puts state root of a given height into the store. +func (dao *Simple) PutStateRoot(r *state.MPTRootState) error { + return dao.Put(r, makeStateRootKey(r.Index)) +} + // GetStorageItem returns StorageItem if it exists in the given store. func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { b, err := dao.Store.Get(makeStorageItemKey(id, key)) @@ -308,13 +351,27 @@ func (dao *Simple) GetStorageItem(id int32, key []byte) *state.StorageItem { // PutStorageItem puts given StorageItem for given id with given // key into the given store. func (dao *Simple) PutStorageItem(id int32, key []byte, si *state.StorageItem) error { - return dao.Put(si, makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + buf := io.NewBufBinWriter() + si.EncodeBinary(buf.BinWriter) + if buf.Err != nil { + return buf.Err + } + v := buf.Bytes() + if err := dao.MPT.Put(stKey[1:], v); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Put(stKey, v) } // DeleteStorageItem drops storage item for the given id with the // given key from the store. func (dao *Simple) DeleteStorageItem(id int32, key []byte) error { - return dao.Store.Delete(makeStorageItemKey(id, key)) + stKey := makeStorageItemKey(id, key) + if err := dao.MPT.Delete(stKey[1:]); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Delete(stKey) } // GetStorageItems returns all storage items for a given id. diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go new file mode 100644 index 000000000..facf3da45 --- /dev/null +++ b/pkg/core/state/mpt_root.go @@ -0,0 +1,105 @@ +package state + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MPTRootBase represents storage state root. +type MPTRootBase struct { + Version byte + Index uint32 + PrevHash util.Uint256 + Root util.Uint256 +} + +// MPTRoot represents storage state root together with sign info. +type MPTRoot struct { + MPTRootBase + Witness *transaction.Witness +} + +// MPTRootStateFlag represents verification state of the state root. +type MPTRootStateFlag byte + +// Possible verification states of MPTRoot. +const ( + Unverified MPTRootStateFlag = 0x00 + Verified MPTRootStateFlag = 0x01 + Invalid MPTRootStateFlag = 0x03 +) + +// MPTRootState represents state root together with its verification state. +type MPTRootState struct { + MPTRoot + Flag MPTRootStateFlag +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootState) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(s.Flag)) + s.MPTRoot.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootState) DecodeBinary(r *io.BinReader) { + s.Flag = MPTRootStateFlag(r.ReadB()) + s.MPTRoot.DecodeBinary(r) +} + +// GetSignedPart returns part of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedPart() []byte { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} + +// Equals checks if s == other. +func (s *MPTRootBase) Equals(other *MPTRootBase) bool { + return s.Version == other.Version && s.Index == other.Index && + s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root) +} + +// Hash returns hash of s. +func (s *MPTRootBase) Hash() util.Uint256 { + return hash.DoubleSha256(s.GetSignedPart()) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootBase) DecodeBinary(r *io.BinReader) { + s.Version = r.ReadB() + s.Index = r.ReadU32LE() + s.PrevHash.DecodeBinary(r) + s.Root.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) { + w.WriteB(s.Version) + w.WriteU32LE(s.Index) + s.PrevHash.EncodeBinary(w) + s.Root.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRoot) DecodeBinary(r *io.BinReader) { + s.MPTRootBase.DecodeBinary(r) + + var ws []transaction.Witness + r.ReadArray(&ws, 1) + if len(ws) == 1 { + s.Witness = &ws[0] + } +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { + s.MPTRootBase.EncodeBinary(w) + if s.Witness == nil { + w.WriteVarUint(0) + } else { + w.WriteArray([]*transaction.Witness{s.Witness}) + } +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go new file mode 100644 index 000000000..15a3ca043 --- /dev/null +++ b/pkg/core/state/mpt_root_test.go @@ -0,0 +1,61 @@ +package state + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func testStateRoot() *MPTRoot { + return &MPTRoot{ + MPTRootBase: MPTRootBase{ + Version: byte(rand.Uint32()), + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + } +} + +func TestStateRoot_Serializable(t *testing.T) { + r := testStateRoot() + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + + t.Run("WithWitness", func(t *testing.T) { + r.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + }) +} + +func TestStateRootEquals(t *testing.T) { + r1 := testStateRoot() + r2 := *r1 + require.True(t, r1.Equals(&r2.MPTRootBase)) + + r2.MPTRootBase.Index++ + require.False(t, r1.Equals(&r2.MPTRootBase)) +} + +func TestMPTRootState_Serializable(t *testing.T) { + rs := &MPTRootState{ + MPTRoot: *testStateRoot(), + Flag: 0x04, + } + rs.MPTRoot.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState)) +} + +func TestMPTRootStateUnverifiedByDefault(t *testing.T) { + var r MPTRootState + require.Equal(t, Unverified, r.Flag) +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 6004c44a4..61cf1939e 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -49,6 +49,9 @@ func (chain *testChain) AddBlock(block *block.Block) error { } return nil } +func (chain *testChain) AddStateRoot(r *state.MPTRoot) error { + panic("TODO") +} func (chain *testChain) BlockHeight() uint32 { return atomic.LoadUint32(&chain.blockheight) } @@ -98,6 +101,9 @@ func (chain testChain) GetEnrollments() ([]state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} func (chain testChain) GetStorageItem(id int32, key []byte) *state.StorageItem { panic("TODO") }