mirror of
https://github.com/nspcc-dev/neo-go.git
synced 2024-11-23 13:38:35 +00:00
Merge pull request #1618 from nspcc-dev/contractcache
native: cache contract in Management contract
This commit is contained in:
commit
203f8adc9d
5 changed files with 235 additions and 29 deletions
|
@ -284,6 +284,11 @@ func (bc *Blockchain) init() error {
|
|||
return fmt.Errorf("can't init cache for NEO native contract: %w", err)
|
||||
}
|
||||
|
||||
err = bc.contracts.Management.InitializeCache(bc.dao)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't init cache for Management native contract: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -46,19 +46,33 @@ func newTestChain(t *testing.T) *Blockchain {
|
|||
}
|
||||
|
||||
func newTestChainWithCustomCfg(t *testing.T, f func(*config.Config)) *Blockchain {
|
||||
return newTestChainWithCustomCfgAndStore(t, nil, f)
|
||||
}
|
||||
|
||||
func newTestChainWithCustomCfgAndStore(t *testing.T, st storage.Store, f func(*config.Config)) *Blockchain {
|
||||
unitTestNetCfg, err := config.Load("../../config", testchain.Network())
|
||||
require.NoError(t, err)
|
||||
if f != nil {
|
||||
f(&unitTestNetCfg)
|
||||
}
|
||||
chain, err := NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t))
|
||||
if st == nil {
|
||||
st = storage.NewMemoryStore()
|
||||
}
|
||||
chain, err := NewBlockchain(st, unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t))
|
||||
require.NoError(t, err)
|
||||
go chain.Run()
|
||||
return chain
|
||||
}
|
||||
|
||||
func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block {
|
||||
lastBlock := bc.topBlock.Load().(*block.Block)
|
||||
lastBlock, ok := bc.topBlock.Load().(*block.Block)
|
||||
if !ok {
|
||||
var err error
|
||||
lastBlock, err = bc.GetBlock(bc.GetHeaderHash(int(bc.BlockHeight())))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
if bc.config.StateRootInHeader {
|
||||
sr, err := bc.GetStateRoot(bc.BlockHeight())
|
||||
if err != nil {
|
||||
|
|
|
@ -6,13 +6,16 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"sync"
|
||||
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/dao"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
|
||||
|
@ -24,6 +27,9 @@ import (
|
|||
// Management is contract-managing native contract.
|
||||
type Management struct {
|
||||
interop.ContractMD
|
||||
|
||||
mtx sync.RWMutex
|
||||
contracts map[util.Uint160]*state.Contract
|
||||
}
|
||||
|
||||
// StoragePrice is the price to pay for 1 byte of storage.
|
||||
|
@ -43,7 +49,10 @@ func makeContractKey(h util.Uint160) []byte {
|
|||
|
||||
// newManagement creates new Management native contract.
|
||||
func newManagement() *Management {
|
||||
var m = &Management{ContractMD: *interop.NewContractMD(nativenames.Management)}
|
||||
var m = &Management{
|
||||
ContractMD: *interop.NewContractMD(nativenames.Management),
|
||||
contracts: make(map[util.Uint160]*state.Contract),
|
||||
}
|
||||
|
||||
desc := newDescriptor("getContract", smartcontract.ArrayType,
|
||||
manifest.NewParameter("hash", smartcontract.Hash160Type))
|
||||
|
@ -89,6 +98,18 @@ func (m *Management) getContract(ic *interop.Context, args []stackitem.Item) sta
|
|||
|
||||
// GetContract returns contract with given hash from given DAO.
|
||||
func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
|
||||
m.mtx.RLock()
|
||||
cs, ok := m.contracts[hash]
|
||||
m.mtx.RUnlock()
|
||||
if !ok {
|
||||
return nil, storage.ErrKeyNotFound
|
||||
} else if cs != nil {
|
||||
return cs, nil
|
||||
}
|
||||
return m.getContractFromDAO(d, hash)
|
||||
}
|
||||
|
||||
func (m *Management) getContractFromDAO(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
|
||||
contract := new(state.Contract)
|
||||
key := makeContractKey(hash)
|
||||
err := getSerializableFromDAO(m.ContractID, d, key, contract)
|
||||
|
@ -173,7 +194,13 @@ func (m *Management) deploy(ic *interop.Context, args []stackitem.Item) stackite
|
|||
}
|
||||
callDeploy(ic, newcontract, false)
|
||||
return contractToStack(newcontract)
|
||||
}
|
||||
|
||||
func (m *Management) markUpdated(h util.Uint160) {
|
||||
m.mtx.Lock()
|
||||
// Just set it to nil, to refresh cache in `PostPersist`.
|
||||
m.contracts[h] = nil
|
||||
m.mtx.Unlock()
|
||||
}
|
||||
|
||||
// Deploy creates contract's hash/ID and saves new contract into the given DAO.
|
||||
|
@ -202,6 +229,7 @@ func (m *Management) Deploy(d dao.DAO, sender util.Uint160, neff *nef.File, mani
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.markUpdated(newcontract.Hash)
|
||||
return newcontract, nil
|
||||
}
|
||||
|
||||
|
@ -232,14 +260,16 @@ func (m *Management) Update(d dao.DAO, hash util.Uint160, neff *nef.File, manif
|
|||
}
|
||||
// if NEF was provided, update the contract script
|
||||
if neff != nil {
|
||||
m.markUpdated(hash)
|
||||
contract.Script = neff.Script
|
||||
}
|
||||
// if manifest was provided, update the contract manifest
|
||||
if manif != nil {
|
||||
contract.Manifest = *manif
|
||||
if !contract.Manifest.IsValid(contract.Hash) {
|
||||
if !manif.IsValid(contract.Hash) {
|
||||
return nil, errors.New("invalid manifest for this contract")
|
||||
}
|
||||
m.markUpdated(hash)
|
||||
contract.Manifest = *manif
|
||||
}
|
||||
contract.UpdateCounter++
|
||||
err = m.PutContractState(d, contract)
|
||||
|
@ -285,6 +315,7 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
m.markUpdated(hash)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -334,13 +365,59 @@ func (m *Management) OnPersist(ic *interop.Context) error {
|
|||
if err := native.Initialize(ic); err != nil {
|
||||
return fmt.Errorf("initializing %s native contract: %w", md.Name, err)
|
||||
}
|
||||
m.mtx.Lock()
|
||||
m.contracts[md.Hash] = cs
|
||||
m.mtx.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeCache initializes contract cache with the proper values from storage.
|
||||
// Cache initialisation should be done apart from Initialize because Initialize is
|
||||
// called only when deploying native contracts.
|
||||
func (m *Management) InitializeCache(d dao.DAO) error {
|
||||
m.mtx.Lock()
|
||||
defer m.mtx.Unlock()
|
||||
|
||||
var initErr error
|
||||
d.Seek(m.ContractID, []byte{prefixContract}, func(_, v []byte) {
|
||||
var r = io.NewBinReaderFromBuf(v)
|
||||
var si state.StorageItem
|
||||
si.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
initErr = r.Err
|
||||
return
|
||||
}
|
||||
|
||||
var cs state.Contract
|
||||
r = io.NewBinReaderFromBuf(si.Value)
|
||||
cs.DecodeBinary(r)
|
||||
if r.Err != nil {
|
||||
initErr = r.Err
|
||||
return
|
||||
}
|
||||
m.contracts[cs.Hash] = &cs
|
||||
})
|
||||
return initErr
|
||||
}
|
||||
|
||||
// PostPersist implements Contract interface.
|
||||
func (m *Management) PostPersist(_ *interop.Context) error {
|
||||
func (m *Management) PostPersist(ic *interop.Context) error {
|
||||
m.mtx.Lock()
|
||||
for h, cs := range m.contracts {
|
||||
if cs != nil {
|
||||
continue
|
||||
}
|
||||
newCs, err := m.getContractFromDAO(ic.DAO, h)
|
||||
if err != nil {
|
||||
// Contract was destroyed.
|
||||
delete(m.contracts, h)
|
||||
continue
|
||||
}
|
||||
m.contracts[h] = newCs
|
||||
}
|
||||
m.mtx.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -355,6 +432,7 @@ func (m *Management) PutContractState(d dao.DAO, cs *state.Contract) error {
|
|||
if err := putSerializableToDAO(m.ContractID, d, key, cs); err != nil {
|
||||
return err
|
||||
}
|
||||
m.markUpdated(cs.Hash)
|
||||
if cs.UpdateCounter != 0 { // Update.
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -61,3 +61,17 @@ func TestDeployGetUpdateDestroyContract(t *testing.T) {
|
|||
_, err = mgmt.GetContract(d, h)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestManagement_Initialize(t *testing.T) {
|
||||
t.Run("good", func(t *testing.T) {
|
||||
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
|
||||
mgmt := newManagement()
|
||||
require.NoError(t, mgmt.InitializeCache(d))
|
||||
})
|
||||
t.Run("invalid contract state", func(t *testing.T) {
|
||||
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
|
||||
mgmt := newManagement()
|
||||
require.NoError(t, d.PutStorageItem(mgmt.ContractID, []byte{prefixContract}, &state.StorageItem{Value: []byte{0xFF}}))
|
||||
require.Error(t, mgmt.InitializeCache(d))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,17 +7,70 @@ import (
|
|||
|
||||
"github.com/nspcc-dev/neo-go/internal/testchain"
|
||||
"github.com/nspcc-dev/neo-go/pkg/config"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/chaindump"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/state"
|
||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||||
"github.com/nspcc-dev/neo-go/pkg/io"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
|
||||
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
|
||||
"github.com/nspcc-dev/neo-go/pkg/util"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/opcode"
|
||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This is in a separate test because test test for long manifest
|
||||
// prevents chain from being dumped. In any real scenario
|
||||
// restrictions on tx script length will be applied before
|
||||
// restrictions on manifest size. In this test providing manifest of max size
|
||||
// leads to tx deserialization failure.
|
||||
func TestRestoreAfterDeploy(t *testing.T) {
|
||||
bc := newTestChain(t)
|
||||
defer bc.Close()
|
||||
|
||||
// nef.NewFile() cares about version a lot.
|
||||
config.Version = "0.90.0-test"
|
||||
mgmtHash := bc.ManagementContractHash()
|
||||
cs1, _ := getTestContractState(bc)
|
||||
cs1.ID = 1
|
||||
cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.Script)
|
||||
manif1, err := json.Marshal(cs1.Manifest)
|
||||
require.NoError(t, err)
|
||||
nef1, err := nef.NewFile(cs1.Script)
|
||||
require.NoError(t, err)
|
||||
nef1b, err := nef1.Bytes()
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
|
||||
require.NoError(t, err)
|
||||
checkFAULTState(t, res)
|
||||
}
|
||||
|
||||
type memoryStore struct {
|
||||
*storage.MemoryStore
|
||||
}
|
||||
|
||||
func (memoryStore) Close() error { return nil }
|
||||
|
||||
func TestStartFromHeight(t *testing.T) {
|
||||
st := memoryStore{storage.NewMemoryStore()}
|
||||
bc := newTestChainWithCustomCfgAndStore(t, st, nil)
|
||||
cs1, _ := getTestContractState(bc)
|
||||
func() {
|
||||
defer bc.Close()
|
||||
require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs1))
|
||||
checkContractState(t, bc, cs1.Hash, cs1)
|
||||
_, err := bc.dao.Store.Persist()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
bc2 := newTestChainWithCustomCfgAndStore(t, st, nil)
|
||||
checkContractState(t, bc2, cs1.Hash, cs1)
|
||||
}
|
||||
|
||||
func TestContractDeploy(t *testing.T) {
|
||||
bc := newTestChain(t)
|
||||
defer bc.Close()
|
||||
|
@ -70,11 +123,6 @@ func TestContractDeploy(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
checkFAULTState(t, res)
|
||||
})
|
||||
t.Run("too long manifest", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
|
||||
require.NoError(t, err)
|
||||
checkFAULTState(t, res)
|
||||
})
|
||||
t.Run("array for manifest", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, []interface{}{int64(1)})
|
||||
require.NoError(t, err)
|
||||
|
@ -99,17 +147,41 @@ func TestContractDeploy(t *testing.T) {
|
|||
checkFAULTState(t, res)
|
||||
})
|
||||
t.Run("positive", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
|
||||
tx1, err := prepareContractMethodInvoke(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs1, res.Stack[0])
|
||||
tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
|
||||
require.NoError(t, err)
|
||||
|
||||
aers, err := persistBlock(bc, tx1, tx2)
|
||||
require.NoError(t, err)
|
||||
for _, res := range aers {
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs1, res.Stack[0])
|
||||
}
|
||||
|
||||
t.Run("_deploy called", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
require.Equal(t, []byte("create"), res.Stack[0].Value())
|
||||
})
|
||||
t.Run("get after deploy", func(t *testing.T) {
|
||||
checkContractState(t, bc, cs1.Hash, cs1)
|
||||
})
|
||||
t.Run("get after restore", func(t *testing.T) {
|
||||
w := io.NewBufBinWriter()
|
||||
require.NoError(t, chaindump.Dump(bc, w.BinWriter, 0, bc.BlockHeight()+1))
|
||||
require.NoError(t, w.Err)
|
||||
|
||||
r := io.NewBinReaderFromBuf(w.Bytes())
|
||||
bc2 := newTestChain(t)
|
||||
defer bc2.Close()
|
||||
|
||||
require.NoError(t, chaindump.Restore(bc2, r, 0, bc.BlockHeight()+1, nil))
|
||||
require.NoError(t, r.Err)
|
||||
checkContractState(t, bc2, cs1.Hash, cs1)
|
||||
})
|
||||
})
|
||||
t.Run("contract already exists", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
|
||||
|
@ -138,6 +210,11 @@ func TestContractDeploy(t *testing.T) {
|
|||
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
|
||||
require.NoError(t, err)
|
||||
checkFAULTState(t, res)
|
||||
|
||||
t.Run("get after failed deploy", func(t *testing.T) {
|
||||
h := state.CreateContractHash(neoOwner, deployScript)
|
||||
checkContractState(t, bc, h, nil)
|
||||
})
|
||||
})
|
||||
t.Run("bad _deploy", func(t *testing.T) { // invalid _deploy signature
|
||||
deployScript := []byte{byte(opcode.RET)}
|
||||
|
@ -162,9 +239,27 @@ func TestContractDeploy(t *testing.T) {
|
|||
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
|
||||
require.NoError(t, err)
|
||||
checkFAULTState(t, res)
|
||||
|
||||
t.Run("get after bad _deploy", func(t *testing.T) {
|
||||
h := state.CreateContractHash(neoOwner, deployScript)
|
||||
checkContractState(t, bc, h, nil)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func checkContractState(t *testing.T, bc *Blockchain, h util.Uint160, cs *state.Contract) {
|
||||
mgmtHash := bc.contracts.Management.Hash
|
||||
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", h.BytesBE())
|
||||
require.NoError(t, err)
|
||||
if cs == nil {
|
||||
require.Equal(t, vm.FaultState, res.VMState)
|
||||
return
|
||||
}
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs, res.Stack[0])
|
||||
}
|
||||
|
||||
func TestContractUpdate(t *testing.T) {
|
||||
bc := newTestChain(t)
|
||||
defer bc.Close()
|
||||
|
@ -231,9 +326,18 @@ func TestContractUpdate(t *testing.T) {
|
|||
cs1.UpdateCounter++
|
||||
|
||||
t.Run("update script, positive", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 10_00000000, cs1.Hash, "update", nef1b, nil)
|
||||
tx1, err := prepareContractMethodInvoke(bc, 10_00000000, cs1.Hash, "update", nef1b, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
|
||||
require.NoError(t, err)
|
||||
|
||||
aers, err := persistBlock(bc, tx1, tx2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vm.HaltState, aers[0].VMState)
|
||||
require.Equal(t, vm.HaltState, aers[1].VMState)
|
||||
require.Equal(t, 1, len(aers[1].Stack))
|
||||
compareContractStates(t, cs1, aers[1].Stack[0])
|
||||
|
||||
t.Run("_deploy called", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
|
||||
require.NoError(t, err)
|
||||
|
@ -241,10 +345,7 @@ func TestContractUpdate(t *testing.T) {
|
|||
require.Equal(t, []byte("update"), res.Stack[0].Value())
|
||||
})
|
||||
t.Run("check contract", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs1, res.Stack[0])
|
||||
checkContractState(t, bc, cs1.Hash, cs1)
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -258,10 +359,7 @@ func TestContractUpdate(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
t.Run("check contract", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs1, res.Stack[0])
|
||||
checkContractState(t, bc, cs1.Hash, cs1)
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -280,10 +378,7 @@ func TestContractUpdate(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, vm.HaltState, res.VMState)
|
||||
t.Run("check contract", func(t *testing.T) {
|
||||
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, len(res.Stack))
|
||||
compareContractStates(t, cs1, res.Stack[0])
|
||||
checkContractState(t, bc, cs1.Hash, cs1)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue