diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 8098eef19..81ced67a1 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -284,6 +284,11 @@ func (bc *Blockchain) init() error { return fmt.Errorf("can't init cache for NEO native contract: %w", err) } + err = bc.contracts.Management.InitializeCache(bc.dao) + if err != nil { + return fmt.Errorf("can't init cache for Management native contract: %w", err) + } + return nil } diff --git a/pkg/core/helper_test.go b/pkg/core/helper_test.go index 003fea1f4..f570f4cd9 100644 --- a/pkg/core/helper_test.go +++ b/pkg/core/helper_test.go @@ -46,19 +46,33 @@ func newTestChain(t *testing.T) *Blockchain { } func newTestChainWithCustomCfg(t *testing.T, f func(*config.Config)) *Blockchain { + return newTestChainWithCustomCfgAndStore(t, nil, f) +} + +func newTestChainWithCustomCfgAndStore(t *testing.T, st storage.Store, f func(*config.Config)) *Blockchain { unitTestNetCfg, err := config.Load("../../config", testchain.Network()) require.NoError(t, err) if f != nil { f(&unitTestNetCfg) } - chain, err := NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t)) + if st == nil { + st = storage.NewMemoryStore() + } + chain, err := NewBlockchain(st, unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t)) require.NoError(t, err) go chain.Run() return chain } func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block { - lastBlock := bc.topBlock.Load().(*block.Block) + lastBlock, ok := bc.topBlock.Load().(*block.Block) + if !ok { + var err error + lastBlock, err = bc.GetBlock(bc.GetHeaderHash(int(bc.BlockHeight()))) + if err != nil { + panic(err) + } + } if bc.config.StateRootInHeader { sr, err := bc.GetStateRoot(bc.BlockHeight()) if err != nil { diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index 92be3e91f..39d9a8213 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -6,13 +6,16 @@ import ( "fmt" "math" "math/big" + "sync" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop/contract" "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" @@ -24,6 +27,9 @@ import ( // Management is contract-managing native contract. type Management struct { interop.ContractMD + + mtx sync.RWMutex + contracts map[util.Uint160]*state.Contract } // StoragePrice is the price to pay for 1 byte of storage. @@ -43,7 +49,10 @@ func makeContractKey(h util.Uint160) []byte { // newManagement creates new Management native contract. func newManagement() *Management { - var m = &Management{ContractMD: *interop.NewContractMD(nativenames.Management)} + var m = &Management{ + ContractMD: *interop.NewContractMD(nativenames.Management), + contracts: make(map[util.Uint160]*state.Contract), + } desc := newDescriptor("getContract", smartcontract.ArrayType, manifest.NewParameter("hash", smartcontract.Hash160Type)) @@ -89,6 +98,18 @@ func (m *Management) getContract(ic *interop.Context, args []stackitem.Item) sta // GetContract returns contract with given hash from given DAO. func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, error) { + m.mtx.RLock() + cs, ok := m.contracts[hash] + m.mtx.RUnlock() + if !ok { + return nil, storage.ErrKeyNotFound + } else if cs != nil { + return cs, nil + } + return m.getContractFromDAO(d, hash) +} + +func (m *Management) getContractFromDAO(d dao.DAO, hash util.Uint160) (*state.Contract, error) { contract := new(state.Contract) key := makeContractKey(hash) err := getSerializableFromDAO(m.ContractID, d, key, contract) @@ -173,7 +194,13 @@ func (m *Management) deploy(ic *interop.Context, args []stackitem.Item) stackite } callDeploy(ic, newcontract, false) return contractToStack(newcontract) +} +func (m *Management) markUpdated(h util.Uint160) { + m.mtx.Lock() + // Just set it to nil, to refresh cache in `PostPersist`. + m.contracts[h] = nil + m.mtx.Unlock() } // Deploy creates contract's hash/ID and saves new contract into the given DAO. @@ -202,6 +229,7 @@ func (m *Management) Deploy(d dao.DAO, sender util.Uint160, neff *nef.File, mani if err != nil { return nil, err } + m.markUpdated(newcontract.Hash) return newcontract, nil } @@ -232,14 +260,16 @@ func (m *Management) Update(d dao.DAO, hash util.Uint160, neff *nef.File, manif } // if NEF was provided, update the contract script if neff != nil { + m.markUpdated(hash) contract.Script = neff.Script } // if manifest was provided, update the contract manifest if manif != nil { - contract.Manifest = *manif - if !contract.Manifest.IsValid(contract.Hash) { + if !manif.IsValid(contract.Hash) { return nil, errors.New("invalid manifest for this contract") } + m.markUpdated(hash) + contract.Manifest = *manif } contract.UpdateCounter++ err = m.PutContractState(d, contract) @@ -285,6 +315,7 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error { return err } } + m.markUpdated(hash) return nil } @@ -334,13 +365,59 @@ func (m *Management) OnPersist(ic *interop.Context) error { if err := native.Initialize(ic); err != nil { return fmt.Errorf("initializing %s native contract: %w", md.Name, err) } + m.mtx.Lock() + m.contracts[md.Hash] = cs + m.mtx.Unlock() } return nil } +// InitializeCache initializes contract cache with the proper values from storage. +// Cache initialisation should be done apart from Initialize because Initialize is +// called only when deploying native contracts. +func (m *Management) InitializeCache(d dao.DAO) error { + m.mtx.Lock() + defer m.mtx.Unlock() + + var initErr error + d.Seek(m.ContractID, []byte{prefixContract}, func(_, v []byte) { + var r = io.NewBinReaderFromBuf(v) + var si state.StorageItem + si.DecodeBinary(r) + if r.Err != nil { + initErr = r.Err + return + } + + var cs state.Contract + r = io.NewBinReaderFromBuf(si.Value) + cs.DecodeBinary(r) + if r.Err != nil { + initErr = r.Err + return + } + m.contracts[cs.Hash] = &cs + }) + return initErr +} + // PostPersist implements Contract interface. -func (m *Management) PostPersist(_ *interop.Context) error { +func (m *Management) PostPersist(ic *interop.Context) error { + m.mtx.Lock() + for h, cs := range m.contracts { + if cs != nil { + continue + } + newCs, err := m.getContractFromDAO(ic.DAO, h) + if err != nil { + // Contract was destroyed. + delete(m.contracts, h) + continue + } + m.contracts[h] = newCs + } + m.mtx.Unlock() return nil } @@ -355,6 +432,7 @@ func (m *Management) PutContractState(d dao.DAO, cs *state.Contract) error { if err := putSerializableToDAO(m.ContractID, d, key, cs); err != nil { return err } + m.markUpdated(cs.Hash) if cs.UpdateCounter != 0 { // Update. return nil } diff --git a/pkg/core/native/management_test.go b/pkg/core/native/management_test.go index e708654a4..0b652d0ed 100644 --- a/pkg/core/native/management_test.go +++ b/pkg/core/native/management_test.go @@ -61,3 +61,17 @@ func TestDeployGetUpdateDestroyContract(t *testing.T) { _, err = mgmt.GetContract(d, h) require.Error(t, err) } + +func TestManagement_Initialize(t *testing.T) { + t.Run("good", func(t *testing.T) { + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) + mgmt := newManagement() + require.NoError(t, mgmt.InitializeCache(d)) + }) + t.Run("invalid contract state", func(t *testing.T) { + d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false) + mgmt := newManagement() + require.NoError(t, d.PutStorageItem(mgmt.ContractID, []byte{prefixContract}, &state.StorageItem{Value: []byte{0xFF}})) + require.Error(t, mgmt.InitializeCache(d)) + }) +} diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index 43204e451..786f5655d 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -7,17 +7,70 @@ import ( "github.com/nspcc-dev/neo-go/internal/testchain" "github.com/nspcc-dev/neo-go/pkg/config" + "github.com/nspcc-dev/neo-go/pkg/core/chaindump" "github.com/nspcc-dev/neo-go/pkg/core/state" + "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" + "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/require" ) +// This is in a separate test because test test for long manifest +// prevents chain from being dumped. In any real scenario +// restrictions on tx script length will be applied before +// restrictions on manifest size. In this test providing manifest of max size +// leads to tx deserialization failure. +func TestRestoreAfterDeploy(t *testing.T) { + bc := newTestChain(t) + defer bc.Close() + + // nef.NewFile() cares about version a lot. + config.Version = "0.90.0-test" + mgmtHash := bc.ManagementContractHash() + cs1, _ := getTestContractState(bc) + cs1.ID = 1 + cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.Script) + manif1, err := json.Marshal(cs1.Manifest) + require.NoError(t, err) + nef1, err := nef.NewFile(cs1.Script) + require.NoError(t, err) + nef1b, err := nef1.Bytes() + require.NoError(t, err) + + res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...)) + require.NoError(t, err) + checkFAULTState(t, res) +} + +type memoryStore struct { + *storage.MemoryStore +} + +func (memoryStore) Close() error { return nil } + +func TestStartFromHeight(t *testing.T) { + st := memoryStore{storage.NewMemoryStore()} + bc := newTestChainWithCustomCfgAndStore(t, st, nil) + cs1, _ := getTestContractState(bc) + func() { + defer bc.Close() + require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs1)) + checkContractState(t, bc, cs1.Hash, cs1) + _, err := bc.dao.Store.Persist() + require.NoError(t, err) + }() + + bc2 := newTestChainWithCustomCfgAndStore(t, st, nil) + checkContractState(t, bc2, cs1.Hash, cs1) +} + func TestContractDeploy(t *testing.T) { bc := newTestChain(t) defer bc.Close() @@ -70,11 +123,6 @@ func TestContractDeploy(t *testing.T) { require.NoError(t, err) checkFAULTState(t, res) }) - t.Run("too long manifest", func(t *testing.T) { - res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...)) - require.NoError(t, err) - checkFAULTState(t, res) - }) t.Run("array for manifest", func(t *testing.T) { res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, []interface{}{int64(1)}) require.NoError(t, err) @@ -99,17 +147,41 @@ func TestContractDeploy(t *testing.T) { checkFAULTState(t, res) }) t.Run("positive", func(t *testing.T) { - res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1) + tx1, err := prepareContractMethodInvoke(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1) require.NoError(t, err) - require.Equal(t, vm.HaltState, res.VMState) - require.Equal(t, 1, len(res.Stack)) - compareContractStates(t, cs1, res.Stack[0]) + tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) + require.NoError(t, err) + + aers, err := persistBlock(bc, tx1, tx2) + require.NoError(t, err) + for _, res := range aers { + require.Equal(t, vm.HaltState, res.VMState) + require.Equal(t, 1, len(res.Stack)) + compareContractStates(t, cs1, res.Stack[0]) + } + t.Run("_deploy called", func(t *testing.T) { res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue") require.NoError(t, err) require.Equal(t, 1, len(res.Stack)) require.Equal(t, []byte("create"), res.Stack[0].Value()) }) + t.Run("get after deploy", func(t *testing.T) { + checkContractState(t, bc, cs1.Hash, cs1) + }) + t.Run("get after restore", func(t *testing.T) { + w := io.NewBufBinWriter() + require.NoError(t, chaindump.Dump(bc, w.BinWriter, 0, bc.BlockHeight()+1)) + require.NoError(t, w.Err) + + r := io.NewBinReaderFromBuf(w.Bytes()) + bc2 := newTestChain(t) + defer bc2.Close() + + require.NoError(t, chaindump.Restore(bc2, r, 0, bc.BlockHeight()+1, nil)) + require.NoError(t, r.Err) + checkContractState(t, bc2, cs1.Hash, cs1) + }) }) t.Run("contract already exists", func(t *testing.T) { res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1) @@ -138,6 +210,11 @@ func TestContractDeploy(t *testing.T) { res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD) require.NoError(t, err) checkFAULTState(t, res) + + t.Run("get after failed deploy", func(t *testing.T) { + h := state.CreateContractHash(neoOwner, deployScript) + checkContractState(t, bc, h, nil) + }) }) t.Run("bad _deploy", func(t *testing.T) { // invalid _deploy signature deployScript := []byte{byte(opcode.RET)} @@ -162,9 +239,27 @@ func TestContractDeploy(t *testing.T) { res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD) require.NoError(t, err) checkFAULTState(t, res) + + t.Run("get after bad _deploy", func(t *testing.T) { + h := state.CreateContractHash(neoOwner, deployScript) + checkContractState(t, bc, h, nil) + }) }) } +func checkContractState(t *testing.T, bc *Blockchain, h util.Uint160, cs *state.Contract) { + mgmtHash := bc.contracts.Management.Hash + res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", h.BytesBE()) + require.NoError(t, err) + if cs == nil { + require.Equal(t, vm.FaultState, res.VMState) + return + } + require.Equal(t, vm.HaltState, res.VMState) + require.Equal(t, 1, len(res.Stack)) + compareContractStates(t, cs, res.Stack[0]) +} + func TestContractUpdate(t *testing.T) { bc := newTestChain(t) defer bc.Close() @@ -231,9 +326,18 @@ func TestContractUpdate(t *testing.T) { cs1.UpdateCounter++ t.Run("update script, positive", func(t *testing.T) { - res, err := invokeContractMethod(bc, 10_00000000, cs1.Hash, "update", nef1b, nil) + tx1, err := prepareContractMethodInvoke(bc, 10_00000000, cs1.Hash, "update", nef1b, nil) require.NoError(t, err) - require.Equal(t, vm.HaltState, res.VMState) + tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) + require.NoError(t, err) + + aers, err := persistBlock(bc, tx1, tx2) + require.NoError(t, err) + require.Equal(t, vm.HaltState, aers[0].VMState) + require.Equal(t, vm.HaltState, aers[1].VMState) + require.Equal(t, 1, len(aers[1].Stack)) + compareContractStates(t, cs1, aers[1].Stack[0]) + t.Run("_deploy called", func(t *testing.T) { res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue") require.NoError(t, err) @@ -241,10 +345,7 @@ func TestContractUpdate(t *testing.T) { require.Equal(t, []byte("update"), res.Stack[0].Value()) }) t.Run("check contract", func(t *testing.T) { - res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) - require.NoError(t, err) - require.Equal(t, 1, len(res.Stack)) - compareContractStates(t, cs1, res.Stack[0]) + checkContractState(t, bc, cs1.Hash, cs1) }) }) @@ -258,10 +359,7 @@ func TestContractUpdate(t *testing.T) { require.NoError(t, err) require.Equal(t, vm.HaltState, res.VMState) t.Run("check contract", func(t *testing.T) { - res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) - require.NoError(t, err) - require.Equal(t, 1, len(res.Stack)) - compareContractStates(t, cs1, res.Stack[0]) + checkContractState(t, bc, cs1.Hash, cs1) }) }) @@ -280,10 +378,7 @@ func TestContractUpdate(t *testing.T) { require.NoError(t, err) require.Equal(t, vm.HaltState, res.VMState) t.Run("check contract", func(t *testing.T) { - res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) - require.NoError(t, err) - require.Equal(t, 1, len(res.Stack)) - compareContractStates(t, cs1, res.Stack[0]) + checkContractState(t, bc, cs1.Hash, cs1) }) }) }