forked from TrueCloudLab/neoneo-go
core: add finalizer functions to interop context
These functions are aimed to free the resources occupied by storage iterator by the end of script execution or whenever Finilize is called.
This commit is contained in:
parent
0a4f45c9b0
commit
07cbe4d253
11 changed files with 60 additions and 29 deletions
|
@ -327,7 +327,7 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTestVM implements Blockchainer interface.
|
// GetTestVM implements Blockchainer interface.
|
||||||
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM {
|
func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
|
||||||
panic("TODO")
|
panic("TODO")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1036,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error
|
||||||
v.LoadToken = contract.LoadToken(systemInterop)
|
v.LoadToken = contract.LoadToken(systemInterop)
|
||||||
v.GasLimit = tx.SystemFee
|
v.GasLimit = tx.SystemFee
|
||||||
|
|
||||||
err := v.Run()
|
err := systemInterop.Exec()
|
||||||
var faultException string
|
var faultException string
|
||||||
if !v.HasFailed() {
|
if !v.HasFailed() {
|
||||||
_, err := systemInterop.DAO.Persist()
|
_, err := systemInterop.DAO.Persist()
|
||||||
|
@ -1223,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA
|
||||||
v := systemInterop.SpawnVM()
|
v := systemInterop.SpawnVM()
|
||||||
v.LoadScriptWithFlags(script, callflag.All)
|
v.LoadScriptWithFlags(script, callflag.All)
|
||||||
v.SetPriceGetter(systemInterop.GetPrice)
|
v.SetPriceGetter(systemInterop.GetPrice)
|
||||||
if err := v.Run(); err != nil {
|
if err := systemInterop.Exec(); err != nil {
|
||||||
return nil, fmt.Errorf("VM has failed: %w", err)
|
return nil, fmt.Errorf("VM has failed: %w", err)
|
||||||
} else if _, err := systemInterop.DAO.Persist(); err != nil {
|
} else if _, err := systemInterop.DAO.Persist(); err != nil {
|
||||||
return nil, fmt.Errorf("can't save changes: %w", err)
|
return nil, fmt.Errorf("can't save changes: %w", err)
|
||||||
|
@ -2052,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) {
|
||||||
return bc.contracts.NEO.GetCandidates(bc.dao)
|
return bc.contracts.NEO.GetCandidates(bc.dao)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTestVM returns a VM and a Store setup for a test run of some sort of code.
|
// GetTestVM returns a VM setup for a test run of some sort of code and finalizer function.
|
||||||
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM {
|
func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) {
|
||||||
d := bc.dao.GetWrapped().(*dao.Simple)
|
d := bc.dao.GetWrapped().(*dao.Simple)
|
||||||
systemInterop := bc.newInteropContext(t, d, b, tx)
|
systemInterop := bc.newInteropContext(t, d, b, tx)
|
||||||
vm := systemInterop.SpawnVM()
|
vm := systemInterop.SpawnVM()
|
||||||
vm.SetPriceGetter(systemInterop.GetPrice)
|
vm.SetPriceGetter(systemInterop.GetPrice)
|
||||||
vm.LoadToken = contract.LoadToken(systemInterop)
|
vm.LoadToken = contract.LoadToken(systemInterop)
|
||||||
return vm
|
return vm, systemInterop.Finalize
|
||||||
}
|
}
|
||||||
|
|
||||||
// Various witness verification errors.
|
// Various witness verification errors.
|
||||||
|
@ -2138,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa
|
||||||
if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil {
|
if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
err := vm.Run()
|
err := interopCtx.Exec()
|
||||||
if vm.HasFailed() {
|
if vm.HasFailed() {
|
||||||
return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err)
|
return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ type Blockchainer interface {
|
||||||
GetStateSyncModule() StateSync
|
GetStateSyncModule() StateSync
|
||||||
GetStorageItem(id int32, key []byte) state.StorageItem
|
GetStorageItem(id int32, key []byte) state.StorageItem
|
||||||
GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
|
GetStorageItems(id int32) ([]state.StorageItemWithKey, error)
|
||||||
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM
|
GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func())
|
||||||
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
|
GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error)
|
||||||
SetOracle(service services.Oracle)
|
SetOracle(service services.Oracle)
|
||||||
mempool.Feer // fee interface
|
mempool.Feer // fee interface
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package interop
|
package interop
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -47,6 +48,7 @@ type Context struct {
|
||||||
Log *zap.Logger
|
Log *zap.Logger
|
||||||
VM *vm.VM
|
VM *vm.VM
|
||||||
Functions []Function
|
Functions []Function
|
||||||
|
cancelFuncs []context.CancelFunc
|
||||||
getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
|
getContract func(dao.DAO, util.Uint160) (*state.Contract, error)
|
||||||
baseExecFee int64
|
baseExecFee int64
|
||||||
}
|
}
|
||||||
|
@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM {
|
||||||
ic.VM = v
|
ic.VM = v
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RegisterCancelFunc adds given function to the list of functions to be called after VM
|
||||||
|
// finishes script execution.
|
||||||
|
func (ic *Context) RegisterCancelFunc(f context.CancelFunc) {
|
||||||
|
if f != nil {
|
||||||
|
ic.cancelFuncs = append(ic.cancelFuncs, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize calls all registered cancel functions to release the occupied resources.
|
||||||
|
func (ic *Context) Finalize() {
|
||||||
|
for _, f := range ic.cancelFuncs {
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
ic.cancelFuncs = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec executes loaded VM script and calls registered finalizers to release the occupied resources.
|
||||||
|
func (ic *Context) Exec() error {
|
||||||
|
defer ic.Finalize()
|
||||||
|
return ic.VM.Run()
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
"github.com/nspcc-dev/neo-go/pkg/core/storage"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
"github.com/nspcc-dev/neo-go/pkg/util/slice"
|
||||||
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
"github.com/nspcc-dev/neo-go/pkg/vm/stackitem"
|
||||||
|
@ -25,7 +23,6 @@ const (
|
||||||
// Iterator is an iterator state representation.
|
// Iterator is an iterator state representation.
|
||||||
type Iterator struct {
|
type Iterator struct {
|
||||||
seekCh chan storage.KeyValue
|
seekCh chan storage.KeyValue
|
||||||
cancel context.CancelFunc
|
|
||||||
curr storage.KeyValue
|
curr storage.KeyValue
|
||||||
next bool
|
next bool
|
||||||
opts int64
|
opts int64
|
||||||
|
@ -33,10 +30,9 @@ type Iterator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewIterator creates a new Iterator with given options for a given channel of store.Seek results.
|
// NewIterator creates a new Iterator with given options for a given channel of store.Seek results.
|
||||||
func NewIterator(seekCh chan storage.KeyValue, cancel context.CancelFunc, prefix []byte, opts int64) *Iterator {
|
func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator {
|
||||||
return &Iterator{
|
return &Iterator{
|
||||||
seekCh: seekCh,
|
seekCh: seekCh,
|
||||||
cancel: cancel,
|
|
||||||
opts: opts,
|
opts: opts,
|
||||||
prefix: slice.Copy(prefix),
|
prefix: slice.Copy(prefix),
|
||||||
}
|
}
|
||||||
|
@ -84,9 +80,3 @@ func (s *Iterator) Value() stackitem.Item {
|
||||||
value,
|
value,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close releases resources occupied by the Iterator.
|
|
||||||
// TODO: call this method on program unloading.
|
|
||||||
func (s *Iterator) Close() {
|
|
||||||
s.cancel()
|
|
||||||
}
|
|
||||||
|
|
|
@ -191,8 +191,9 @@ func storageFind(ic *interop.Context) error {
|
||||||
// sorted items, so no need to sort them one more time.
|
// sorted items, so no need to sort them one more time.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix)
|
seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix)
|
||||||
item := istorage.NewIterator(seekres, cancel, prefix, opts)
|
item := istorage.NewIterator(seekres, prefix, opts)
|
||||||
ic.VM.Estack().PushItem(stackitem.NewInterop(item))
|
ic.VM.Estack().PushItem(stackitem.NewInterop(item))
|
||||||
|
ic.RegisterCancelFunc(cancel)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -366,6 +366,8 @@ func BenchmarkStorageFind(b *testing.B) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.FailNow()
|
b.FailNow()
|
||||||
}
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
context.Finalize()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -425,6 +427,7 @@ func BenchmarkStorageFindIteratorNext(b *testing.B) {
|
||||||
} else {
|
} else {
|
||||||
require.True(b, actual)
|
require.True(b, actual)
|
||||||
}
|
}
|
||||||
|
context.Finalize()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,11 @@ type Invoke struct {
|
||||||
FaultException string
|
FaultException string
|
||||||
Transaction *transaction.Transaction
|
Transaction *transaction.Transaction
|
||||||
maxIteratorResultItems int
|
maxIteratorResultItems int
|
||||||
|
finalize func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInvoke returns new Invoke structure with the given fields set.
|
// NewInvoke returns new Invoke structure with the given fields set.
|
||||||
func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke {
|
func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke {
|
||||||
return &Invoke{
|
return &Invoke{
|
||||||
State: vm.State().String(),
|
State: vm.State().String(),
|
||||||
GasConsumed: vm.GasConsumed(),
|
GasConsumed: vm.GasConsumed(),
|
||||||
|
@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul
|
||||||
Stack: vm.Estack().ToArray(),
|
Stack: vm.Estack().ToArray(),
|
||||||
FaultException: faultException,
|
FaultException: faultException,
|
||||||
maxIteratorResultItems: maxIteratorResultItems,
|
maxIteratorResultItems: maxIteratorResultItems,
|
||||||
|
finalize: finalize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,8 +57,17 @@ type Iterator struct {
|
||||||
Truncated bool
|
Truncated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Finalize releases resources occupied by Iterators created at the script invocation.
|
||||||
|
// This method will be called automatically on Invoke marshalling.
|
||||||
|
func (r *Invoke) Finalize() {
|
||||||
|
if r.finalize != nil {
|
||||||
|
r.finalize()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler.
|
// MarshalJSON implements json.Marshaler.
|
||||||
func (r Invoke) MarshalJSON() ([]byte, error) {
|
func (r Invoke) MarshalJSON() ([]byte, error) {
|
||||||
|
defer r.Finalize()
|
||||||
var st json.RawMessage
|
var st json.RawMessage
|
||||||
arr := make([]json.RawMessage, len(r.Stack))
|
arr := make([]json.RawMessage, len(r.Stack))
|
||||||
for i := range arr {
|
for i := range arr {
|
||||||
|
|
|
@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, acc.SignTx(testchain.Network(), tx))
|
require.NoError(t, acc.SignTx(testchain.Network(), tx))
|
||||||
require.NoError(t, chain.VerifyTx(tx))
|
require.NoError(t, chain.VerifyTx(tx))
|
||||||
v := chain.GetTestVM(trigger.Application, tx, nil)
|
v, _ := chain.GetTestVM(trigger.Application, tx, nil)
|
||||||
v.LoadScriptWithFlags(tx.Script, callflag.All)
|
v.LoadScriptWithFlags(tx.Script, callflag.All)
|
||||||
require.NoError(t, v.Run())
|
require.NoError(t, v.Run())
|
||||||
}
|
}
|
||||||
|
|
|
@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
return 0, respErr
|
return 0, respErr
|
||||||
}
|
}
|
||||||
|
res.Finalize()
|
||||||
if res.State != "HALT" {
|
if res.State != "HALT" {
|
||||||
cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException)
|
cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException)
|
||||||
return 0, response.NewRPCError(verificationErr, cause.Error(), cause)
|
return 0, response.NewRPCError(verificationErr, cause.Error(), cause)
|
||||||
|
@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin
|
||||||
}
|
}
|
||||||
script := bw.Bytes()
|
script := bw.Bytes()
|
||||||
tx := &transaction.Transaction{Script: script}
|
tx := &transaction.Transaction{Script: script}
|
||||||
v := s.chain.GetTestVM(trigger.Application, tx, nil)
|
v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil)
|
||||||
|
defer finalize()
|
||||||
v.GasLimit = core.HeaderVerificationGasLimit
|
v.GasLimit = core.HeaderVerificationGasLimit
|
||||||
v.LoadScriptWithFlags(script, callflag.All)
|
v.LoadScriptWithFlags(script, callflag.All)
|
||||||
err := v.Run()
|
err := v.Run()
|
||||||
|
@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r
|
||||||
tx.Signers = []transaction.Signer{{Account: scriptHash}}
|
tx.Signers = []transaction.Signer{{Account: scriptHash}}
|
||||||
tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}}
|
tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}}
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx)
|
return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
|
||||||
}
|
}
|
||||||
b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond))
|
b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond))
|
||||||
|
|
||||||
vm := s.chain.GetTestVM(t, tx, b)
|
vm, finalize := s.chain.GetTestVM(t, tx, b)
|
||||||
vm.GasLimit = int64(s.config.MaxGasInvoke)
|
vm.GasLimit = int64(s.config.MaxGasInvoke)
|
||||||
if t == trigger.Verification {
|
if t == trigger.Verification {
|
||||||
// We need this special case because witnesses verification is not the simple System.Contract.Call,
|
// We need this special case because witnesses verification is not the simple System.Contract.Call,
|
||||||
|
@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash
|
||||||
if err != nil {
|
if err != nil {
|
||||||
faultException = err.Error()
|
faultException = err.Error()
|
||||||
}
|
}
|
||||||
return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil
|
return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// submitBlock broadcasts a raw block over the NEO network.
|
// submitBlock broadcasts a raw block over the NEO network.
|
||||||
|
|
|
@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) {
|
func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) {
|
||||||
v := o.Chain.GetTestVM(trigger.Verification, tx, nil)
|
v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil)
|
||||||
v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS()
|
v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS()
|
||||||
v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly)
|
v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly)
|
||||||
v.Jump(v.Context(), o.verifyOffset)
|
v.Jump(v.Context(), o.verifyOffset)
|
||||||
|
|
||||||
ok := isVerifyOk(v)
|
ok := isVerifyOk(v, finalize)
|
||||||
return v.GasConsumed(), ok
|
return v.GasConsumed(), ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func isVerifyOk(v *vm.VM) bool {
|
func isVerifyOk(v *vm.VM, finalize func()) bool {
|
||||||
|
defer finalize()
|
||||||
if err := v.Run(); err != nil {
|
if err := v.Run(); err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue