From 07cbe4d253ba384a51cde53a605d66968c03f443 Mon Sep 17 00:00:00 2001 From: Anna Shaleva Date: Thu, 7 Oct 2021 14:27:55 +0300 Subject: [PATCH] core: add finalizer functions to interop context These functions are aimed to free the resources occupied by storage iterator by the end of script execution or whenever Finilize is called. --- internal/fakechain/fakechain.go | 2 +- pkg/core/blockchain.go | 12 ++++++------ pkg/core/blockchainer/blockchainer.go | 2 +- pkg/core/interop/context.go | 24 ++++++++++++++++++++++++ pkg/core/interop/storage/find.go | 12 +----------- pkg/core/interop_system.go | 3 ++- pkg/core/interop_system_test.go | 3 +++ pkg/rpc/response/result/invoke.go | 13 ++++++++++++- pkg/rpc/server/client_test.go | 2 +- pkg/rpc/server/server.go | 9 +++++---- pkg/services/oracle/response.go | 7 ++++--- 11 files changed, 60 insertions(+), 29 deletions(-) diff --git a/internal/fakechain/fakechain.go b/internal/fakechain/fakechain.go index ae08e4ad3..73d0d9746 100644 --- a/internal/fakechain/fakechain.go +++ b/internal/fakechain/fakechain.go @@ -327,7 +327,7 @@ func (chain *FakeChain) GetStorageItem(id int32, key []byte) state.StorageItem { } // GetTestVM implements Blockchainer interface. -func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { +func (chain *FakeChain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) { panic("TODO") } diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index fa0d8870c..77543467e 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1036,7 +1036,7 @@ func (bc *Blockchain) storeBlock(block *block.Block, txpool *mempool.Pool) error v.LoadToken = contract.LoadToken(systemInterop) v.GasLimit = tx.SystemFee - err := v.Run() + err := systemInterop.Exec() var faultException string if !v.HasFailed() { _, err := systemInterop.DAO.Persist() @@ -1223,7 +1223,7 @@ func (bc *Blockchain) runPersist(script []byte, block *block.Block, cache dao.DA v := systemInterop.SpawnVM() v.LoadScriptWithFlags(script, callflag.All) v.SetPriceGetter(systemInterop.GetPrice) - if err := v.Run(); err != nil { + if err := systemInterop.Exec(); err != nil { return nil, fmt.Errorf("VM has failed: %w", err) } else if _, err := systemInterop.DAO.Persist(); err != nil { return nil, fmt.Errorf("can't save changes: %w", err) @@ -2052,14 +2052,14 @@ func (bc *Blockchain) GetEnrollments() ([]state.Validator, error) { return bc.contracts.NEO.GetCandidates(bc.dao) } -// GetTestVM returns a VM and a Store setup for a test run of some sort of code. -func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM { +// GetTestVM returns a VM setup for a test run of some sort of code and finalizer function. +func (bc *Blockchain) GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) { d := bc.dao.GetWrapped().(*dao.Simple) systemInterop := bc.newInteropContext(t, d, b, tx) vm := systemInterop.SpawnVM() vm.SetPriceGetter(systemInterop.GetPrice) vm.LoadToken = contract.LoadToken(systemInterop) - return vm + return vm, systemInterop.Finalize } // Various witness verification errors. @@ -2138,7 +2138,7 @@ func (bc *Blockchain) verifyHashAgainstScript(hash util.Uint160, witness *transa if err := bc.InitVerificationVM(vm, interopCtx.GetContract, hash, witness); err != nil { return 0, err } - err := vm.Run() + err := interopCtx.Exec() if vm.HasFailed() { return 0, fmt.Errorf("%w: vm execution has failed: %v", ErrVerificationFailed, err) } diff --git a/pkg/core/blockchainer/blockchainer.go b/pkg/core/blockchainer/blockchainer.go index e6619843c..b63bca04c 100644 --- a/pkg/core/blockchainer/blockchainer.go +++ b/pkg/core/blockchainer/blockchainer.go @@ -59,7 +59,7 @@ type Blockchainer interface { GetStateSyncModule() StateSync GetStorageItem(id int32, key []byte) state.StorageItem GetStorageItems(id int32) ([]state.StorageItemWithKey, error) - GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) *vm.VM + GetTestVM(t trigger.Type, tx *transaction.Transaction, b *block.Block) (*vm.VM, func()) GetTransaction(util.Uint256) (*transaction.Transaction, uint32, error) SetOracle(service services.Oracle) mempool.Feer // fee interface diff --git a/pkg/core/interop/context.go b/pkg/core/interop/context.go index 908b8a6ac..1c7a3ad6e 100644 --- a/pkg/core/interop/context.go +++ b/pkg/core/interop/context.go @@ -1,6 +1,7 @@ package interop import ( + "context" "encoding/binary" "errors" "fmt" @@ -47,6 +48,7 @@ type Context struct { Log *zap.Logger VM *vm.VM Functions []Function + cancelFuncs []context.CancelFunc getContract func(dao.DAO, util.Uint160) (*state.Contract, error) baseExecFee int64 } @@ -285,3 +287,25 @@ func (ic *Context) SpawnVM() *vm.VM { ic.VM = v return v } + +// RegisterCancelFunc adds given function to the list of functions to be called after VM +// finishes script execution. +func (ic *Context) RegisterCancelFunc(f context.CancelFunc) { + if f != nil { + ic.cancelFuncs = append(ic.cancelFuncs, f) + } +} + +// Finalize calls all registered cancel functions to release the occupied resources. +func (ic *Context) Finalize() { + for _, f := range ic.cancelFuncs { + f() + } + ic.cancelFuncs = nil +} + +// Exec executes loaded VM script and calls registered finalizers to release the occupied resources. +func (ic *Context) Exec() error { + defer ic.Finalize() + return ic.VM.Run() +} diff --git a/pkg/core/interop/storage/find.go b/pkg/core/interop/storage/find.go index e09e569b3..35b6621f4 100644 --- a/pkg/core/interop/storage/find.go +++ b/pkg/core/interop/storage/find.go @@ -1,8 +1,6 @@ package storage import ( - "context" - "github.com/nspcc-dev/neo-go/pkg/core/storage" "github.com/nspcc-dev/neo-go/pkg/util/slice" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" @@ -25,7 +23,6 @@ const ( // Iterator is an iterator state representation. type Iterator struct { seekCh chan storage.KeyValue - cancel context.CancelFunc curr storage.KeyValue next bool opts int64 @@ -33,10 +30,9 @@ type Iterator struct { } // NewIterator creates a new Iterator with given options for a given channel of store.Seek results. -func NewIterator(seekCh chan storage.KeyValue, cancel context.CancelFunc, prefix []byte, opts int64) *Iterator { +func NewIterator(seekCh chan storage.KeyValue, prefix []byte, opts int64) *Iterator { return &Iterator{ seekCh: seekCh, - cancel: cancel, opts: opts, prefix: slice.Copy(prefix), } @@ -84,9 +80,3 @@ func (s *Iterator) Value() stackitem.Item { value, }) } - -// Close releases resources occupied by the Iterator. -// TODO: call this method on program unloading. -func (s *Iterator) Close() { - s.cancel() -} diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index a21044057..fd473b65b 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -191,8 +191,9 @@ func storageFind(ic *interop.Context) error { // sorted items, so no need to sort them one more time. ctx, cancel := context.WithCancel(context.Background()) seekres := ic.DAO.SeekAsync(ctx, stc.ID, prefix) - item := istorage.NewIterator(seekres, cancel, prefix, opts) + item := istorage.NewIterator(seekres, prefix, opts) ic.VM.Estack().PushItem(stackitem.NewInterop(item)) + ic.RegisterCancelFunc(cancel) return nil } diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 60a95990d..f86b73858 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -366,6 +366,8 @@ func BenchmarkStorageFind(b *testing.B) { if err != nil { b.FailNow() } + b.StopTimer() + context.Finalize() } }) } @@ -425,6 +427,7 @@ func BenchmarkStorageFindIteratorNext(b *testing.B) { } else { require.True(b, actual) } + context.Finalize() } }) } diff --git a/pkg/rpc/response/result/invoke.go b/pkg/rpc/response/result/invoke.go index 83685a7ac..ebd7b699c 100644 --- a/pkg/rpc/response/result/invoke.go +++ b/pkg/rpc/response/result/invoke.go @@ -20,10 +20,11 @@ type Invoke struct { FaultException string Transaction *transaction.Transaction maxIteratorResultItems int + finalize func() } // NewInvoke returns new Invoke structure with the given fields set. -func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResultItems int) *Invoke { +func NewInvoke(vm *vm.VM, finalize func(), script []byte, faultException string, maxIteratorResultItems int) *Invoke { return &Invoke{ State: vm.State().String(), GasConsumed: vm.GasConsumed(), @@ -31,6 +32,7 @@ func NewInvoke(vm *vm.VM, script []byte, faultException string, maxIteratorResul Stack: vm.Estack().ToArray(), FaultException: faultException, maxIteratorResultItems: maxIteratorResultItems, + finalize: finalize, } } @@ -55,8 +57,17 @@ type Iterator struct { Truncated bool } +// Finalize releases resources occupied by Iterators created at the script invocation. +// This method will be called automatically on Invoke marshalling. +func (r *Invoke) Finalize() { + if r.finalize != nil { + r.finalize() + } +} + // MarshalJSON implements json.Marshaler. func (r Invoke) MarshalJSON() ([]byte, error) { + defer r.Finalize() var st json.RawMessage arr := make([]json.RawMessage, len(r.Stack)) for i := range arr { diff --git a/pkg/rpc/server/client_test.go b/pkg/rpc/server/client_test.go index 4bff819b0..cd3b3a291 100644 --- a/pkg/rpc/server/client_test.go +++ b/pkg/rpc/server/client_test.go @@ -714,7 +714,7 @@ func TestCreateNEP17TransferTx(t *testing.T) { require.NoError(t, err) require.NoError(t, acc.SignTx(testchain.Network(), tx)) require.NoError(t, chain.VerifyTx(tx)) - v := chain.GetTestVM(trigger.Application, tx, nil) + v, _ := chain.GetTestVM(trigger.Application, tx, nil) v.LoadScriptWithFlags(tx.Script, callflag.All) require.NoError(t, v.Run()) } diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index ab9ac6ce1..125b6ebe5 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -624,6 +624,7 @@ func (s *Server) calculateNetworkFee(reqParams request.Params) (interface{}, *re if respErr != nil { return 0, respErr } + res.Finalize() if res.State != "HALT" { cause := fmt.Errorf("invalid VM state %s due to an error: %s", res.State, res.FaultException) return 0, response.NewRPCError(verificationErr, cause.Error(), cause) @@ -742,7 +743,8 @@ func (s *Server) getNEP17Balance(h util.Uint160, acc util.Uint160, bw *io.BufBin } script := bw.Bytes() tx := &transaction.Transaction{Script: script} - v := s.chain.GetTestVM(trigger.Application, tx, nil) + v, finalize := s.chain.GetTestVM(trigger.Application, tx, nil) + defer finalize() v.GasLimit = core.HeaderVerificationGasLimit v.LoadScriptWithFlags(script, callflag.All) err := v.Run() @@ -1490,7 +1492,6 @@ func (s *Server) invokeContractVerify(reqParams request.Params) (interface{}, *r tx.Signers = []transaction.Signer{{Account: scriptHash}} tx.Scripts = []transaction.Witness{{InvocationScript: invocationScript, VerificationScript: []byte{}}} } - return s.runScriptInVM(trigger.Verification, invocationScript, scriptHash, tx) } @@ -1511,7 +1512,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash } b.Timestamp = hdr.Timestamp + uint64(s.chain.GetConfig().SecondsPerBlock*int(time.Second/time.Millisecond)) - vm := s.chain.GetTestVM(t, tx, b) + vm, finalize := s.chain.GetTestVM(t, tx, b) vm.GasLimit = int64(s.config.MaxGasInvoke) if t == trigger.Verification { // We need this special case because witnesses verification is not the simple System.Contract.Call, @@ -1539,7 +1540,7 @@ func (s *Server) runScriptInVM(t trigger.Type, script []byte, contractScriptHash if err != nil { faultException = err.Error() } - return result.NewInvoke(vm, script, faultException, s.config.MaxIteratorResultItems), nil + return result.NewInvoke(vm, finalize, script, faultException, s.config.MaxIteratorResultItems), nil } // submitBlock broadcasts a raw block over the NEO network. diff --git a/pkg/services/oracle/response.go b/pkg/services/oracle/response.go index 6b5f9d541..38a69550e 100644 --- a/pkg/services/oracle/response.go +++ b/pkg/services/oracle/response.go @@ -134,16 +134,17 @@ func (o *Oracle) CreateResponseTx(gasForResponse int64, vub uint32, resp *transa } func (o *Oracle) testVerify(tx *transaction.Transaction) (int64, bool) { - v := o.Chain.GetTestVM(trigger.Verification, tx, nil) + v, finalize := o.Chain.GetTestVM(trigger.Verification, tx, nil) v.GasLimit = o.Chain.GetPolicer().GetMaxVerificationGAS() v.LoadScriptWithHash(o.oracleScript, o.oracleHash, callflag.ReadOnly) v.Jump(v.Context(), o.verifyOffset) - ok := isVerifyOk(v) + ok := isVerifyOk(v, finalize) return v.GasConsumed(), ok } -func isVerifyOk(v *vm.VM) bool { +func isVerifyOk(v *vm.VM, finalize func()) bool { + defer finalize() if err := v.Run(); err != nil { return false }