native: cache contract in Management contract

This commit is contained in:
Evgenii Stratonikov 2020-12-15 13:53:35 +03:00
parent c13d6ecc55
commit 3397f2c9be
5 changed files with 235 additions and 29 deletions

View file

@ -284,6 +284,11 @@ func (bc *Blockchain) init() error {
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
}
err = bc.contracts.Management.InitializeCache(bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for Management native contract: %w", err)
}
return nil
}

View file

@ -46,19 +46,33 @@ func newTestChain(t *testing.T) *Blockchain {
}
func newTestChainWithCustomCfg(t *testing.T, f func(*config.Config)) *Blockchain {
return newTestChainWithCustomCfgAndStore(t, nil, f)
}
func newTestChainWithCustomCfgAndStore(t *testing.T, st storage.Store, f func(*config.Config)) *Blockchain {
unitTestNetCfg, err := config.Load("../../config", testchain.Network())
require.NoError(t, err)
if f != nil {
f(&unitTestNetCfg)
}
chain, err := NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t))
if st == nil {
st = storage.NewMemoryStore()
}
chain, err := NewBlockchain(st, unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t))
require.NoError(t, err)
go chain.Run()
return chain
}
func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block {
lastBlock := bc.topBlock.Load().(*block.Block)
lastBlock, ok := bc.topBlock.Load().(*block.Block)
if !ok {
var err error
lastBlock, err = bc.GetBlock(bc.GetHeaderHash(int(bc.BlockHeight())))
if err != nil {
panic(err)
}
}
if bc.config.StateRootInHeader {
sr, err := bc.GetStateRoot(bc.BlockHeight())
if err != nil {

View file

@ -6,13 +6,16 @@ import (
"fmt"
"math"
"math/big"
"sync"
"github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
@ -24,6 +27,9 @@ import (
// Management is contract-managing native contract.
type Management struct {
interop.ContractMD
mtx sync.RWMutex
contracts map[util.Uint160]*state.Contract
}
// StoragePrice is the price to pay for 1 byte of storage.
@ -43,7 +49,10 @@ func makeContractKey(h util.Uint160) []byte {
// newManagement creates new Management native contract.
func newManagement() *Management {
var m = &Management{ContractMD: *interop.NewContractMD(nativenames.Management)}
var m = &Management{
ContractMD: *interop.NewContractMD(nativenames.Management),
contracts: make(map[util.Uint160]*state.Contract),
}
desc := newDescriptor("getContract", smartcontract.ArrayType,
manifest.NewParameter("hash", smartcontract.Hash160Type))
@ -89,6 +98,18 @@ func (m *Management) getContract(ic *interop.Context, args []stackitem.Item) sta
// GetContract returns contract with given hash from given DAO.
func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
m.mtx.RLock()
cs, ok := m.contracts[hash]
m.mtx.RUnlock()
if !ok {
return nil, storage.ErrKeyNotFound
} else if cs != nil {
return cs, nil
}
return m.getContractFromDAO(d, hash)
}
func (m *Management) getContractFromDAO(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
contract := new(state.Contract)
key := makeContractKey(hash)
err := getSerializableFromDAO(m.ContractID, d, key, contract)
@ -173,7 +194,13 @@ func (m *Management) deploy(ic *interop.Context, args []stackitem.Item) stackite
}
callDeploy(ic, newcontract, false)
return contractToStack(newcontract)
}
func (m *Management) markUpdated(h util.Uint160) {
m.mtx.Lock()
// Just set it to nil, to refresh cache in `PostPersist`.
m.contracts[h] = nil
m.mtx.Unlock()
}
// Deploy creates contract's hash/ID and saves new contract into the given DAO.
@ -202,6 +229,7 @@ func (m *Management) Deploy(d dao.DAO, sender util.Uint160, neff *nef.File, mani
if err != nil {
return nil, err
}
m.markUpdated(newcontract.Hash)
return newcontract, nil
}
@ -232,14 +260,16 @@ func (m *Management) Update(d dao.DAO, hash util.Uint160, neff *nef.File, manif
}
// if NEF was provided, update the contract script
if neff != nil {
m.markUpdated(hash)
contract.Script = neff.Script
}
// if manifest was provided, update the contract manifest
if manif != nil {
contract.Manifest = *manif
if !contract.Manifest.IsValid(contract.Hash) {
if !manif.IsValid(contract.Hash) {
return nil, errors.New("invalid manifest for this contract")
}
m.markUpdated(hash)
contract.Manifest = *manif
}
contract.UpdateCounter++
err = m.PutContractState(d, contract)
@ -285,6 +315,7 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error {
return err
}
}
m.markUpdated(hash)
return nil
}
@ -334,13 +365,59 @@ func (m *Management) OnPersist(ic *interop.Context) error {
if err := native.Initialize(ic); err != nil {
return fmt.Errorf("initializing %s native contract: %w", md.Name, err)
}
m.mtx.Lock()
m.contracts[md.Hash] = cs
m.mtx.Unlock()
}
return nil
}
// InitializeCache initializes contract cache with the proper values from storage.
// Cache initialisation should be done apart from Initialize because Initialize is
// called only when deploying native contracts.
func (m *Management) InitializeCache(d dao.DAO) error {
m.mtx.Lock()
defer m.mtx.Unlock()
var initErr error
d.Seek(m.ContractID, []byte{prefixContract}, func(_, v []byte) {
var r = io.NewBinReaderFromBuf(v)
var si state.StorageItem
si.DecodeBinary(r)
if r.Err != nil {
initErr = r.Err
return
}
var cs state.Contract
r = io.NewBinReaderFromBuf(si.Value)
cs.DecodeBinary(r)
if r.Err != nil {
initErr = r.Err
return
}
m.contracts[cs.Hash] = &cs
})
return initErr
}
// PostPersist implements Contract interface.
func (m *Management) PostPersist(_ *interop.Context) error {
func (m *Management) PostPersist(ic *interop.Context) error {
m.mtx.Lock()
for h, cs := range m.contracts {
if cs != nil {
continue
}
newCs, err := m.getContractFromDAO(ic.DAO, h)
if err != nil {
// Contract was destroyed.
delete(m.contracts, h)
continue
}
m.contracts[h] = newCs
}
m.mtx.Unlock()
return nil
}
@ -355,6 +432,7 @@ func (m *Management) PutContractState(d dao.DAO, cs *state.Contract) error {
if err := putSerializableToDAO(m.ContractID, d, key, cs); err != nil {
return err
}
m.markUpdated(cs.Hash)
if cs.UpdateCounter != 0 { // Update.
return nil
}

View file

@ -61,3 +61,17 @@ func TestDeployGetUpdateDestroyContract(t *testing.T) {
_, err = mgmt.GetContract(d, h)
require.Error(t, err)
}
func TestManagement_Initialize(t *testing.T) {
t.Run("good", func(t *testing.T) {
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
mgmt := newManagement()
require.NoError(t, mgmt.InitializeCache(d))
})
t.Run("invalid contract state", func(t *testing.T) {
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
mgmt := newManagement()
require.NoError(t, d.PutStorageItem(mgmt.ContractID, []byte{prefixContract}, &state.StorageItem{Value: []byte{0xFF}}))
require.Error(t, mgmt.InitializeCache(d))
})
}

View file

@ -7,17 +7,70 @@ import (
"github.com/nspcc-dev/neo-go/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/core/chaindump"
"github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require"
)
// This is in a separate test because test test for long manifest
// prevents chain from being dumped. In any real scenario
// restrictions on tx script length will be applied before
// restrictions on manifest size. In this test providing manifest of max size
// leads to tx deserialization failure.
func TestRestoreAfterDeploy(t *testing.T) {
bc := newTestChain(t)
defer bc.Close()
// nef.NewFile() cares about version a lot.
config.Version = "0.90.0-test"
mgmtHash := bc.ManagementContractHash()
cs1, _ := getTestContractState(bc)
cs1.ID = 1
cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.Script)
manif1, err := json.Marshal(cs1.Manifest)
require.NoError(t, err)
nef1, err := nef.NewFile(cs1.Script)
require.NoError(t, err)
nef1b, err := nef1.Bytes()
require.NoError(t, err)
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
require.NoError(t, err)
checkFAULTState(t, res)
}
type memoryStore struct {
*storage.MemoryStore
}
func (memoryStore) Close() error { return nil }
func TestStartFromHeight(t *testing.T) {
st := memoryStore{storage.NewMemoryStore()}
bc := newTestChainWithCustomCfgAndStore(t, st, nil)
cs1, _ := getTestContractState(bc)
func() {
defer bc.Close()
require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs1))
checkContractState(t, bc, cs1.Hash, cs1)
_, err := bc.dao.Store.Persist()
require.NoError(t, err)
}()
bc2 := newTestChainWithCustomCfgAndStore(t, st, nil)
checkContractState(t, bc2, cs1.Hash, cs1)
}
func TestContractDeploy(t *testing.T) {
bc := newTestChain(t)
defer bc.Close()
@ -70,11 +123,6 @@ func TestContractDeploy(t *testing.T) {
require.NoError(t, err)
checkFAULTState(t, res)
})
t.Run("too long manifest", func(t *testing.T) {
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
require.NoError(t, err)
checkFAULTState(t, res)
})
t.Run("array for manifest", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, []interface{}{int64(1)})
require.NoError(t, err)
@ -99,17 +147,41 @@ func TestContractDeploy(t *testing.T) {
checkFAULTState(t, res)
})
t.Run("positive", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
tx1, err := prepareContractMethodInvoke(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
aers, err := persistBlock(bc, tx1, tx2)
require.NoError(t, err)
for _, res := range aers {
require.Equal(t, vm.HaltState, res.VMState)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
}
t.Run("_deploy called", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
require.Equal(t, []byte("create"), res.Stack[0].Value())
})
t.Run("get after deploy", func(t *testing.T) {
checkContractState(t, bc, cs1.Hash, cs1)
})
t.Run("get after restore", func(t *testing.T) {
w := io.NewBufBinWriter()
require.NoError(t, chaindump.Dump(bc, w.BinWriter, 0, bc.BlockHeight()+1))
require.NoError(t, w.Err)
r := io.NewBinReaderFromBuf(w.Bytes())
bc2 := newTestChain(t)
defer bc2.Close()
require.NoError(t, chaindump.Restore(bc2, r, 0, bc.BlockHeight()+1, nil))
require.NoError(t, r.Err)
checkContractState(t, bc2, cs1.Hash, cs1)
})
})
t.Run("contract already exists", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
@ -138,6 +210,11 @@ func TestContractDeploy(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
require.NoError(t, err)
checkFAULTState(t, res)
t.Run("get after failed deploy", func(t *testing.T) {
h := state.CreateContractHash(neoOwner, deployScript)
checkContractState(t, bc, h, nil)
})
})
t.Run("bad _deploy", func(t *testing.T) { // invalid _deploy signature
deployScript := []byte{byte(opcode.RET)}
@ -162,9 +239,27 @@ func TestContractDeploy(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
require.NoError(t, err)
checkFAULTState(t, res)
t.Run("get after bad _deploy", func(t *testing.T) {
h := state.CreateContractHash(neoOwner, deployScript)
checkContractState(t, bc, h, nil)
})
})
}
func checkContractState(t *testing.T, bc *Blockchain, h util.Uint160, cs *state.Contract) {
mgmtHash := bc.contracts.Management.Hash
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", h.BytesBE())
require.NoError(t, err)
if cs == nil {
require.Equal(t, vm.FaultState, res.VMState)
return
}
require.Equal(t, vm.HaltState, res.VMState)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs, res.Stack[0])
}
func TestContractUpdate(t *testing.T) {
bc := newTestChain(t)
defer bc.Close()
@ -231,9 +326,18 @@ func TestContractUpdate(t *testing.T) {
cs1.UpdateCounter++
t.Run("update script, positive", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, cs1.Hash, "update", nef1b, nil)
tx1, err := prepareContractMethodInvoke(bc, 10_00000000, cs1.Hash, "update", nef1b, nil)
require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState)
tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
aers, err := persistBlock(bc, tx1, tx2)
require.NoError(t, err)
require.Equal(t, vm.HaltState, aers[0].VMState)
require.Equal(t, vm.HaltState, aers[1].VMState)
require.Equal(t, 1, len(aers[1].Stack))
compareContractStates(t, cs1, aers[1].Stack[0])
t.Run("_deploy called", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
require.NoError(t, err)
@ -241,10 +345,7 @@ func TestContractUpdate(t *testing.T) {
require.Equal(t, []byte("update"), res.Stack[0].Value())
})
t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
checkContractState(t, bc, cs1.Hash, cs1)
})
})
@ -258,10 +359,7 @@ func TestContractUpdate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState)
t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
checkContractState(t, bc, cs1.Hash, cs1)
})
})
@ -280,10 +378,7 @@ func TestContractUpdate(t *testing.T) {
require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState)
t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
checkContractState(t, bc, cs1.Hash, cs1)
})
})
}