diff --git a/pkg/consensus/commit.go b/pkg/consensus/commit.go index 492a1a156..372448576 100644 --- a/pkg/consensus/commit.go +++ b/pkg/consensus/commit.go @@ -8,6 +8,7 @@ import ( // commit represents dBFT Commit message. type commit struct { signature [signatureSize]byte + stateSig [signatureSize]byte } // signatureSize is an rfc6989 signature size in bytes @@ -19,11 +20,13 @@ var _ payload.Commit = (*commit)(nil) // EncodeBinary implements io.Serializable interface. func (c *commit) EncodeBinary(w *io.BinWriter) { w.WriteBytes(c.signature[:]) + w.WriteBytes(c.stateSig[:]) } // DecodeBinary implements io.Serializable interface. func (c *commit) DecodeBinary(r *io.BinReader) { r.ReadBytes(c.signature[:]) + r.ReadBytes(c.stateSig[:]) } // Signature implements payload.Commit interface. diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index fbb84a3e1..742e8d23f 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -14,6 +14,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core" coreb "github.com/nspcc-dev/neo-go/pkg/core/block" "github.com/nspcc-dev/neo-go/pkg/core/mempool" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -136,10 +137,10 @@ func NewService(cfg Config) (Service, error) { dbft.WithGetConsensusAddress(srv.getConsensusAddress), dbft.WithNewConsensusPayload(func() payload.ConsensusPayload { p := new(Payload); p.message = &message{}; return p }), - dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }), + dbft.WithNewPrepareRequest(srv.newPrepareRequest), dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }), dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }), - dbft.WithNewCommit(func() payload.Commit { return new(commit) }), + dbft.WithNewCommit(srv.newCommit), dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }), dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }), ) @@ -210,6 +211,33 @@ func (s *service) eventLoop() { } } +func (s *service) newPrepareRequest() payload.PrepareRequest { + sr, err := s.Chain.GetStateRoot(s.Chain.BlockHeight()) + if err != nil { + return new(prepareRequest) + } + return &prepareRequest{ + proposalStateRoot: sr.MPTRootBase, + } +} + +func (s *service) newCommit() payload.Commit { + for _, p := range s.dbft.Context.PreparationPayloads { + if p != nil && p.ViewNumber() == s.dbft.ViewNumber && p.Type() == payload.PrepareRequestType { + pr := p.GetPrepareRequest().(*prepareRequest) + data := pr.proposalStateRoot.GetSignedPart() + sign, err := s.dbft.Priv.Sign(data) + if err == nil { + var c commit + copy(c.stateSig[:], sign) + return &c + } + break + } + } + return new(commit) +} + func (s *service) validatePayload(p *Payload) bool { validators := s.getValidators() if int(p.validatorIndex) >= len(validators) { @@ -351,16 +379,35 @@ func (s *service) processBlock(b block.Block) { s.log.Warn("error on add block", zap.Error(err)) } } + + var rb *state.MPTRootBase + for _, p := range s.dbft.PreparationPayloads { + if p != nil && p.Type() == payload.PrepareRequestType { + rb = &p.GetPrepareRequest().(*prepareRequest).proposalStateRoot + } + } + w := s.getWitness(func(p payload.Commit) []byte { return p.(*commit).stateSig[:] }) + r := &state.MPTRoot{ + MPTRootBase: *rb, + Witness: w, + } + if err := s.Chain.AddStateRoot(r); err != nil { + s.log.Warn("errors while adding state root", zap.Error(err)) + } } -func (s *service) getBlockWitness(b *coreb.Block) *transaction.Witness { +func (s *service) getBlockWitness(_ *coreb.Block) *transaction.Witness { + return s.getWitness(func(p payload.Commit) []byte { return p.Signature() }) +} + +func (s *service) getWitness(f func(p payload.Commit) []byte) *transaction.Witness { dctx := s.dbft.Context pubs := convertKeys(dctx.Validators) sigs := make(map[*keys.PublicKey][]byte) for i := range pubs { if p := dctx.CommitPayloads[i]; p != nil && p.ViewNumber() == dctx.ViewNumber { - sigs[pubs[i]] = p.GetCommit().Signature() + sigs[pubs[i]] = f(p.GetCommit()) } } diff --git a/pkg/consensus/prepare_request.go b/pkg/consensus/prepare_request.go index f40b74ab0..ff94ba213 100644 --- a/pkg/consensus/prepare_request.go +++ b/pkg/consensus/prepare_request.go @@ -2,6 +2,7 @@ package consensus import ( "github.com/nspcc-dev/dbft/payload" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" @@ -14,6 +15,7 @@ type prepareRequest struct { transactionHashes []util.Uint256 minerTx transaction.Transaction nextConsensus util.Uint160 + proposalStateRoot state.MPTRootBase } var _ payload.PrepareRequest = (*prepareRequest)(nil) @@ -25,6 +27,7 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { w.WriteBytes(p.nextConsensus[:]) w.WriteArray(p.transactionHashes) p.minerTx.EncodeBinary(w) + p.proposalStateRoot.EncodeBinary(w) } // DecodeBinary implements io.Serializable interface. @@ -34,6 +37,7 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) { r.ReadBytes(p.nextConsensus[:]) r.ReadArray(&p.transactionHashes) p.minerTx.DecodeBinary(r) + p.proposalStateRoot.DecodeBinary(r) } // Timestamp implements payload.PrepareRequest interface. diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index aa452bf86..88bb71db9 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -16,6 +16,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" @@ -229,6 +230,9 @@ func (bc *Blockchain) init() error { } bc.blockHeight = bHeight bc.persistedHeight = bHeight + if err = bc.dao.InitMPT(bHeight); err != nil { + return errors.Wrapf(err, "can't init MPT at height %d", bHeight) + } hashes, err := bc.dao.GetHeaderHashes() if err != nil { @@ -551,6 +555,11 @@ func (bc *Blockchain) getSystemFeeAmount(h util.Uint256) uint32 { return sf } +// GetStateRoot returns state root for a given height. +func (bc *Blockchain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + return bc.dao.GetStateRoot(height) +} + // TODO: storeBlock needs some more love, its implemented as in the original // project. This for the sake of development speed and understanding of what // is happening here, quite allot as you can see :). If things are wired together @@ -819,16 +828,37 @@ func (bc *Blockchain) storeBlock(block *block.Block) error { } } + root := bc.dao.MPT.StateRoot() + var prevHash util.Uint256 + if block.Index > 0 { + prev, err := bc.dao.GetStateRoot(block.Index - 1) + if err != nil { + return errors.WithMessagef(err, "can't get previous state root") + } + prevHash = prev.Root + } + err := bc.AddStateRoot(&state.MPTRoot{ + MPTRootBase: state.MPTRootBase{ + Index: block.Index, + PrevHash: prevHash, + Root: root, + }, + }) + if err != nil { + return err + } + if bc.config.SaveStorageBatch { bc.lastBatch = cache.DAO.GetBatch() } bc.lock.Lock() - _, err := cache.Persist() + _, err = cache.Persist() if err != nil { bc.lock.Unlock() return err } + bc.dao.MPT.Flush() bc.topBlock.Store(block) atomic.StoreUint32(&bc.blockHeight, block.Index) bc.memPool.RemoveStale(bc.isTxStillRelevant) @@ -1732,6 +1762,65 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// AddStateRoot add new (possibly unverified) state root to the blockchain. +func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { + our, err := bc.GetStateRoot(r.Index) + if err == nil { + if our.Flag == state.Verified { + return nil + } else if r.Witness == nil && our.Witness != nil { + r.Witness = our.Witness + } + } + if err := bc.verifyStateRoot(r); err != nil { + return errors.WithMessage(err, "invalid state root") + } + if r.Index > bc.BlockHeight() { // just put it into the store for future checks + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: state.Unverified, + }) + } + + flag := state.Unverified + if r.Witness != nil { + if err := bc.verifyStateRootWitness(r); err != nil { + return errors.WithMessage(err, "can't verify signature") + } + flag = state.Verified + } + return bc.dao.PutStateRoot(&state.MPTRootState{ + MPTRoot: *r, + Flag: flag, + }) +} + +// verifyStateRoot checks if state root is valid. +func (bc *Blockchain) verifyStateRoot(r *state.MPTRoot) error { + if r.Index == 0 { + return nil + } + prev, err := bc.GetStateRoot(r.Index - 1) + if err != nil { + return errors.New("can't get previous state root") + } else if !prev.Root.Equals(r.PrevHash) { + return errors.New("previous hash mismatch") + } else if prev.Version != r.Version { + return errors.New("version mismatch") + } + return nil +} + +// verifyStateRootWitness verifies that state root signature is correct. +func (bc *Blockchain) verifyStateRootWitness(r *state.MPTRoot) error { + b, err := bc.GetBlock(bc.GetHeaderHash(int(r.Index))) + if err != nil { + return err + } + interopCtx := bc.newInteropContext(trigger.Verification, bc.dao, nil, nil) + return bc.verifyHashAgainstScript(b.NextConsensus, r.Witness, hash.Sha256(r.GetSignedPart()), interopCtx, true) +} + // VerifyTx verifies whether a transaction is bonafide or not. Block parameter // is used for easy interop access and can be omitted for transactions that are // not yet added into any block. diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index d3e0309de..eac6e4edc 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -18,6 +18,7 @@ type Blockchainer interface { GetConfig() config.ProtocolConfiguration AddHeaders(...*block.Header) error AddBlock(*block.Block) error + AddStateRoot(r *state.MPTRoot) error BlockHeight() uint32 CalculateClaimable(value util.Fixed8, startHeight, endHeight uint32) (util.Fixed8, util.Fixed8, error) Close() @@ -38,6 +39,7 @@ type Blockchainer interface { GetNEP5Balances(util.Uint160) *state.NEP5Balances GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160) (map[string]*state.StorageItem, error) GetTestVM() *vm.VM diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index 7126969ac..b6fb6023b 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -7,6 +7,7 @@ import ( "sort" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/mpt" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/transaction" @@ -34,6 +35,8 @@ type DAO interface { GetHeaderHashes() ([]util.Uint256, error) GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) + GetStateRoot(height uint32) (*state.MPTRootState, error) + PutStateRoot(root *state.MPTRootState) error GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem GetStorageItems(hash util.Uint160, prefix []byte) ([]StorageItemWithKey, error) GetTransaction(hash util.Uint256) (*transaction.Transaction, uint32, error) @@ -70,12 +73,14 @@ type DAO interface { // Simple is memCached wrapper around DB, simple DAO implementation. type Simple struct { + MPT *mpt.Trie Store *storage.MemCachedStore } // NewSimple creates new simple dao using provided backend store. func NewSimple(backend storage.Store) *Simple { - return &Simple{Store: storage.NewMemCachedStore(backend)} + st := storage.NewMemCachedStore(backend) + return &Simple{Store: st, MPT: mpt.NewTrie(nil, st)} } // GetBatch returns currently accumulated DB changeset. @@ -86,7 +91,9 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - return NewSimple(dao.Store) + d := NewSimple(dao.Store) + d.MPT = dao.MPT + return d } // GetAndDecode performs get operation and decoding with serializable structures. @@ -406,6 +413,42 @@ func (dao *Simple) PutAppExecResult(aer *state.AppExecResult) error { // -- start storage item. +func makeStateRootKey(height uint32) []byte { + key := make([]byte, 5) + key[0] = byte(storage.DataMPT) + binary.LittleEndian.PutUint32(key[1:], height) + return key +} + +// InitMPT initializes MPT at the given height. +func (dao *Simple) InitMPT(height uint32) error { + if height == 0 { + dao.MPT = mpt.NewTrie(nil, dao.Store) + return nil + } + r, err := dao.GetStateRoot(height) + if err != nil { + return err + } + dao.MPT = mpt.NewTrie(mpt.NewHashNode(r.Root), dao.Store) + return nil +} + +// GetStateRoot returns state root of a given height. +func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { + r := new(state.MPTRootState) + err := dao.GetAndDecode(r, makeStateRootKey(height)) + if err != nil { + return nil, err + } + return r, nil +} + +// PutStateRoot puts state root of a given height into the store. +func (dao *Simple) PutStateRoot(r *state.MPTRootState) error { + return dao.Put(r, makeStateRootKey(r.Index)) +} + // GetStorageItem returns StorageItem if it exists in the given store. func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem { b, err := dao.Store.Get(makeStorageItemKey(scripthash, key)) @@ -426,13 +469,24 @@ func (dao *Simple) GetStorageItem(scripthash util.Uint160, key []byte) *state.St // PutStorageItem puts given StorageItem for given script with given // key into the given store. func (dao *Simple) PutStorageItem(scripthash util.Uint160, key []byte, si *state.StorageItem) error { - return dao.Put(si, makeStorageItemKey(scripthash, key)) + stKey := makeStorageItemKey(scripthash, key) + k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix + v := mpt.ToNeoStorageValue(si) + if err := dao.MPT.Put(k, v); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Put(si, stKey) } // DeleteStorageItem drops storage item for the given script with the // given key from the store. func (dao *Simple) DeleteStorageItem(scripthash util.Uint160, key []byte) error { - return dao.Store.Delete(makeStorageItemKey(scripthash, key)) + stKey := makeStorageItemKey(scripthash, key) + k := mpt.ToNeoStorageKey(stKey[1:]) // strip STStorage prefix + if err := dao.MPT.Delete(k); err != nil && err != mpt.ErrNotFound { + return err + } + return dao.Store.Delete(stKey) } // StorageItemWithKey is a Key-Value pair together with possible const modifier. diff --git a/pkg/core/mpt/helpers.go b/pkg/core/mpt/helpers.go index fe59b2917..4f508445d 100644 --- a/pkg/core/mpt/helpers.go +++ b/pkg/core/mpt/helpers.go @@ -1,6 +1,10 @@ package mpt -import "github.com/nspcc-dev/neo-go/pkg/util" +import ( + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) // lcp returns longest common prefix of a and b. // Note: it does no allocations. @@ -69,3 +73,14 @@ func ToNeoStorageKey(key []byte) []byte { } return append(nkey, byte(padding)) } + +// ToNeoStorageValue serializes si to a C# neo node's format. +// It has additional version (0x00) byte at the beginning. +func ToNeoStorageValue(si *state.StorageItem) []byte { + const version = 0 + + buf := io.NewBufBinWriter() + buf.BinWriter.WriteB(version) + si.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} diff --git a/pkg/core/state/mpt_root.go b/pkg/core/state/mpt_root.go new file mode 100644 index 000000000..facf3da45 --- /dev/null +++ b/pkg/core/state/mpt_root.go @@ -0,0 +1,105 @@ +package state + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/io" + "github.com/nspcc-dev/neo-go/pkg/util" +) + +// MPTRootBase represents storage state root. +type MPTRootBase struct { + Version byte + Index uint32 + PrevHash util.Uint256 + Root util.Uint256 +} + +// MPTRoot represents storage state root together with sign info. +type MPTRoot struct { + MPTRootBase + Witness *transaction.Witness +} + +// MPTRootStateFlag represents verification state of the state root. +type MPTRootStateFlag byte + +// Possible verification states of MPTRoot. +const ( + Unverified MPTRootStateFlag = 0x00 + Verified MPTRootStateFlag = 0x01 + Invalid MPTRootStateFlag = 0x03 +) + +// MPTRootState represents state root together with its verification state. +type MPTRootState struct { + MPTRoot + Flag MPTRootStateFlag +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootState) EncodeBinary(w *io.BinWriter) { + w.WriteB(byte(s.Flag)) + s.MPTRoot.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootState) DecodeBinary(r *io.BinReader) { + s.Flag = MPTRootStateFlag(r.ReadB()) + s.MPTRoot.DecodeBinary(r) +} + +// GetSignedPart returns part of MPTRootBase which needs to be signed. +func (s *MPTRootBase) GetSignedPart() []byte { + buf := io.NewBufBinWriter() + s.EncodeBinary(buf.BinWriter) + return buf.Bytes() +} + +// Equals checks if s == other. +func (s *MPTRootBase) Equals(other *MPTRootBase) bool { + return s.Version == other.Version && s.Index == other.Index && + s.PrevHash.Equals(other.PrevHash) && s.Root.Equals(other.Root) +} + +// Hash returns hash of s. +func (s *MPTRootBase) Hash() util.Uint256 { + return hash.DoubleSha256(s.GetSignedPart()) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRootBase) DecodeBinary(r *io.BinReader) { + s.Version = r.ReadB() + s.Index = r.ReadU32LE() + s.PrevHash.DecodeBinary(r) + s.Root.DecodeBinary(r) +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRootBase) EncodeBinary(w *io.BinWriter) { + w.WriteB(s.Version) + w.WriteU32LE(s.Index) + s.PrevHash.EncodeBinary(w) + s.Root.EncodeBinary(w) +} + +// DecodeBinary implements io.Serializable. +func (s *MPTRoot) DecodeBinary(r *io.BinReader) { + s.MPTRootBase.DecodeBinary(r) + + var ws []transaction.Witness + r.ReadArray(&ws, 1) + if len(ws) == 1 { + s.Witness = &ws[0] + } +} + +// EncodeBinary implements io.Serializable. +func (s *MPTRoot) EncodeBinary(w *io.BinWriter) { + s.MPTRootBase.EncodeBinary(w) + if s.Witness == nil { + w.WriteVarUint(0) + } else { + w.WriteArray([]*transaction.Witness{s.Witness}) + } +} diff --git a/pkg/core/state/mpt_root_test.go b/pkg/core/state/mpt_root_test.go new file mode 100644 index 000000000..15a3ca043 --- /dev/null +++ b/pkg/core/state/mpt_root_test.go @@ -0,0 +1,61 @@ +package state + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" + "github.com/stretchr/testify/require" +) + +func testStateRoot() *MPTRoot { + return &MPTRoot{ + MPTRootBase: MPTRootBase{ + Version: byte(rand.Uint32()), + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + } +} + +func TestStateRoot_Serializable(t *testing.T) { + r := testStateRoot() + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + + t.Run("WithWitness", func(t *testing.T) { + r.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, r, new(MPTRoot)) + }) +} + +func TestStateRootEquals(t *testing.T) { + r1 := testStateRoot() + r2 := *r1 + require.True(t, r1.Equals(&r2.MPTRootBase)) + + r2.MPTRootBase.Index++ + require.False(t, r1.Equals(&r2.MPTRootBase)) +} + +func TestMPTRootState_Serializable(t *testing.T) { + rs := &MPTRootState{ + MPTRoot: *testStateRoot(), + Flag: 0x04, + } + rs.MPTRoot.Witness = &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + } + testserdes.EncodeDecodeBinary(t, rs, new(MPTRootState)) +} + +func TestMPTRootStateUnverifiedByDefault(t *testing.T) { + var r MPTRootState + require.Equal(t, Unverified, r.Flag) +} diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index a719d012d..157ebdba0 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -59,6 +59,9 @@ func (chain *testChain) AddBlock(block *block.Block) error { } return nil } +func (chain *testChain) AddStateRoot(r *state.MPTRoot) error { + panic("TODO") +} func (chain *testChain) BlockHeight() uint32 { return atomic.LoadUint32(&chain.blockheight) } @@ -105,6 +108,9 @@ func (chain testChain) GetEnrollments() ([]*state.Validator, error) { func (chain testChain) GetScriptHashesForVerifying(*transaction.Transaction) ([]util.Uint160, error) { panic("TODO") } +func (chain testChain) GetStateRoot(height uint32) (*state.MPTRootState, error) { + panic("TODO") +} func (chain testChain) GetStorageItem(scripthash util.Uint160, key []byte) *state.StorageItem { panic("TODO") } diff --git a/pkg/network/message.go b/pkg/network/message.go index a8bedc96c..f17b62658 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" @@ -59,12 +60,15 @@ const ( CMDGetBlocks CommandType = "getblocks" CMDGetData CommandType = "getdata" CMDGetHeaders CommandType = "getheaders" + CMDGetRoots CommandType = "getroots" CMDHeaders CommandType = "headers" CMDInv CommandType = "inv" CMDMempool CommandType = "mempool" CMDMerkleBlock CommandType = "merkleblock" CMDPing CommandType = "ping" CMDPong CommandType = "pong" + CMDRoots CommandType = "roots" + CMDStateRoot CommandType = "stateroot" CMDTX CommandType = "tx" CMDUnknown CommandType = "unknown" CMDVerack CommandType = "verack" @@ -124,6 +128,8 @@ func (m *Message) CommandType() CommandType { return CMDGetData case "getheaders": return CMDGetHeaders + case "getroots": + return CMDGetRoots case "headers": return CMDHeaders case "inv": @@ -136,6 +142,10 @@ func (m *Message) CommandType() CommandType { return CMDPing case "pong": return CMDPong + case "roots": + return CMDRoots + case "stateroot": + return CMDStateRoot case "tx": return CMDTX case "verack": @@ -191,6 +201,8 @@ func (m *Message) decodePayload(br *io.BinReader) error { fallthrough case CMDGetHeaders: p = &payload.GetBlocks{} + case CMDGetRoots: + p = &payload.GetStateRoots{} case CMDHeaders: p = &payload.Headers{} case CMDTX: @@ -199,6 +211,10 @@ func (m *Message) decodePayload(br *io.BinReader) error { p = &payload.MerkleBlock{} case CMDPing, CMDPong: p = &payload.Ping{} + case CMDRoots: + p = &payload.StateRoots{} + case CMDStateRoot: + p = &state.MPTRoot{} default: return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command)) } diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index d582e0486..fd5f9ed71 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -18,6 +18,8 @@ func (i InventoryType) String() string { return "TX" case 0x02: return "block" + case StateRootType: + return "stateroot" case 0xe0: return "consensus" default: @@ -27,13 +29,14 @@ func (i InventoryType) String() string { // Valid returns true if the inventory (type) is known. func (i InventoryType) Valid() bool { - return i == BlockType || i == TXType || i == ConsensusType + return i == BlockType || i == TXType || i == ConsensusType || i == StateRootType } // List of valid InventoryTypes. const ( TXType InventoryType = 0x01 // 1 BlockType InventoryType = 0x02 // 2 + StateRootType InventoryType = 0x03 // 3 ConsensusType InventoryType = 0xe0 // 224 ) diff --git a/pkg/network/payload/state_root.go b/pkg/network/payload/state_root.go new file mode 100644 index 000000000..f43584375 --- /dev/null +++ b/pkg/network/payload/state_root.go @@ -0,0 +1,43 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MaxStateRootsAllowed is a maxumum amount of state roots +// which can be sent in a single payload. +const MaxStateRootsAllowed = 2000 + +// StateRoots contains multiple StateRoots. +type StateRoots struct { + Roots []state.MPTRoot +} + +// GetStateRoots represents request for state roots. +type GetStateRoots struct { + Start uint32 + Count uint32 +} + +// EncodeBinary implements io.Serializable. +func (s *StateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteArray(s.Roots) +} + +// DecodeBinary implements io.Serializable. +func (s *StateRoots) DecodeBinary(r *io.BinReader) { + r.ReadArray(&s.Roots, MaxStateRootsAllowed) +} + +// DecodeBinary implements io.Serializable. +func (g *GetStateRoots) DecodeBinary(r *io.BinReader) { + g.Start = r.ReadU32LE() + g.Count = r.ReadU32LE() +} + +// EncodeBinary implements io.Serializable. +func (g *GetStateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteU32LE(g.Start) + w.WriteU32LE(g.Count) +} diff --git a/pkg/network/payload/state_root_test.go b/pkg/network/payload/state_root_test.go new file mode 100644 index 000000000..a3f670713 --- /dev/null +++ b/pkg/network/payload/state_root_test.go @@ -0,0 +1,51 @@ +package payload + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" +) + +func TestStateRoots_Serializable(t *testing.T) { + expected := &StateRoots{ + Roots: []state.MPTRoot{ + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + }, + } + + testserdes.EncodeDecodeBinary(t, expected, new(StateRoots)) +} + +func TestGetStateRoots_Serializable(t *testing.T) { + expected := &GetStateRoots{ + Start: rand.Uint32(), + Count: rand.Uint32(), + } + + testserdes.EncodeDecodeBinary(t, expected, new(GetStateRoots)) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 1836cdf92..a9559eba7 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -507,6 +508,8 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { if err == nil { msg = s.MkMsg(CMDBlock, b) } + case payload.StateRootType: + return nil // do nothing case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { msg = s.MkMsg(CMDConsensus, cp) @@ -589,6 +592,35 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { return p.EnqueueP2PMessage(msg) } +// handleGetRootsCmd processees `getroots` request. +func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { + count := gr.Count + if count > payload.MaxStateRootsAllowed { + count = payload.MaxStateRootsAllowed + } + var rs payload.StateRoots + for height := gr.Start; height < gr.Start+gr.Count; height++ { + r, err := s.chain.GetStateRoot(height) + if err != nil { + return err + } else if r.Flag == state.Verified { + rs.Roots = append(rs.Roots, r.MPTRoot) + } + } + msg := s.MkMsg(CMDRoots, &rs) + return p.EnqueueP2PMessage(msg) +} + +// handleStateRootsCmd processees `roots` request. +func (s *Server) handleRootsCmd(rs *payload.StateRoots) error { + return nil // TODO +} + +// handleStateRootCmd processees `stateroot` request. +func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { + return nil // TODO +} + // handleConsensusCmd processes received consensus payload. // It never returns an error. func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { @@ -697,6 +729,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlocks) return s.handleGetHeadersCmd(peer, gh) + case CMDGetRoots: + gr := msg.Payload.(*payload.GetStateRoots) + return s.handleGetRootsCmd(peer, gr) case CMDHeaders: headers := msg.Payload.(*payload.Headers) go s.handleHeadersCmd(peer, headers) @@ -718,6 +753,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDPong: pong := msg.Payload.(*payload.Ping) return s.handlePong(peer, pong) + case CMDRoots: + rs := msg.Payload.(*payload.StateRoots) + return s.handleRootsCmd(rs) + case CMDStateRoot: + r := msg.Payload.(*state.MPTRoot) + return s.handleStateRootCmd(r) case CMDVersion, CMDVerack: return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) }