core: add finalizer functions to interop context

These functions are aimed to free the resources occupied by storage
iterator by the end of script execution or whenever Finilize is called.
This commit is contained in:
Anna Shaleva 2021-10-07 14:27:55 +03:00
parent 0a4f45c9b0
commit 07cbe4d253
11 changed files with 60 additions and 29 deletions

View file

@ -327,7 +327,7 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
}
// GetTestVM implements Blockchainer interface.
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM {
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
panic("TODO")
}

View file

@ -1036,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
v.LoadToken = contract.LoadToken(systemInterop)
v.GasLimit = tx.SystemFee
err := v.Run()
err := systemInterop.Exec()
var faultException string
if !v.HasFailed() {
_, err := systemInterop.DAO.Persist()
@ -1223,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA
v := systemInterop.SpawnVM()
v.LoadScriptWithFlags(script, callflag.All)
v.SetPriceGetter(systemInterop.GetPrice)
if err := v.Run(); err != nil {
if err := systemInterop.Exec(); err != nil {
return nil, fmt.Errorf("VM has failed: %w", err)
} else if _, err := systemInterop.DAO.Persist(); err != nil {
return nil, fmt.Errorf("can't save changes: %w", err)
@ -2052,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) {
return bc.contracts.NEO.GetCandidates(bc.dao)
}
// GetTestVM returns a VM and a Store setup for a test run of some sort of code.
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM {
// GetTestVM returns a VM setup for a test run of some sort of code and finalizer function.
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
d := bc.dao.GetWrapped().(*dao.Simple)
systemInterop := bc.newInteropContext(t, d, b, tx)
vm := systemInterop.SpawnVM()
vm.SetPriceGetter(systemInterop.GetPrice)
vm.LoadToken = contract.LoadToken(systemInterop)
return vm
return vm, systemInterop.Finalize
}
// Various witness verification errors.
@ -2138,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil {
return 0, err
}
err := vm.Run()
err := interopCtx.Exec()
if vm.HasFailed() {
return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err)
}

View file

@ -59,7 +59,7 @@ type Blockchainer interface {
GetStateSyncModule() StateSync
GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func())
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
SetOracle(service services.Oracle)
mempool.Feer // fee interface

View file

@ -1,6 +1,7 @@
package interop
import (
"context"
"encoding/binary"
"errors"
"fmt"
@ -47,6 +48,7 @@ type Context struct {
Log *zap.Logger
VM *vm.VM
Functions []Function
cancelFuncs []context.CancelFunc
getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
baseExecFee int64
}
@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM {
ic.VM = v
return v
}
// RegisterCancelFunc adds given function to the list of functions to be called after VM
// finishes script execution.
func (ic *Context) RegisterCancelFunc(f context.CancelFunc) {
if f != nil {
ic.cancelFuncs = append(ic.cancelFuncs, f)
}
}
// Finalize calls all registered cancel functions to release the occupied resources.
func (ic *Context) Finalize() {
for _, f := range ic.cancelFuncs {
f()
}
ic.cancelFuncs = nil
}
// Exec executes loaded VM script and calls registered finalizers to release the occupied resources.
func (ic *Context) Exec() error {
defer ic.Finalize()
return ic.VM.Run()
}

View file

@ -1,8 +1,6 @@
package storage
import (
"context"
"github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -25,7 +23,6 @@ const (
// Iterator is an iterator state representation.
type Iterator struct {
seekCh chan storage.KeyValue
cancel context.CancelFunc
curr storage.KeyValue
next bool
opts int64
@ -33,10 +30,9 @@ type Iterator struct {
}
// NewIterator creates a new Iterator with given options for a given channel of store.Seek results.
func NewIterator(seekCh chan storage.KeyValue, cancel context.CancelFunc, prefix []byte, opts int64) *Iterator {
func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator {
return &Iterator{
seekCh: seekCh,
cancel: cancel,
opts: opts,
prefix: slice.Copy(prefix),
}
@ -84,9 +80,3 @@ func (s *Iterator) Value() stackitem.Item {
value,
})
}
// Close releases resources occupied by the Iterator.
// TODO: call this method on program unloading.
func (s *Iterator) Close() {
s.cancel()
}

View file

@ -191,8 +191,9 @@ func storageFind(ic *interop.Context) error {
// sorted items, so no need to sort them one more time.
ctx, cancel := context.WithCancel(context.Background())
seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix)
item := istorage.NewIterator(seekres, cancel, prefix, opts)
item := istorage.NewIterator(seekres, prefix, opts)
ic.VM.Estack().PushItem(stackitem.NewInterop(item))
ic.RegisterCancelFunc(cancel)
return nil
}

View file

@ -366,6 +366,8 @@ func BenchmarkStorageFind(b *testing.B) {
if err != nil {
b.FailNow()
}
b.StopTimer()
context.Finalize()
}
})
}
@ -425,6 +427,7 @@ func BenchmarkStorageFindIteratorNext(b *testing.B) {
} else {
require.True(b, actual)
}
context.Finalize()
}
})
}

