diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index be97d38bc..528002435 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1776,12 +1776,18 @@ func (bc *Blockchain) isTxStillRelevant(t *transaction.Transaction) bool { } +// StateHeight returns height of the verified state root. +func (bc *Blockchain) StateHeight() uint32 { + h, _ := bc.dao.GetCurrentStateRootHeight() + return h +} + // AddStateRoot add new (possibly unverified) state root to the blockchain. func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { our, err := bc.GetStateRoot(r.Index) if err == nil { if our.Flag == state.Verified { - return nil + return bc.updateStateHeight(r.Index) } else if r.Witness == nil && our.Witness != nil { r.Witness = our.Witness } @@ -1803,10 +1809,24 @@ func (bc *Blockchain) AddStateRoot(r *state.MPTRoot) error { } flag = state.Verified } - return bc.dao.PutStateRoot(&state.MPTRootState{ + err = bc.dao.PutStateRoot(&state.MPTRootState{ MPTRoot: *r, Flag: flag, }) + if err != nil { + return err + } + return bc.updateStateHeight(r.Index) +} + +func (bc *Blockchain) updateStateHeight(newHeight uint32) error { + h, err := bc.dao.GetCurrentStateRootHeight() + if err != nil { + return errors.WithMessage(err, "can't get current state root height") + } else if newHeight == h+1 { + return bc.dao.PutCurrentStateRootHeight(h + 1) + } + return nil } // verifyStateRoot checks if state root is valid. diff --git a/pkg/core/blockchainer.go b/pkg/core/blockchainer.go index db2d11abe..d32a3eb5e 100644 --- a/pkg/core/blockchainer.go +++ b/pkg/core/blockchainer.go @@ -49,6 +49,7 @@ type Blockchainer interface { References(t *transaction.Transaction) ([]transaction.InOut, error) mempool.Feer // fee interface PoolTx(*transaction.Transaction) error + StateHeight() uint32 SubscribeForBlocks(ch chan<- *block.Block) SubscribeForExecutions(ch chan<- *state.AppExecResult) SubscribeForNotifications(ch chan<- *state.NotificationEvent) diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index a1865517c..262584d47 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -32,6 +32,7 @@ type DAO interface { GetContractState(hash util.Uint160) (*state.Contract, error) GetCurrentBlockHeight() (uint32, error) GetCurrentHeaderHeight() (i uint32, h util.Uint256, err error) + GetCurrentStateRootHeight() (uint32, error) GetHeaderHashes() ([]util.Uint256, error) GetNEP5Balances(acc util.Uint160) (*state.NEP5Balances, error) GetNEP5TransferLog(acc util.Uint160, index uint32) (*state.NEP5TransferLog, error) @@ -434,6 +435,27 @@ func (dao *Simple) InitMPT(height uint32) error { return nil } +// GetCurrentStateRootHeight returns current state root height. +func (dao *Simple) GetCurrentStateRootHeight() (uint32, error) { + key := []byte{byte(storage.DataMPT)} + val, err := dao.Store.Get(key) + if err != nil { + if err == storage.ErrKeyNotFound { + err = nil + } + return 0, err + } + return binary.LittleEndian.Uint32(val), nil +} + +// PutCurrentStateRootHeight updates current state root height. +func (dao *Simple) PutCurrentStateRootHeight(height uint32) error { + key := []byte{byte(storage.DataMPT)} + val := make([]byte, 4) + binary.LittleEndian.PutUint32(val, height) + return dao.Store.Put(key, val) +} + // GetStateRoot returns state root of a given height. func (dao *Simple) GetStateRoot(height uint32) (*state.MPTRootState, error) { r := new(state.MPTRootState) diff --git a/pkg/network/helper_test.go b/pkg/network/helper_test.go index 2e2b697ad..0bc0b6637 100644 --- a/pkg/network/helper_test.go +++ b/pkg/network/helper_test.go @@ -154,7 +154,9 @@ func (chain testChain) IsLowPriority(util.Fixed8) bool { func (chain testChain) PoolTx(*transaction.Transaction) error { panic("TODO") } - +func (chain testChain) StateHeight() uint32 { + panic("TODO") +} func (chain testChain) SubscribeForBlocks(ch chan<- *block.Block) { panic("TODO") } diff --git a/pkg/network/server.go b/pkg/network/server.go index 1b0c3076c..5ce2118aa 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -620,11 +620,34 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { } // handleStateRootsCmd processees `roots` request. -func (s *Server) handleRootsCmd(rs *payload.StateRoots) error { +func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { + h := s.chain.StateHeight() for i := range rs.Roots { + if rs.Roots[i].Index <= h { + continue + } _ = s.chain.AddStateRoot(&rs.Roots[i]) } - return nil + // request more state roots from peer if needed + return s.requestStateRoot(p) +} + +// requestStateRoot sends `getroots` message to get verified state roots. +func (s *Server) requestStateRoot(p Peer) error { + stateHeight := s.chain.StateHeight() + hdrHeight := s.chain.BlockHeight() + count := uint32(payload.MaxStateRootsAllowed) + if diff := hdrHeight - stateHeight; diff < count { + count = diff + } + if count == 0 { + return nil + } + gr := &payload.GetStateRoots{ + Start: stateHeight + 1, + Count: count, + } + return p.EnqueueP2PMessage(s.MkMsg(CMDGetRoots, gr)) } // handleStateRootCmd processees `stateroot` request. @@ -772,7 +795,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return s.handlePong(peer, pong) case CMDRoots: rs := msg.Payload.(*payload.StateRoots) - return s.handleRootsCmd(rs) + return s.handleRootsCmd(peer, rs) case CMDStateRoot: r := msg.Payload.(*state.MPTRoot) return s.handleStateRootCmd(r) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index db1c13bc4..058faddf1 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -251,6 +251,9 @@ func (p *TCPPeer) StartProtocol() { if p.LastBlockIndex() > p.server.chain.BlockHeight() { err = p.server.requestBlocks(p) } + if err == nil { + err = p.server.requestStateRoot(p) + } if err == nil { timer.Reset(p.server.ProtoTickInterval) } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 6912c4e24..675efc9b7 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -745,10 +745,9 @@ func (s *Server) verifyProof(ps request.Params) (interface{}, *response.Error) { } func (s *Server) getStateHeight(_ request.Params) (interface{}, *response.Error) { - height := s.chain.BlockHeight() return &result.StateHeight{ - BlockHeight: height, - StateHeight: height, + BlockHeight: s.chain.BlockHeight(), + StateHeight: s.chain.StateHeight(), }, nil } diff --git a/pkg/rpc/server/server_test.go b/pkg/rpc/server/server_test.go index e53687baf..509867121 100644 --- a/pkg/rpc/server/server_test.go +++ b/pkg/rpc/server/server_test.go @@ -246,9 +246,8 @@ var rpcTestCases = map[string][]rpcTestCase{ sh, ok := res.(*result.StateHeight) require.True(t, ok) - h := e.chain.BlockHeight() - require.Equal(t, h, sh.BlockHeight) - require.Equal(t, h, sh.StateHeight) + require.Equal(t, e.chain.BlockHeight(), sh.BlockHeight) + require.Equal(t, e.chain.StateHeight(), sh.StateHeight) }, }, },