core: add finalizer functions to interop context

These functions are aimed to free the resources occupied by storage
iterator by the end of script execution or whenever Finilize is called.
This commit is contained in:
Anna Shaleva 2021-10-07 14:27:55 +03:00
parent 0a4f45c9b0
commit 07cbe4d253
11 changed files with 60 additions and 29 deletions

View file

@ -327,7 +327,7 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
} }
// GetTestVM implements Blockchainer interface. // GetTestVM implements Blockchainer interface.
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
panic("TODO") panic("TODO")
} }

View file

@ -1036,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
v.LoadToken = contract.LoadToken(systemInterop) v.LoadToken = contract.LoadToken(systemInterop)
v.GasLimit = tx.SystemFee v.GasLimit = tx.SystemFee
err := v.Run() err := systemInterop.Exec()
var faultException string var faultException string
if !v.HasFailed() { if !v.HasFailed() {
_, err := systemInterop.DAO.Persist() _, err := systemInterop.DAO.Persist()
@ -1223,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA
v := systemInterop.SpawnVM() v := systemInterop.SpawnVM()
v.LoadScriptWithFlags(script, callflag.All) v.LoadScriptWithFlags(script, callflag.All)
v.SetPriceGetter(systemInterop.GetPrice) v.SetPriceGetter(systemInterop.GetPrice)
if err := v.Run(); err != nil { if err := systemInterop.Exec(); err != nil {
return nil, fmt.Errorf("VM has failed: %w", err) return nil, fmt.Errorf("VM has failed: %w", err)
} else if _, err := systemInterop.DAO.Persist(); err != nil { } else if _, err := systemInterop.DAO.Persist(); err != nil {
return nil, fmt.Errorf("can't save changes: %w", err) return nil, fmt.Errorf("can't save changes: %w", err)
@ -2052,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) {
return bc.contracts.NEO.GetCandidates(bc.dao) return bc.contracts.NEO.GetCandidates(bc.dao)
} }
// GetTestVM returns a VM and a Store setup for a test run of some sort of code. // GetTestVM returns a VM setup for a test run of some sort of code and finalizer function.
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
d := bc.dao.GetWrapped().(*dao.Simple) d := bc.dao.GetWrapped().(*dao.Simple)
systemInterop := bc.newInteropContext(t, d, b, tx) systemInterop := bc.newInteropContext(t, d, b, tx)
vm := systemInterop.SpawnVM() vm := systemInterop.SpawnVM()
vm.SetPriceGetter(systemInterop.GetPrice) vm.SetPriceGetter(systemInterop.GetPrice)
vm.LoadToken = contract.LoadToken(systemInterop) vm.LoadToken = contract.LoadToken(systemInterop)
return vm return vm, systemInterop.Finalize
} }
// Various witness verification errors. // Various witness verification errors.
@ -2138,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil { if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil {
return 0, err return 0, err
} }
err := vm.Run() err := interopCtx.Exec()
if vm.HasFailed() { if vm.HasFailed() {
return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err) return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err)
} }

View file

@ -59,7 +59,7 @@ type Blockchainer interface {
GetStateSyncModule() StateSync GetStateSyncModule() StateSync
GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItem(id int32, key []byte) state.StorageItem
GetStorageItems(id int32) ([]state.StorageItemWithKey, error) GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func())
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
SetOracle(service services.Oracle) SetOracle(service services.Oracle)
mempool.Feer // fee interface mempool.Feer // fee interface

View file

