native: cache contract in Management contract

This commit is contained in:
Evgenii Stratonikov 2020-12-15 13:53:35 +03:00
parent c13d6ecc55
commit 3397f2c9be
5 changed files with 235 additions and 29 deletions

View file

@ -284,6 +284,11 @@ func (bc *Blockchain) init() error {
return fmt.Errorf("can't init cache for NEO native contract: %w", err) return fmt.Errorf("can't init cache for NEO native contract: %w", err)
} }
err = bc.contracts.Management.InitializeCache(bc.dao)
if err != nil {
return fmt.Errorf("can't init cache for Management native contract: %w", err)
}
return nil return nil
} }

View file

@ -46,19 +46,33 @@ func newTestChain(t *testing.T) *Blockchain {
} }
func newTestChainWithCustomCfg(t *testing.T, f func(*config.Config)) *Blockchain { func newTestChainWithCustomCfg(t *testing.T, f func(*config.Config)) *Blockchain {
return newTestChainWithCustomCfgAndStore(t, nil, f)
}
func newTestChainWithCustomCfgAndStore(t *testing.T, st storage.Store, f func(*config.Config)) *Blockchain {
unitTestNetCfg, err := config.Load("../../config", testchain.Network()) unitTestNetCfg, err := config.Load("../../config", testchain.Network())
require.NoError(t, err) require.NoError(t, err)
if f != nil { if f != nil {
f(&unitTestNetCfg) f(&unitTestNetCfg)
} }
chain, err := NewBlockchain(storage.NewMemoryStore(), unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t)) if st == nil {
st = storage.NewMemoryStore()
}
chain, err := NewBlockchain(st, unitTestNetCfg.ProtocolConfiguration, zaptest.NewLogger(t))
require.NoError(t, err) require.NoError(t, err)
go chain.Run() go chain.Run()
return chain return chain
} }
func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block { func (bc *Blockchain) newBlock(txs ...*transaction.Transaction) *block.Block {
lastBlock := bc.topBlock.Load().(*block.Block) lastBlock, ok := bc.topBlock.Load().(*block.Block)
if !ok {
var err error
lastBlock, err = bc.GetBlock(bc.GetHeaderHash(int(bc.BlockHeight())))
if err != nil {
panic(err)
}
}
if bc.config.StateRootInHeader { if bc.config.StateRootInHeader {
sr, err := bc.GetStateRoot(bc.BlockHeight()) sr, err := bc.GetStateRoot(bc.BlockHeight())
if err != nil { if err != nil {

View file

@ -6,13 +6,16 @@ import (
"fmt" "fmt"
"math" "math"
"math/big" "math/big"
"sync"
"github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/dao"
"github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/interop"
"github.com/nspcc-dev/neo-go/pkg/core/interop/contract" "github.com/nspcc-dev/neo-go/pkg/core/interop/contract"
"github.com/nspcc-dev/neo-go/pkg/core/native/nativenames" "github.com/nspcc-dev/neo-go/pkg/core/native/nativenames"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/encoding/bigint" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
@ -24,6 +27,9 @@ import (
// Management is contract-managing native contract. // Management is contract-managing native contract.
type Management struct { type Management struct {
interop.ContractMD interop.ContractMD
mtx sync.RWMutex
contracts map[util.Uint160]*state.Contract
} }
// StoragePrice is the price to pay for 1 byte of storage. // StoragePrice is the price to pay for 1 byte of storage.
@ -43,7 +49,10 @@ func makeContractKey(h util.Uint160) []byte {
// newManagement creates new Management native contract. // newManagement creates new Management native contract.
func newManagement() *Management { func newManagement() *Management {
var m = &Management{ContractMD: *interop.NewContractMD(nativenames.Management)} var m = &Management{
ContractMD: *interop.NewContractMD(nativenames.Management),
contracts: make(map[util.Uint160]*state.Contract),
}
desc := newDescriptor("getContract", smartcontract.ArrayType, desc := newDescriptor("getContract", smartcontract.ArrayType,
manifest.NewParameter("hash", smartcontract.Hash160Type)) manifest.NewParameter("hash", smartcontract.Hash160Type))
@ -89,6 +98,18 @@ func (m *Management) getContract(ic *interop.Context, args []stackitem.Item) sta
// GetContract returns contract with given hash from given DAO. // GetContract returns contract with given hash from given DAO.
func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, error) { func (m *Management) GetContract(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
m.mtx.RLock()
cs, ok := m.contracts[hash]
m.mtx.RUnlock()
if !ok {
return nil, storage.ErrKeyNotFound
} else if cs != nil {
return cs, nil
}
return m.getContractFromDAO(d, hash)
}
func (m *Management) getContractFromDAO(d dao.DAO, hash util.Uint160) (*state.Contract, error) {
contract := new(state.Contract) contract := new(state.Contract)
key := makeContractKey(hash) key := makeContractKey(hash)
err := getSerializableFromDAO(m.ContractID, d, key, contract) err := getSerializableFromDAO(m.ContractID, d, key, contract)
@ -173,7 +194,13 @@ func (m *Management) deploy(ic *interop.Context, args []stackitem.Item) stackite
} }
callDeploy(ic, newcontract, false) callDeploy(ic, newcontract, false)
return contractToStack(newcontract) return contractToStack(newcontract)
}
func (m *Management) markUpdated(h util.Uint160) {
m.mtx.Lock()
// Just set it to nil, to refresh cache in `PostPersist`.
m.contracts[h] = nil
m.mtx.Unlock()
} }
// Deploy creates contract's hash/ID and saves new contract into the given DAO. // Deploy creates contract's hash/ID and saves new contract into the given DAO.
@ -202,6 +229,7 @@ func (m *Management) Deploy(d dao.DAO, sender util.Uint160, neff *nef.File, mani
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.markUpdated(newcontract.Hash)
return newcontract, nil return newcontract, nil
} }
@ -232,14 +260,16 @@ func (m *Management) Update(d dao.DAO, hash util.Uint160, neff *nef.File, manif
} }
// if NEF was provided, update the contract script // if NEF was provided, update the contract script
if neff != nil { if neff != nil {
m.markUpdated(hash)
contract.Script = neff.Script contract.Script = neff.Script
} }
// if manifest was provided, update the contract manifest // if manifest was provided, update the contract manifest
if manif != nil { if manif != nil {
contract.Manifest = *manif if !manif.IsValid(contract.Hash) {
if !contract.Manifest.IsValid(contract.Hash) {
return nil, errors.New("invalid manifest for this contract") return nil, errors.New("invalid manifest for this contract")
} }
m.markUpdated(hash)
contract.Manifest = *manif
} }
contract.UpdateCounter++ contract.UpdateCounter++
err = m.PutContractState(d, contract) err = m.PutContractState(d, contract)
@ -285,6 +315,7 @@ func (m *Management) Destroy(d dao.DAO, hash util.Uint160) error {
return err return err
} }
} }
m.markUpdated(hash)
return nil return nil
} }
@ -334,13 +365,59 @@ func (m *Management) OnPersist(ic *interop.Context) error {
if err := native.Initialize(ic); err != nil { if err := native.Initialize(ic); err != nil {
return fmt.Errorf("initializing %s native contract: %w", md.Name, err) return fmt.Errorf("initializing %s native contract: %w", md.Name, err)
} }
m.mtx.Lock()
m.contracts[md.Hash] = cs
m.mtx.Unlock()
} }
return nil return nil
} }
// InitializeCache initializes contract cache with the proper values from storage.
// Cache initialisation should be done apart from Initialize because Initialize is
// called only when deploying native contracts.
func (m *Management) InitializeCache(d dao.DAO) error {
m.mtx.Lock()
defer m.mtx.Unlock()
var initErr error
d.Seek(m.ContractID, []byte{prefixContract}, func(_, v []byte) {
var r = io.NewBinReaderFromBuf(v)
var si state.StorageItem
si.DecodeBinary(r)
if r.Err != nil {
initErr = r.Err
return
}
var cs state.Contract
r = io.NewBinReaderFromBuf(si.Value)
cs.DecodeBinary(r)
if r.Err != nil {
initErr = r.Err
return
}
m.contracts[cs.Hash] = &cs
})
return initErr
}
// PostPersist implements Contract interface. // PostPersist implements Contract interface.
func (m *Management) PostPersist(_ *interop.Context) error { func (m *Management) PostPersist(ic *interop.Context) error {
m.mtx.Lock()
for h, cs := range m.contracts {
if cs != nil {
continue
}
newCs, err := m.getContractFromDAO(ic.DAO, h)
if err != nil {
// Contract was destroyed.
delete(m.contracts, h)
continue
}
m.contracts[h] = newCs
}
m.mtx.Unlock()
return nil return nil
} }
@ -355,6 +432,7 @@ func (m *Management) PutContractState(d dao.DAO, cs *state.Contract) error {
if err := putSerializableToDAO(m.ContractID, d, key, cs); err != nil { if err := putSerializableToDAO(m.ContractID, d, key, cs); err != nil {
return err return err
} }
m.markUpdated(cs.Hash)
if cs.UpdateCounter != 0 { // Update. if cs.UpdateCounter != 0 { // Update.
return nil return nil
} }

