diff --git a/pkg/network/message.go b/pkg/network/message.go index a8bedc96c..f17b62658 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -8,6 +8,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" "github.com/nspcc-dev/neo-go/pkg/io" @@ -59,12 +60,15 @@ const ( CMDGetBlocks CommandType = "getblocks" CMDGetData CommandType = "getdata" CMDGetHeaders CommandType = "getheaders" + CMDGetRoots CommandType = "getroots" CMDHeaders CommandType = "headers" CMDInv CommandType = "inv" CMDMempool CommandType = "mempool" CMDMerkleBlock CommandType = "merkleblock" CMDPing CommandType = "ping" CMDPong CommandType = "pong" + CMDRoots CommandType = "roots" + CMDStateRoot CommandType = "stateroot" CMDTX CommandType = "tx" CMDUnknown CommandType = "unknown" CMDVerack CommandType = "verack" @@ -124,6 +128,8 @@ func (m *Message) CommandType() CommandType { return CMDGetData case "getheaders": return CMDGetHeaders + case "getroots": + return CMDGetRoots case "headers": return CMDHeaders case "inv": @@ -136,6 +142,10 @@ func (m *Message) CommandType() CommandType { return CMDPing case "pong": return CMDPong + case "roots": + return CMDRoots + case "stateroot": + return CMDStateRoot case "tx": return CMDTX case "verack": @@ -191,6 +201,8 @@ func (m *Message) decodePayload(br *io.BinReader) error { fallthrough case CMDGetHeaders: p = &payload.GetBlocks{} + case CMDGetRoots: + p = &payload.GetStateRoots{} case CMDHeaders: p = &payload.Headers{} case CMDTX: @@ -199,6 +211,10 @@ func (m *Message) decodePayload(br *io.BinReader) error { p = &payload.MerkleBlock{} case CMDPing, CMDPong: p = &payload.Ping{} + case CMDRoots: + p = &payload.StateRoots{} + case CMDStateRoot: + p = &state.MPTRoot{} default: return fmt.Errorf("can't decode command %s", cmdByteArrayToString(m.Command)) } diff --git a/pkg/network/payload/inventory.go b/pkg/network/payload/inventory.go index d582e0486..fd5f9ed71 100644 --- a/pkg/network/payload/inventory.go +++ b/pkg/network/payload/inventory.go @@ -18,6 +18,8 @@ func (i InventoryType) String() string { return "TX" case 0x02: return "block" + case StateRootType: + return "stateroot" case 0xe0: return "consensus" default: @@ -27,13 +29,14 @@ func (i InventoryType) String() string { // Valid returns true if the inventory (type) is known. func (i InventoryType) Valid() bool { - return i == BlockType || i == TXType || i == ConsensusType + return i == BlockType || i == TXType || i == ConsensusType || i == StateRootType } // List of valid InventoryTypes. const ( TXType InventoryType = 0x01 // 1 BlockType InventoryType = 0x02 // 2 + StateRootType InventoryType = 0x03 // 3 ConsensusType InventoryType = 0xe0 // 224 ) diff --git a/pkg/network/payload/state_root.go b/pkg/network/payload/state_root.go new file mode 100644 index 000000000..f43584375 --- /dev/null +++ b/pkg/network/payload/state_root.go @@ -0,0 +1,43 @@ +package payload + +import ( + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/io" +) + +// MaxStateRootsAllowed is a maxumum amount of state roots +// which can be sent in a single payload. +const MaxStateRootsAllowed = 2000 + +// StateRoots contains multiple StateRoots. +type StateRoots struct { + Roots []state.MPTRoot +} + +// GetStateRoots represents request for state roots. +type GetStateRoots struct { + Start uint32 + Count uint32 +} + +// EncodeBinary implements io.Serializable. +func (s *StateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteArray(s.Roots) +} + +// DecodeBinary implements io.Serializable. +func (s *StateRoots) DecodeBinary(r *io.BinReader) { + r.ReadArray(&s.Roots, MaxStateRootsAllowed) +} + +// DecodeBinary implements io.Serializable. +func (g *GetStateRoots) DecodeBinary(r *io.BinReader) { + g.Start = r.ReadU32LE() + g.Count = r.ReadU32LE() +} + +// EncodeBinary implements io.Serializable. +func (g *GetStateRoots) EncodeBinary(w *io.BinWriter) { + w.WriteU32LE(g.Start) + w.WriteU32LE(g.Count) +} diff --git a/pkg/network/payload/state_root_test.go b/pkg/network/payload/state_root_test.go new file mode 100644 index 000000000..a3f670713 --- /dev/null +++ b/pkg/network/payload/state_root_test.go @@ -0,0 +1,51 @@ +package payload + +import ( + "math/rand" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/transaction" + "github.com/nspcc-dev/neo-go/pkg/internal/random" + "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" +) + +func TestStateRoots_Serializable(t *testing.T) { + expected := &StateRoots{ + Roots: []state.MPTRoot{ + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + { + MPTRootBase: state.MPTRootBase{ + Index: rand.Uint32(), + PrevHash: random.Uint256(), + Root: random.Uint256(), + }, + Witness: &transaction.Witness{ + InvocationScript: random.Bytes(10), + VerificationScript: random.Bytes(11), + }, + }, + }, + } + + testserdes.EncodeDecodeBinary(t, expected, new(StateRoots)) +} + +func TestGetStateRoots_Serializable(t *testing.T) { + expected := &GetStateRoots{ + Start: rand.Uint32(), + Count: rand.Uint32(), + } + + testserdes.EncodeDecodeBinary(t, expected, new(GetStateRoots)) +} diff --git a/pkg/network/server.go b/pkg/network/server.go index 1836cdf92..a9559eba7 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -13,6 +13,7 @@ import ( "github.com/nspcc-dev/neo-go/pkg/consensus" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/block" + "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/network/payload" "github.com/nspcc-dev/neo-go/pkg/util" @@ -507,6 +508,8 @@ func (s *Server) handleGetDataCmd(p Peer, inv *payload.Inventory) error { if err == nil { msg = s.MkMsg(CMDBlock, b) } + case payload.StateRootType: + return nil // do nothing case payload.ConsensusType: if cp := s.consensus.GetPayload(hash); cp != nil { msg = s.MkMsg(CMDConsensus, cp) @@ -589,6 +592,35 @@ func (s *Server) handleGetHeadersCmd(p Peer, gh *payload.GetBlocks) error { return p.EnqueueP2PMessage(msg) } +// handleGetRootsCmd processees `getroots` request. +func (s *Server) handleGetRootsCmd(p Peer, gr *payload.GetStateRoots) error { + count := gr.Count + if count > payload.MaxStateRootsAllowed { + count = payload.MaxStateRootsAllowed + } + var rs payload.StateRoots + for height := gr.Start; height < gr.Start+gr.Count; height++ { + r, err := s.chain.GetStateRoot(height) + if err != nil { + return err + } else if r.Flag == state.Verified { + rs.Roots = append(rs.Roots, r.MPTRoot) + } + } + msg := s.MkMsg(CMDRoots, &rs) + return p.EnqueueP2PMessage(msg) +} + +// handleStateRootsCmd processees `roots` request. +func (s *Server) handleRootsCmd(rs *payload.StateRoots) error { + return nil // TODO +} + +// handleStateRootCmd processees `stateroot` request. +func (s *Server) handleStateRootCmd(r *state.MPTRoot) error { + return nil // TODO +} + // handleConsensusCmd processes received consensus payload. // It never returns an error. func (s *Server) handleConsensusCmd(cp *consensus.Payload) error { @@ -697,6 +729,9 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDGetHeaders: gh := msg.Payload.(*payload.GetBlocks) return s.handleGetHeadersCmd(peer, gh) + case CMDGetRoots: + gr := msg.Payload.(*payload.GetStateRoots) + return s.handleGetRootsCmd(peer, gr) case CMDHeaders: headers := msg.Payload.(*payload.Headers) go s.handleHeadersCmd(peer, headers) @@ -718,6 +753,12 @@ func (s *Server) handleMessage(peer Peer, msg *Message) error { case CMDPong: pong := msg.Payload.(*payload.Ping) return s.handlePong(peer, pong) + case CMDRoots: + rs := msg.Payload.(*payload.StateRoots) + return s.handleRootsCmd(rs) + case CMDStateRoot: + r := msg.Payload.(*state.MPTRoot) + return s.handleStateRootCmd(r) case CMDVersion, CMDVerack: return fmt.Errorf("received '%s' after the handshake", msg.CommandType()) }