@ -1,6 +1,7 @@
package interop package interop
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -47,6 +48,7 @@ type Context struct {
Log *zap.Logger Log *zap.Logger
VM *vm.VM VM *vm.VM
Functions []Function Functions []Function
cancelFuncs []context.CancelFunc
getContract func(dao.DAO, util.Uint160) (*state.Contract, error) getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
baseExecFee int64 baseExecFee int64
} }
@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM {
ic.VM = v ic.VM = v
return v return v
} }
// RegisterCancelFunc adds given function to the list of functions to be called after VM
// finishes script execution.
func (ic *Context) RegisterCancelFunc(f context.CancelFunc) {
if f != nil {
ic.cancelFuncs = append(ic.cancelFuncs, f)
}
}
// Finalize calls all registered cancel functions to release the occupied resources.
func (ic *Context) Finalize() {
for _, f := range ic.cancelFuncs {
f()
}
ic.cancelFuncs = nil
}
// Exec executes loaded VM script and calls registered finalizers to release the occupied resources.
func (ic *Context) Exec() error {
defer ic.Finalize()
return ic.VM.Run()
}

View file

@ -1,8 +1,6 @@
package storage package storage
import ( import (
"context"
"github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/core/storage"
"github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/util/slice"
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
@ -25,7 +23,6 @@ const (
// Iterator is an iterator state representation. // Iterator is an iterator state representation.
type Iterator struct { type Iterator struct {
seekCh chan storage.KeyValue seekCh chan storage.KeyValue
cancel context.CancelFunc
curr storage.KeyValue curr storage.KeyValue
next bool next bool
opts int64 opts int64
@ -33,10 +30,9 @@ type Iterator struct {
} }
// NewIterator creates a new Iterator with given options for a given channel of store.Seek results. // NewIterator creates a new Iterator with given options for a given channel of store.Seek results.
func NewIterator(seekCh chan storage.KeyValue, cancel context.CancelFunc, prefix []byte, opts int64) *Iterator { func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator {
return &Iterator{ return &Iterator{
seekCh: seekCh, seekCh: seekCh,
cancel: cancel,
opts: opts, opts: opts,
prefix: slice.Copy(prefix), prefix: slice.Copy(prefix),
} }
@ -84,9 +80,3 @@ func (s *Iterator) Value() stackitem.Item {
value, value,
}) })
} }
// Close releases resources occupied by the Iterator.
// TODO: call this method on program unloading.
func (s *Iterator) Close() {
s.cancel()
}

View file

@ -191,8 +191,9 @@ func storageFind(ic *interop.Context) error {
// sorted items, so no need to sort them one more time. // sorted items, so no need to sort them one more time.
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix) seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix)
item := istorage.NewIterator(seekres, cancel, prefix, opts) item := istorage.NewIterator(seekres, prefix, opts)
ic.VM.Estack().PushItem(stackitem.NewInterop(item)) ic.VM.Estack().PushItem(stackitem.NewInterop(item))
ic.RegisterCancelFunc(cancel)
return nil return nil
} }

View file

@ -366,6 +366,8 @@ func BenchmarkStorageFind(b *testing.B) {
if err != nil { if err != nil {
b.FailNow() b.FailNow()
} }
b.StopTimer()
context.Finalize()
} }
}) })
} }
@ -425,6 +427,7 @@ func BenchmarkStorageFindIteratorNext(b *testing.B) {
} else { } else {
require.True(b, actual) require.True(b, actual)
} }
context.Finalize()
} }
}) })
} }

View file