View file

@ -61,3 +61,17 @@ func TestDeployGetUpdateDestroyContract(t *testing.T) {
_, err = mgmt.GetContract(d, h) _, err = mgmt.GetContract(d, h)
require.Error(t, err) require.Error(t, err)
} }
func TestManagement_Initialize(t *testing.T) {
t.Run("good", func(t *testing.T) {
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
mgmt := newManagement()
require.NoError(t, mgmt.InitializeCache(d))
})
t.Run("invalid contract state", func(t *testing.T) {
d := dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)
mgmt := newManagement()
require.NoError(t, d.PutStorageItem(mgmt.ContractID, []byte{prefixContract}, &state.StorageItem{Value: []byte{0xFF}}))
require.Error(t, mgmt.InitializeCache(d))
})
}

View file

@ -7,17 +7,70 @@ import (
"github.com/nspcc-dev/neo-go/internal/testchain" "github.com/nspcc-dev/neo-go/internal/testchain"
"github.com/nspcc-dev/neo-go/pkg/config" "github.com/nspcc-dev/neo-go/pkg/config"
"github.com/nspcc-dev/neo-go/pkg/core/chaindump"
"github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/state"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/crypto/keys"
"github.com/nspcc-dev/neo-go/pkg/io"
"github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest"
"github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm"
"github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/opcode"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// This is in a separate test because test test for long manifest
// prevents chain from being dumped. In any real scenario
// restrictions on tx script length will be applied before
// restrictions on manifest size. In this test providing manifest of max size
// leads to tx deserialization failure.
func TestRestoreAfterDeploy(t *testing.T) {
bc := newTestChain(t)
defer bc.Close()
// nef.NewFile() cares about version a lot.
config.Version = "0.90.0-test"
mgmtHash := bc.ManagementContractHash()
cs1, _ := getTestContractState(bc)
cs1.ID = 1
cs1.Hash = state.CreateContractHash(testchain.MultisigScriptHash(), cs1.Script)
manif1, err := json.Marshal(cs1.Manifest)
require.NoError(t, err)
nef1, err := nef.NewFile(cs1.Script)
require.NoError(t, err)
nef1b, err := nef1.Bytes()
require.NoError(t, err)
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
require.NoError(t, err)
checkFAULTState(t, res)
}
type memoryStore struct {
*storage.MemoryStore
}
func (memoryStore) Close() error { return nil }
func TestStartFromHeight(t *testing.T) {
st := memoryStore{storage.NewMemoryStore()}
bc := newTestChainWithCustomCfgAndStore(t, st, nil)
cs1, _ := getTestContractState(bc)
func() {
defer bc.Close()
require.NoError(t, bc.contracts.Management.PutContractState(bc.dao, cs1))
checkContractState(t, bc, cs1.Hash, cs1)
_, err := bc.dao.Store.Persist()
require.NoError(t, err)
}()
bc2 := newTestChainWithCustomCfgAndStore(t, st, nil)
checkContractState(t, bc2, cs1.Hash, cs1)
}
func TestContractDeploy(t *testing.T) { func TestContractDeploy(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer bc.Close() defer bc.Close()
@ -70,11 +123,6 @@ func TestContractDeploy(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
checkFAULTState(t, res) checkFAULTState(t, res)
}) })
t.Run("too long manifest", func(t *testing.T) {
res, err := invokeContractMethod(bc, 100_00000000, mgmtHash, "deploy", nef1b, append(manif1, make([]byte, manifest.MaxManifestSize)...))
require.NoError(t, err)
checkFAULTState(t, res)
})
t.Run("array for manifest", func(t *testing.T) { t.Run("array for manifest", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, []interface{}{int64(1)}) res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, []interface{}{int64(1)})
require.NoError(t, err) require.NoError(t, err)
@ -99,17 +147,41 @@ func TestContractDeploy(t *testing.T) {
checkFAULTState(t, res) checkFAULTState(t, res)
}) })
t.Run("positive", func(t *testing.T) { t.Run("positive", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1) tx1, err := prepareContractMethodInvoke(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState) tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.Equal(t, 1, len(res.Stack)) require.NoError(t, err)
compareContractStates(t, cs1, res.Stack[0])
aers, err := persistBlock(bc, tx1, tx2)
require.NoError(t, err)
for _, res := range aers {
require.Equal(t, vm.HaltState, res.VMState)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
}
t.Run("_deploy called", func(t *testing.T) { t.Run("_deploy called", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue") res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(res.Stack)) require.Equal(t, 1, len(res.Stack))
require.Equal(t, []byte("create"), res.Stack[0].Value()) require.Equal(t, []byte("create"), res.Stack[0].Value())
}) })
t.Run("get after deploy", func(t *testing.T) {
checkContractState(t, bc, cs1.Hash, cs1)
})
t.Run("get after restore", func(t *testing.T) {
w := io.NewBufBinWriter()
require.NoError(t, chaindump.Dump(bc, w.BinWriter, 0, bc.BlockHeight()+1))
require.NoError(t, w.Err)
r := io.NewBinReaderFromBuf(w.Bytes())
bc2 := newTestChain(t)
defer bc2.Close()
require.NoError(t, chaindump.Restore(bc2, r, 0, bc.BlockHeight()+1, nil))
require.NoError(t, r.Err)
checkContractState(t, bc2, cs1.Hash, cs1)
})
}) })
t.Run("contract already exists", func(t *testing.T) { t.Run("contract already exists", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1) res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nef1b, manif1)
@ -138,6 +210,11 @@ func TestContractDeploy(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD) res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
require.NoError(t, err) require.NoError(t, err)
checkFAULTState(t, res) checkFAULTState(t, res)
t.Run("get after failed deploy", func(t *testing.T) {
h := state.CreateContractHash(neoOwner, deployScript)
checkContractState(t, bc, h, nil)
})
}) })
t.Run("bad _deploy", func(t *testing.T) { // invalid _deploy signature t.Run("bad _deploy", func(t *testing.T) { // invalid _deploy signature
deployScript := []byte{byte(opcode.RET)} deployScript := []byte{byte(opcode.RET)}
@ -162,9 +239,27 @@ func TestContractDeploy(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD) res, err := invokeContractMethod(bc, 10_00000000, mgmtHash, "deploy", nefDb, manifD)
require.NoError(t, err) require.NoError(t, err)
checkFAULTState(t, res) checkFAULTState(t, res)
t.Run("get after bad _deploy", func(t *testing.T) {
h := state.CreateContractHash(neoOwner, deployScript)
checkContractState(t, bc, h, nil)
})
}) })
} }
func checkContractState(t *testing.T, bc *Blockchain, h util.Uint160, cs *state.Contract) {
mgmtHash := bc.contracts.Management.Hash
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", h.BytesBE())
require.NoError(t, err)
if cs == nil {
require.Equal(t, vm.FaultState, res.VMState)
return
}
require.Equal(t, vm.HaltState, res.VMState)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs, res.Stack[0])
}
func TestContractUpdate(t *testing.T) { func TestContractUpdate(t *testing.T) {
bc := newTestChain(t) bc := newTestChain(t)
defer bc.Close() defer bc.Close()
@ -231,9 +326,18 @@ func TestContractUpdate(t *testing.T) {
cs1.UpdateCounter++ cs1.UpdateCounter++
t.Run("update script, positive", func(t *testing.T) { t.Run("update script, positive", func(t *testing.T) {
res, err := invokeContractMethod(bc, 10_00000000, cs1.Hash, "update", nef1b, nil) tx1, err := prepareContractMethodInvoke(bc, 10_00000000, cs1.Hash, "update", nef1b, nil)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState) tx2, err := prepareContractMethodInvoke(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE())
require.NoError(t, err)
aers, err := persistBlock(bc, tx1, tx2)
require.NoError(t, err)
require.Equal(t, vm.HaltState, aers[0].VMState)
require.Equal(t, vm.HaltState, aers[1].VMState)
require.Equal(t, 1, len(aers[1].Stack))
compareContractStates(t, cs1, aers[1].Stack[0])
t.Run("_deploy called", func(t *testing.T) { t.Run("_deploy called", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue") res, err := invokeContractMethod(bc, 1_00000000, cs1.Hash, "getValue")
require.NoError(t, err) require.NoError(t, err)
@ -241,10 +345,7 @@ func TestContractUpdate(t *testing.T) {
require.Equal(t, []byte("update"), res.Stack[0].Value()) require.Equal(t, []byte("update"), res.Stack[0].Value())
}) })
t.Run("check contract", func(t *testing.T) { t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) checkContractState(t, bc, cs1.Hash, cs1)
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
}) })
}) })
@ -258,10 +359,7 @@ func TestContractUpdate(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState) require.Equal(t, vm.HaltState, res.VMState)
t.Run("check contract", func(t *testing.T) { t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) checkContractState(t, bc, cs1.Hash, cs1)
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
}) })
}) })
@ -280,10 +378,7 @@ func TestContractUpdate(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, vm.HaltState, res.VMState) require.Equal(t, vm.HaltState, res.VMState)
t.Run("check contract", func(t *testing.T) { t.Run("check contract", func(t *testing.T) {
res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "getContract", cs1.Hash.BytesBE()) checkContractState(t, bc, cs1.Hash, cs1)
require.NoError(t, err)
require.Equal(t, 1, len(res.Stack))
compareContractStates(t, cs1, res.Stack[0])
}) })
}) })
} }