View file

@ -20,10 +20,11 @@ type Invoke struct {
FaultException string
Transaction *transaction.Transaction
maxIteratorResultItems int
finalize func()
}
// NewInvoke returns new Invoke structure with the given fields set.
func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke {
func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke {
return &Invoke{
State: vm.State().String(),
GasConsumed: vm.GasConsumed(),
@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul
Stack: vm.Estack().ToArray(),
FaultException: faultException,
maxIteratorResultItems: maxIteratorResultItems,
finalize: finalize,
}
}
@ -55,8 +57,17 @@ type Iterator struct {
Truncated bool
}
// Finalize releases resources occupied by Iterators created at the script invocation.
// This method will be called automatically on Invoke marshalling.
func (r *Invoke) Finalize() {
if r.finalize != nil {
r.finalize()
}
}
// MarshalJSON implements json.Marshaler.
func (r Invoke) MarshalJSON() ([]byte, error) {
defer r.Finalize()
var st json.RawMessage
arr := make([]json.RawMessage, len(r.Stack))
for i := range arr {

View file

@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) {
require.NoError(t, err)
require.NoError(t, acc.SignTx(testchain.Network(), tx))
require.NoError(t, chain.VerifyTx(tx))
v := chain.GetTestVM(trigger.Application, tx, nil)
v, _ := chain.GetTestVM(trigger.Application, tx, nil)
v.LoadScriptWithFlags(tx.Script, callflag.All)
require.NoError(t, v.Run())
}

View file

@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re
if respErr != nil {
return 0, respErr
}
res.Finalize()
if res.State != "HALT" {
cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException)
return 0, response.NewRPCError(verificationErr, cause.Error(), cause)
@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin
}
script := bw.Bytes()
tx := &transaction.Transaction{Script: script}
v := s.chain.GetTestVM(trigger.Application, tx, nil)
v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil)
defer finalize()
v.GasLimit = core.HeaderVerificationGasLimit
v.LoadScriptWithFlags(script, callflag.All)
err := v.Run()
@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r
tx.Signers = []transaction.Signer{{Account: scriptHash}}
tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}}
}
return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx)
}
@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
}
b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond))
vm := s.chain.GetTestVM(t, tx, b)
vm, finalize := s.chain.GetTestVM(t, tx, b)
vm.GasLimit = int64(s.config.MaxGasInvoke)
if t == trigger.Verification {
// We need this special case because witnesses verification is not the simple System.Contract.Call,
@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
if err != nil {
faultException = err.Error()
}
return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil
return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil
}
// submitBlock broadcasts a raw block over the NEO network.

View file

@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa
}
func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) {
v := o.Chain.GetTestVM(trigger.Verification, tx, nil)
v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil)
v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS()
v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly)
v.Jump(v.Context(), o.verifyOffset)
ok := isVerifyOk(v)
ok := isVerifyOk(v, finalize)
return v.GasConsumed(), ok
}
func isVerifyOk(v *vm.VM) bool {
func isVerifyOk(v *vm.VM, finalize func()) bool {
defer finalize()
if err := v.Run(); err != nil {
return false
}