diff --git a/cli/server/server.go b/cli/server/server.go index b4a874963..ff538269c 100644 --- a/cli/server/server.go +++ b/cli/server/server.go @@ -281,7 +281,7 @@ func restoreDB(ctx *cli.Context) error { default: } bytes, err := readBlock(reader) - block := block.New(cfg.ProtocolConfiguration.Magic) + block := block.New(cfg.ProtocolConfiguration.Magic, cfg.ProtocolConfiguration.StateRootInHeader) newReader := io.NewBinReaderFromBuf(bytes) block.DecodeBinary(newReader) if err != nil { diff --git a/pkg/compiler/interop_test.go b/pkg/compiler/interop_test.go index 901217b96..c484f5db4 100644 --- a/pkg/compiler/interop_test.go +++ b/pkg/compiler/interop_test.go @@ -105,7 +105,7 @@ func TestAppCall(t *testing.T) { require.NoError(t, err) ih := hash.Hash160(inner) - ic := interop.NewContext(trigger.Application, nil, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil, nil, zaptest.NewLogger(t)) + ic := interop.NewContext(trigger.Application, nil, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false), nil, nil, nil, zaptest.NewLogger(t)) require.NoError(t, ic.DAO.PutContractState(&state.Contract{ Script: inner, Manifest: *m, diff --git a/pkg/config/protocol_config.go b/pkg/config/protocol_config.go index f9beb5100..c934a5a7a 100644 --- a/pkg/config/protocol_config.go +++ b/pkg/config/protocol_config.go @@ -24,7 +24,9 @@ type ( SecondsPerBlock int `yaml:"SecondsPerBlock"` SeedList []string `yaml:"SeedList"` StandbyCommittee []string `yaml:"StandbyCommittee"` - ValidatorsCount int `yaml:"ValidatorsCount"` + // StateRooInHeader enables storing state root in block header. + StateRootInHeader bool `yaml:"StateRootInHeader"` + ValidatorsCount int `yaml:"ValidatorsCount"` // Whether to verify received blocks. VerifyBlocks bool `yaml:"VerifyBlocks"` // Whether to verify transactions in received blocks. diff --git a/pkg/consensus/consensus.go b/pkg/consensus/consensus.go index 3cd091d0f..696b99e02 100644 --- a/pkg/consensus/consensus.go +++ b/pkg/consensus/consensus.go @@ -2,6 +2,7 @@ package consensus import ( "errors" + "fmt" "sort" "time" @@ -73,6 +74,8 @@ type service struct { lastProposal []util.Uint256 wallet *wallet.Wallet network netmode.Magic + // stateRootEnabled specifies if state root should be exchanged and checked during consensus. + stateRootEnabled bool // started is a flag set with Start method that runs an event handling // goroutine. started *atomic.Bool @@ -116,12 +119,13 @@ func NewService(cfg Config) (Service, error) { txx: newFIFOCache(cacheMaxCapacity), messages: make(chan Payload, 100), - transactions: make(chan *transaction.Transaction, 100), - blockEvents: make(chan *coreb.Block, 1), - network: cfg.Chain.GetConfig().Magic, - started: atomic.NewBool(false), - quit: make(chan struct{}), - finished: make(chan struct{}), + transactions: make(chan *transaction.Transaction, 100), + blockEvents: make(chan *coreb.Block, 1), + network: cfg.Chain.GetConfig().Magic, + stateRootEnabled: cfg.Chain.GetConfig().StateRootInHeader, + started: atomic.NewBool(false), + quit: make(chan struct{}), + finished: make(chan struct{}), } if cfg.Wallet == nil { @@ -168,12 +172,14 @@ func NewService(cfg Config) (Service, error) { dbft.WithGetConsensusAddress(srv.getConsensusAddress), dbft.WithNewConsensusPayload(srv.newPayload), - dbft.WithNewPrepareRequest(func() payload.PrepareRequest { return new(prepareRequest) }), + dbft.WithNewPrepareRequest(srv.newPrepareRequest), dbft.WithNewPrepareResponse(func() payload.PrepareResponse { return new(prepareResponse) }), dbft.WithNewChangeView(func() payload.ChangeView { return new(changeView) }), dbft.WithNewCommit(func() payload.Commit { return new(commit) }), dbft.WithNewRecoveryRequest(func() payload.RecoveryRequest { return new(recoveryRequest) }), - dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { return new(recoveryMessage) }), + dbft.WithNewRecoveryMessage(func() payload.RecoveryMessage { + return &recoveryMessage{stateRootEnabled: srv.stateRootEnabled} + }), dbft.WithVerifyPrepareRequest(srv.verifyRequest), dbft.WithVerifyPrepareResponse(func(_ payload.ConsensusPayload) error { return nil }), ) @@ -191,15 +197,30 @@ var ( ) // NewPayload creates new consensus payload for the provided network. -func NewPayload(m netmode.Magic) *Payload { +func NewPayload(m netmode.Magic, stateRootEnabled bool) *Payload { return &Payload{ network: m, - message: new(message), + message: &message{ + stateRootEnabled: stateRootEnabled, + }, } } func (s *service) newPayload() payload.ConsensusPayload { - return NewPayload(s.network) + return NewPayload(s.network, s.stateRootEnabled) +} + +func (s *service) newPrepareRequest() payload.PrepareRequest { + r := new(prepareRequest) + if s.stateRootEnabled { + r.stateRootEnabled = true + if sr, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1); err == nil { + r.stateRoot = sr.Root + } else { + panic(err) + } + } + return r } func (s *service) Start() { @@ -446,6 +467,14 @@ func (s *service) verifyBlock(b block.Block) bool { func (s *service) verifyRequest(p payload.ConsensusPayload) error { req := p.GetPrepareRequest().(*prepareRequest) + if s.stateRootEnabled { + sr, err := s.Chain.GetStateRoot(s.dbft.BlockIndex - 1) + if err != nil { + return err + } else if sr.Root != req.stateRoot { + return fmt.Errorf("state root mismatch: %s != %s", sr.Root, req.stateRoot) + } + } // Save lastProposal for getVerified(). s.lastProposal = req.transactionHashes @@ -584,6 +613,14 @@ func (s *service) newBlockFromContext(ctx *dbft.Context) block.Block { block.Block.Network = s.network block.Block.Timestamp = ctx.Timestamp / nsInMs block.Block.Index = ctx.BlockIndex + if s.stateRootEnabled { + sr, err := s.Chain.GetStateRoot(ctx.BlockIndex - 1) + if err != nil { + return nil + } + block.StateRootEnabled = true + block.PrevStateRoot = sr.Root + } var validators keys.PublicKeys var err error diff --git a/pkg/consensus/consensus_test.go b/pkg/consensus/consensus_test.go index ccb28e928..d269ad35a 100644 --- a/pkg/consensus/consensus_test.go +++ b/pkg/consensus/consensus_test.go @@ -285,6 +285,31 @@ func TestService_getTx(t *testing.T) { srv.Chain.Close() } +func TestService_PrepareRequest(t *testing.T) { + srv := newTestServiceWithState(t, true) + srv.dbft.Start() + defer srv.dbft.Timer.Stop() + + priv, _ := getTestValidator(1) + p := new(Payload) + p.message = &message{} + p.SetValidatorIndex(1) + + p.SetPayload(&prepareRequest{}) + require.NoError(t, p.Sign(priv)) + require.Error(t, srv.verifyRequest(p), "invalid stateroot setting") + + p.SetPayload(&prepareRequest{stateRootEnabled: true}) + require.NoError(t, p.Sign(priv)) + require.Error(t, srv.verifyRequest(p), "invalid state root") + + sr, err := srv.Chain.GetStateRoot(srv.dbft.BlockIndex - 1) + require.NoError(t, err) + p.SetPayload(&prepareRequest{stateRootEnabled: true, stateRoot: sr.Root}) + require.NoError(t, p.Sign(priv)) + require.NoError(t, srv.verifyRequest(p)) +} + func TestService_OnPayload(t *testing.T) { srv := newTestService(t) // This test directly reads things from srv.messages that normally @@ -407,8 +432,12 @@ func shouldNotReceive(t *testing.T, ch chan Payload) { } } +func newTestServiceWithState(t *testing.T, stateRootInHeader bool) *service { + return newTestServiceWithChain(t, newTestChain(t, stateRootInHeader)) +} + func newTestService(t *testing.T) *service { - return newTestServiceWithChain(t, newTestChain(t)) + return newTestServiceWithState(t, false) } func newTestServiceWithChain(t *testing.T, bc *core.Blockchain) *service { @@ -445,9 +474,10 @@ func newSingleTestChain(t *testing.T) *core.Blockchain { return chain } -func newTestChain(t *testing.T) *core.Blockchain { +func newTestChain(t *testing.T, stateRootInHeader bool) *core.Blockchain { unitTestNetCfg, err := config.Load("../../config", netmode.UnitTestNet) require.NoError(t, err) + unitTestNetCfg.ProtocolConfiguration.StateRootInHeader = stateRootInHeader chain, err := core.NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t)) require.NoError(t, err) diff --git a/pkg/consensus/payload.go b/pkg/consensus/payload.go index 33c29161b..ce563e4eb 100644 --- a/pkg/consensus/payload.go +++ b/pkg/consensus/payload.go @@ -21,6 +21,8 @@ type ( ViewNumber byte payload io.Serializable + // stateRootEnabled specifies if state root is exchanged during consensus. + stateRootEnabled bool } // Payload is a type for consensus-related messages. @@ -302,7 +304,11 @@ func (m *message) DecodeBinary(r *io.BinReader) { cv.newViewNumber = m.ViewNumber + 1 m.payload = cv case prepareRequestType: - m.payload = new(prepareRequest) + r := new(prepareRequest) + if m.stateRootEnabled { + r.stateRootEnabled = true + } + m.payload = r case prepareResponseType: m.payload = new(prepareResponse) case commitType: @@ -310,7 +316,11 @@ func (m *message) DecodeBinary(r *io.BinReader) { case recoveryRequestType: m.payload = new(recoveryRequest) case recoveryMessageType: - m.payload = new(recoveryMessage) + r := new(recoveryMessage) + if m.stateRootEnabled { + r.stateRootEnabled = true + } + m.payload = r default: r.Err = fmt.Errorf("invalid type: 0x%02x", byte(m.Type)) return diff --git a/pkg/consensus/payload_test.go b/pkg/consensus/payload_test.go index e74803c57..9ebbc5201 100644 --- a/pkg/consensus/payload_test.go +++ b/pkg/consensus/payload_test.go @@ -104,12 +104,13 @@ func TestConsensusPayload_Serializable(t *testing.T) { require.Nil(t, actual.message) actual.message = new(message) // message should now be decoded from actual.data byte array + actual.message = new(message) assert.NoError(t, actual.decodeData()) assert.NotNil(t, actual.MarshalUnsigned()) require.Equal(t, p, actual) data = p.MarshalUnsigned() - pu := NewPayload(netmode.Magic(rand.Uint32())) + pu := NewPayload(netmode.Magic(rand.Uint32()), false) require.NoError(t, pu.UnmarshalUnsigned(data)) assert.NoError(t, pu.decodeData()) _ = pu.MarshalUnsigned() @@ -316,7 +317,7 @@ func TestPayload_Sign(t *testing.T) { p := randomPayload(t, prepareRequestType) h := priv.PublicKey().GetScriptHash() - bc := newTestChain(t) + bc := newTestChain(t, false) defer bc.Close() require.Error(t, bc.VerifyWitness(h, p, &p.Witness, payloadGasLimit)) require.NoError(t, p.Sign(priv)) diff --git a/pkg/consensus/prepare_request.go b/pkg/consensus/prepare_request.go index ad745f9c6..9099740fe 100644 --- a/pkg/consensus/prepare_request.go +++ b/pkg/consensus/prepare_request.go @@ -12,6 +12,8 @@ type prepareRequest struct { timestamp uint64 nonce uint64 transactionHashes []util.Uint256 + stateRootEnabled bool + stateRoot util.Uint256 } var _ payload.PrepareRequest = (*prepareRequest)(nil) @@ -21,6 +23,9 @@ func (p *prepareRequest) EncodeBinary(w *io.BinWriter) { w.WriteU64LE(p.timestamp) w.WriteU64LE(p.nonce) w.WriteArray(p.transactionHashes) + if p.stateRootEnabled { + w.WriteBytes(p.stateRoot[:]) + } } // DecodeBinary implements io.Serializable interface. @@ -28,6 +33,9 @@ func (p *prepareRequest) DecodeBinary(r *io.BinReader) { p.timestamp = r.ReadU64LE() p.nonce = r.ReadU64LE() r.ReadArray(&p.transactionHashes, block.MaxTransactionsPerBlock) + if p.stateRootEnabled { + r.ReadBytes(p.stateRoot[:]) + } } // Timestamp implements payload.PrepareRequest interface. diff --git a/pkg/consensus/recovery_message.go b/pkg/consensus/recovery_message.go index 44151aff4..c094c7f3c 100644 --- a/pkg/consensus/recovery_message.go +++ b/pkg/consensus/recovery_message.go @@ -16,6 +16,7 @@ type ( preparationPayloads []*preparationCompact commitPayloads []*commitCompact changeViewPayloads []*changeViewCompact + stateRootEnabled bool prepareRequest *message } @@ -47,7 +48,7 @@ func (m *recoveryMessage) DecodeBinary(r *io.BinReader) { var hasReq = r.ReadBool() if hasReq { - m.prepareRequest = new(message) + m.prepareRequest = &message{stateRootEnabled: m.stateRootEnabled} m.prepareRequest.DecodeBinary(r) if r.Err == nil && m.prepareRequest.Type != prepareRequestType { r.Err = errors.New("recovery message PrepareRequest has wrong type") @@ -143,9 +144,10 @@ func (m *recoveryMessage) AddPayload(p payload.ConsensusPayload) { switch p.Type() { case payload.PrepareRequestType: m.prepareRequest = &message{ - Type: prepareRequestType, - ViewNumber: p.ViewNumber(), - payload: p.GetPrepareRequest().(*prepareRequest), + Type: prepareRequestType, + ViewNumber: p.ViewNumber(), + payload: p.GetPrepareRequest().(*prepareRequest), + stateRootEnabled: m.stateRootEnabled, } h := p.Hash() m.preparationHash = &h @@ -291,9 +293,10 @@ func fromPayload(t messageType, recovery *Payload, p io.Serializable) *Payload { return &Payload{ network: recovery.network, message: &message{ - Type: t, - ViewNumber: recovery.message.ViewNumber, - payload: p, + Type: t, + ViewNumber: recovery.message.ViewNumber, + payload: p, + stateRootEnabled: recovery.stateRootEnabled, }, version: recovery.Version(), prevHash: recovery.PrevHash(), diff --git a/pkg/consensus/recovery_message_test.go b/pkg/consensus/recovery_message_test.go index 47471a0e6..b0fb4237b 100644 --- a/pkg/consensus/recovery_message_test.go +++ b/pkg/consensus/recovery_message_test.go @@ -12,8 +12,17 @@ import ( "github.com/stretchr/testify/require" ) -func TestRecoveryMessage_Setters(t *testing.T) { - srv := newTestService(t) +func TestRecoveryMessageSetters(t *testing.T) { + t.Run("NoStateRoot", func(t *testing.T) { + testRecoveryMessageSetters(t, false) + }) + t.Run("WithStateRoot", func(t *testing.T) { + testRecoveryMessageSetters(t, true) + }) +} + +func testRecoveryMessageSetters(t *testing.T, enableStateRoot bool) { + srv := newTestServiceWithState(t, enableStateRoot) defer srv.Chain.Close() privs := make([]*privateKey, testchain.Size()) pubs := make([]crypto.PublicKey, testchain.Size()) @@ -21,8 +30,8 @@ func TestRecoveryMessage_Setters(t *testing.T) { privs[i], pubs[i] = getTestValidator(i) } - r := &recoveryMessage{} - p := NewPayload(netmode.UnitTestNet) + r := &recoveryMessage{stateRootEnabled: enableStateRoot} + p := NewPayload(netmode.UnitTestNet, enableStateRoot) p.SetType(payload.RecoveryMessageType) p.SetPayload(r) // sign payload to have verification script @@ -32,15 +41,16 @@ func TestRecoveryMessage_Setters(t *testing.T) { timestamp: 87, nonce: 321, transactionHashes: []util.Uint256{{1}}, + stateRootEnabled: enableStateRoot, } - p1 := NewPayload(netmode.UnitTestNet) + p1 := NewPayload(netmode.UnitTestNet, enableStateRoot) p1.SetType(payload.PrepareRequestType) p1.SetPayload(req) p1.SetValidatorIndex(0) require.NoError(t, p1.Sign(privs[0])) t.Run("prepare response is added", func(t *testing.T) { - p2 := NewPayload(netmode.UnitTestNet) + p2 := NewPayload(netmode.UnitTestNet, enableStateRoot) p2.SetType(payload.PrepareResponseType) p2.SetPayload(&prepareResponse{ preparationHash: p1.Hash(), @@ -76,7 +86,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { }) t.Run("change view is added", func(t *testing.T) { - p3 := NewPayload(netmode.UnitTestNet) + p3 := NewPayload(netmode.UnitTestNet, enableStateRoot) p3.SetType(payload.ChangeViewType) p3.SetPayload(&changeView{ newViewNumber: 1, @@ -98,7 +108,7 @@ func TestRecoveryMessage_Setters(t *testing.T) { }) t.Run("commit is added", func(t *testing.T) { - p4 := NewPayload(netmode.UnitTestNet) + p4 := NewPayload(netmode.UnitTestNet, enableStateRoot) p4.SetType(payload.CommitType) p4.SetPayload(randomMessage(t, commitType)) p4.SetValidatorIndex(3) diff --git a/pkg/core/block/block.go b/pkg/core/block/block.go index 650df993e..e7fa4974a 100644 --- a/pkg/core/block/block.go +++ b/pkg/core/block/block.go @@ -77,10 +77,11 @@ func (b *Block) RebuildMerkleRoot() { // This is commonly used to create a block from stored data. // Blocks created from trimmed data will have their Trimmed field // set to true. -func NewBlockFromTrimmedBytes(network netmode.Magic, b []byte) (*Block, error) { +func NewBlockFromTrimmedBytes(network netmode.Magic, stateRootEnabled bool, b []byte) (*Block, error) { block := &Block{ Base: Base{ - Network: network, + Network: network, + StateRootEnabled: stateRootEnabled, }, Trimmed: true, } @@ -113,10 +114,11 @@ func NewBlockFromTrimmedBytes(network netmode.Magic, b []byte) (*Block, error) { } // New creates a new blank block tied to the specific network. -func New(network netmode.Magic) *Block { +func New(network netmode.Magic, stateRootEnabled bool) *Block { return &Block{ Base: Base{ - Network: network, + Network: network, + StateRootEnabled: stateRootEnabled, }, } } diff --git a/pkg/core/block/block_base.go b/pkg/core/block/block_base.go index 2fb7b2e01..483a62c01 100644 --- a/pkg/core/block/block_base.go +++ b/pkg/core/block/block_base.go @@ -43,6 +43,11 @@ type Base struct { // necessary for correct signing/verification. Network netmode.Magic + // StateRootEnabled specifies if header contains state root. + StateRootEnabled bool + // PrevStateRoot is state root of the previous block. + PrevStateRoot util.Uint256 + // Hash of this block, created when binary encoded (double SHA256). hash util.Uint256 @@ -61,6 +66,7 @@ type baseAux struct { Timestamp uint64 `json:"time"` Index uint32 `json:"index"` NextConsensus string `json:"nextconsensus"` + PrevStateRoot *util.Uint256 `json:"previousstateroot,omitempty"` Witnesses []transaction.Witness `json:"witnesses"` } @@ -130,6 +136,9 @@ func (b *Base) encodeHashableFields(bw *io.BinWriter) { bw.WriteU64LE(b.Timestamp) bw.WriteU32LE(b.Index) bw.WriteBytes(b.NextConsensus[:]) + if b.StateRootEnabled { + bw.WriteBytes(b.PrevStateRoot[:]) + } } // decodeHashableFields decodes the fields used for hashing. @@ -141,6 +150,9 @@ func (b *Base) decodeHashableFields(br *io.BinReader) { b.Timestamp = br.ReadU64LE() b.Index = br.ReadU32LE() br.ReadBytes(b.NextConsensus[:]) + if b.StateRootEnabled { + br.ReadBytes(b.PrevStateRoot[:]) + } // Make the hash of the block here so we dont need to do this // again. @@ -161,6 +173,9 @@ func (b Base) MarshalJSON() ([]byte, error) { NextConsensus: address.Uint160ToString(b.NextConsensus), Witnesses: []transaction.Witness{b.Script}, } + if b.StateRootEnabled { + aux.PrevStateRoot = &b.PrevStateRoot + } return json.Marshal(aux) } @@ -188,6 +203,12 @@ func (b *Base) UnmarshalJSON(data []byte) error { b.Index = aux.Index b.NextConsensus = nextC b.Script = aux.Witnesses[0] + if b.StateRootEnabled { + if aux.PrevStateRoot == nil { + return errors.New("'previousstateroot' is empty") + } + b.PrevStateRoot = *aux.PrevStateRoot + } if !aux.Hash.Equals(b.Hash()) { return errors.New("json 'hash' doesn't match block hash") } diff --git a/pkg/core/block/block_test.go b/pkg/core/block/block_test.go index ab1052944..a681056da 100644 --- a/pkg/core/block/block_test.go +++ b/pkg/core/block/block_test.go @@ -31,7 +31,7 @@ func TestDecodeBlock1(t *testing.T) { b, err := hex.DecodeString(data["raw"].(string)) require.NoError(t, err) - block := New(netmode.TestNet) + block := New(netmode.TestNet, false) assert.NoError(t, testserdes.DecodeBinary(b, block)) assert.Equal(t, uint32(data["index"].(float64)), block.Index) @@ -58,7 +58,7 @@ func TestTrimmedBlock(t *testing.T) { b, err := block.Trim() require.NoError(t, err) - trimmedBlock, err := NewBlockFromTrimmedBytes(netmode.TestNet, b) + trimmedBlock, err := NewBlockFromTrimmedBytes(netmode.TestNet, false, b) require.NoError(t, err) assert.True(t, trimmedBlock.Trimmed) @@ -114,7 +114,7 @@ func TestBinBlockDecodeEncode(t *testing.T) { rawtx := "0000000005440c786a66aaebf472aacb1d1db19d5b494c6a9226ea91bf5cf0e63a6605138cde5064efb81bc6539620b9e6d6d7c74f97d415b922c4fb4bb1833ce6a97a9d61f962fb7301000065f000005d12ac6c589d59f92e82d8bf60659cb716ffc1f101fd4a010c4011ff5d2138cf546d112ef712ee8a15277f7b6f1d5d2564b97497ac155782e6089cd3005dc9de81a8b22bb2f1c3a2edbac55e01581cb27980fdedf3a8bc57fa470c40657253c374a48da773fc653591f282a63a60695f29ab6c86300020ed505a019e5563e1be493efa71bdde37b16b4ec3f5f6dc2d2a2550151b020176b4dbe7afe40c403efdc559cb6bff135fd79138267db897c6fded01e3a0f15c0fb1c337359935d65e7ac49239f020951a74a96e11e73d225c9789953ffec40d5f7c9a84707b1d9a0c402804f24ab8034fa41223977ba48883eb94951184e31e5739872daf4f65461de3196ebf333f6d7dc4aff0b7b2143793179415f50a715484aba4e33b97dc636e150c40ed6b2ffeaef97eef746815ad16f5b8aed743892e93f7216bb744eb5c2f4cad91ae291919b61cd9a8d50fe85630d5e010c49a01ed687727c3ae5a7e17d4da213afdfd00150c2103009b7540e10f2562e5fd8fac9eaec25166a58b26e412348ff5a86927bfac22a20c21030205e9cefaea5a1dfc580af20c8d5aa2468bb0148f1a5e4605fc622c80e604ba0c210214baf0ceea3a66f17e7e1e839ea25fd8bed6cd82e6bb6e68250189065f44ff010c2103408dcd416396f64783ac587ea1e1593c57d9fea880c8a6a1920e92a2594778060c2102a7834be9b32e2981d157cb5bbd3acb42cfd11ea5c3b10224d7a44e98c5910f1b0c2102ba2c70f5996f357a43198705859fae2cfea13e1172962800772b3d588a9d4abd0c2102f889ecd43c5126ff1932d75fa87dea34fc95325fb724db93c8f79fe32cc3f180170b41138defaf0202c1353ed4e94d0cbc00be80024f7673890000000000261c130000000000e404210001f813c2cc8e18bbe4b3b87f8ef9105b50bb93918e01005d0300743ba40b0000000c14aa07cc3f2193a973904a09a6e60b87f1f96273970c14f813c2cc8e18bbe4b3b87f8ef9105b50bb93918e13c00c087472616e736665720c14bcaf41d684c7d4ad6ee0d99da9707b9d1f0c8e6641627d5b523801420c402360bbf64b9644c25f066dbd406454b07ab9f56e8e25d92d90c96c598f6c29d97eabdcf226f3575481662cfcdd064ee410978e5fae3f09a2f83129ba9cd82641290c2103caf763f91d3691cba5b5df3eb13e668fdace0295b37e2e259fd0fb152d354f900b4195440d78" rawtxBytes, _ := hex.DecodeString(rawtx) - b := New(netmode.TestNet) + b := New(netmode.TestNet, false) assert.NoError(t, testserdes.DecodeBinary(rawtxBytes, b)) expected := map[string]bool{ // 1 trans @@ -150,7 +150,7 @@ func TestBinBlockDecodeEncode(t *testing.T) { // update hidden hash value. _ = b.ConsensusData.Hash() - testserdes.MarshalUnmarshalJSON(t, b, New(netmode.TestNet)) + testserdes.MarshalUnmarshalJSON(t, b, New(netmode.TestNet, false)) } func TestBlockSizeCalculation(t *testing.T) { @@ -163,7 +163,7 @@ func TestBlockSizeCalculation(t *testing.T) { rawBlock := "0000000005440c786a66aaebf472aacb1d1db19d5b494c6a9226ea91bf5cf0e63a6605138cde5064efb81bc6539620b9e6d6d7c74f97d415b922c4fb4bb1833ce6a97a9d61f962fb7301000065f000005d12ac6c589d59f92e82d8bf60659cb716ffc1f101fd4a010c4011ff5d2138cf546d112ef712ee8a15277f7b6f1d5d2564b97497ac155782e6089cd3005dc9de81a8b22bb2f1c3a2edbac55e01581cb27980fdedf3a8bc57fa470c40657253c374a48da773fc653591f282a63a60695f29ab6c86300020ed505a019e5563e1be493efa71bdde37b16b4ec3f5f6dc2d2a2550151b020176b4dbe7afe40c403efdc559cb6bff135fd79138267db897c6fded01e3a0f15c0fb1c337359935d65e7ac49239f020951a74a96e11e73d225c9789953ffec40d5f7c9a84707b1d9a0c402804f24ab8034fa41223977ba48883eb94951184e31e5739872daf4f65461de3196ebf333f6d7dc4aff0b7b2143793179415f50a715484aba4e33b97dc636e150c40ed6b2ffeaef97eef746815ad16f5b8aed743892e93f7216bb744eb5c2f4cad91ae291919b61cd9a8d50fe85630d5e010c49a01ed687727c3ae5a7e17d4da213afdfd00150c2103009b7540e10f2562e5fd8fac9eaec25166a58b26e412348ff5a86927bfac22a20c21030205e9cefaea5a1dfc580af20c8d5aa2468bb0148f1a5e4605fc622c80e604ba0c210214baf0ceea3a66f17e7e1e839ea25fd8bed6cd82e6bb6e68250189065f44ff010c2103408dcd416396f64783ac587ea1e1593c57d9fea880c8a6a1920e92a2594778060c2102a7834be9b32e2981d157cb5bbd3acb42cfd11ea5c3b10224d7a44e98c5910f1b0c2102ba2c70f5996f357a43198705859fae2cfea13e1172962800772b3d588a9d4abd0c2102f889ecd43c5126ff1932d75fa87dea34fc95325fb724db93c8f79fe32cc3f180170b41138defaf0202c1353ed4e94d0cbc00be80024f7673890000000000261c130000000000e404210001f813c2cc8e18bbe4b3b87f8ef9105b50bb93918e01005d0300743ba40b0000000c14aa07cc3f2193a973904a09a6e60b87f1f96273970c14f813c2cc8e18bbe4b3b87f8ef9105b50bb93918e13c00c087472616e736665720c14bcaf41d684c7d4ad6ee0d99da9707b9d1f0c8e6641627d5b523801420c402360bbf64b9644c25f066dbd406454b07ab9f56e8e25d92d90c96c598f6c29d97eabdcf226f3575481662cfcdd064ee410978e5fae3f09a2f83129ba9cd82641290c2103caf763f91d3691cba5b5df3eb13e668fdace0295b37e2e259fd0fb152d354f900b4195440d78" rawBlockBytes, _ := hex.DecodeString(rawBlock) - b := New(netmode.TestNet) + b := New(netmode.TestNet, false) assert.NoError(t, testserdes.DecodeBinary(rawBlockBytes, b)) expected := []struct { diff --git a/pkg/core/block/header_test.go b/pkg/core/block/header_test.go index 989920fbf..4fc865757 100644 --- a/pkg/core/block/header_test.go +++ b/pkg/core/block/header_test.go @@ -6,12 +6,13 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/internal/random" "github.com/nspcc-dev/neo-go/pkg/internal/testserdes" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/stretchr/testify/assert" ) -func TestHeaderEncodeDecode(t *testing.T) { +func testHeaderEncodeDecode(t *testing.T, stateRootEnabled bool) { header := Header{Base: Base{ Version: 0, PrevHash: hash.Sha256([]byte("prevhash")), @@ -24,9 +25,13 @@ func TestHeaderEncodeDecode(t *testing.T) { VerificationScript: []byte{0x11}, }, }} + if stateRootEnabled { + header.StateRootEnabled = stateRootEnabled + header.PrevStateRoot = random.Uint256() + } _ = header.Hash() - headerDecode := &Header{} + headerDecode := &Header{Base: Base{StateRootEnabled: stateRootEnabled}} testserdes.EncodeDecodeBinary(t, &header, headerDecode) assert.Equal(t, header.Version, headerDecode.Version, "expected both versions to be equal") @@ -36,4 +41,14 @@ func TestHeaderEncodeDecode(t *testing.T) { assert.Equal(t, header.NextConsensus, headerDecode.NextConsensus, "expected both next consensus fields to be equal") assert.Equal(t, header.Script.InvocationScript, headerDecode.Script.InvocationScript, "expected equal invocation scripts") assert.Equal(t, header.Script.VerificationScript, headerDecode.Script.VerificationScript, "expected equal verification scripts") + assert.Equal(t, header.PrevStateRoot, headerDecode.PrevStateRoot, "expected equal state roots") +} + +func TestHeaderEncodeDecode(t *testing.T) { + t.Run("NoStateRoot", func(t *testing.T) { + testHeaderEncodeDecode(t, false) + }) + t.Run("WithStateRoot", func(t *testing.T) { + testHeaderEncodeDecode(t, true) + }) } diff --git a/pkg/core/block/helper_test.go b/pkg/core/block/helper_test.go index 98f1abf23..cefea3959 100644 --- a/pkg/core/block/helper_test.go +++ b/pkg/core/block/helper_test.go @@ -19,7 +19,7 @@ func getDecodedBlock(t *testing.T, i int) *Block { b, err := hex.DecodeString(data["raw"].(string)) require.NoError(t, err) - block := New(netmode.TestNet) + block := New(netmode.TestNet, false) require.NoError(t, testserdes.DecodeBinary(b, block)) return block diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index d26377d45..d7247f973 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -160,7 +160,7 @@ func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration, log *zap.L } bc := &Blockchain{ config: cfg, - dao: dao.NewSimple(s, cfg.Magic), + dao: dao.NewSimple(s, cfg.Magic, cfg.StateRootInHeader), stopCh: make(chan struct{}), runToExitCh: make(chan struct{}), memPool: mempool.New(cfg.MemPoolSize), @@ -429,6 +429,16 @@ func (bc *Blockchain) AddBlock(block *block.Block) error { if expectedHeight != block.Index { return fmt.Errorf("expected %d, got %d: %w", expectedHeight, block.Index, ErrInvalidBlockIndex) } + if bc.config.StateRootInHeader != block.StateRootEnabled { + return fmt.Errorf("%w: %v != %v", + ErrHdrStateRootSetting, bc.config.StateRootInHeader, block.StateRootEnabled) + } + if bc.config.StateRootInHeader { + if sr := bc.dao.MPT.StateRoot(); block.PrevStateRoot != sr { + return fmt.Errorf("%w: %s != %s", + ErrHdrInvalidStateRoot, block.PrevStateRoot.StringLE(), sr.StringLE()) + } + } if block.Index == bc.HeaderHeight()+1 { err := bc.addHeaders(bc.config.VerifyBlocks, block.Header()) @@ -1218,6 +1228,8 @@ var ( ErrHdrHashMismatch = errors.New("previous header hash doesn't match") ErrHdrIndexMismatch = errors.New("previous header index doesn't match") ErrHdrInvalidTimestamp = errors.New("block is not newer than the previous one") + ErrHdrStateRootSetting = errors.New("state root setting mismatch") + ErrHdrInvalidStateRoot = errors.New("state root for previous block is invalid") ) func (bc *Blockchain) verifyHeader(currHeader, prevHeader *block.Header) error { diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 93d18c442..15276ddcf 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -122,6 +122,33 @@ func TestAddBlock(t *testing.T) { assert.Equal(t, lastBlock.Hash(), bc.CurrentHeaderHash()) } +func TestAddBlockStateRoot(t *testing.T) { + bc := newTestChainWithStateRoot(t, true) + defer bc.Close() + + sr, err := bc.GetStateRoot(bc.BlockHeight()) + require.NoError(t, err) + + tx := newNEP5Transfer(bc.contracts.NEO.Hash, neoOwner, util.Uint160{}, 1) + tx.ValidUntilBlock = bc.BlockHeight() + 1 + addSigners(tx) + require.NoError(t, signTx(bc, tx)) + + lastBlock := bc.topBlock.Load().(*block.Block) + b := newBlock(bc.config, lastBlock.Index+1, lastBlock.Hash(), tx) + err = bc.AddBlock(b) + require.True(t, errors.Is(err, ErrHdrStateRootSetting), "got: %v", err) + + u := sr.Root + u[0] ^= 0xFF + b = newBlockWithState(bc.config, lastBlock.Index+1, lastBlock.Hash(), &u, tx) + err = bc.AddBlock(b) + require.True(t, errors.Is(err, ErrHdrInvalidStateRoot), "got: %v", err) + + b = bc.newBlock(tx) + require.NoError(t, bc.AddBlock(b)) +} + func TestAddBadBlock(t *testing.T) { bc := newTestChain(t) defer bc.Close() @@ -500,7 +527,7 @@ func TestVerifyTx(t *testing.T) { InvocationScript: testchain.SignCommittee(txSetOracle.GetSignedPart()), VerificationScript: testchain.CommitteeVerificationScript(), }} - bl := block.New(netmode.UnitTestNet) + bl := block.New(netmode.UnitTestNet, bc.config.StateRootInHeader) bl.Index = bc.BlockHeight() + 1 ic := bc.newInteropContext(trigger.All, bc.dao, bl, txSetOracle) ic.SpawnVM() diff --git a/pkg/core/dao/cacheddao_test.go b/pkg/core/dao/cacheddao_test.go index 7170a46ed..6a58debaa 100644 --- a/pkg/core/dao/cacheddao_test.go +++ b/pkg/core/dao/cacheddao_test.go @@ -15,7 +15,7 @@ import ( func TestCachedDaoContracts(t *testing.T) { store := storage.NewMemoryStore() - pdao := NewSimple(store, netmode.UnitTestNet) + pdao := NewSimple(store, netmode.UnitTestNet, false) dao := NewCached(pdao) script := []byte{0xde, 0xad, 0xbe, 0xef} @@ -54,7 +54,7 @@ func TestCachedDaoContracts(t *testing.T) { func TestCachedCachedDao(t *testing.T) { store := storage.NewMemoryStore() // Persistent DAO to check for backing storage. - pdao := NewSimple(store, netmode.UnitTestNet) + pdao := NewSimple(store, netmode.UnitTestNet, false) assert.NotEqual(t, store, pdao.Store) // Cached DAO. cdao := NewCached(pdao) diff --git a/pkg/core/dao/dao.go b/pkg/core/dao/dao.go index d2a4f9ebe..ed687d6ac 100644 --- a/pkg/core/dao/dao.go +++ b/pkg/core/dao/dao.go @@ -76,12 +76,14 @@ type Simple struct { MPT *mpt.Trie Store *storage.MemCachedStore network netmode.Magic + // stateRootInHeader specifies if block header contains state root. + stateRootInHeader bool } // NewSimple creates new simple dao using provided backend store. -func NewSimple(backend storage.Store, network netmode.Magic) *Simple { +func NewSimple(backend storage.Store, network netmode.Magic, stateRootInHeader bool) *Simple { st := storage.NewMemCachedStore(backend) - return &Simple{Store: st, network: network} + return &Simple{Store: st, network: network, stateRootInHeader: stateRootInHeader} } // GetBatch returns currently accumulated DB changeset. @@ -92,7 +94,7 @@ func (dao *Simple) GetBatch() *storage.MemBatch { // GetWrapped returns new DAO instance with another layer of wrapped // MemCachedStore around the current DAO Store. func (dao *Simple) GetWrapped() DAO { - d := NewSimple(dao.Store, dao.network) + d := NewSimple(dao.Store, dao.network, dao.stateRootInHeader) d.MPT = dao.MPT return d } @@ -514,7 +516,7 @@ func (dao *Simple) GetBlock(hash util.Uint256) (*block.Block, error) { return nil, err } - block, err := block.NewBlockFromTrimmedBytes(dao.network, b) + block, err := block.NewBlockFromTrimmedBytes(dao.network, dao.stateRootInHeader, b) if err != nil { return nil, err } diff --git a/pkg/core/dao/dao_test.go b/pkg/core/dao/dao_test.go index bf6ba00b1..74f76a1a7 100644 --- a/pkg/core/dao/dao_test.go +++ b/pkg/core/dao/dao_test.go @@ -18,7 +18,7 @@ import ( ) func TestPutGetAndDecode(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) serializable := &TestSerializable{field: random.String(4)} hash := []byte{1} err := dao.Put(serializable, hash) @@ -43,7 +43,7 @@ func (t *TestSerializable) DecodeBinary(reader *io.BinReader) { } func TestPutAndGetContractState(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) contractState := &state.Contract{Script: []byte{}} hash := contractState.ScriptHash() err := dao.PutContractState(contractState) @@ -54,7 +54,7 @@ func TestPutAndGetContractState(t *testing.T) { } func TestDeleteContractState(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) contractState := &state.Contract{Script: []byte{}} hash := contractState.ScriptHash() err := dao.PutContractState(contractState) @@ -67,7 +67,7 @@ func TestDeleteContractState(t *testing.T) { } func TestSimple_GetAndUpdateNextContractID(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) id, err := dao.GetAndUpdateNextContractID() require.NoError(t, err) require.EqualValues(t, 0, id) @@ -80,7 +80,7 @@ func TestSimple_GetAndUpdateNextContractID(t *testing.T) { } func TestPutGetAppExecResult(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) hash := random.Uint256() appExecResult := &state.AppExecResult{ Container: hash, @@ -98,7 +98,7 @@ func TestPutGetAppExecResult(t *testing.T) { } func TestPutGetStorageItem(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) id := int32(random.Int(0, 1024)) key := []byte{0} storageItem := &state.StorageItem{Value: []uint8{}} @@ -109,7 +109,7 @@ func TestPutGetStorageItem(t *testing.T) { } func TestDeleteStorageItem(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) id := int32(random.Int(0, 1024)) key := []byte{0} storageItem := &state.StorageItem{Value: []uint8{}} @@ -122,7 +122,7 @@ func TestDeleteStorageItem(t *testing.T) { } func TestGetBlock_NotExists(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) hash := random.Uint256() block, err := dao.GetBlock(hash) require.Error(t, err) @@ -130,7 +130,7 @@ func TestGetBlock_NotExists(t *testing.T) { } func TestPutGetBlock(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) b := &block.Block{ Base: block.Base{ Script: transaction.Witness{ @@ -148,14 +148,14 @@ func TestPutGetBlock(t *testing.T) { } func TestGetVersion_NoVersion(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) version, err := dao.GetVersion() require.Error(t, err) require.Equal(t, "", version) } func TestGetVersion(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) err := dao.PutVersion("testVersion") require.NoError(t, err) version, err := dao.GetVersion() @@ -164,14 +164,14 @@ func TestGetVersion(t *testing.T) { } func TestGetCurrentHeaderHeight_NoHeader(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) height, err := dao.GetCurrentBlockHeight() require.Error(t, err) require.Equal(t, uint32(0), height) } func TestGetCurrentHeaderHeight_Store(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) b := &block.Block{ Base: block.Base{ Script: transaction.Witness{ @@ -188,7 +188,7 @@ func TestGetCurrentHeaderHeight_Store(t *testing.T) { } func TestStoreAsTransaction(t *testing.T) { - dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet) + dao := NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) tx := transaction.New(netmode.UnitTestNet, []byte{byte(opcode.PUSH1)}, 1) hash := tx.Hash() err := dao.StoreAsTransaction(tx, 0, nil) diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go index ad8205e81..63119ab2e 100644 --- a/pkg/core/helper_test.go +++ b/pkg/core/helper_test.go @@ -38,8 +38,13 @@ var neoOwner = testchain.MultisigScriptHash() // newTestChain should be called before newBlock invocation to properly setup // global state. func newTestChain(t *testing.T) *Blockchain { + return newTestChainWithStateRoot(t, false) +} + +func newTestChainWithStateRoot(t *testing.T, stateRootInHeader bool) *Blockchain { unitTestNetCfg, err := config.Load("../../config", testchain.Network()) require.NoError(t, err) + unitTestNetCfg.ProtocolConfiguration.StateRootInHeader = stateRootInHeader chain, err := NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t)) require.NoError(t, err) go chain.Run() @@ -48,10 +53,22 @@ func newTestChain(t *testing.T) *Blockchain { func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block { lastBlock := bc.topBlock.Load().(*block.Block) + if bc.config.StateRootInHeader { + sr, err := bc.GetStateRoot(bc.BlockHeight()) + if err != nil { + panic(err) + } + return newBlockWithState(bc.config, lastBlock.Index+1, lastBlock.Hash(), &sr.Root, txs...) + } return newBlock(bc.config, lastBlock.Index+1, lastBlock.Hash(), txs...) } func newBlock(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, txs ...*transaction.Transaction) *block.Block { + return newBlockWithState(cfg, index, prev, nil, txs...) +} + +func newBlockWithState(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, + prevState *util.Uint256, txs ...*transaction.Transaction) *block.Block { validators, _ := validatorsFromConfig(cfg) valScript, _ := smartcontract.CreateDefaultMultiSigRedeemScript(validators) witness := transaction.Witness{ @@ -73,6 +90,10 @@ func newBlock(cfg config.ProtocolConfiguration, index uint32, prev util.Uint256, }, Transactions: txs, } + if prevState != nil { + b.StateRootEnabled = true + b.PrevStateRoot = *prevState + } b.RebuildMerkleRoot() b.Script.InvocationScript = testchain.Sign(b.GetSignedPart()) return b @@ -99,7 +120,7 @@ func getDecodedBlock(t *testing.T, i int) *block.Block { b, err := hex.DecodeString(data["raw"].(string)) require.NoError(t, err) - block := block.New(testchain.Network()) + block := block.New(testchain.Network(), false) require.NoError(t, testserdes.DecodeBinary(b, block)) return block diff --git a/pkg/core/interop_neo_test.go b/pkg/core/interop_neo_test.go index a04fd3f12..47b6c34b8 100644 --- a/pkg/core/interop_neo_test.go +++ b/pkg/core/interop_neo_test.go @@ -144,7 +144,7 @@ func TestECDSAVerify(t *testing.T) { chain := newTestChain(t) defer chain.Close() - ic := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil) + ic := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false), nil, nil) runCase := func(t *testing.T, isErr bool, result interface{}, args ...interface{}) { ic.SpawnVM() for i := range args { @@ -266,7 +266,7 @@ func TestRuntimeEncodeDecode(t *testing.T) { func createVM(t *testing.T) (*vm.VM, *interop.Context, *Blockchain) { chain := newTestChain(t) context := chain.newInteropContext(trigger.Application, - dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil) + dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.config.StateRootInHeader), nil, nil) v := context.SpawnVM() return v, context, chain } @@ -280,7 +280,8 @@ func createVMAndPushBlock(t *testing.T) (*vm.VM, *block.Block, *interop.Context, func createVMAndBlock(t *testing.T) (*vm.VM, *block.Block, *interop.Context, *Blockchain) { block := newDumbBlock() chain := newTestChain(t) - context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), block, nil) + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.GetConfig().StateRootInHeader) + context := chain.newInteropContext(trigger.Application, d, block, nil) v := context.SpawnVM() return v, block, context, chain } @@ -301,7 +302,8 @@ func createVMAndContractState(t *testing.T) (*vm.VM, *state.Contract, *interop.C } chain := newTestChain(t) - context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil) + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.config.StateRootInHeader) + context := chain.newInteropContext(trigger.Application, d, nil, nil) v := context.SpawnVM() return v, contractState, context, chain } @@ -312,7 +314,8 @@ func createVMAndTX(t *testing.T) (*vm.VM, *transaction.Transaction, *interop.Con tx.Signers = []transaction.Signer{{Account: util.Uint160{1, 2, 3, 4}}} chain := newTestChain(t) - context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, tx) + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.config.StateRootInHeader) + context := chain.newInteropContext(trigger.Application, d, nil, tx) v := context.SpawnVM() return v, tx, context, chain } diff --git a/pkg/core/interops_test.go b/pkg/core/interops_test.go index 748744e0b..a8391ef9d 100644 --- a/pkg/core/interops_test.go +++ b/pkg/core/interops_test.go @@ -19,7 +19,8 @@ func testNonInterop(t *testing.T, value interface{}, f func(*interop.Context) er v.Estack().PushVal(value) chain := newTestChain(t) defer chain.Close() - context := chain.newInteropContext(trigger.Application, dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil) + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.config.StateRootInHeader) + context := chain.newInteropContext(trigger.Application, d, nil, nil) context.VM = v require.Error(t, f(context)) } diff --git a/pkg/core/native_contract_test.go b/pkg/core/native_contract_test.go index 49061db49..3a438c365 100644 --- a/pkg/core/native_contract_test.go +++ b/pkg/core/native_contract_test.go @@ -225,8 +225,8 @@ func TestNativeContract_InvokeInternal(t *testing.T) { }) require.NoError(t, err) - ic := chain.newInteropContext(trigger.Application, - dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet), nil, nil) + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, chain.config.StateRootInHeader) + ic := chain.newInteropContext(trigger.Application, d, nil, nil) v := ic.SpawnVM() t.Run("fail, bad current script hash", func(t *testing.T) { diff --git a/pkg/core/native_designate_test.go b/pkg/core/native_designate_test.go index f12fff91a..dcc75c17e 100644 --- a/pkg/core/native_designate_test.go +++ b/pkg/core/native_designate_test.go @@ -118,7 +118,7 @@ func TestDesignate_DesignateAsRole(t *testing.T) { des := bc.contracts.Designate tx := transaction.New(netmode.UnitTestNet, []byte{}, 0) - bl := block.New(netmode.UnitTestNet) + bl := block.New(netmode.UnitTestNet, bc.config.StateRootInHeader) bl.Index = bc.BlockHeight() + 1 ic := bc.newInteropContext(trigger.OnPersist, bc.dao, bl, tx) ic.SpawnVM() diff --git a/pkg/core/native_oracle_test.go b/pkg/core/native_oracle_test.go index 2d603e2b2..3df0c9b55 100644 --- a/pkg/core/native_oracle_test.go +++ b/pkg/core/native_oracle_test.go @@ -142,7 +142,7 @@ func TestOracle_Request(t *testing.T) { pub := priv.PublicKey() tx := transaction.New(netmode.UnitTestNet, []byte{}, 0) - bl := block.New(netmode.UnitTestNet) + bl := block.New(netmode.UnitTestNet, bc.config.StateRootInHeader) bl.Index = bc.BlockHeight() + 1 setSigner(tx, testchain.CommitteeScriptHash()) ic := bc.newInteropContext(trigger.Application, bc.dao, bl, tx) diff --git a/pkg/network/message.go b/pkg/network/message.go index 368d69cc9..efa58ad29 100644 --- a/pkg/network/message.go +++ b/pkg/network/message.go @@ -34,6 +34,9 @@ type Message struct { // Network this message comes from, it has to be set upon Message // creation for correct decoding. Network netmode.Magic + // StateRootInHeader specifies if state root is included in block header. + // This is needed for correct decoding. + StateRootInHeader bool } // MessageFlag represents compression level of message payload @@ -106,7 +109,7 @@ func (m *Message) Decode(br *io.BinReader) error { case CMDFilterClear, CMDGetAddr, CMDMempool, CMDVerack: m.Payload = payload.NewNullPayload() default: - return errors.New("unexpected empty payload") + return fmt.Errorf("unexpected empty payload: %s", m.Command) } return nil } @@ -142,9 +145,9 @@ func (m *Message) decodePayload() error { case CMDAddr: p = &payload.AddressList{} case CMDBlock: - p = block.New(m.Network) + p = block.New(m.Network, m.StateRootInHeader) case CMDConsensus: - p = consensus.NewPayload(m.Network) + p = consensus.NewPayload(m.Network, m.StateRootInHeader) case CMDGetBlocks: p = &payload.GetBlocks{} case CMDGetHeaders: @@ -152,7 +155,7 @@ func (m *Message) decodePayload() error { case CMDGetBlockByIndex: p = &payload.GetBlockByIndex{} case CMDHeaders: - p = &payload.Headers{Network: m.Network} + p = &payload.Headers{Network: m.Network, StateRootInHeader: m.StateRootInHeader} case CMDTX: p = &transaction.Transaction{Network: m.Network} case CMDMerkleBlock: diff --git a/pkg/network/payload/headers.go b/pkg/network/payload/headers.go index ba68bbc78..61b27053e 100644 --- a/pkg/network/payload/headers.go +++ b/pkg/network/payload/headers.go @@ -12,6 +12,8 @@ import ( type Headers struct { Hdrs []*block.Header Network netmode.Magic + // StateRootInHeader specifies whether header contains state root. + StateRootInHeader bool } // Users can at most request 2k header. @@ -38,6 +40,7 @@ func (p *Headers) DecodeBinary(br *io.BinReader) { for i := 0; i < int(lenHeaders); i++ { header := &block.Header{} header.Network = p.Network + header.StateRootEnabled = p.StateRootInHeader header.DecodeBinary(br) p.Hdrs[i] = header } diff --git a/pkg/network/server.go b/pkg/network/server.go index 6dd91b0ba..fef50fe27 100644 --- a/pkg/network/server.go +++ b/pkg/network/server.go @@ -55,6 +55,8 @@ type ( // Network's magic number for correct message decoding. network netmode.Magic + // stateRootInHeader specifies if block header contain state root. + stateRootInHeader bool transport Transporter discovery Discoverer @@ -95,17 +97,18 @@ func NewServer(config ServerConfig, chain blockchainer.Blockchainer, log *zap.Lo } s := &Server{ - ServerConfig: config, - chain: chain, - id: randomID(), - network: chain.GetConfig().Magic, - quit: make(chan struct{}), - register: make(chan Peer), - unregister: make(chan peerDrop), - peers: make(map[Peer]bool), - consensusStarted: atomic.NewBool(false), - log: log, - transactions: make(chan *transaction.Transaction, 64), + ServerConfig: config, + chain: chain, + id: randomID(), + network: chain.GetConfig().Magic, + stateRootInHeader: chain.GetConfig().StateRootInHeader, + quit: make(chan struct{}), + register: make(chan Peer), + unregister: make(chan peerDrop), + peers: make(map[Peer]bool), + consensusStarted: atomic.NewBool(false), + log: log, + transactions: make(chan *transaction.Transaction, 64), } s.bQueue = newBlockQueue(maxBlockBatch, chain, log, func(b *block.Block) { if !s.consensusStarted.Load() { diff --git a/pkg/network/tcp_peer.go b/pkg/network/tcp_peer.go index 64d884feb..ea383f5d9 100644 --- a/pkg/network/tcp_peer.go +++ b/pkg/network/tcp_peer.go @@ -150,7 +150,7 @@ func (p *TCPPeer) handleConn() { if err == nil { r := io.NewBinReaderFromIO(p.conn) for { - msg := &Message{Network: p.server.network} + msg := &Message{Network: p.server.network, StateRootInHeader: p.server.stateRootInHeader} err = msg.Decode(r) if err == payload.ErrTooManyHeaders { diff --git a/pkg/rpc/client/client.go b/pkg/rpc/client/client.go index e72becf7f..c9042a468 100644 --- a/pkg/rpc/client/client.go +++ b/pkg/rpc/client/client.go @@ -28,14 +28,15 @@ const ( // Client represents the middleman for executing JSON RPC calls // to remote NEO RPC nodes. type Client struct { - cli *http.Client - endpoint *url.URL - network netmode.Magic - initDone bool - ctx context.Context - opts Options - requestF func(*request.Raw) (*response.Raw, error) - cache cache + cli *http.Client + endpoint *url.URL + network netmode.Magic + stateRootInHeader bool + initDone bool + ctx context.Context + opts Options + requestF func(*request.Raw) (*response.Raw, error) + cache cache } // Options defines options for the RPC client. @@ -115,6 +116,7 @@ func (c *Client) Init() error { return fmt.Errorf("failed to get network magic: %w", err) } c.network = version.Magic + c.stateRootInHeader = version.StateRootInHeader neoContractHash, err := c.GetContractStateByAddressOrName("neo") if err != nil { return fmt.Errorf("failed to get NEO contract scripthash: %w", err) diff --git a/pkg/rpc/client/rpc.go b/pkg/rpc/client/rpc.go index 075de911f..446ca9b3d 100644 --- a/pkg/rpc/client/rpc.go +++ b/pkg/rpc/client/rpc.go @@ -83,7 +83,7 @@ func (c *Client) getBlock(params request.RawParams) (*block.Block, error) { return nil, err } r := io.NewBinReaderFromBuf(resp) - b = block.New(c.GetNetwork()) + b = block.New(c.GetNetwork(), c.StateRootInHeader()) b.DecodeBinary(r) if r.Err != nil { return nil, r.Err @@ -606,6 +606,11 @@ func (c *Client) GetNetwork() netmode.Magic { return c.network } +// StateRootInHeader returns true if state root is contained in block header. +func (c *Client) StateRootInHeader() bool { + return c.stateRootInHeader +} + // GetNativeContractHash returns native contract hash by its name. It is not case-sensitive. func (c *Client) GetNativeContractHash(name string) (util.Uint160, error) { lowercasedName := strings.ToLower(name) diff --git a/pkg/rpc/client/rpc_test.go b/pkg/rpc/client/rpc_test.go index 0ca265f13..b82f3f74c 100644 --- a/pkg/rpc/client/rpc_test.go +++ b/pkg/rpc/client/rpc_test.go @@ -61,7 +61,7 @@ func getResultBlock1() *result.Block { if err != nil { panic(err) } - b := block.New(netmode.UnitTestNet) + b := block.New(netmode.UnitTestNet, false) err = testserdes.DecodeBinary(binB, b) if err != nil { panic(err) diff --git a/pkg/rpc/client/wsclient.go b/pkg/rpc/client/wsclient.go index 266257ef1..7f79b6711 100644 --- a/pkg/rpc/client/wsclient.go +++ b/pkg/rpc/client/wsclient.go @@ -139,7 +139,7 @@ readloop: var val interface{} switch event { case response.BlockEventID: - val = block.New(c.GetNetwork()) + val = block.New(c.GetNetwork(), c.StateRootInHeader()) case response.TransactionEventID: val = &transaction.Transaction{Network: c.GetNetwork()} case response.NotificationEventID: diff --git a/pkg/rpc/response/result/version.go b/pkg/rpc/response/result/version.go index 4d0fd76d0..5a1f9c5ce 100644 --- a/pkg/rpc/response/result/version.go +++ b/pkg/rpc/response/result/version.go @@ -11,5 +11,7 @@ type ( WSPort uint16 `json:"wsport,omitempty"` Nonce uint32 `json:"nonce"` UserAgent string `json:"useragent"` + // StateRootInHeader is true if state root is contained in block header. + StateRootInHeader bool `json:"staterootinheader,omitempty"` } ) diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 8afba33b3..5e3083d74 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -39,13 +39,14 @@ type ( // Server represents the JSON-RPC 2.0 server. Server struct { *http.Server - chain blockchainer.Blockchainer - config rpc.Config - network netmode.Magic - coreServer *network.Server - log *zap.Logger - https *http.Server - shutdown chan struct{} + chain blockchainer.Blockchainer + config rpc.Config + network netmode.Magic + stateRootEnabled bool + coreServer *network.Server + log *zap.Logger + https *http.Server + shutdown chan struct{} subsLock sync.RWMutex subscribers map[*subscriber]bool @@ -138,14 +139,15 @@ func New(chain blockchainer.Blockchainer, conf rpc.Config, coreServer *network.S } return Server{ - Server: httpServer, - chain: chain, - config: conf, - network: chain.GetConfig().Magic, - coreServer: coreServer, - log: log, - https: tlsServer, - shutdown: make(chan struct{}), + Server: httpServer, + chain: chain, + config: conf, + network: chain.GetConfig().Magic, + stateRootEnabled: chain.GetConfig().StateRootInHeader, + coreServer: coreServer, + log: log, + https: tlsServer, + shutdown: make(chan struct{}), subscribers: make(map[*subscriber]bool), // These are NOT buffered to preserve original order of events. @@ -1039,7 +1041,7 @@ func (s *Server) submitBlock(reqParams request.Params) (interface{}, *response.E if err != nil { return nil, response.ErrInvalidParams } - b := block.New(s.network) + b := block.New(s.network, s.stateRootEnabled) r := io.NewBinReaderFromBuf(blockBytes) b.DecodeBinary(r) if r.Err != nil { diff --git a/pkg/rpc/server/server_helper_test.go b/pkg/rpc/server/server_helper_test.go index ee41e0be4..a62c70820 100644 --- a/pkg/rpc/server/server_helper_test.go +++ b/pkg/rpc/server/server_helper_test.go @@ -52,7 +52,7 @@ func getTestBlocks(t *testing.T) []*block.Block { blocks := make([]*block.Block, 0, int(nBlocks)) for i := 0; i < int(nBlocks); i++ { _ = br.ReadU32LE() - b := block.New(netmode.UnitTestNet) + b := block.New(netmode.UnitTestNet, false) b.DecodeBinary(br) require.Nil(t, br.Err) blocks = append(blocks, b)