@ -20,10 +20,11 @@ type Invoke struct {
FaultException string FaultException string
Transaction *transaction.Transaction Transaction *transaction.Transaction
maxIteratorResultItems int maxIteratorResultItems int
finalize func()
} }
// NewInvoke returns new Invoke structure with the given fields set. // NewInvoke returns new Invoke structure with the given fields set.
func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke { func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke {
return &Invoke{ return &Invoke{
State: vm.State().String(), State: vm.State().String(),
GasConsumed: vm.GasConsumed(), GasConsumed: vm.GasConsumed(),
@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul
Stack: vm.Estack().ToArray(), Stack: vm.Estack().ToArray(),
FaultException: faultException, FaultException: faultException,
maxIteratorResultItems: maxIteratorResultItems, maxIteratorResultItems: maxIteratorResultItems,
finalize: finalize,
} }
} }
@ -55,8 +57,17 @@ type Iterator struct {
Truncated bool Truncated bool
} }
// Finalize releases resources occupied by Iterators created at the script invocation.
// This method will be called automatically on Invoke marshalling.
func (r *Invoke) Finalize() {
if r.finalize != nil {
r.finalize()
}
}
// MarshalJSON implements json.Marshaler. // MarshalJSON implements json.Marshaler.
func (r Invoke) MarshalJSON() ([]byte, error) { func (r Invoke) MarshalJSON() ([]byte, error) {
defer r.Finalize()
var st json.RawMessage var st json.RawMessage
arr := make([]json.RawMessage, len(r.Stack)) arr := make([]json.RawMessage, len(r.Stack))
for i := range arr { for i := range arr {

View file

@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, acc.SignTx(testchain.Network(), tx)) require.NoError(t, acc.SignTx(testchain.Network(), tx))
require.NoError(t, chain.VerifyTx(tx)) require.NoError(t, chain.VerifyTx(tx))
v := chain.GetTestVM(trigger.Application, tx, nil) v, _ := chain.GetTestVM(trigger.Application, tx, nil)
v.LoadScriptWithFlags(tx.Script, callflag.All) v.LoadScriptWithFlags(tx.Script, callflag.All)
require.NoError(t, v.Run()) require.NoError(t, v.Run())
} }

View file

@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re
if respErr != nil { if respErr != nil {
return 0, respErr return 0, respErr
} }
res.Finalize()
if res.State != "HALT" { if res.State != "HALT" {
cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException)
return 0, response.NewRPCError(verificationErr, cause.Error(), cause) return 0, response.NewRPCError(verificationErr, cause.Error(), cause)
@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin
} }
script := bw.Bytes() script := bw.Bytes()
tx := &transaction.Transaction{Script: script} tx := &transaction.Transaction{Script: script}
v := s.chain.GetTestVM(trigger.Application, tx, nil) v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil)
defer finalize()
v.GasLimit = core.HeaderVerificationGasLimit v.GasLimit = core.HeaderVerificationGasLimit
v.LoadScriptWithFlags(script, callflag.All) v.LoadScriptWithFlags(script, callflag.All)
err := v.Run() err := v.Run()
@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r
tx.Signers = []transaction.Signer{{Account: scriptHash}} tx.Signers = []transaction.Signer{{Account: scriptHash}}
tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}} tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}}
} }
return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx) return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx)
} }
@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
} }
b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond)) b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond))
vm := s.chain.GetTestVM(t, tx, b) vm, finalize := s.chain.GetTestVM(t, tx, b)
vm.GasLimit = int64(s.config.MaxGasInvoke) vm.GasLimit = int64(s.config.MaxGasInvoke)
if t == trigger.Verification { if t == trigger.Verification {
// We need this special case because witnesses verification is not the simple System.Contract.Call, // We need this special case because witnesses verification is not the simple System.Contract.Call,
@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
if err != nil { if err != nil {
faultException = err.Error() faultException = err.Error()
} }
return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil
} }
// submitBlock broadcasts a raw block over the NEO network. // submitBlock broadcasts a raw block over the NEO network.

View file

@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa
} }
func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) { func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) {
v := o.Chain.GetTestVM(trigger.Verification, tx, nil) v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil)
v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS() v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS()
v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly) v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly)
v.Jump(v.Context(), o.verifyOffset) v.Jump(v.Context(), o.verifyOffset)
ok := isVerifyOk(v) ok := isVerifyOk(v, finalize)
return v.GasConsumed(), ok return v.GasConsumed(), ok
} }
func isVerifyOk(v *vm.VM) bool { func isVerifyOk(v *vm.VM, finalize func()) bool {
defer finalize()
if err := v.Run(); err != nil { if err := v.Run(); err != nil {
return false return false
} }