From e2e1bd09aee66494045f358943e6b1fc3d47c6d3 Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 22 Jun 2020 11:41:31 +0300 Subject: [PATCH] network: request state roots if needed --- pkg/network/server.go | 25 ++++++++++++++++++++++--- pkg/network/tcp_peer.go | 3 +++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pkg/network/server.go b/pkg/network/server.go index 1b0c3076c..69fedadf0 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -620,11 +620,30 @@ func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { } // handleStateRootsCmd processees `roots` request. -func (s *Server) handleRootsCmd(rs *payload.StateRoots) error { +func (s *Server) handleRootsCmd(p Peer, rs *payload.StateRoots) error { for i := range rs.Roots { _ = s.chain.AddStateRoot(&rs.Roots[i]) } - return nil + // request more state roots from peer if needed + return s.requestStateRoot(p) +} + +// requestStateRoot sends `getroots` message to get verified state roots. +func (s *Server) requestStateRoot(p Peer) error { + stateHeight := s.chain.StateHeight() + hdrHeight := s.chain.BlockHeight() + count := uint32(payload.MaxStateRootsAllowed) + if diff := hdrHeight - stateHeight; diff < count { + count = diff + } + if count == 0 { + return nil + } + gr := &payload.GetStateRoots{ + Start: stateHeight + 1, + Count: count, + } + return p.EnqueueP2PMessage(s.MkMsg(CMDGetRoots, gr)) } // handleStateRootCmd processees `stateroot` request. @@ -772,7 +791,7 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { return s.handlePong(peer, pong) case CMDRoots: rs := msg.Payload.(*payload.StateRoots) - return s.handleRootsCmd(rs) + return s.handleRootsCmd(peer, rs) case CMDStateRoot: r := msg.Payload.(*state.MPTRoot) return s.handleStateRootCmd(r) diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index db1c13bc4..058faddf1 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -251,6 +251,9 @@ func (p *TCPPeer) StartProtocol() { if p.LastBlockIndex() > p.server.chain.BlockHeight() { err = p.server.requestBlocks(p) } + if err == nil { + err = p.server.requestStateRoot(p) + } if err == nil { timer.Reset(p.server.ProtoTickInterval) }